diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 838a9362f072d..dfeb15ddaa3a1 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -48,7 +48,7 @@ updates: - package-ecosystem: "github-actions" directory: "/" schedule: - interval: "daily" + interval: "weekly" open-pull-requests-limit: 10 labels: [auto-dependencies] - package-ecosystem: "pip" diff --git a/.github/workflows/audit.yml b/.github/workflows/audit.yml index 87143a04be4f9..5e75211e415f9 100644 --- a/.github/workflows/audit.yml +++ b/.github/workflows/audit.yml @@ -40,9 +40,9 @@ jobs: security_audit: runs-on: ubuntu-latest steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Install cargo-audit - uses: taiki-e/install-action@3522286d40783523f9c7880e33f785905b4c20d0 # v2.66.1 + uses: taiki-e/install-action@f8d25fb8a2df08dcd3cead89780d572767b8655f # v2.68.0 with: tool: cargo-audit - name: Run audit check diff --git a/.github/workflows/dependencies.yml b/.github/workflows/dependencies.yml index fef65870b697d..3b2cc243d4967 100644 --- a/.github/workflows/dependencies.yml +++ b/.github/workflows/dependencies.yml @@ -44,7 +44,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -62,8 +62,8 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Install cargo-machete run: cargo install cargo-machete --version ^0.9 --locked - name: Detect unused dependencies - run: cargo machete --with-metadata \ No newline at end of file + run: cargo machete --with-metadata diff --git a/.github/workflows/dev.yml b/.github/workflows/dev.yml index 1ec7c16b488f5..2fec343650914 100644 --- a/.github/workflows/dev.yml +++ b/.github/workflows/dev.yml @@ -32,7 +32,7 @@ jobs: runs-on: ubuntu-latest name: Check License Header steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Install HawkEye # This CI job is bound by installation time, use `--profile dev` to speed it up run: cargo install hawkeye --version 6.2.0 --locked --profile dev @@ -43,8 +43,8 @@ jobs: name: Use prettier to check formatting of documents runs-on: ubuntu-latest steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - - uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # v6.1.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + - uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # v6.2.0 with: node-version: "20" - name: Prettier check @@ -55,7 +55,7 @@ jobs: name: Spell Check with Typos runs-on: ubuntu-latest steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false # Version fixed on purpose. It uses heuristics to detect typos, so upgrading diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 3e2c48643c366..b644d8721631e 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -32,16 +32,16 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout docs sources - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Checkout asf-site branch - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: asf-site path: asf-site - name: Setup Python - uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 with: python-version: "3.12" @@ -51,6 +51,12 @@ jobs: python3 -m venv venv source venv/bin/activate pip install -r docs/requirements.txt + - name: Install dependency graph tooling + run: | + set -x + sudo apt-get update + sudo apt-get install -y graphviz + cargo install cargo-depgraph --version ^1.6 --locked - name: Build docs run: | diff --git a/.github/workflows/docs_pr.yaml b/.github/workflows/docs_pr.yaml index 81eeb4039ba97..8d547b329ab41 100644 --- a/.github/workflows/docs_pr.yaml +++ b/.github/workflows/docs_pr.yaml @@ -40,12 +40,12 @@ jobs: name: Test doc build runs-on: ubuntu-latest steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 - name: Setup Python - uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 with: python-version: "3.12" - name: Install doc dependencies @@ -54,10 +54,15 @@ jobs: python3 -m venv venv source venv/bin/activate pip install -r docs/requirements.txt + - name: Install dependency graph tooling + run: | + set -x + sudo apt-get update + sudo apt-get install -y graphviz + cargo install cargo-depgraph --version ^1.6 --locked - name: Build docs html and check for warnings run: | set -x source venv/bin/activate cd docs ./build.sh # fails on errors - diff --git a/.github/workflows/extended.yml b/.github/workflows/extended.yml index 01de0d5b77a7a..e9eb27dd96527 100644 --- a/.github/workflows/extended.yml +++ b/.github/workflows/extended.yml @@ -69,7 +69,7 @@ jobs: runs-on: ubuntu-latest # note: do not use amd/rust container to preserve disk space steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ github.event.inputs.pr_head_sha }} # will be empty if triggered by push submodules: true @@ -93,7 +93,7 @@ jobs: runs-on: ubuntu-latest # note: do not use amd/rust container to preserve disk space steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ github.event.inputs.pr_head_sha }} # will be empty if triggered by push submodules: true @@ -137,7 +137,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ github.event.inputs.pr_head_sha }} # will be empty if triggered by push submodules: true @@ -158,7 +158,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ github.event.inputs.pr_head_sha }} # will be empty if triggered by push submodules: true diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index 01e21115010fc..06c58cd802e56 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -39,7 +39,7 @@ jobs: contents: read pull-requests: write steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Assign GitHub labels if: | diff --git a/.github/workflows/large_files.yml b/.github/workflows/large_files.yml index b96b8cd4544ee..12b7bae76ab32 100644 --- a/.github/workflows/large_files.yml +++ b/.github/workflows/large_files.yml @@ -29,7 +29,7 @@ jobs: check-files: runs-on: ubuntu-latest steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: fetch-depth: 0 - name: Check size of new Git objects diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 611410d7e0c96..70381edc65dc2 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -45,11 +45,12 @@ jobs: # Check crate compiles and base cargo check passes linux-build-lib: name: linux build test - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m7a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -77,7 +78,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -102,7 +103,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -139,7 +140,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -170,7 +171,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -235,7 +236,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -266,13 +267,14 @@ jobs: linux-test: name: cargo test (amd64) needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m7a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust volumes: - /usr/local:/host/usr/local steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -318,7 +320,7 @@ jobs: needs: linux-build-lib runs-on: ubuntu-latest steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -349,7 +351,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -379,7 +381,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -400,7 +402,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -412,7 +414,7 @@ jobs: name: build and run with wasm-pack runs-on: ubuntu-24.04 steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup for wasm32 run: | rustup target add wasm32-unknown-unknown @@ -421,7 +423,7 @@ jobs: sudo apt-get update -qq sudo apt-get install -y -qq clang - name: Setup wasm-pack - uses: taiki-e/install-action@3522286d40783523f9c7880e33f785905b4c20d0 # v2.66.1 + uses: taiki-e/install-action@f8d25fb8a2df08dcd3cead89780d572767b8655f # v2.68.0 with: tool: wasm-pack - name: Run tests with headless mode @@ -440,7 +442,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -487,7 +489,7 @@ jobs: --health-timeout 5s --health-retries 5 steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -511,7 +513,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -549,7 +551,7 @@ jobs: name: cargo test (macos-aarch64) runs-on: macos-14 steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -565,7 +567,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -582,7 +584,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -641,7 +643,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -666,7 +668,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -687,7 +689,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -695,7 +697,7 @@ jobs: uses: ./.github/actions/setup-builder with: rust-version: stable - - uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # v6.1.0 + - uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # v6.2.0 with: node-version: "20" - name: Check if configs.md has been modified @@ -709,6 +711,11 @@ jobs: ./dev/update_function_docs.sh git diff --exit-code +# This job ensures `datafusion-examples/README.md` stays in sync with the source code: +# 1. Generates README automatically using the Rust examples docs generator +# (parsing documentation from `examples//main.rs`) +# 2. Formats the generated Markdown using DataFusion's standard Prettier setup +# 3. Compares the result against the committed README.md and fails if out-of-date examples-docs-check: name: check example README is up-to-date needs: linux-build-lib @@ -717,10 +724,20 @@ jobs: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 + + - name: Mark repository as safe for git + # Required for git commands inside container (avoids "dubious ownership" error) + run: git config --global --add safe.directory "$GITHUB_WORKSPACE" + + - name: Set up Node.js (required for prettier) + # doc_prettier_check.sh uses npx to run prettier for Markdown formatting + uses: actions/setup-node@v6 + with: + node-version: '18' - name: Run examples docs check script run: | @@ -737,11 +754,11 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - name: Install cargo-msrv - uses: taiki-e/install-action@3522286d40783523f9c7880e33f785905b4c20d0 # v2.66.1 + uses: taiki-e/install-action@f8d25fb8a2df08dcd3cead89780d572767b8655f # v2.68.0 with: tool: cargo-msrv @@ -778,4 +795,4 @@ jobs: run: cargo msrv --output-format json --log-target stdout verify - name: Check datafusion-proto working-directory: datafusion/proto - run: cargo msrv --output-format json --log-target stdout verify \ No newline at end of file + run: cargo msrv --output-format json --log-target stdout verify diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 2aba1085b8329..ec7f54ec24dbc 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -27,7 +27,7 @@ jobs: issues: write pull-requests: write steps: - - uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10.1.1 + - uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10.2.0 with: stale-pr-message: "Thank you for your contribution. Unfortunately, this pull request is stale because it has been open 60 days with no activity. Please remove the stale label or comment or this will be closed in 7 days." days-before-pr-stale: 60 diff --git a/Cargo.lock b/Cargo.lock index 22a0a9f6e8023..7fd39099579fa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -232,9 +232,9 @@ checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "arrow" -version = "57.1.0" +version = "57.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb372a7cbcac02a35d3fb7b3fc1f969ec078e871f9bb899bf00a2e1809bec8a3" +checksum = "e4754a624e5ae42081f464514be454b39711daae0458906dacde5f4c632f33a8" dependencies = [ "arrow-arith", "arrow-array", @@ -255,9 +255,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "57.1.0" +version = "57.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f377dcd19e440174596d83deb49cd724886d91060c07fec4f67014ef9d54049" +checksum = "f7b3141e0ec5145a22d8694ea8b6d6f69305971c4fa1c1a13ef0195aef2d678b" dependencies = [ "arrow-array", "arrow-buffer", @@ -269,9 +269,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "57.1.0" +version = "57.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a23eaff85a44e9fa914660fb0d0bb00b79c4a3d888b5334adb3ea4330c84f002" +checksum = "4c8955af33b25f3b175ee10af580577280b4bd01f7e823d94c7cdef7cf8c9aef" dependencies = [ "ahash", "arrow-buffer", @@ -288,9 +288,9 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "57.1.0" +version = "57.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2819d893750cb3380ab31ebdc8c68874dd4429f90fd09180f3c93538bd21626" +checksum = "c697ddca96183182f35b3a18e50b9110b11e916d7b7799cbfd4d34662f2c56c2" dependencies = [ "bytes", "half", @@ -300,9 +300,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "57.1.0" +version = "57.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3d131abb183f80c450d4591dc784f8d7750c50c6e2bc3fcaad148afc8361271" +checksum = "646bbb821e86fd57189c10b4fcdaa941deaf4181924917b0daa92735baa6ada5" dependencies = [ "arrow-array", "arrow-buffer", @@ -322,9 +322,9 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "57.1.0" +version = "57.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2275877a0e5e7e7c76954669366c2aa1a829e340ab1f612e647507860906fb6b" +checksum = "8da746f4180004e3ce7b83c977daf6394d768332349d3d913998b10a120b790a" dependencies = [ "arrow-array", "arrow-cast", @@ -337,9 +337,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "57.1.0" +version = "57.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05738f3d42cb922b9096f7786f606fcb8669260c2640df8490533bb2fa38c9d3" +checksum = "1fdd994a9d28e6365aa78e15da3f3950c0fdcea6b963a12fa1c391afb637b304" dependencies = [ "arrow-buffer", "arrow-schema", @@ -350,9 +350,9 @@ dependencies = [ [[package]] name = "arrow-flight" -version = "57.1.0" +version = "57.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b5f57c3d39d1b1b7c1376a772ea86a131e7da310aed54ebea9363124bb885e3" +checksum = "58c5b083668e6230eae3eab2fc4b5fb989974c845d0aa538dde61a4327c78675" dependencies = [ "arrow-arith", "arrow-array", @@ -378,9 +378,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "57.1.0" +version = "57.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d09446e8076c4b3f235603d9ea7c5494e73d441b01cd61fb33d7254c11964b3" +checksum = "abf7df950701ab528bf7c0cf7eeadc0445d03ef5d6ffc151eaae6b38a58feff1" dependencies = [ "arrow-array", "arrow-buffer", @@ -394,9 +394,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "57.1.0" +version = "57.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "371ffd66fa77f71d7628c63f209c9ca5341081051aa32f9c8020feb0def787c0" +checksum = "0ff8357658bedc49792b13e2e862b80df908171275f8e6e075c460da5ee4bf86" dependencies = [ "arrow-array", "arrow-buffer", @@ -418,9 +418,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "57.1.0" +version = "57.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cbc94fc7adec5d1ba9e8cd1b1e8d6f72423b33fe978bf1f46d970fafab787521" +checksum = "f7d8f1870e03d4cbed632959498bcc84083b5a24bded52905ae1695bd29da45b" dependencies = [ "arrow-array", "arrow-buffer", @@ -431,9 +431,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "57.1.0" +version = "57.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "169676f317157dc079cc5def6354d16db63d8861d61046d2f3883268ced6f99f" +checksum = "18228633bad92bff92a95746bbeb16e5fc318e8382b75619dec26db79e4de4c0" dependencies = [ "arrow-array", "arrow-buffer", @@ -444,9 +444,9 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "57.1.0" +version = "57.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d27609cd7dd45f006abae27995c2729ef6f4b9361cde1ddd019dc31a5aa017e0" +checksum = "8c872d36b7bf2a6a6a2b40de9156265f0242910791db366a2c17476ba8330d68" dependencies = [ "bitflags", "serde", @@ -456,9 +456,9 @@ dependencies = [ [[package]] name = "arrow-select" -version = "57.1.0" +version = "57.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae980d021879ea119dd6e2a13912d81e64abed372d53163e804dfe84639d8010" +checksum = "68bf3e3efbd1278f770d67e5dc410257300b161b93baedb3aae836144edcaf4b" dependencies = [ "ahash", "arrow-array", @@ -470,9 +470,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "57.1.0" +version = "57.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf35e8ef49dcf0c5f6d175edee6b8af7b45611805333129c541a8b89a0fc0534" +checksum = "85e968097061b3c0e9fe3079cf2e703e487890700546b5b0647f60fca1b5a8d8" dependencies = [ "arrow-array", "arrow-buffer", @@ -515,9 +515,9 @@ dependencies = [ [[package]] name = "async-compression" -version = "0.4.37" +version = "0.4.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d10e4f991a553474232bc0a31799f6d24b034a84c0971d80d2e2f78b2e576e40" +checksum = "68650b7df54f0293fd061972a0fb05aaf4fc0879d3b3d21a638a182c5c543b9f" dependencies = [ "compression-codecs", "compression-core", @@ -542,7 +542,7 @@ checksum = "3b43422f69d8ff38f95f1b2bb76517c91589a924d1559a0e935d7c8ce0274c11" dependencies = [ "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] @@ -564,7 +564,7 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] @@ -575,7 +575,7 @@ checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] @@ -601,9 +601,9 @@ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "aws-config" -version = "1.8.12" +version = "1.8.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96571e6996817bf3d58f6b569e4b9fd2e9d2fcf9f7424eed07b2ce9bb87535e5" +checksum = "8a8fc176d53d6fe85017f230405e3255cedb4a02221cb55ed6d76dccbbb099b2" dependencies = [ "aws-credential-types", "aws-runtime", @@ -631,9 +631,9 @@ dependencies = [ [[package]] name = "aws-credential-types" -version = "1.2.11" +version = "1.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3cd362783681b15d136480ad555a099e82ecd8e2d10a841e14dfd0078d67fee3" +checksum = "e26bbf46abc608f2dc61fd6cb3b7b0665497cc259a21520151ed98f8b37d2c79" dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api", @@ -665,9 +665,9 @@ dependencies = [ [[package]] name = "aws-runtime" -version = "1.5.17" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d81b5b2898f6798ad58f484856768bca817e3cd9de0974c24ae0f1113fe88f1b" +checksum = "b0f92058d22a46adf53ec57a6a96f34447daf02bff52e8fb956c66bcd5c6ac12" dependencies = [ "aws-credential-types", "aws-sigv4", @@ -678,9 +678,10 @@ dependencies = [ "aws-smithy-types", "aws-types", "bytes", + "bytes-utils", "fastrand", - "http 0.2.12", - "http-body 0.4.6", + "http 1.4.0", + "http-body 1.0.1", "percent-encoding", "pin-project-lite", "tracing", @@ -689,15 +690,16 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.91.0" +version = "1.94.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ee6402a36f27b52fe67661c6732d684b2635152b676aa2babbfb5204f99115d" +checksum = "699da1961a289b23842d88fe2984c6ff68735fdf9bdcbc69ceaeb2491c9bf434" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", "aws-smithy-http", "aws-smithy-json", + "aws-smithy-observability", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", @@ -705,21 +707,23 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", + "http 1.4.0", "regex-lite", "tracing", ] [[package]] name = "aws-sdk-ssooidc" -version = "1.93.0" +version = "1.96.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a45a7f750bbd170ee3677671ad782d90b894548f4e4ae168302c57ec9de5cb3e" +checksum = "e3e3a4cb3b124833eafea9afd1a6cc5f8ddf3efefffc6651ef76a03cbc6b4981" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", "aws-smithy-http", "aws-smithy-json", + "aws-smithy-observability", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", @@ -727,21 +731,23 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", + "http 1.4.0", "regex-lite", "tracing", ] [[package]] name = "aws-sdk-sts" -version = "1.95.0" +version = "1.98.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55542378e419558e6b1f398ca70adb0b2088077e79ad9f14eb09441f2f7b2164" +checksum = "89c4f19655ab0856375e169865c91264de965bd74c407c7f1e403184b1049409" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", "aws-smithy-http", "aws-smithy-json", + "aws-smithy-observability", "aws-smithy-query", "aws-smithy-runtime", "aws-smithy-runtime-api", @@ -750,15 +756,16 @@ dependencies = [ "aws-types", "fastrand", "http 0.2.12", + "http 1.4.0", "regex-lite", "tracing", ] [[package]] name = "aws-sigv4" -version = "1.3.7" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69e523e1c4e8e7e8ff219d732988e22bfeae8a1cafdbe6d9eca1546fa080be7c" +checksum = "68f6ae9b71597dc5fd115d52849d7a5556ad9265885ad3492ea8d73b93bbc46e" dependencies = [ "aws-credential-types", "aws-smithy-http", @@ -778,9 +785,9 @@ dependencies = [ [[package]] name = "aws-smithy-async" -version = "1.2.7" +version = "1.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ee19095c7c4dda59f1697d028ce704c24b2d33c6718790c7f1d5a3015b4107c" +checksum = "5cc50d0f63e714784b84223abd7abbc8577de8c35d699e0edd19f0a88a08ae13" dependencies = [ "futures-util", "pin-project-lite", @@ -789,9 +796,9 @@ dependencies = [ [[package]] name = "aws-smithy-http" -version = "0.62.6" +version = "0.63.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "826141069295752372f8203c17f28e30c464d22899a43a0c9fd9c458d469c88b" +checksum = "d619373d490ad70966994801bc126846afaa0d1ee920697a031f0cf63f2568e7" dependencies = [ "aws-smithy-runtime-api", "aws-smithy-types", @@ -799,9 +806,9 @@ dependencies = [ "bytes-utils", "futures-core", "futures-util", - "http 0.2.12", "http 1.4.0", - "http-body 0.4.6", + "http-body 1.0.1", + "http-body-util", "percent-encoding", "pin-project-lite", "pin-utils", @@ -810,9 +817,9 @@ dependencies = [ [[package]] name = "aws-smithy-http-client" -version = "1.1.5" +version = "1.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59e62db736db19c488966c8d787f52e6270be565727236fd5579eaa301e7bc4a" +checksum = "00ccbb08c10f6bcf912f398188e42ee2eab5f1767ce215a02a73bc5df1bbdd95" dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api", @@ -834,27 +841,27 @@ dependencies = [ [[package]] name = "aws-smithy-json" -version = "0.61.9" +version = "0.62.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49fa1213db31ac95288d981476f78d05d9cbb0353d22cdf3472cc05bb02f6551" +checksum = "27b3a779093e18cad88bbae08dc4261e1d95018c4c5b9356a52bcae7c0b6e9bb" dependencies = [ "aws-smithy-types", ] [[package]] name = "aws-smithy-observability" -version = "0.1.5" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17f616c3f2260612fe44cede278bafa18e73e6479c4e393e2c4518cf2a9a228a" +checksum = "4d3f39d5bb871aaf461d59144557f16d5927a5248a983a40654d9cf3b9ba183b" dependencies = [ "aws-smithy-runtime-api", ] [[package]] name = "aws-smithy-query" -version = "0.60.9" +version = "0.60.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae5d689cf437eae90460e944a58b5668530d433b4ff85789e69d2f2a556e057d" +checksum = "05f76a580e3d8f8961e5d48763214025a2af65c2fa4cd1fb7f270a0e107a71b0" dependencies = [ "aws-smithy-types", "urlencoding", @@ -862,9 +869,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.9.5" +version = "1.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a392db6c583ea4a912538afb86b7be7c5d8887d91604f50eb55c262ee1b4a5f5" +checksum = "22ccf7f6eba8b2dcf8ce9b74806c6c185659c311665c4bf8d6e71ebd454db6bf" dependencies = [ "aws-smithy-async", "aws-smithy-http", @@ -878,6 +885,7 @@ dependencies = [ "http 1.4.0", "http-body 0.4.6", "http-body 1.0.1", + "http-body-util", "pin-project-lite", "pin-utils", "tokio", @@ -886,9 +894,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime-api" -version = "1.9.3" +version = "1.11.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab0d43d899f9e508300e587bf582ba54c27a452dd0a9ea294690669138ae14a2" +checksum = "b4af6e5def28be846479bbeac55aa4603d6f7986fc5da4601ba324dd5d377516" dependencies = [ "aws-smithy-async", "aws-smithy-types", @@ -903,9 +911,9 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.3.5" +version = "1.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "905cb13a9895626d49cf2ced759b062d913834c7482c38e49557eac4e6193f01" +checksum = "8ca2734c16913a45343b37313605d84e7d8b34a4611598ce1d25b35860a2bed3" dependencies = [ "base64-simd", "bytes", @@ -926,18 +934,18 @@ dependencies = [ [[package]] name = "aws-smithy-xml" -version = "0.60.13" +version = "0.60.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11b2f670422ff42bf7065031e72b45bc52a3508bd089f743ea90731ca2b6ea57" +checksum = "b53543b4b86ed43f051644f704a98c7291b3618b67adf057ee77a366fa52fcaa" dependencies = [ "xmlparser", ] [[package]] name = "aws-types" -version = "1.3.11" +version = "1.3.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d980627d2dd7bfc32a3c025685a033eeab8d365cc840c631ef59d1b8f428164" +checksum = "6c50f3cdf47caa8d01f2be4a6663ea02418e892f9bbfd82c7b9a3a37eaccdd3a" dependencies = [ "aws-credential-types", "aws-smithy-async", @@ -1163,7 +1171,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] @@ -1211,9 +1219,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.11.0" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" [[package]] name = "bytes-utils" @@ -1266,16 +1274,16 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "chrono" -version = "0.4.42" +version = "0.4.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "145052bdd345b87320e369255277e3fb5152762ad123a901ef5c262dd38fe8d2" +checksum = "fac4744fb15ae8337dc853fee7fb3f4e48c0fbaa23d0afe49c447b4fab126118" dependencies = [ "iana-time-zone", "js-sys", "num-traits", "serde", "wasm-bindgen", - "windows-link 0.2.1", + "windows-link", ] [[package]] @@ -1317,9 +1325,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.54" +version = "4.5.59" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6e6ff9dcd79cff5cd969a17a545d79e84ab086e444102a591e288a8aa3ce394" +checksum = "c5caf74d17c3aec5495110c34cc3f78644bfa89af6c8993ed4de2790e49b6499" dependencies = [ "clap_builder", "clap_derive", @@ -1327,9 +1335,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.54" +version = "4.5.59" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa42cf4d2b7a41bc8f663a7cab4031ebafa1bf3875705bfaf8466dc60ab52c00" +checksum = "370daa45065b80218950227371916a1633217ae42b2715b2287b606dcd618e24" dependencies = [ "anstream", "anstyle", @@ -1339,21 +1347,21 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.49" +version = "4.5.55" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a0b5487afeab2deb2ff4e03a807ad1a03ac532ff5a2cee5d86884440c7f7671" +checksum = "a92793da1a46a5f2a02a6f4c46c6496b28c43638adea8306fcb0caa1634f24e5" dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] name = "clap_lex" -version = "0.7.6" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d" +checksum = "3a822ea5bc7590f9d40f1ba12c0dc3c2760f3482c6984db1573ad11031420831" [[package]] name = "clipboard-win" @@ -1531,9 +1539,9 @@ dependencies = [ [[package]] name = "criterion" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d883447757bb0ee46f233e9dc22eb84d93a9508c9b868687b274fc431d886bf" +checksum = "950046b2aa2492f9a536f5f4f9a3de7b9e2476e575e05bd6c333371add4d98f3" dependencies = [ "alloca", "anes", @@ -1558,9 +1566,9 @@ dependencies = [ [[package]] name = "criterion-plot" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed943f81ea2faa8dcecbbfa50164acf95d555afec96a27871663b300e387b2e4" +checksum = "d8d80a2f4f5b554395e47b5d8305bc3d27813bacb73493eb1001e8f76dae29ea" dependencies = [ "cast", "itertools 0.13.0", @@ -1680,7 +1688,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] @@ -1691,7 +1699,7 @@ checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81" dependencies = [ "darling_core", "quote", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] @@ -1710,7 +1718,7 @@ dependencies = [ [[package]] name = "datafusion" -version = "52.0.0" +version = "52.1.0" dependencies = [ "arrow", "arrow-schema", @@ -1759,13 +1767,15 @@ dependencies = [ "itertools 0.14.0", "liblzma", "log", - "nix", + "nix 0.31.1", "object_store", "parking_lot", "parquet", "paste", + "pretty_assertions", "rand 0.9.2", "rand_distr", + "recursive", "regex", "rstest", "serde", @@ -1782,7 +1792,7 @@ dependencies = [ [[package]] name = "datafusion-benchmarks" -version = "52.0.0" +version = "52.1.0" dependencies = [ "arrow", "clap", @@ -1807,7 +1817,7 @@ dependencies = [ [[package]] name = "datafusion-catalog" -version = "52.0.0" +version = "52.1.0" dependencies = [ "arrow", "async-trait", @@ -1830,7 +1840,7 @@ dependencies = [ [[package]] name = "datafusion-catalog-listing" -version = "52.0.0" +version = "52.1.0" dependencies = [ "arrow", "async-trait", @@ -1852,7 +1862,7 @@ dependencies = [ [[package]] name = "datafusion-cli" -version = "52.0.0" +version = "52.1.0" dependencies = [ "arrow", "async-trait", @@ -1883,7 +1893,7 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "52.0.0" +version = "52.1.0" dependencies = [ "ahash", "apache-avro", @@ -1896,6 +1906,7 @@ dependencies = [ "hex", "indexmap 2.13.0", "insta", + "itertools 0.14.0", "libc", "log", "object_store", @@ -1910,7 +1921,7 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" -version = "52.0.0" +version = "52.1.0" dependencies = [ "futures", "log", @@ -1919,7 +1930,7 @@ dependencies = [ [[package]] name = "datafusion-datasource" -version = "52.0.0" +version = "52.1.0" dependencies = [ "arrow", "async-compression", @@ -1954,7 +1965,7 @@ dependencies = [ [[package]] name = "datafusion-datasource-arrow" -version = "52.0.0" +version = "52.1.0" dependencies = [ "arrow", "arrow-ipc", @@ -1977,7 +1988,7 @@ dependencies = [ [[package]] name = "datafusion-datasource-avro" -version = "52.0.0" +version = "52.1.0" dependencies = [ "apache-avro", "arrow", @@ -1996,7 +2007,7 @@ dependencies = [ [[package]] name = "datafusion-datasource-csv" -version = "52.0.0" +version = "52.1.0" dependencies = [ "arrow", "async-trait", @@ -2017,7 +2028,7 @@ dependencies = [ [[package]] name = "datafusion-datasource-json" -version = "52.0.0" +version = "52.1.0" dependencies = [ "arrow", "async-trait", @@ -2032,12 +2043,14 @@ dependencies = [ "datafusion-session", "futures", "object_store", + "serde_json", "tokio", + "tokio-stream", ] [[package]] name = "datafusion-datasource-parquet" -version = "52.0.0" +version = "52.1.0" dependencies = [ "arrow", "async-trait", @@ -2069,11 +2082,11 @@ dependencies = [ [[package]] name = "datafusion-doc" -version = "52.0.0" +version = "52.1.0" [[package]] name = "datafusion-examples" -version = "52.0.0" +version = "52.1.0" dependencies = [ "arrow", "arrow-flight", @@ -2093,10 +2106,12 @@ dependencies = [ "insta", "log", "mimalloc", - "nix", + "nix 0.31.1", + "nom", "object_store", "prost", "rand 0.9.2", + "serde", "serde_json", "strum", "strum_macros", @@ -2112,7 +2127,7 @@ dependencies = [ [[package]] name = "datafusion-execution" -version = "52.0.0" +version = "52.1.0" dependencies = [ "arrow", "async-trait", @@ -2134,7 +2149,7 @@ dependencies = [ [[package]] name = "datafusion-expr" -version = "52.0.0" +version = "52.1.0" dependencies = [ "arrow", "async-trait", @@ -2158,7 +2173,7 @@ dependencies = [ [[package]] name = "datafusion-expr-common" -version = "52.0.0" +version = "52.1.0" dependencies = [ "arrow", "datafusion-common", @@ -2169,7 +2184,7 @@ dependencies = [ [[package]] name = "datafusion-ffi" -version = "52.0.0" +version = "52.1.0" dependencies = [ "abi_stable", "arrow", @@ -2203,7 +2218,7 @@ dependencies = [ [[package]] name = "datafusion-functions" -version = "52.0.0" +version = "52.1.0" dependencies = [ "arrow", "arrow-buffer", @@ -2225,6 +2240,7 @@ dependencies = [ "itertools 0.14.0", "log", "md-5", + "memchr", "num-traits", "rand 0.9.2", "regex", @@ -2236,7 +2252,7 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" -version = "52.0.0" +version = "52.1.0" dependencies = [ "ahash", "arrow", @@ -2251,13 +2267,14 @@ dependencies = [ "datafusion-physical-expr-common", "half", "log", + "num-traits", "paste", "rand 0.9.2", ] [[package]] name = "datafusion-functions-aggregate-common" -version = "52.0.0" +version = "52.1.0" dependencies = [ "ahash", "arrow", @@ -2270,7 +2287,7 @@ dependencies = [ [[package]] name = "datafusion-functions-nested" -version = "52.0.0" +version = "52.1.0" dependencies = [ "arrow", "arrow-ord", @@ -2293,7 +2310,7 @@ dependencies = [ [[package]] name = "datafusion-functions-table" -version = "52.0.0" +version = "52.1.0" dependencies = [ "arrow", "async-trait", @@ -2307,7 +2324,7 @@ dependencies = [ [[package]] name = "datafusion-functions-window" -version = "52.0.0" +version = "52.1.0" dependencies = [ "arrow", "criterion", @@ -2324,7 +2341,7 @@ dependencies = [ [[package]] name = "datafusion-functions-window-common" -version = "52.0.0" +version = "52.1.0" dependencies = [ "datafusion-common", "datafusion-physical-expr-common", @@ -2332,16 +2349,16 @@ dependencies = [ [[package]] name = "datafusion-macros" -version = "52.0.0" +version = "52.1.0" dependencies = [ "datafusion-doc", "quote", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] name = "datafusion-optimizer" -version = "52.0.0" +version = "52.1.0" dependencies = [ "arrow", "async-trait", @@ -2368,7 +2385,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" -version = "52.0.0" +version = "52.1.0" dependencies = [ "ahash", "arrow", @@ -2386,7 +2403,7 @@ dependencies = [ "itertools 0.14.0", "parking_lot", "paste", - "petgraph 0.8.3", + "petgraph", "rand 0.9.2", "recursive", "rstest", @@ -2395,7 +2412,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr-adapter" -version = "52.0.0" +version = "52.1.0" dependencies = [ "arrow", "datafusion-common", @@ -2408,7 +2425,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr-common" -version = "52.0.0" +version = "52.1.0" dependencies = [ "ahash", "arrow", @@ -2423,7 +2440,7 @@ dependencies = [ [[package]] name = "datafusion-physical-optimizer" -version = "52.0.0" +version = "52.1.0" dependencies = [ "arrow", "datafusion-common", @@ -2443,7 +2460,7 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" -version = "52.0.0" +version = "52.1.0" dependencies = [ "ahash", "arrow", @@ -2480,7 +2497,7 @@ dependencies = [ [[package]] name = "datafusion-proto" -version = "52.0.0" +version = "52.1.0" dependencies = [ "arrow", "async-trait", @@ -2507,9 +2524,10 @@ dependencies = [ "datafusion-proto-common", "doc-comment", "object_store", - "pbjson", + "pbjson 0.9.0", "pretty_assertions", "prost", + "rand 0.9.2", "serde", "serde_json", "tokio", @@ -2517,19 +2535,19 @@ dependencies = [ [[package]] name = "datafusion-proto-common" -version = "52.0.0" +version = "52.1.0" dependencies = [ "arrow", "datafusion-common", "doc-comment", - "pbjson", + "pbjson 0.9.0", "prost", "serde", ] [[package]] name = "datafusion-pruning" -version = "52.0.0" +version = "52.1.0" dependencies = [ "arrow", "datafusion-common", @@ -2547,7 +2565,7 @@ dependencies = [ [[package]] name = "datafusion-session" -version = "52.0.0" +version = "52.1.0" dependencies = [ "async-trait", "datafusion-common", @@ -2559,13 +2577,14 @@ dependencies = [ [[package]] name = "datafusion-spark" -version = "52.0.0" +version = "52.1.0" dependencies = [ "arrow", "bigdecimal", "chrono", "crc32fast", "criterion", + "datafusion", "datafusion-catalog", "datafusion-common", "datafusion-execution", @@ -2577,12 +2596,13 @@ dependencies = [ "percent-encoding", "rand 0.9.2", "sha1", + "sha2", "url", ] [[package]] name = "datafusion-sql" -version = "52.0.0" +version = "52.1.0" dependencies = [ "arrow", "bigdecimal", @@ -2608,7 +2628,7 @@ dependencies = [ [[package]] name = "datafusion-sqllogictest" -version = "52.0.0" +version = "52.1.0" dependencies = [ "arrow", "async-trait", @@ -2639,7 +2659,7 @@ dependencies = [ [[package]] name = "datafusion-substrait" -version = "52.0.0" +version = "52.1.0" dependencies = [ "async-recursion", "async-trait", @@ -2661,8 +2681,9 @@ dependencies = [ [[package]] name = "datafusion-wasmtest" -version = "52.0.0" +version = "52.1.0" dependencies = [ + "bytes", "chrono", "console_error_panic_hook", "datafusion", @@ -2672,6 +2693,7 @@ dependencies = [ "datafusion-optimizer", "datafusion-physical-plan", "datafusion-sql", + "futures", "getrandom 0.3.4", "object_store", "tokio", @@ -2736,7 +2758,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] @@ -2792,7 +2814,7 @@ dependencies = [ "enum-ordinalize", "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] @@ -2830,14 +2852,14 @@ checksum = "8ca9601fb2d62598ee17836250842873a413586e5d7ed88b356e38ddbb0ec631" dependencies = [ "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] name = "env_filter" -version = "0.1.4" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bf3c259d255ca70051b30e2e95b5446cdb8949ac4cd22c0d7fd634d89f568e2" +checksum = "7a1c3cc8e57274ec99de65301228b537f1e4eedc1b8e0f9411c6caac8ae7308f" dependencies = [ "log", "regex", @@ -2845,9 +2867,9 @@ dependencies = [ [[package]] name = "env_logger" -version = "0.11.8" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13c863f0904021b108aa8b2f55046443e6b1ebde8fd4a15c399893aae4fa069f" +checksum = "b2daee4ea451f429a58296525ddf28b45a3b64f1acf6587e2067437bb11e218d" dependencies = [ "anstream", "anstyle", @@ -2994,13 +3016,13 @@ dependencies = [ [[package]] name = "flate2" -version = "1.1.5" +version = "1.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfe33edd8e85a12a67454e37f8c75e730830d83e313556ab9ebf9ee7fbeb3bfb" +checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" dependencies = [ "crc32fast", - "libz-rs-sys", "miniz_oxide", + "zlib-rs", ] [[package]] @@ -3101,7 +3123,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] @@ -3144,7 +3166,7 @@ dependencies = [ name = "gen" version = "0.1.0" dependencies = [ - "pbjson-build", + "pbjson-build 0.9.0", "prost-build", ] @@ -3152,7 +3174,7 @@ dependencies = [ name = "gen-common" version = "0.1.0" dependencies = [ - "pbjson-build", + "pbjson-build 0.9.0", "prost-build", ] @@ -3184,7 +3206,7 @@ dependencies = [ "cfg-if", "js-sys", "libc", - "wasi", + "wasi 0.11.1+wasi-snapshot-preview1", "wasm-bindgen", ] @@ -3202,6 +3224,19 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "getrandom" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "139ef39800118c7683f2fd3c98c1b23c09ae076556b435f8e9064ae108aaeeec" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", + "wasip3", +] + [[package]] name = "glob" version = "0.3.3" @@ -3506,7 +3541,7 @@ dependencies = [ "js-sys", "log", "wasm-bindgen", - "windows-core 0.62.2", + "windows-core", ] [[package]] @@ -3599,6 +3634,12 @@ dependencies = [ "zerovec", ] +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + [[package]] name = "ident_case" version = "1.0.1" @@ -3651,9 +3692,9 @@ dependencies = [ [[package]] name = "indicatif" -version = "0.18.3" +version = "0.18.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9375e112e4b463ec1b1c6c011953545c65a30164fbab5b581df32b3abf0dcb88" +checksum = "25470f23803092da7d239834776d653104d551bc4d7eacaf31e6837854b8e9eb" dependencies = [ "console 0.16.2", "portable-atomic", @@ -3664,9 +3705,9 @@ dependencies = [ [[package]] name = "insta" -version = "1.46.0" +version = "1.46.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b66886d14d18d420ab5052cbff544fc5d34d0b2cdd35eb5976aaa10a4a472e5" +checksum = "e82db8c87c7f1ccecb34ce0c24399b8a73081427f3c7c50a5d597925356115e4" dependencies = [ "console 0.15.11", "globset", @@ -3762,7 +3803,7 @@ checksum = "e0c84ee7f197eca9a86c6fd6cb771e55eb991632f15f2bc3ca6ec838929e6e78" dependencies = [ "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] @@ -3777,9 +3818,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.83" +version = "0.3.85" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "464a3709c7f55f1f721e5389aa6ea4e3bc6aba669353300af094b29ffbdde1d8" +checksum = "8c942ebf8e95485ca0d52d97da7c5a2c387d0e7f0ba4c35e93bfcaee045955b3" dependencies = [ "once_cell", "wasm-bindgen", @@ -3791,6 +3832,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + [[package]] name = "lexical-core" version = "1.0.6" @@ -3872,9 +3919,9 @@ dependencies = [ [[package]] name = "liblzma" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73c36d08cad03a3fbe2c4e7bb3a9e84c57e4ee4135ed0b065cade3d98480c648" +checksum = "b6033b77c21d1f56deeae8014eb9fbe7bdf1765185a6c508b5ca82eeaed7f899" dependencies = [ "liblzma-sys", ] @@ -3930,15 +3977,6 @@ dependencies = [ "escape8259", ] -[[package]] -name = "libz-rs-sys" -version = "0.5.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c10501e7805cee23da17c7790e59df2870c0d4043ec6d03f67d31e2b53e77415" -dependencies = [ - "zlib-rs", -] - [[package]] name = "linux-raw-sys" version = "0.11.0" @@ -3999,9 +4037,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.7.6" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" [[package]] name = "mimalloc" @@ -4045,7 +4083,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" dependencies = [ "libc", - "wasi", + "wasi 0.11.1+wasi-snapshot-preview1", "windows-sys 0.61.2", ] @@ -4076,6 +4114,27 @@ dependencies = [ "libc", ] +[[package]] +name = "nix" +version = "0.31.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225e7cfe711e0ba79a68baeddb2982723e4235247aefce1482f2f16c27865b66" +dependencies = [ + "bitflags", + "cfg-if", + "cfg_aliases", + "libc", +] + +[[package]] +name = "nom" +version = "8.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df9761775871bdef83bee530e60050f7e54b1105350d6884eb0fb4f46c2f9405" +dependencies = [ + "memchr", +] + [[package]] name = "ntapi" version = "0.4.2" @@ -4130,9 +4189,9 @@ dependencies = [ [[package]] name = "num-conv" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" +checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050" [[package]] name = "num-integer" @@ -4177,18 +4236,18 @@ dependencies = [ [[package]] name = "objc2-core-foundation" -version = "0.3.2" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a180dd8642fa45cdb7dd721cd4c11b1cadd4929ce112ebd8b9f5803cc79d536" +checksum = "1c10c2894a6fed806ade6027bcd50662746363a9589d3ec9d9bef30a4e4bc166" dependencies = [ "bitflags", ] [[package]] name = "objc2-io-kit" -version = "0.3.2" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33fafba39597d6dc1fb709123dfa8289d39406734be322956a69f0931c73bb15" +checksum = "71c1c64d6120e51cd86033f67176b1cb66780c2efe34dec55176f77befd93c0a" dependencies = [ "libc", "objc2-core-foundation", @@ -4205,9 +4264,9 @@ dependencies = [ [[package]] name = "object_store" -version = "0.12.4" +version = "0.12.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c1be0c6c22ec0817cdc77d3842f721a17fd30ab6965001415b5402a74e6b740" +checksum = "fbfbfff40aeccab00ec8a910b57ca8ecf4319b335c542f2edcd19dd25a1e2a00" dependencies = [ "async-trait", "base64 0.22.1", @@ -4321,14 +4380,14 @@ dependencies = [ "libc", "redox_syscall 0.5.18", "smallvec", - "windows-link 0.2.1", + "windows-link", ] [[package]] name = "parquet" -version = "57.1.0" +version = "57.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be3e4f6d320dd92bfa7d612e265d7d08bba0a240bab86af3425e1d255a511d89" +checksum = "6ee96b29972a257b855ff2341b37e61af5f12d6af1158b6dcdb5b31ea07bb3cb" dependencies = [ "ahash", "arrow-array", @@ -4384,7 +4443,7 @@ dependencies = [ "regex", "regex-syntax", "structmeta", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] @@ -4403,6 +4462,16 @@ dependencies = [ "serde", ] +[[package]] +name = "pbjson" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8edd1efdd8ab23ba9cb9ace3d9987a72663d5d7c9f74fa00b51d6213645cf6c" +dependencies = [ + "base64 0.22.1", + "serde", +] + [[package]] name = "pbjson-build" version = "0.8.0" @@ -4415,6 +4484,18 @@ dependencies = [ "prost-types", ] +[[package]] +name = "pbjson-build" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ed4d5c6ae95e08ac768883c8401cf0e8deb4e6e1d6a4e1fd3d2ec4f0ec63200" +dependencies = [ + "heck", + "itertools 0.14.0", + "prost", + "prost-types", +] + [[package]] name = "pbjson-types" version = "0.8.0" @@ -4423,8 +4504,8 @@ checksum = "8e748e28374f10a330ee3bb9f29b828c0ac79831a32bab65015ad9b661ead526" dependencies = [ "bytes", "chrono", - "pbjson", - "pbjson-build", + "pbjson 0.8.0", + "pbjson-build 0.8.0", "prost", "prost-build", "serde", @@ -4436,16 +4517,6 @@ version = "2.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" -[[package]] -name = "petgraph" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" -dependencies = [ - "fixedbitset", - "indexmap 2.13.0", -] - [[package]] name = "petgraph" version = "0.8.3" @@ -4512,7 +4583,7 @@ checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" dependencies = [ "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] @@ -4585,7 +4656,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] @@ -4608,9 +4679,9 @@ dependencies = [ [[package]] name = "postgres-types" -version = "0.2.11" +version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef4605b7c057056dd35baeb6ac0c0338e4975b1f2bef0f65da953285eb007095" +checksum = "54b858f82211e84682fecd373f68e1ceae642d8d751a1ebd13f33de6257b3e20" dependencies = [ "bytes", "chrono", @@ -4660,7 +4731,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" dependencies = [ "proc-macro2", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] @@ -4683,9 +4754,9 @@ dependencies = [ [[package]] name = "prost" -version = "0.14.1" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7231bd9b3d3d33c86b58adbac74b5ec0ad9f496b19d22801d773636feaa95f3d" +checksum = "d2ea70524a2f82d518bce41317d0fae74151505651af45faf1ffbd6fd33f0568" dependencies = [ "bytes", "prost-derive", @@ -4693,42 +4764,41 @@ dependencies = [ [[package]] name = "prost-build" -version = "0.14.1" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac6c3320f9abac597dcbc668774ef006702672474aad53c6d596b62e487b40b1" +checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7" dependencies = [ "heck", "itertools 0.14.0", "log", "multimap", - "once_cell", - "petgraph 0.7.1", + "petgraph", "prettyplease", "prost", "prost-types", "regex", - "syn 2.0.114", + "syn 2.0.116", "tempfile", ] [[package]] name = "prost-derive" -version = "0.14.1" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9120690fafc389a67ba3803df527d0ec9cbbc9cc45e4cc20b332996dfb672425" +checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b" dependencies = [ "anyhow", "itertools 0.14.0", "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] name = "prost-types" -version = "0.14.1" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9b4db3d6da204ed77bb26ba83b6122a73aeb2e87e25fbf7ad2e84c4ccbf8f72" +checksum = "8991c4cbdb8bc5b11f0b074ffe286c30e523de90fee5ba8132f1399f23cb3dd7" dependencies = [ "prost", ] @@ -4825,9 +4895,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.43" +version = "1.0.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc74d9a594b72ae6656596548f56f667211f8a97b3d4c3d467150794690dc40a" +checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4" dependencies = [ "proc-macro2", ] @@ -4954,7 +5024,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" dependencies = [ "quote", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] @@ -5003,14 +5073,14 @@ checksum = "b7186006dcb21920990093f30e3dea63b7d6e977bf1256be20c3563a5db070da" dependencies = [ "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] name = "regex" -version = "1.12.2" +version = "1.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "843bc0191f75f3e22651ae5f1e72939ab2f72a4bc30fa80a066bd66edefc24d4" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" dependencies = [ "aho-corasick", "memchr", @@ -5037,9 +5107,9 @@ checksum = "8d942b98df5e658f56f20d592c7f868833fe38115e65c33003d8cd224b0155da" [[package]] name = "regex-syntax" -version = "0.8.8" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" +checksum = "a96887878f22d7bad8a3b6dc5b7440e0ada9a245242924394987b21cf2210a4c" [[package]] name = "regress" @@ -5147,7 +5217,7 @@ dependencies = [ "regex", "relative-path", "rustc_version", - "syn 2.0.114", + "syn 2.0.116", "unicode-ident", ] @@ -5159,7 +5229,7 @@ checksum = "b3a8fb4672e840a587a66fc577a5491375df51ddb88f2a2c2a792598c326fe14" dependencies = [ "quote", "rand 0.8.5", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] @@ -5269,7 +5339,7 @@ dependencies = [ "libc", "log", "memchr", - "nix", + "nix 0.30.1", "radix_trie", "unicode-segmentation", "unicode-width 0.2.2", @@ -5346,7 +5416,7 @@ dependencies = [ "proc-macro2", "quote", "serde_derive_internals", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] @@ -5431,7 +5501,7 @@ checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] @@ -5442,7 +5512,7 @@ checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" dependencies = [ "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] @@ -5466,7 +5536,7 @@ checksum = "175ee3e80ae9982737ca543e96133087cbd9a485eecc3bc4de9c1a37b47ea59c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] @@ -5478,7 +5548,7 @@ dependencies = [ "proc-macro2", "quote", "serde", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] @@ -5521,7 +5591,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] @@ -5656,9 +5726,9 @@ dependencies = [ [[package]] name = "sqllogictest" -version = "0.29.0" +version = "0.29.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dffbf03091090a9330529c3926313be0a0570f036edfd490b11db39eea4b7118" +checksum = "d03b2262a244037b0b510edbd25a8e6c9fb8d73ee0237fc6cc95a54c16f94a82" dependencies = [ "async-trait", "educe", @@ -5681,9 +5751,9 @@ dependencies = [ [[package]] name = "sqlparser" -version = "0.59.0" +version = "0.60.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4591acadbcf52f0af60eafbb2c003232b2b4cd8de5f0e9437cb8b1b59046cc0f" +checksum = "505aa16b045c4c1375bf5f125cce3813d0176325bfe9ffc4a903f423de7774ff" dependencies = [ "log", "recursive", @@ -5692,13 +5762,13 @@ dependencies = [ [[package]] name = "sqlparser_derive" -version = "0.3.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da5fc6819faabb412da764b99d3b713bb55083c11e7e0c00144d386cd6a1939c" +checksum = "028e551d5e270b31b9f3ea271778d9d827148d4287a5d96167b6bb9787f5cc38" dependencies = [ "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] @@ -5746,7 +5816,7 @@ dependencies = [ "proc-macro2", "quote", "structmeta-derive", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] @@ -5757,7 +5827,7 @@ checksum = "152a0b65a590ff6c3da95cabe2353ee04e6167c896b28e3b14478c2636c922fc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] @@ -5775,7 +5845,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] @@ -5795,8 +5865,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62fc4b483a129b9772ccb9c3f7945a472112fdd9140da87f8a4e7f1d44e045d0" dependencies = [ "heck", - "pbjson", - "pbjson-build", + "pbjson 0.8.0", + "pbjson-build 0.8.0", "pbjson-types", "prettyplease", "prost", @@ -5809,7 +5879,7 @@ dependencies = [ "serde", "serde_json", "serde_yaml", - "syn 2.0.114", + "syn 2.0.116", "typify", "walkdir", ] @@ -5833,9 +5903,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.114" +version = "2.0.116" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4d107df263a3013ef9b1879b0df87d706ff80f65a86ea879bd9c31f9b307c2a" +checksum = "3df424c70518695237746f84cede799c9c58fcb37450d7b23716568cc8bc69cb" dependencies = [ "proc-macro2", "quote", @@ -5859,14 +5929,14 @@ checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] name = "sysinfo" -version = "0.37.2" +version = "0.38.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16607d5caffd1c07ce073528f9ed972d88db15dd44023fa57142963be3feb11f" +checksum = "1efc19935b4b66baa6f654ac7924c192f55b175c00a7ab72410fc24284dacda8" dependencies = [ "libc", "memchr", @@ -5878,12 +5948,12 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.24.0" +version = "3.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c" +checksum = "0136791f7c95b1f6dd99f9cc786b91bb81c3800b639b3478e561ddb7be95e5f1" dependencies = [ "fastrand", - "getrandom 0.3.4", + "getrandom 0.4.1", "once_cell", "rustix", "windows-sys 0.61.2", @@ -5941,22 +6011,22 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.17" +version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "2.0.17" +version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] @@ -5981,30 +6051,30 @@ dependencies = [ [[package]] name = "time" -version = "0.3.44" +version = "0.3.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91e7d9e3bb61134e77bde20dd4825b97c010155709965fedf0f49bb138e52a9d" +checksum = "743bd48c283afc0388f9b8827b976905fb217ad9e647fae3a379a9283c4def2c" dependencies = [ "deranged", "itoa", "num-conv", "powerfmt", - "serde", + "serde_core", "time-core", "time-macros", ] [[package]] name = "time-core" -version = "0.1.6" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40868e7c1d2f0b8d73e4a8c7f0ff63af4f6d19be117e90bd73eb1d62cf831c6b" +checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca" [[package]] name = "time-macros" -version = "0.2.24" +version = "0.2.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30cfb0125f12d9c277f35663a0a33f8c30190f4e4574868a330595412d34ebf3" +checksum = "2e70e4c5a0e0a8a4823ad65dfe1a6930e4f4d756dcd9dd7939022b5e8c501215" dependencies = [ "num-conv", "time-core", @@ -6079,14 +6149,14 @@ checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] name = "tokio-postgres" -version = "0.7.15" +version = "0.7.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b40d66d9b2cfe04b628173409368e58247e8eddbbd3b0e6c6ba1d09f20f6c9e" +checksum = "dcea47c8f71744367793f16c2db1f11cb859d28f436bdb4ca9193eb1f787ee42" dependencies = [ "async-trait", "byteorder", @@ -6127,6 +6197,7 @@ dependencies = [ "futures-core", "pin-project-lite", "tokio", + "tokio-util", ] [[package]] @@ -6174,9 +6245,9 @@ dependencies = [ [[package]] name = "tonic" -version = "0.14.2" +version = "0.14.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb7613188ce9f7df5bfe185db26c5814347d110db17920415cf2fbcad85e7203" +checksum = "7f32a6f80051a4111560201420c7885d0082ba9efe2ab61875c587bb6b18b9a0" dependencies = [ "async-trait", "axum", @@ -6280,7 +6351,7 @@ checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" dependencies = [ "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] @@ -6388,7 +6459,7 @@ dependencies = [ "semver", "serde", "serde_json", - "syn 2.0.114", + "syn 2.0.116", "thiserror", "unicode-ident", ] @@ -6406,7 +6477,7 @@ dependencies = [ "serde", "serde_json", "serde_tokenstream", - "syn 2.0.114", + "syn 2.0.116", "typify-impl", ] @@ -6455,6 +6526,12 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + [[package]] name = "unit-prefix" version = "0.5.2" @@ -6540,11 +6617,11 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.19.0" +version = "1.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2e054861b4bd027cd373e18e8d8d8e6548085000e41290d95ce0c373a654b4a" +checksum = "b672338555252d43fd2240c714dc444b8c6fb0a5c5335e65a07bba7742735ddb" dependencies = [ - "getrandom 0.3.4", + "getrandom 0.4.1", "js-sys", "serde_core", "wasm-bindgen", @@ -6593,26 +6670,47 @@ version = "0.11.1+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" +[[package]] +name = "wasi" +version = "0.14.7+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "883478de20367e224c0090af9cf5f9fa85bed63a95c1abf3afc5c083ebc06e8c" +dependencies = [ + "wasip2", +] + [[package]] name = "wasip2" version = "1.0.1+wasi-0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" dependencies = [ - "wit-bindgen", + "wit-bindgen 0.46.0", +] + +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" +dependencies = [ + "wit-bindgen 0.51.0", ] [[package]] name = "wasite" -version = "0.1.0" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" +checksum = "66fe902b4a6b8028a753d5424909b764ccf79b7a209eac9bf97e59cda9f71a42" +dependencies = [ + "wasi 0.14.7+wasi-0.2.4", +] [[package]] name = "wasm-bindgen" -version = "0.2.106" +version = "0.2.108" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d759f433fa64a2d763d1340820e46e111a7a5ab75f993d1852d70b03dbb80fd" +checksum = "64024a30ec1e37399cf85a7ffefebdb72205ca1c972291c51512360d90bd8566" dependencies = [ "cfg-if", "once_cell", @@ -6623,11 +6721,12 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.56" +version = "0.4.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "836d9622d604feee9e5de25ac10e3ea5f2d65b41eac0d9ce72eb5deae707ce7c" +checksum = "70a6e77fd0ae8029c9ea0063f87c46fde723e7d887703d74ad2616d792e51e6f" dependencies = [ "cfg-if", + "futures-util", "js-sys", "once_cell", "wasm-bindgen", @@ -6636,9 +6735,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.106" +version = "0.2.108" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48cb0d2638f8baedbc542ed444afc0644a29166f1595371af4fecf8ce1e7eeb3" +checksum = "008b239d9c740232e71bd39e8ef6429d27097518b6b30bdf9086833bd5b6d608" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -6646,31 +6745,31 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.106" +version = "0.2.108" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cefb59d5cd5f92d9dcf80e4683949f15ca4b511f4ac0a6e14d4e1ac60c6ecd40" +checksum = "5256bae2d58f54820e6490f9839c49780dff84c65aeab9e772f15d5f0e913a55" dependencies = [ "bumpalo", "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.106" +version = "0.2.108" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cbc538057e648b67f72a982e708d485b2efa771e1ac05fec311f9f63e5800db4" +checksum = "1f01b580c9ac74c8d8f0c0e4afb04eeef2acf145458e52c03845ee9cd23e3d12" dependencies = [ "unicode-ident", ] [[package]] name = "wasm-bindgen-test" -version = "0.3.56" +version = "0.3.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25e90e66d265d3a1efc0e72a54809ab90b9c0c515915c67cdf658689d2c22c6c" +checksum = "45649196a53b0b7a15101d845d44d2dda7374fc1b5b5e2bbf58b7577ff4b346d" dependencies = [ "async-trait", "cast", @@ -6685,17 +6784,46 @@ dependencies = [ "wasm-bindgen", "wasm-bindgen-futures", "wasm-bindgen-test-macro", + "wasm-bindgen-test-shared", ] [[package]] name = "wasm-bindgen-test-macro" -version = "0.3.56" +version = "0.3.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7150335716dce6028bead2b848e72f47b45e7b9422f64cccdc23bedca89affc1" +checksum = "f579cdd0123ac74b94e1a4a72bd963cf30ebac343f2df347da0b8df24cdebed2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", +] + +[[package]] +name = "wasm-bindgen-test-shared" +version = "0.2.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8145dd1593bf0fb137dbfa85b8be79ec560a447298955877804640e40c2d6ea" + +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap 2.13.0", + "wasm-encoder", + "wasmparser", ] [[package]] @@ -6711,11 +6839,23 @@ dependencies = [ "web-sys", ] +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags", + "hashbrown 0.15.5", + "indexmap 2.13.0", + "semver", +] + [[package]] name = "web-sys" -version = "0.3.83" +version = "0.3.85" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b32828d774c412041098d182a8b38b16ea816958e07cf40eec2bc080ae137ac" +checksum = "312e32e551d92129218ea9a2452120f4aabc03529ef03e4d0d82fb2780608598" dependencies = [ "js-sys", "wasm-bindgen", @@ -6742,9 +6882,9 @@ dependencies = [ [[package]] name = "whoami" -version = "1.6.1" +version = "2.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d4a4db5077702ca3015d3d02d74974948aba2ad9e12ab7df718ee64ccd7e97d" +checksum = "ace4d5c7b5ab3d99629156d4e0997edbe98a4beb6d5ba99e2cae830207a81983" dependencies = [ "libredox", "wasite", @@ -6784,37 +6924,23 @@ checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] name = "windows" -version = "0.61.3" +version = "0.62.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9babd3a767a4c1aef6900409f85f5d53ce2544ccdfaa86dad48c91782c6d6893" +checksum = "527fadee13e0c05939a6a05d5bd6eec6cd2e3dbd648b9f8e447c6518133d8580" dependencies = [ "windows-collections", - "windows-core 0.61.2", + "windows-core", "windows-future", - "windows-link 0.1.3", "windows-numerics", ] [[package]] name = "windows-collections" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3beeceb5e5cfd9eb1d76b381630e82c4241ccd0d27f1a39ed41b2760b255c5e8" -dependencies = [ - "windows-core 0.61.2", -] - -[[package]] -name = "windows-core" -version = "0.61.2" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" +checksum = "23b2d95af1a8a14a3c7367e1ed4fc9c20e0a26e79551b1454d72583c97cc6610" dependencies = [ - "windows-implement", - "windows-interface", - "windows-link 0.1.3", - "windows-result 0.3.4", - "windows-strings 0.4.2", + "windows-core", ] [[package]] @@ -6825,19 +6951,19 @@ checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" dependencies = [ "windows-implement", "windows-interface", - "windows-link 0.2.1", - "windows-result 0.4.1", - "windows-strings 0.5.1", + "windows-link", + "windows-result", + "windows-strings", ] [[package]] name = "windows-future" -version = "0.2.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc6a41e98427b19fe4b73c550f060b59fa592d7d686537eebf9385621bfbad8e" +checksum = "e1d6f90251fe18a279739e78025bd6ddc52a7e22f921070ccdc67dde84c605cb" dependencies = [ - "windows-core 0.61.2", - "windows-link 0.1.3", + "windows-core", + "windows-link", "windows-threading", ] @@ -6849,7 +6975,7 @@ checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" dependencies = [ "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] @@ -6860,15 +6986,9 @@ checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" dependencies = [ "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", ] -[[package]] -name = "windows-link" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" - [[package]] name = "windows-link" version = "0.2.1" @@ -6877,21 +6997,12 @@ checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" [[package]] name = "windows-numerics" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9150af68066c4c5c07ddc0ce30421554771e528bde427614c61038bc2c92c2b1" -dependencies = [ - "windows-core 0.61.2", - "windows-link 0.1.3", -] - -[[package]] -name = "windows-result" -version = "0.3.4" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" +checksum = "6e2e40844ac143cdb44aead537bbf727de9b044e107a0f1220392177d15b0f26" dependencies = [ - "windows-link 0.1.3", + "windows-core", + "windows-link", ] [[package]] @@ -6900,16 +7011,7 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" dependencies = [ - "windows-link 0.2.1", -] - -[[package]] -name = "windows-strings" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" -dependencies = [ - "windows-link 0.1.3", + "windows-link", ] [[package]] @@ -6918,7 +7020,7 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" dependencies = [ - "windows-link 0.2.1", + "windows-link", ] [[package]] @@ -6954,7 +7056,7 @@ version = "0.61.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" dependencies = [ - "windows-link 0.2.1", + "windows-link", ] [[package]] @@ -6979,7 +7081,7 @@ version = "0.53.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" dependencies = [ - "windows-link 0.2.1", + "windows-link", "windows_aarch64_gnullvm 0.53.1", "windows_aarch64_msvc 0.53.1", "windows_i686_gnu 0.53.1", @@ -6992,11 +7094,11 @@ dependencies = [ [[package]] name = "windows-threading" -version = "0.1.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b66463ad2e0ea3bbf808b7f1d371311c80e115c0b71d60efc142cafbcfb057a6" +checksum = "3949bd5b99cafdf1c7ca86b43ca564028dfe27d66958f2470940f73d86d75b37" dependencies = [ - "windows-link 0.1.3", + "windows-link", ] [[package]] @@ -7110,6 +7212,94 @@ version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen-core" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck", + "indexmap 2.13.0", + "prettyplease", + "syn 2.0.116", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn 2.0.116", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags", + "indexmap 2.13.0", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap 2.13.0", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] + [[package]] name = "writeable" version = "0.6.2" @@ -7157,7 +7347,7 @@ checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", "synstructure", ] @@ -7178,7 +7368,7 @@ checksum = "c9c2d862265a8bb4471d87e033e730f536e2a285cc7cb05dbce09a2a97075f90" dependencies = [ "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] @@ -7198,7 +7388,7 @@ checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" dependencies = [ "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", "synstructure", ] @@ -7238,14 +7428,14 @@ checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.114", + "syn 2.0.116", ] [[package]] name = "zlib-rs" -version = "0.5.5" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40990edd51aae2c2b6907af74ffb635029d5788228222c4bb811e9351c0caad3" +checksum = "a7948af682ccbc3342b6e9420e8c51c1fe5d7bf7756002b4a3c6cabfe96a7e3c" [[package]] name = "zmij" diff --git a/Cargo.toml b/Cargo.toml index 8ab3e2c535571..2186632113511 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -79,7 +79,7 @@ repository = "https://github.com/apache/datafusion" # Define Minimum Supported Rust Version (MSRV) rust-version = "1.88.0" # Define DataFusion version -version = "52.0.0" +version = "52.1.0" [workspace.dependencies] # We turn off default-features for some dependencies here so the workspaces which inherit them can @@ -91,89 +91,90 @@ ahash = { version = "0.8", default-features = false, features = [ "runtime-rng", ] } apache-avro = { version = "0.21", default-features = false } -arrow = { version = "57.1.0", features = [ +arrow = { version = "57.3.0", features = [ "prettyprint", "chrono-tz", ] } -arrow-buffer = { version = "57.1.0", default-features = false } -arrow-flight = { version = "57.1.0", features = [ +arrow-buffer = { version = "57.2.0", default-features = false } +arrow-flight = { version = "57.3.0", features = [ "flight-sql-experimental", ] } -arrow-ipc = { version = "57.1.0", default-features = false, features = [ +arrow-ipc = { version = "57.2.0", default-features = false, features = [ "lz4", ] } -arrow-ord = { version = "57.1.0", default-features = false } -arrow-schema = { version = "57.1.0", default-features = false } +arrow-ord = { version = "57.2.0", default-features = false } +arrow-schema = { version = "57.2.0", default-features = false } async-trait = "0.1.89" bigdecimal = "0.4.8" bytes = "1.11" bzip2 = "0.6.1" -chrono = { version = "0.4.42", default-features = false } +chrono = { version = "0.4.43", default-features = false } criterion = "0.8" ctor = "0.6.3" dashmap = "6.0.1" -datafusion = { path = "datafusion/core", version = "52.0.0", default-features = false } -datafusion-catalog = { path = "datafusion/catalog", version = "52.0.0" } -datafusion-catalog-listing = { path = "datafusion/catalog-listing", version = "52.0.0" } -datafusion-common = { path = "datafusion/common", version = "52.0.0", default-features = false } -datafusion-common-runtime = { path = "datafusion/common-runtime", version = "52.0.0" } -datafusion-datasource = { path = "datafusion/datasource", version = "52.0.0", default-features = false } -datafusion-datasource-arrow = { path = "datafusion/datasource-arrow", version = "52.0.0", default-features = false } -datafusion-datasource-avro = { path = "datafusion/datasource-avro", version = "52.0.0", default-features = false } -datafusion-datasource-csv = { path = "datafusion/datasource-csv", version = "52.0.0", default-features = false } -datafusion-datasource-json = { path = "datafusion/datasource-json", version = "52.0.0", default-features = false } -datafusion-datasource-parquet = { path = "datafusion/datasource-parquet", version = "52.0.0", default-features = false } -datafusion-doc = { path = "datafusion/doc", version = "52.0.0" } -datafusion-execution = { path = "datafusion/execution", version = "52.0.0", default-features = false } -datafusion-expr = { path = "datafusion/expr", version = "52.0.0", default-features = false } -datafusion-expr-common = { path = "datafusion/expr-common", version = "52.0.0" } -datafusion-ffi = { path = "datafusion/ffi", version = "52.0.0" } -datafusion-functions = { path = "datafusion/functions", version = "52.0.0" } -datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "52.0.0" } -datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "52.0.0" } -datafusion-functions-nested = { path = "datafusion/functions-nested", version = "52.0.0", default-features = false } -datafusion-functions-table = { path = "datafusion/functions-table", version = "52.0.0" } -datafusion-functions-window = { path = "datafusion/functions-window", version = "52.0.0" } -datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "52.0.0" } -datafusion-macros = { path = "datafusion/macros", version = "52.0.0" } -datafusion-optimizer = { path = "datafusion/optimizer", version = "52.0.0", default-features = false } -datafusion-physical-expr = { path = "datafusion/physical-expr", version = "52.0.0", default-features = false } -datafusion-physical-expr-adapter = { path = "datafusion/physical-expr-adapter", version = "52.0.0", default-features = false } -datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "52.0.0", default-features = false } -datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "52.0.0" } -datafusion-physical-plan = { path = "datafusion/physical-plan", version = "52.0.0" } -datafusion-proto = { path = "datafusion/proto", version = "52.0.0" } -datafusion-proto-common = { path = "datafusion/proto-common", version = "52.0.0" } -datafusion-pruning = { path = "datafusion/pruning", version = "52.0.0" } -datafusion-session = { path = "datafusion/session", version = "52.0.0" } -datafusion-spark = { path = "datafusion/spark", version = "52.0.0" } -datafusion-sql = { path = "datafusion/sql", version = "52.0.0" } -datafusion-substrait = { path = "datafusion/substrait", version = "52.0.0" } +datafusion = { path = "datafusion/core", version = "52.1.0", default-features = false } +datafusion-catalog = { path = "datafusion/catalog", version = "52.1.0" } +datafusion-catalog-listing = { path = "datafusion/catalog-listing", version = "52.1.0" } +datafusion-common = { path = "datafusion/common", version = "52.1.0", default-features = false } +datafusion-common-runtime = { path = "datafusion/common-runtime", version = "52.1.0" } +datafusion-datasource = { path = "datafusion/datasource", version = "52.1.0", default-features = false } +datafusion-datasource-arrow = { path = "datafusion/datasource-arrow", version = "52.1.0", default-features = false } +datafusion-datasource-avro = { path = "datafusion/datasource-avro", version = "52.1.0", default-features = false } +datafusion-datasource-csv = { path = "datafusion/datasource-csv", version = "52.1.0", default-features = false } +datafusion-datasource-json = { path = "datafusion/datasource-json", version = "52.1.0", default-features = false } +datafusion-datasource-parquet = { path = "datafusion/datasource-parquet", version = "52.1.0", default-features = false } +datafusion-doc = { path = "datafusion/doc", version = "52.1.0" } +datafusion-execution = { path = "datafusion/execution", version = "52.1.0", default-features = false } +datafusion-expr = { path = "datafusion/expr", version = "52.1.0", default-features = false } +datafusion-expr-common = { path = "datafusion/expr-common", version = "52.1.0" } +datafusion-ffi = { path = "datafusion/ffi", version = "52.1.0" } +datafusion-functions = { path = "datafusion/functions", version = "52.1.0" } +datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "52.1.0" } +datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "52.1.0" } +datafusion-functions-nested = { path = "datafusion/functions-nested", version = "52.1.0", default-features = false } +datafusion-functions-table = { path = "datafusion/functions-table", version = "52.1.0" } +datafusion-functions-window = { path = "datafusion/functions-window", version = "52.1.0" } +datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "52.1.0" } +datafusion-macros = { path = "datafusion/macros", version = "52.1.0" } +datafusion-optimizer = { path = "datafusion/optimizer", version = "52.1.0", default-features = false } +datafusion-physical-expr = { path = "datafusion/physical-expr", version = "52.1.0", default-features = false } +datafusion-physical-expr-adapter = { path = "datafusion/physical-expr-adapter", version = "52.1.0", default-features = false } +datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "52.1.0", default-features = false } +datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "52.1.0" } +datafusion-physical-plan = { path = "datafusion/physical-plan", version = "52.1.0" } +datafusion-proto = { path = "datafusion/proto", version = "52.1.0" } +datafusion-proto-common = { path = "datafusion/proto-common", version = "52.1.0" } +datafusion-pruning = { path = "datafusion/pruning", version = "52.1.0" } +datafusion-session = { path = "datafusion/session", version = "52.1.0" } +datafusion-spark = { path = "datafusion/spark", version = "52.1.0" } +datafusion-sql = { path = "datafusion/sql", version = "52.1.0" } +datafusion-substrait = { path = "datafusion/substrait", version = "52.1.0" } doc-comment = "0.3" env_logger = "0.11" -flate2 = "1.1.5" +flate2 = "1.1.9" futures = "0.3" glob = "0.3.0" half = { version = "2.7.0", default-features = false } hashbrown = { version = "0.16.1" } hex = { version = "0.4.3" } indexmap = "2.13.0" -insta = { version = "1.46.0", features = ["glob", "filters"] } +insta = { version = "1.46.3", features = ["glob", "filters"] } itertools = "0.14" -liblzma = { version = "0.4.4", features = ["static"] } +liblzma = { version = "0.4.6", features = ["static"] } log = "^0.4" +memchr = "2.8.0" num-traits = { version = "0.2" } -object_store = { version = "0.12.4", default-features = false } +object_store = { version = "0.12.5", default-features = false } parking_lot = "0.12" -parquet = { version = "57.1.0", default-features = false, features = [ +parquet = { version = "57.3.0", default-features = false, features = [ "arrow", "async", "object_store", ] } paste = "1.0.15" -pbjson = { version = "0.8.0" } -pbjson-types = "0.8" +pbjson = { version = "0.9.0" } +pbjson-types = "0.9" # Should match arrow-flight's version of prost. prost = "0.14.1" rand = "0.9" @@ -181,13 +182,17 @@ recursive = "0.1.1" regex = "1.12" rstest = "0.26.1" serde_json = "1" -sqlparser = { version = "0.59.0", default-features = false, features = ["std", "visitor"] } +sha2 = "^0.10.9" +sqlparser = { version = "0.60.0", default-features = false, features = ["std", "visitor"] } strum = "0.27.2" strum_macros = "0.27.2" tempfile = "3" testcontainers-modules = { version = "0.14" } tokio = { version = "1.48", features = ["macros", "rt", "sync"] } +tokio-stream = "0.1" +tokio-util = "0.7" url = "2.5.7" +uuid = "1.21" zstd = { version = "0.13", default-features = false } [workspace.lints.clippy] @@ -200,6 +205,8 @@ uninlined_format_args = "warn" inefficient_to_string = "warn" # https://github.com/apache/datafusion/issues/18503 needless_pass_by_value = "warn" +# https://github.com/apache/datafusion/issues/18881 +allow_attributes = "warn" [workspace.lints.rust] unexpected_cfgs = { level = "warn", check-cfg = [ diff --git a/README.md b/README.md index 880adfb3ac392..630d4295bd427 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ DataFusion is an extensible query engine written in [Rust] that uses [Apache Arrow] as its in-memory format. This crate provides libraries and binaries for developers building fast and -feature rich database and analytic systems, customized to particular workloads. +feature-rich database and analytic systems, customized for particular workloads. See [use cases] for examples. The following related subprojects target end users: - [DataFusion Python](https://github.com/apache/datafusion-python/) offers a Python interface for SQL and DataFrame @@ -67,7 +67,7 @@ See [use cases] for examples. The following related subprojects target end users DataFusion. "Out of the box," -DataFusion offers [SQL](https://datafusion.apache.org/user-guide/sql/index.html) and [Dataframe](https://datafusion.apache.org/user-guide/dataframe.html) APIs, excellent [performance], +DataFusion offers [SQL](https://datafusion.apache.org/user-guide/sql/index.html) and [DataFrame](https://datafusion.apache.org/user-guide/dataframe.html) APIs, excellent [performance], built-in support for CSV, Parquet, JSON, and Avro, extensive customization, and a great community. @@ -84,7 +84,7 @@ See the [Architecture] section for more details. [performance]: https://benchmark.clickhouse.com/ [architecture]: https://datafusion.apache.org/contributor-guide/architecture.html -Here are links to some important information +Here are links to important resources: - [Project Site](https://datafusion.apache.org/) - [Installation](https://datafusion.apache.org/user-guide/cli/installation.html) @@ -97,8 +97,8 @@ Here are links to some important information ## What can you do with this crate? -DataFusion is great for building projects such as domain specific query engines, new database platforms and data pipelines, query languages and more. -It lets you start quickly from a fully working engine, and then customize those features specific to your use. [Click Here](https://datafusion.apache.org/user-guide/introduction.html#known-users) to see a list known users. +DataFusion is great for building projects such as domain-specific query engines, new database platforms and data pipelines, query languages and more. +It lets you start quickly from a fully working engine, and then customize those features specific to your needs. See the [list of known users](https://datafusion.apache.org/user-guide/introduction.html#known-users). ## Contributing to DataFusion @@ -115,15 +115,15 @@ This crate has several [features] which can be specified in your `Cargo.toml`. Default features: -- `nested_expressions`: functions for working with nested type function such as `array_to_string` +- `nested_expressions`: functions for working with nested types such as `array_to_string` - `compression`: reading files compressed with `xz2`, `bzip2`, `flate2`, and `zstd` - `crypto_expressions`: cryptographic functions such as `md5` and `sha256` - `datetime_expressions`: date and time functions such as `to_timestamp` - `encoding_expressions`: `encode` and `decode` functions - `parquet`: support for reading the [Apache Parquet] format -- `sql`: Support for sql parsing / planning +- `sql`: support for SQL parsing and planning - `regex_expressions`: regular expression functions, such as `regexp_match` -- `unicode_expressions`: Include unicode aware functions such as `character_length` +- `unicode_expressions`: include Unicode-aware functions such as `character_length` - `unparser`: enables support to reverse LogicalPlans back into SQL - `recursive_protection`: uses [recursive](https://docs.rs/recursive/latest/recursive/) for stack overflow protection. diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index df04f56235ec3..a07be54948e86 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -40,7 +40,7 @@ mimalloc_extended = ["libmimalloc-sys/extended"] [dependencies] arrow = { workspace = true } -clap = { version = "4.5.53", features = ["derive"] } +clap = { version = "4.5.59", features = ["derive"] } datafusion = { workspace = true, default-features = true } datafusion-common = { workspace = true, default-features = true } env_logger = { workspace = true } diff --git a/benchmarks/bench.sh b/benchmarks/bench.sh index fd58c17f8ab40..e7f643a5d51d5 100755 --- a/benchmarks/bench.sh +++ b/benchmarks/bench.sh @@ -683,7 +683,7 @@ run_tpch_mem() { # Runs the tpcds benchmark run_tpcds() { - TPCDS_DIR="${DATA_DIR}" + TPCDS_DIR="${DATA_DIR}/tpcds_sf1" # Check if TPCDS data directory and representative file exists if [ ! -f "${TPCDS_DIR}/web_site.parquet" ]; then diff --git a/benchmarks/src/clickbench.rs b/benchmarks/src/clickbench.rs index a9da57b02ae32..c0f911c566f4d 100644 --- a/benchmarks/src/clickbench.rs +++ b/benchmarks/src/clickbench.rs @@ -29,6 +29,16 @@ use datafusion::{ use datafusion_common::exec_datafusion_err; use datafusion_common::instant::Instant; +/// SQL to create the hits view with proper EventDate casting. +/// +/// ClickBench stores EventDate as UInt16 (days since 1970-01-01) for +/// storage efficiency (2 bytes vs 4-8 bytes for date types). +/// This view transforms it to SQL DATE type for query compatibility. +const HITS_VIEW_DDL: &str = r#"CREATE VIEW hits AS +SELECT * EXCEPT ("EventDate"), + CAST(CAST("EventDate" AS INTEGER) AS DATE) AS "EventDate" +FROM hits_raw"#; + /// Driver program to run the ClickBench benchmark /// /// The ClickBench[1] benchmarks are widely cited in the industry and @@ -295,7 +305,7 @@ impl RunOpt { // Build CREATE EXTERNAL TABLE DDL with WITH ORDER clause // Schema will be automatically inferred from the Parquet file let create_table_sql = format!( - "CREATE EXTERNAL TABLE hits \ + "CREATE EXTERNAL TABLE hits_raw \ STORED AS PARQUET \ LOCATION '{}' \ WITH ORDER ({} {})", @@ -308,20 +318,34 @@ impl RunOpt { // Execute the CREATE EXTERNAL TABLE statement ctx.sql(&create_table_sql).await?.collect().await?; - - Ok(()) } else { // Original registration without sort order let options = Default::default(); - ctx.register_parquet("hits", path, options) + ctx.register_parquet("hits_raw", path, options) .await .map_err(|e| { DataFusionError::Context( - format!("Registering 'hits' as {path}"), + format!("Registering 'hits_raw' as {path}"), Box::new(e), ) - }) + })?; } + + // Create the hits view with EventDate transformation + Self::create_hits_view(ctx).await + } + + /// Creates the hits view with EventDate transformation from UInt16 to DATE. + /// + /// ClickBench encodes EventDate as UInt16 days since epoch (1970-01-01). + async fn create_hits_view(ctx: &SessionContext) -> Result<()> { + ctx.sql(HITS_VIEW_DDL).await?.collect().await.map_err(|e| { + DataFusionError::Context( + "Creating 'hits' view with EventDate transformation".to_string(), + Box::new(e), + ) + })?; + Ok(()) } fn iterations(&self) -> usize { diff --git a/ci/scripts/check_examples_docs.sh b/ci/scripts/check_examples_docs.sh index 37b0cc088df4c..62308b323b535 100755 --- a/ci/scripts/check_examples_docs.sh +++ b/ci/scripts/check_examples_docs.sh @@ -17,48 +17,61 @@ # specific language governing permissions and limitations # under the License. -set -euo pipefail - -EXAMPLES_DIR="datafusion-examples/examples" -README="datafusion-examples/README.md" +# Generates documentation for DataFusion examples using the Rust-based +# documentation generator and verifies that the committed README.md +# is up to date. +# +# The README is generated from documentation comments in: +# datafusion-examples/examples//main.rs +# +# This script is intended to be run in CI to ensure that example +# documentation stays in sync with the code. +# +# To update the README locally, run this script and replace README.md +# with the generated output. -# ffi examples are skipped because they were not part of the recent example -# consolidation work and do not follow the new grouping and execution pattern. -# They are not documented in the README using the new structure, so including -# them here would cause false CI failures. -SKIP_LIST=("ffi") +set -euo pipefail -missing=0 +ROOT_DIR="$(git rev-parse --show-toplevel)" -skip() { - local value="$1" - for item in "${SKIP_LIST[@]}"; do - if [[ "$item" == "$value" ]]; then - return 0 - fi - done - return 1 -} +# Load centralized tool versions +source "${ROOT_DIR}/ci/scripts/utils/tool_versions.sh" -# collect folder names -folders=$(find "$EXAMPLES_DIR" -mindepth 1 -maxdepth 1 -type d -exec basename {} \;) +EXAMPLES_DIR="$ROOT_DIR/datafusion-examples" +README="$EXAMPLES_DIR/README.md" +README_NEW="$EXAMPLES_DIR/README-NEW.md" -# collect group names from README headers -groups=$(grep "^### Group:" "$README" | sed -E 's/^### Group: `([^`]+)`.*/\1/') +echo "▶ Generating examples README (Rust generator)…" +cargo run --quiet \ + --manifest-path "$EXAMPLES_DIR/Cargo.toml" \ + --bin examples-docs \ + > "$README_NEW" -for folder in $folders; do - if skip "$folder"; then - echo "Skipped group: $folder" - continue - fi +echo "▶ Formatting generated README with prettier ${PRETTIER_VERSION}…" +npx "prettier@${PRETTIER_VERSION}" \ + --parser markdown \ + --write "$README_NEW" - if ! echo "$groups" | grep -qx "$folder"; then - echo "Missing README entry for example group: $folder" - missing=1 - fi -done +echo "▶ Comparing generated README with committed version…" -if [[ $missing -eq 1 ]]; then - echo "README is out of sync with examples" - exit 1 +if ! diff -u "$README" "$README_NEW" > /tmp/examples-readme.diff; then + echo "" + echo "❌ Examples README is out of date." + echo "" + echo "The examples documentation is generated automatically from:" + echo " - datafusion-examples/examples//main.rs" + echo "" + echo "To update the README locally, run:" + echo "" + echo " cargo run --bin examples-docs \\" + echo " | npx prettier@${PRETTIER_VERSION} --parser markdown --write \\" + echo " > datafusion-examples/README.md" + echo "" + echo "Diff:" + echo "------------------------------------------------------------" + cat /tmp/examples-readme.diff + echo "------------------------------------------------------------" + exit 1 fi + +echo "✅ Examples README is up-to-date." diff --git a/ci/scripts/doc_prettier_check.sh b/ci/scripts/doc_prettier_check.sh index d94a0d1c96171..95332eb65aaf2 100755 --- a/ci/scripts/doc_prettier_check.sh +++ b/ci/scripts/doc_prettier_check.sh @@ -17,41 +17,70 @@ # specific language governing permissions and limitations # under the License. -SCRIPT_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)/$(basename "${BASH_SOURCE[0]}")" - -MODE="--check" -ACTION="Checking" -if [ $# -gt 0 ]; then - if [ "$1" = "--write" ]; then - MODE="--write" - ACTION="Formatting" - else - echo "Usage: $0 [--write]" >&2 - exit 1 - fi +set -euo pipefail + +ROOT_DIR="$(git rev-parse --show-toplevel)" +SCRIPT_NAME="$(basename "${BASH_SOURCE[0]}")" + +# Load shared utilities and tool versions +source "${ROOT_DIR}/ci/scripts/utils/tool_versions.sh" +source "${ROOT_DIR}/ci/scripts/utils/git.sh" + +PRETTIER_TARGETS=( + '{datafusion,datafusion-cli,datafusion-examples,dev,docs}/**/*.md' + '!datafusion/CHANGELOG.md' + README.md + CONTRIBUTING.md +) + +MODE="check" +ALLOW_DIRTY=0 + +usage() { + cat >&2 </dev/null 2>&1; then echo "npx is required to run the prettier check. Install Node.js (e.g., brew install node) and re-run." >&2 exit 1 fi - -# Ignore subproject CHANGELOG.md because it is machine generated -npx prettier@2.7.1 $MODE \ - '{datafusion,datafusion-cli,datafusion-examples,dev,docs}/**/*.md' \ - '!datafusion/CHANGELOG.md' \ - README.md \ - CONTRIBUTING.md -status=$? - -if [ $status -ne 0 ]; then - if [ "$MODE" = "--check" ]; then - echo "Prettier check failed. Re-run with --write (e.g., ./ci/scripts/doc_prettier_check.sh --write) to format files, commit the changes, and re-run the check." >&2 - else - echo "Prettier format failed. Files may have been modified; commit any changes and re-run." >&2 - fi - exit $status + +PRETTIER_MODE=(--check) +if [[ "$MODE" == "write" ]]; then + PRETTIER_MODE=(--write) fi + +# Ignore subproject CHANGELOG.md because it is machine generated +npx "prettier@${PRETTIER_VERSION}" "${PRETTIER_MODE[@]}" "${PRETTIER_TARGETS[@]}" diff --git a/ci/scripts/license_header.sh b/ci/scripts/license_header.sh index 5345728f9cdf0..7ab8c9637598b 100755 --- a/ci/scripts/license_header.sh +++ b/ci/scripts/license_header.sh @@ -17,6 +17,62 @@ # specific language governing permissions and limitations # under the License. -# Check Apache license header -set -ex -hawkeye check --config licenserc.toml +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +SCRIPT_NAME="$(basename "${BASH_SOURCE[0]}")" + +source "${SCRIPT_DIR}/utils/git.sh" + +MODE="check" +ALLOW_DIRTY=0 +HAWKEYE_CONFIG="licenserc.toml" + +usage() { + cat >&2 <&2 <&2 <&2 <&2 <&2 + return 1 + fi +} diff --git a/datafusion/sqllogictest/test_files/spark/string/unbase64.slt b/ci/scripts/utils/tool_versions.sh similarity index 55% rename from datafusion/sqllogictest/test_files/spark/string/unbase64.slt rename to ci/scripts/utils/tool_versions.sh index 5cf3fbee0455d..ac731ed0d5341 100644 --- a/datafusion/sqllogictest/test_files/spark/string/unbase64.slt +++ b/ci/scripts/utils/tool_versions.sh @@ -1,3 +1,5 @@ +#!/usr/bin/env bash +# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -5,9 +7,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at - +# # http://www.apache.org/licenses/LICENSE-2.0 - +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -15,13 +17,7 @@ # specific language governing permissions and limitations # under the License. -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 +# This file defines centralized tool versions used by CI and development scripts. +# It is intended to be sourced by other scripts and should not be executed directly. -## Original Query: SELECT unbase64('U3BhcmsgU1FM'); -## PySpark 3.5.5 Result: {'unbase64(U3BhcmsgU1FM)': bytearray(b'Spark SQL'), 'typeof(unbase64(U3BhcmsgU1FM))': 'binary', 'typeof(U3BhcmsgU1FM)': 'string'} -#query -#SELECT unbase64('U3BhcmsgU1FM'::string); +PRETTIER_VERSION="2.7.1" diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index 31941d87165a6..c58b9d9061863 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -37,10 +37,10 @@ backtrace = ["datafusion/backtrace"] [dependencies] arrow = { workspace = true } async-trait = { workspace = true } -aws-config = "1.8.12" +aws-config = "1.8.14" aws-credential-types = "1.2.7" chrono = { workspace = true } -clap = { version = "4.5.53", features = ["cargo", "derive"] } +clap = { version = "4.5.59", features = ["cargo", "derive"] } datafusion = { workspace = true, features = [ "avro", "compression", diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index 2b8385ac2d89c..94bd8ee2c4f9d 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -269,7 +269,7 @@ impl StatementExecutor { let options = task_ctx.session_config().options(); // Track memory usage for the query result if it's bounded - let mut reservation = + let reservation = MemoryConsumer::new("DataFusion-Cli").register(task_ctx.memory_pool()); if physical_plan.boundedness().is_unbounded() { @@ -300,7 +300,7 @@ impl StatementExecutor { let curr_num_rows = batch.num_rows(); // Stop collecting results if the number of rows exceeds the limit // results batch should include the last batch that exceeds the limit - if row_count < max_rows + curr_num_rows { + if row_count < max_rows.saturating_add(curr_num_rows) { // Try to grow the reservation to accommodate the batch in memory reservation.try_grow(get_record_batch_memory_size(&batch))?; results.push(batch); diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index 6a97c5355ffc7..cef057545c113 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -426,7 +426,7 @@ impl TableFunctionImpl for ParquetMetadataFunc { compression_arr.push(format!("{:?}", column.compression())); // need to collect into Vec to format let encodings: Vec<_> = column.encodings().collect(); - encodings_arr.push(format!("{:?}", encodings)); + encodings_arr.push(format!("{encodings:?}")); index_page_offset_arr.push(column.index_page_offset()); dictionary_page_offset_arr.push(column.dictionary_page_offset()); data_page_offset_arr.push(column.data_page_offset()); @@ -703,10 +703,13 @@ impl TableFunctionImpl for StatisticsCacheFunc { } } -// Implementation of the `list_files_cache` table function in datafusion-cli. +/// Implementation of the `list_files_cache` table function in datafusion-cli. +/// +/// This function returns the cached results of running a LIST command on a +/// particular object store path for a table. The object metadata is returned as +/// a List of Structs, with one Struct for each object. DataFusion uses these +/// cached results to plan queries against external tables. /// -/// This function returns the cached results of running a LIST command on a particular object store path for a table. The object metadata is returned as a List of Structs, with one Struct for each object. -/// DataFusion uses these cached results to plan queries against external tables. /// # Schema /// ```sql /// > describe select * from list_files_cache(); @@ -788,7 +791,7 @@ impl TableFunctionImpl for ListFilesCacheFunc { Field::new("metadata", DataType::Struct(nested_fields.clone()), true); let schema = Arc::new(Schema::new(vec![ - Field::new("table", DataType::Utf8, false), + Field::new("table", DataType::Utf8, true), Field::new("path", DataType::Utf8, false), Field::new("metadata_size_bytes", DataType::UInt64, false), // expires field in ListFilesEntry has type Instant when set, from which we cannot get "the number of seconds", hence using Duration instead of Timestamp as data type. @@ -821,7 +824,7 @@ impl TableFunctionImpl for ListFilesCacheFunc { let mut current_offset: i32 = 0; for (path, entry) in list_files_cache.list_entries() { - table_arr.push(path.table.map_or("NULL".to_string(), |t| t.to_string())); + table_arr.push(path.table.map(|t| t.to_string())); path_arr.push(path.path.to_string()); metadata_size_bytes_arr.push(entry.size_bytes as u64); // calculates time left before entry expires diff --git a/datafusion-cli/src/helper.rs b/datafusion-cli/src/helper.rs index df7afc14048b9..c53272ee196ce 100644 --- a/datafusion-cli/src/helper.rs +++ b/datafusion-cli/src/helper.rs @@ -20,7 +20,7 @@ use std::borrow::Cow; -use crate::highlighter::{NoSyntaxHighlighter, SyntaxHighlighter}; +use crate::highlighter::{Color, NoSyntaxHighlighter, SyntaxHighlighter}; use datafusion::sql::parser::{DFParser, Statement}; use datafusion::sql::sqlparser::dialect::dialect_from_str; @@ -33,6 +33,9 @@ use rustyline::hint::Hinter; use rustyline::validate::{ValidationContext, ValidationResult, Validator}; use rustyline::{Context, Helper, Result}; +/// Default suggestion shown when the input line is empty. +const DEFAULT_HINT_SUGGESTION: &str = " \\? for help, \\q to quit"; + pub struct CliHelper { completer: FilenameCompleter, dialect: Dialect, @@ -114,6 +117,15 @@ impl Highlighter for CliHelper { impl Hinter for CliHelper { type Hint = String; + + fn hint(&self, line: &str, _pos: usize, _ctx: &Context<'_>) -> Option { + if line.trim().is_empty() { + let suggestion = Color::gray(DEFAULT_HINT_SUGGESTION); + Some(suggestion) + } else { + None + } + } } /// returns true if the current position is after the open quote for diff --git a/datafusion-cli/src/highlighter.rs b/datafusion-cli/src/highlighter.rs index 912a13916a5bd..0a2a2e6c14f03 100644 --- a/datafusion-cli/src/highlighter.rs +++ b/datafusion-cli/src/highlighter.rs @@ -80,16 +80,20 @@ impl Highlighter for SyntaxHighlighter { } /// Convenient utility to return strings with [ANSI color](https://gist.github.com/JBlond/2fea43a3049b38287e5e9cefc87b2124). -struct Color {} +pub(crate) struct Color {} impl Color { - fn green(s: impl Display) -> String { + pub(crate) fn green(s: impl Display) -> String { format!("\x1b[92m{s}\x1b[0m") } - fn red(s: impl Display) -> String { + pub(crate) fn red(s: impl Display) -> String { format!("\x1b[91m{s}\x1b[0m") } + + pub(crate) fn gray(s: impl Display) -> String { + format!("\x1b[90m{s}\x1b[0m") + } } #[cfg(test)] diff --git a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@no_track.snap b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@no_track.snap index 89b646a531f8b..fe454595eb4bc 100644 --- a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@no_track.snap +++ b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@no_track.snap @@ -14,7 +14,7 @@ success: false exit_code: 1 ----- stdout ----- [CLI_VERSION] -Error: Not enough memory to continue external sort. Consider increasing the memory limit, or decreasing sort_spill_reservation_bytes +Error: Not enough memory to continue external sort. Consider increasing the memory limit config: 'datafusion.runtime.memory_limit', or decreasing the config: 'datafusion.execution.sort_spill_reservation_bytes'. caused by Resources exhausted: Failed to allocate diff --git a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top2.snap b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top2.snap index 62f864b3adb6e..bb30e387166bc 100644 --- a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top2.snap +++ b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top2.snap @@ -14,7 +14,7 @@ success: false exit_code: 1 ----- stdout ----- [CLI_VERSION] -Error: Not enough memory to continue external sort. Consider increasing the memory limit, or decreasing sort_spill_reservation_bytes +Error: Not enough memory to continue external sort. Consider increasing the memory limit config: 'datafusion.runtime.memory_limit', or decreasing the config: 'datafusion.execution.sort_spill_reservation_bytes'. caused by Resources exhausted: Additional allocation failed for ExternalSorter[0] with top memory consumers (across reservations) as: Consumer(can spill: bool) consumed XB, peak XB, diff --git a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top3_default.snap b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top3_default.snap index 9845d095c9180..891d72e3cc639 100644 --- a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top3_default.snap +++ b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top3_default.snap @@ -12,7 +12,7 @@ success: false exit_code: 1 ----- stdout ----- [CLI_VERSION] -Error: Not enough memory to continue external sort. Consider increasing the memory limit, or decreasing sort_spill_reservation_bytes +Error: Not enough memory to continue external sort. Consider increasing the memory limit config: 'datafusion.runtime.memory_limit', or decreasing the config: 'datafusion.execution.sort_spill_reservation_bytes'. caused by Resources exhausted: Additional allocation failed for ExternalSorter[0] with top memory consumers (across reservations) as: Consumer(can spill: bool) consumed XB, peak XB, diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index 0c632d92f6e37..e56f5ad6b8ca7 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -40,8 +40,9 @@ arrow = { workspace = true } arrow-schema = { workspace = true } datafusion = { workspace = true, default-features = true, features = ["parquet_encryption"] } datafusion-common = { workspace = true } +nom = "8.0.0" tempfile = { workspace = true } -tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot"] } +tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot", "fs"] } [dev-dependencies] arrow-flight = { workspace = true } @@ -62,6 +63,7 @@ mimalloc = { version = "0.1", default-features = false } object_store = { workspace = true, features = ["aws", "http"] } prost = { workspace = true } rand = { workspace = true } +serde = { version = "1", features = ["derive"] } serde_json = { workspace = true } strum = { workspace = true } strum_macros = { workspace = true } @@ -70,7 +72,7 @@ tonic = "0.14" tracing = { version = "0.1" } tracing-subscriber = { version = "0.3" } url = { workspace = true } -uuid = "1.19" +uuid = { workspace = true } [target.'cfg(not(target_os = "windows"))'.dev-dependencies] -nix = { version = "0.30.1", features = ["fs"] } +nix = { version = "0.31.1", features = ["fs"] } diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 61afbf6682bea..2cf0ec52409f8 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -71,15 +71,16 @@ cargo run --example dataframe -- dataframe #### Category: Single Process -| Subcommand | File Path | Description | -| --------------------- | ----------------------------------------------------------------------------------------------------- | --------------------------------------------- | -| csv_sql_streaming | [`custom_data_source/csv_sql_streaming.rs`](examples/custom_data_source/csv_sql_streaming.rs) | Run a streaming SQL query against CSV data | -| csv_json_opener | [`custom_data_source/csv_json_opener.rs`](examples/custom_data_source/csv_json_opener.rs) | Use low-level FileOpener APIs for CSV/JSON | -| custom_datasource | [`custom_data_source/custom_datasource.rs`](examples/custom_data_source/custom_datasource.rs) | Query a custom TableProvider | -| custom_file_casts | [`custom_data_source/custom_file_casts.rs`](examples/custom_data_source/custom_file_casts.rs) | Implement custom casting rules | -| custom_file_format | [`custom_data_source/custom_file_format.rs`](examples/custom_data_source/custom_file_format.rs) | Write to a custom file format | -| default_column_values | [`custom_data_source/default_column_values.rs`](examples/custom_data_source/default_column_values.rs) | Custom default values using metadata | -| file_stream_provider | [`custom_data_source/file_stream_provider.rs`](examples/custom_data_source/file_stream_provider.rs) | Read/write via FileStreamProvider for streams | +| Subcommand | File Path | Description | +| --------------------- | ----------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------- | +| adapter_serialization | [`custom_data_source/adapter_serialization.rs`](examples/custom_data_source/adapter_serialization.rs) | Preserve custom PhysicalExprAdapter information during plan serialization using PhysicalExtensionCodec interception | +| csv_json_opener | [`custom_data_source/csv_json_opener.rs`](examples/custom_data_source/csv_json_opener.rs) | Use low-level FileOpener APIs for CSV/JSON | +| csv_sql_streaming | [`custom_data_source/csv_sql_streaming.rs`](examples/custom_data_source/csv_sql_streaming.rs) | Run a streaming SQL query against CSV data | +| custom_datasource | [`custom_data_source/custom_datasource.rs`](examples/custom_data_source/custom_datasource.rs) | Query a custom TableProvider | +| custom_file_casts | [`custom_data_source/custom_file_casts.rs`](examples/custom_data_source/custom_file_casts.rs) | Implement custom casting rules | +| custom_file_format | [`custom_data_source/custom_file_format.rs`](examples/custom_data_source/custom_file_format.rs) | Write to a custom file format | +| default_column_values | [`custom_data_source/default_column_values.rs`](examples/custom_data_source/default_column_values.rs) | Custom default values using metadata | +| file_stream_provider | [`custom_data_source/file_stream_provider.rs`](examples/custom_data_source/file_stream_provider.rs) | Read/write via FileStreamProvider for streams | ## Data IO Examples @@ -143,8 +144,8 @@ cargo run --example dataframe -- dataframe | Subcommand | File Path | Description | | ---------- | ------------------------------------------------------- | ------------------------------------------------------ | -| server | [`flight/server.rs`](examples/flight/server.rs) | Run DataFusion server accepting FlightSQL/JDBC queries | | client | [`flight/client.rs`](examples/flight/client.rs) | Execute SQL queries via Arrow Flight protocol | +| server | [`flight/server.rs`](examples/flight/server.rs) | Run DataFusion server accepting FlightSQL/JDBC queries | | sql_server | [`flight/sql_server.rs`](examples/flight/sql_server.rs) | Standalone SQL server for JDBC clients | ## Proto Examples @@ -153,9 +154,10 @@ cargo run --example dataframe -- dataframe #### Category: Single Process -| Subcommand | File Path | Description | -| ------------------------ | --------------------------------------------------------------------------------- | --------------------------------------------------------------- | -| composed_extension_codec | [`proto/composed_extension_codec.rs`](examples/proto/composed_extension_codec.rs) | Use multiple extension codecs for serialization/deserialization | +| Subcommand | File Path | Description | +| ------------------------ | --------------------------------------------------------------------------------- | ----------------------------------------------------------------------------- | +| composed_extension_codec | [`proto/composed_extension_codec.rs`](examples/proto/composed_extension_codec.rs) | Use multiple extension codecs for serialization/deserialization | +| expression_deduplication | [`proto/expression_deduplication.rs`](examples/proto/expression_deduplication.rs) | Example of expression caching/deduplication using the codec decorator pattern | ## Query Planning Examples diff --git a/datafusion-examples/examples/builtin_functions/main.rs b/datafusion-examples/examples/builtin_functions/main.rs index 638f56dfbe463..42ca15f91935d 100644 --- a/datafusion-examples/examples/builtin_functions/main.rs +++ b/datafusion-examples/examples/builtin_functions/main.rs @@ -26,9 +26,15 @@ //! //! Each subcommand runs a corresponding example: //! - `all` — run all examples included in this module -//! - `date_time` — examples of date-time related functions and queries -//! - `function_factory` — register `CREATE FUNCTION` handler to implement SQL macros -//! - `regexp` — examples of using regular expression functions +//! +//! - `date_time` +//! (file: date_time.rs, desc: Examples of date-time related functions and queries) +//! +//! - `function_factory` +//! (file: function_factory.rs, desc: Register `CREATE FUNCTION` handler to implement SQL macros) +//! +//! - `regexp` +//! (file: regexp.rs, desc: Examples of using regular expression functions) mod date_time; mod function_factory; diff --git a/datafusion-examples/examples/custom_data_source/adapter_serialization.rs b/datafusion-examples/examples/custom_data_source/adapter_serialization.rs new file mode 100644 index 0000000000000..f19d628fa8bee --- /dev/null +++ b/datafusion-examples/examples/custom_data_source/adapter_serialization.rs @@ -0,0 +1,519 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. +//! +//! This example demonstrates how to use the `PhysicalExtensionCodec` trait's +//! interception methods (`serialize_physical_plan` and `deserialize_physical_plan`) +//! to implement custom serialization logic. +//! +//! The key insight is that `FileScanConfig::expr_adapter_factory` is NOT serialized by +//! default. This example shows how to: +//! 1. Detect plans with custom adapters during serialization +//! 2. Wrap them as Extension nodes with JSON-serialized adapter metadata +//! 3. Store the inner DataSourceExec (without adapter) as a child in the extension's inputs field +//! 4. Unwrap and restore the adapter during deserialization +//! +//! This demonstrates nested serialization (protobuf outer, JSON inner) and the power +//! of the `PhysicalExtensionCodec` interception pattern. Both plan and expression +//! serialization route through the codec, enabling interception at every node in the tree. + +use std::fmt::Debug; +use std::sync::Arc; + +use arrow::array::record_batch; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::assert_batches_eq; +use datafusion::common::{Result, not_impl_err}; +use datafusion::datasource::listing::{ + ListingTable, ListingTableConfig, ListingTableConfigExt, ListingTableUrl, +}; +use datafusion::datasource::physical_plan::{FileScanConfig, FileScanConfigBuilder}; +use datafusion::datasource::source::DataSourceExec; +use datafusion::execution::TaskContext; +use datafusion::execution::context::SessionContext; +use datafusion::execution::object_store::ObjectStoreUrl; +use datafusion::parquet::arrow::ArrowWriter; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::SessionConfig; +use datafusion_physical_expr_adapter::{ + DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, +}; +use datafusion_proto::bytes::{ + physical_plan_from_bytes_with_proto_converter, + physical_plan_to_bytes_with_proto_converter, +}; +use datafusion_proto::physical_plan::from_proto::parse_physical_expr_with_converter; +use datafusion_proto::physical_plan::to_proto::serialize_physical_expr_with_converter; +use datafusion_proto::physical_plan::{ + PhysicalExtensionCodec, PhysicalProtoConverterExtension, +}; +use datafusion_proto::protobuf::physical_plan_node::PhysicalPlanType; +use datafusion_proto::protobuf::{ + PhysicalExprNode, PhysicalExtensionNode, PhysicalPlanNode, +}; +use object_store::memory::InMemory; +use object_store::path::Path; +use object_store::{ObjectStore, PutPayload}; +use serde::{Deserialize, Serialize}; + +/// Example showing how to preserve custom adapter information during plan serialization. +/// +/// This demonstrates: +/// 1. Creating a custom PhysicalExprAdapter with metadata +/// 2. Using PhysicalExtensionCodec to intercept serialization +/// 3. Wrapping adapter info as Extension nodes +/// 4. Restoring adapters during deserialization +pub async fn adapter_serialization() -> Result<()> { + println!("=== PhysicalExprAdapter Serialization Example ===\n"); + + // Step 1: Create sample Parquet data in memory + println!("Step 1: Creating sample Parquet data..."); + let store = Arc::new(InMemory::new()) as Arc; + let batch = record_batch!(("id", Int32, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))?; + let path = Path::from("data.parquet"); + write_parquet(&store, &path, &batch).await?; + + // Step 2: Set up session with custom adapter + println!("Step 2: Setting up session with custom adapter..."); + let logical_schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + + let mut cfg = SessionConfig::new(); + cfg.options_mut().execution.parquet.pushdown_filters = true; + let ctx = SessionContext::new_with_config(cfg); + ctx.runtime_env().register_object_store( + ObjectStoreUrl::parse("memory://")?.as_ref(), + Arc::clone(&store), + ); + + // Create a table with our custom MetadataAdapterFactory + let adapter_factory = Arc::new(MetadataAdapterFactory::new("v1")); + let listing_config = + ListingTableConfig::new(ListingTableUrl::parse("memory:///data.parquet")?) + .infer_options(&ctx.state()) + .await? + .with_schema(logical_schema) + .with_expr_adapter_factory( + Arc::clone(&adapter_factory) as Arc + ); + let table = ListingTable::try_new(listing_config)?; + ctx.register_table("my_table", Arc::new(table))?; + + // Step 3: Create physical plan with filter + println!("Step 3: Creating physical plan with filter..."); + let df = ctx.sql("SELECT * FROM my_table WHERE id > 5").await?; + let original_plan = df.create_physical_plan().await?; + + // Verify adapter is present in original plan + let has_adapter_before = verify_adapter_in_plan(&original_plan, "original"); + println!(" Original plan has adapter: {has_adapter_before}"); + + // Step 4: Serialize with our custom codec + println!("\nStep 4: Serializing plan with AdapterPreservingCodec..."); + let codec = AdapterPreservingCodec; + let bytes = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&original_plan), + &codec, + &codec, + )?; + println!(" Serialized {} bytes", bytes.len()); + println!(" (DataSourceExec with adapter was wrapped as PhysicalExtensionNode)"); + + // Step 5: Deserialize with our custom codec + println!("\nStep 5: Deserializing plan with AdapterPreservingCodec..."); + let task_ctx = ctx.task_ctx(); + let restored_plan = + physical_plan_from_bytes_with_proto_converter(&bytes, &task_ctx, &codec, &codec)?; + + // Verify adapter is restored + let has_adapter_after = verify_adapter_in_plan(&restored_plan, "restored"); + println!(" Restored plan has adapter: {has_adapter_after}"); + + // Step 6: Execute and compare results + println!("\nStep 6: Executing plans and comparing results..."); + let original_results = + datafusion::physical_plan::collect(Arc::clone(&original_plan), task_ctx.clone()) + .await?; + let restored_results = + datafusion::physical_plan::collect(restored_plan, task_ctx).await?; + + #[rustfmt::skip] + let expected = [ + "+----+", + "| id |", + "+----+", + "| 6 |", + "| 7 |", + "| 8 |", + "| 9 |", + "| 10 |", + "+----+", + ]; + + println!("\n Original plan results:"); + arrow::util::pretty::print_batches(&original_results)?; + assert_batches_eq!(expected, &original_results); + + println!("\n Restored plan results:"); + arrow::util::pretty::print_batches(&restored_results)?; + assert_batches_eq!(expected, &restored_results); + + println!("\n=== Example Complete! ==="); + println!("Key takeaways:"); + println!( + " 1. PhysicalExtensionCodec provides serialize_physical_plan/deserialize_physical_plan hooks" + ); + println!(" 2. Custom metadata can be wrapped as PhysicalExtensionNode"); + println!(" 3. Nested serialization (protobuf + JSON) works seamlessly"); + println!( + " 4. Both plans produce identical results despite serialization round-trip" + ); + println!(" 5. Adapters are fully preserved through the serialization round-trip"); + + Ok(()) +} + +// ============================================================================ +// MetadataAdapter - A simple custom adapter with a tag +// ============================================================================ + +/// A custom PhysicalExprAdapter that wraps another adapter. +/// The tag metadata is stored in the factory, not the adapter itself. +#[derive(Debug)] +struct MetadataAdapter { + inner: Arc, +} + +impl PhysicalExprAdapter for MetadataAdapter { + fn rewrite(&self, expr: Arc) -> Result> { + // Simply delegate to inner adapter + self.inner.rewrite(expr) + } +} + +// ============================================================================ +// MetadataAdapterFactory - Factory for creating MetadataAdapter instances +// ============================================================================ + +/// Factory for creating MetadataAdapter instances. +/// The tag is stored in the factory and extracted via Debug formatting in `extract_adapter_tag`. +#[derive(Debug)] +struct MetadataAdapterFactory { + // Note: This field is read via Debug formatting in `extract_adapter_tag`. + // Rust's dead code analysis doesn't recognize Debug-based field access. + // In PR #19234, this field is used by `with_partition_values`, but that method + // doesn't exist in upstream DataFusion's PhysicalExprAdapter trait. + #[expect(dead_code)] + tag: String, +} + +impl MetadataAdapterFactory { + fn new(tag: impl Into) -> Self { + Self { tag: tag.into() } + } +} + +impl PhysicalExprAdapterFactory for MetadataAdapterFactory { + fn create( + &self, + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + ) -> Result> { + let inner = DefaultPhysicalExprAdapterFactory + .create(logical_file_schema, physical_file_schema)?; + Ok(Arc::new(MetadataAdapter { inner })) + } +} + +// ============================================================================ +// AdapterPreservingCodec - Custom codec that preserves adapters +// ============================================================================ + +/// Extension payload structure for serializing adapter info +#[derive(Serialize, Deserialize)] +struct ExtensionPayload { + /// Marker to identify this is our custom extension + marker: String, + /// JSON-serialized adapter metadata + adapter_metadata: AdapterMetadata, +} + +/// Metadata about the adapter to recreate it during deserialization +#[derive(Serialize, Deserialize)] +struct AdapterMetadata { + /// The adapter tag (e.g., "v1") + tag: String, +} + +const EXTENSION_MARKER: &str = "adapter_preserving_extension_v1"; + +/// A codec that intercepts serialization to preserve adapter information. +#[derive(Debug)] +struct AdapterPreservingCodec; + +impl PhysicalExtensionCodec for AdapterPreservingCodec { + // Required method: decode custom extension nodes + fn try_decode( + &self, + buf: &[u8], + inputs: &[Arc], + _ctx: &TaskContext, + ) -> Result> { + // Try to parse as our extension payload + if let Ok(payload) = serde_json::from_slice::(buf) + && payload.marker == EXTENSION_MARKER + { + if inputs.len() != 1 { + return Err(datafusion::error::DataFusionError::Plan(format!( + "Extension node expected exactly 1 child, got {}", + inputs.len() + ))); + } + let inner_plan = inputs[0].clone(); + + // Recreate the adapter factory + let adapter_factory = create_adapter_factory(&payload.adapter_metadata.tag); + + // Inject adapter into the plan + return inject_adapter_into_plan(inner_plan, adapter_factory); + } + + not_impl_err!("Unknown extension type") + } + + // Required method: encode custom execution plans + fn try_encode( + &self, + _node: Arc, + _buf: &mut Vec, + ) -> Result<()> { + // We don't need this for the example - we use serialize_physical_plan instead + not_impl_err!( + "try_encode not used - adapter wrapping happens in serialize_physical_plan" + ) + } +} + +impl PhysicalProtoConverterExtension for AdapterPreservingCodec { + fn execution_plan_to_proto( + &self, + plan: &Arc, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + // Check if this is a DataSourceExec with adapter + if let Some(exec) = plan.as_any().downcast_ref::() + && let Some(config) = + exec.data_source().as_any().downcast_ref::() + && let Some(adapter_factory) = &config.expr_adapter_factory + && let Some(tag) = extract_adapter_tag(adapter_factory.as_ref()) + { + // Try to extract our MetadataAdapterFactory's tag + println!(" [Serialize] Found DataSourceExec with adapter tag: {tag}"); + + // 1. Create adapter metadata + let adapter_metadata = AdapterMetadata { tag }; + + // 2. Serialize the inner plan to protobuf + // Note that this will drop the custom adapter since the default serialization cannot handle it + let inner_proto = PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(plan), + extension_codec, + self, + )?; + + // 3. Create extension payload to wrap the plan + // so that the custom adapter gets re-attached during deserialization + // The choice of JSON is arbitrary; other formats could be used. + let payload = ExtensionPayload { + marker: EXTENSION_MARKER.to_string(), + adapter_metadata, + }; + let payload_bytes = serde_json::to_vec(&payload).map_err(|e| { + datafusion::error::DataFusionError::Plan(format!( + "Failed to serialize payload: {e}" + )) + })?; + + // 4. Return as PhysicalExtensionNode with child plan in inputs + return Ok(PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Extension( + PhysicalExtensionNode { + node: payload_bytes, + inputs: vec![inner_proto], + }, + )), + }); + } + + // No adapter found, not a DataSourceExec, etc. - use default serialization + PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(plan), + extension_codec, + self, + ) + } + + // Interception point: override deserialization to unwrap adapters + fn proto_to_execution_plan( + &self, + ctx: &TaskContext, + extension_codec: &dyn PhysicalExtensionCodec, + proto: &PhysicalPlanNode, + ) -> Result> { + // Check if this is our custom extension wrapper + if let Some(PhysicalPlanType::Extension(extension)) = &proto.physical_plan_type + && let Ok(payload) = + serde_json::from_slice::(&extension.node) + && payload.marker == EXTENSION_MARKER + { + println!( + " [Deserialize] Found adapter extension with tag: {}", + payload.adapter_metadata.tag + ); + + // Get the inner plan proto from inputs field + if extension.inputs.is_empty() { + return Err(datafusion::error::DataFusionError::Plan( + "Extension node missing child plan in inputs".to_string(), + )); + } + let inner_proto = &extension.inputs[0]; + + // Deserialize the inner plan + let inner_plan = inner_proto.try_into_physical_plan_with_converter( + ctx, + extension_codec, + self, + )?; + + // Recreate the adapter factory + let adapter_factory = create_adapter_factory(&payload.adapter_metadata.tag); + + // Inject adapter into the plan + return inject_adapter_into_plan(inner_plan, adapter_factory); + } + + // Not our extension - use default deserialization + proto.try_into_physical_plan_with_converter(ctx, extension_codec, self) + } + + fn proto_to_physical_expr( + &self, + proto: &PhysicalExprNode, + ctx: &TaskContext, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + parse_physical_expr_with_converter(proto, ctx, input_schema, codec, self) + } + + fn physical_expr_to_proto( + &self, + expr: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result { + serialize_physical_expr_with_converter(expr, codec, self) + } +} + +// ============================================================================ +// Helper functions +// ============================================================================ + +/// Write a RecordBatch to Parquet in the object store +async fn write_parquet( + store: &dyn ObjectStore, + path: &Path, + batch: &arrow::record_batch::RecordBatch, +) -> Result<()> { + let mut buf = vec![]; + let mut writer = ArrowWriter::try_new(&mut buf, batch.schema(), None)?; + writer.write(batch)?; + writer.close()?; + + let payload = PutPayload::from_bytes(buf.into()); + store.put(path, payload).await?; + Ok(()) +} + +/// Extract the tag from a MetadataAdapterFactory. +/// +/// Note: Since `PhysicalExprAdapterFactory` doesn't provide `as_any()` for downcasting, +/// we parse the Debug output. In a production system, you might add a dedicated trait +/// method for metadata extraction. +fn extract_adapter_tag(factory: &dyn PhysicalExprAdapterFactory) -> Option { + let debug_str = format!("{factory:?}"); + if debug_str.contains("MetadataAdapterFactory") { + // Extract tag from debug output: MetadataAdapterFactory { tag: "v1" } + if let Some(start) = debug_str.find("tag: \"") { + let after_tag = &debug_str[start + 6..]; + if let Some(end) = after_tag.find('"') { + return Some(after_tag[..end].to_string()); + } + } + } + None +} + +/// Create an adapter factory from a tag +fn create_adapter_factory(tag: &str) -> Arc { + Arc::new(MetadataAdapterFactory::new(tag)) +} + +/// Inject an adapter into a plan (assumes plan is a DataSourceExec with FileScanConfig) +fn inject_adapter_into_plan( + plan: Arc, + adapter_factory: Arc, +) -> Result> { + if let Some(exec) = plan.as_any().downcast_ref::() + && let Some(config) = exec.data_source().as_any().downcast_ref::() + { + let new_config = FileScanConfigBuilder::from(config.clone()) + .with_expr_adapter(Some(adapter_factory)) + .build(); + return Ok(DataSourceExec::from_data_source(new_config)); + } + // If not a DataSourceExec with FileScanConfig, return as-is + Ok(plan) +} + +/// Helper to verify if a plan has an adapter (for testing/validation) +fn verify_adapter_in_plan(plan: &Arc, label: &str) -> bool { + // Walk the plan tree to find DataSourceExec with adapter + fn check_plan(plan: &dyn ExecutionPlan) -> bool { + if let Some(exec) = plan.as_any().downcast_ref::() + && let Some(config) = + exec.data_source().as_any().downcast_ref::() + && config.expr_adapter_factory.is_some() + { + return true; + } + // Check children + for child in plan.children() { + if check_plan(child.as_ref()) { + return true; + } + } + false + } + + let has_adapter = check_plan(plan.as_ref()); + println!(" [Verify] {label} plan adapter check: {has_adapter}"); + has_adapter +} diff --git a/datafusion-examples/examples/custom_data_source/csv_json_opener.rs b/datafusion-examples/examples/custom_data_source/csv_json_opener.rs index 347f1a0464716..fc1130313e00c 100644 --- a/datafusion-examples/examples/custom_data_source/csv_json_opener.rs +++ b/datafusion-examples/examples/custom_data_source/csv_json_opener.rs @@ -125,6 +125,7 @@ async fn json_opener() -> Result<()> { projected, FileCompressionType::UNCOMPRESSED, Arc::new(object_store), + true, ); let scan_config = FileScanConfigBuilder::new( diff --git a/datafusion-examples/examples/custom_data_source/custom_file_casts.rs b/datafusion-examples/examples/custom_data_source/custom_file_casts.rs index 895b6f52b6e1e..36cc936332065 100644 --- a/datafusion-examples/examples/custom_data_source/custom_file_casts.rs +++ b/datafusion-examples/examples/custom_data_source/custom_file_casts.rs @@ -156,14 +156,14 @@ impl PhysicalExprAdapterFactory for CustomCastPhysicalExprAdapterFactory { &self, logical_file_schema: SchemaRef, physical_file_schema: SchemaRef, - ) -> Arc { + ) -> Result> { let inner = self .inner - .create(logical_file_schema, Arc::clone(&physical_file_schema)); - Arc::new(CustomCastsPhysicalExprAdapter { + .create(logical_file_schema, Arc::clone(&physical_file_schema))?; + Ok(Arc::new(CustomCastsPhysicalExprAdapter { physical_file_schema, inner, - }) + })) } } diff --git a/datafusion-examples/examples/custom_data_source/default_column_values.rs b/datafusion-examples/examples/custom_data_source/default_column_values.rs index 81d74cfbecabd..d7171542d5186 100644 --- a/datafusion-examples/examples/custom_data_source/default_column_values.rs +++ b/datafusion-examples/examples/custom_data_source/default_column_values.rs @@ -278,18 +278,18 @@ impl PhysicalExprAdapterFactory for DefaultValuePhysicalExprAdapterFactory { &self, logical_file_schema: SchemaRef, physical_file_schema: SchemaRef, - ) -> Arc { + ) -> Result> { let default_factory = DefaultPhysicalExprAdapterFactory; let default_adapter = default_factory.create( Arc::clone(&logical_file_schema), Arc::clone(&physical_file_schema), - ); + )?; - Arc::new(DefaultValuePhysicalExprAdapter { + Ok(Arc::new(DefaultValuePhysicalExprAdapter { logical_file_schema, physical_file_schema, default_adapter, - }) + })) } } diff --git a/datafusion-examples/examples/custom_data_source/file_stream_provider.rs b/datafusion-examples/examples/custom_data_source/file_stream_provider.rs index 936da0a33d47b..5b43072d43f80 100644 --- a/datafusion-examples/examples/custom_data_source/file_stream_provider.rs +++ b/datafusion-examples/examples/custom_data_source/file_stream_provider.rs @@ -22,7 +22,7 @@ /// /// On non-Windows systems, this example creates a named pipe (FIFO) and /// writes rows into it asynchronously while DataFusion reads the data -/// through a `FileStreamProvider`. +/// through a `FileStreamProvider`. /// /// This illustrates how to integrate dynamically updated data sources /// with DataFusion without needing to reload the entire dataset each time. @@ -126,7 +126,6 @@ mod non_windows { let broken_pipe_timeout = Duration::from_secs(10); let sa = file_path; // Spawn a new thread to write to the FIFO file - #[allow(clippy::disallowed_methods)] // spawn allowed only in tests tasks.spawn_blocking(move || { let file = OpenOptions::new().write(true).open(sa).unwrap(); // Reference time to use when deciding to fail the test diff --git a/datafusion-examples/examples/custom_data_source/main.rs b/datafusion-examples/examples/custom_data_source/main.rs index 5846626d81380..0d21a62591129 100644 --- a/datafusion-examples/examples/custom_data_source/main.rs +++ b/datafusion-examples/examples/custom_data_source/main.rs @@ -26,14 +26,32 @@ //! //! Each subcommand runs a corresponding example: //! - `all` — run all examples included in this module -//! - `csv_json_opener` — use low level FileOpener APIs to read CSV/JSON into Arrow RecordBatches -//! - `csv_sql_streaming` — build and run a streaming query plan from a SQL statement against a local CSV file -//! - `custom_datasource` — run queries against a custom datasource (TableProvider) -//! - `custom_file_casts` — implement custom casting rules to adapt file schemas -//! - `custom_file_format` — write data to a custom file format -//! - `default_column_values` — implement custom default value handling for missing columns using field metadata and PhysicalExprAdapter -//! - `file_stream_provider` — run a query on FileStreamProvider which implements StreamProvider for reading and writing to arbitrary stream sources/sinks +//! +//! - `adapter_serialization` +//! (file: adapter_serialization.rs, desc: Preserve custom PhysicalExprAdapter information during plan serialization using PhysicalExtensionCodec interception) +//! +//! - `csv_json_opener` +//! (file: csv_json_opener.rs, desc: Use low-level FileOpener APIs for CSV/JSON) +//! +//! - `csv_sql_streaming` +//! (file: csv_sql_streaming.rs, desc: Run a streaming SQL query against CSV data) +//! +//! - `custom_datasource` +//! (file: custom_datasource.rs, desc: Query a custom TableProvider) +//! +//! - `custom_file_casts` +//! (file: custom_file_casts.rs, desc: Implement custom casting rules) +//! +//! - `custom_file_format` +//! (file: custom_file_format.rs, desc: Write to a custom file format) +//! +//! - `default_column_values` +//! (file: default_column_values.rs, desc: Custom default values using metadata) +//! +//! - `file_stream_provider` +//! (file: file_stream_provider.rs, desc: Read/write via FileStreamProvider for streams) +mod adapter_serialization; mod csv_json_opener; mod csv_sql_streaming; mod custom_datasource; @@ -50,6 +68,7 @@ use strum_macros::{Display, EnumIter, EnumString, VariantNames}; #[strum(serialize_all = "snake_case")] enum ExampleKind { All, + AdapterSerialization, CsvJsonOpener, CsvSqlStreaming, CustomDatasource, @@ -74,6 +93,9 @@ impl ExampleKind { Box::pin(example.run()).await?; } } + ExampleKind::AdapterSerialization => { + adapter_serialization::adapter_serialization().await? + } ExampleKind::CsvJsonOpener => csv_json_opener::csv_json_opener().await?, ExampleKind::CsvSqlStreaming => { csv_sql_streaming::csv_sql_streaming().await? diff --git a/datafusion-examples/examples/data_io/catalog.rs b/datafusion-examples/examples/data_io/catalog.rs index d2ddff82e32db..9781a93374ea6 100644 --- a/datafusion-examples/examples/data_io/catalog.rs +++ b/datafusion-examples/examples/data_io/catalog.rs @@ -140,7 +140,6 @@ struct DirSchemaOpts<'a> { /// Schema where every file with extension `ext` in a given `dir` is a table. #[derive(Debug)] struct DirSchema { - ext: String, tables: RwLock>>, } @@ -173,14 +172,8 @@ impl DirSchema { } Ok(Arc::new(Self { tables: RwLock::new(tables), - ext: ext.to_string(), })) } - - #[allow(unused)] - fn name(&self) -> &str { - &self.ext - } } #[async_trait] @@ -217,7 +210,6 @@ impl SchemaProvider for DirSchema { /// If supported by the implementation, removes an existing table from this schema and returns it. /// If no table of that name exists, returns Ok(None). - #[allow(unused_variables)] fn deregister_table(&self, name: &str) -> Result>> { let mut tables = self.tables.write().unwrap(); log::info!("dropping table {name}"); diff --git a/datafusion-examples/examples/data_io/json_shredding.rs b/datafusion-examples/examples/data_io/json_shredding.rs index d2ffacc9464c2..77dba5a98ac6f 100644 --- a/datafusion-examples/examples/data_io/json_shredding.rs +++ b/datafusion-examples/examples/data_io/json_shredding.rs @@ -275,17 +275,17 @@ impl PhysicalExprAdapterFactory for ShreddedJsonRewriterFactory { &self, logical_file_schema: SchemaRef, physical_file_schema: SchemaRef, - ) -> Arc { + ) -> Result> { let default_factory = DefaultPhysicalExprAdapterFactory; let default_adapter = default_factory.create( Arc::clone(&logical_file_schema), Arc::clone(&physical_file_schema), - ); + )?; - Arc::new(ShreddedJsonRewriter { + Ok(Arc::new(ShreddedJsonRewriter { physical_file_schema, default_adapter, - }) + })) } } diff --git a/datafusion-examples/examples/data_io/main.rs b/datafusion-examples/examples/data_io/main.rs index 0b2bd03f7ea9e..0039585d15b60 100644 --- a/datafusion-examples/examples/data_io/main.rs +++ b/datafusion-examples/examples/data_io/main.rs @@ -26,16 +26,36 @@ //! //! Each subcommand runs a corresponding example: //! - `all` — run all examples included in this module -//! - `catalog` — register the table into a custom catalog -//! - `json_shredding` — shows how to implement custom filter rewriting for JSON shredding -//! - `parquet_adv_idx` — create a detailed secondary index that covers the contents of several parquet files -//! - `parquet_emb_idx` — store a custom index inside a Parquet file and use it to speed up queries -//! - `parquet_enc_with_kms` — read and write encrypted Parquet files using an encryption factory -//! - `parquet_enc` — read and write encrypted Parquet files using DataFusion -//! - `parquet_exec_visitor` — extract statistics by visiting an ExecutionPlan after execution -//! - `parquet_idx` — create an secondary index over several parquet files and use it to speed up queries -//! - `query_http_csv` — configure `object_store` and run a query against files via HTTP -//! - `remote_catalog` — interfacing with a remote catalog (e.g. over a network) +//! +//! - `catalog` +//! (file: catalog.rs, desc: Register tables into a custom catalog) +//! +//! - `json_shredding` +//! (file: json_shredding.rs, desc: Implement filter rewriting for JSON shredding) +//! +//! - `parquet_adv_idx` +//! (file: parquet_advanced_index.rs, desc: Create a secondary index across multiple parquet files) +//! +//! - `parquet_emb_idx` +//! (file: parquet_embedded_index.rs, desc: Store a custom index inside Parquet files) +//! +//! - `parquet_enc` +//! (file: parquet_encrypted.rs, desc: Read & write encrypted Parquet files) +//! +//! - `parquet_enc_with_kms` +//! (file: parquet_encrypted_with_kms.rs, desc: Encrypted Parquet I/O using a KMS-backed factory) +//! +//! - `parquet_exec_visitor` +//! (file: parquet_exec_visitor.rs, desc: Extract statistics by visiting an ExecutionPlan) +//! +//! - `parquet_idx` +//! (file: parquet_index.rs, desc: Create a secondary index) +//! +//! - `query_http_csv` +//! (file: query_http_csv.rs, desc: Query CSV files via HTTP) +//! +//! - `remote_catalog` +//! (file: remote_catalog.rs, desc: Interact with a remote catalog) mod catalog; mod json_shredding; diff --git a/datafusion-examples/examples/data_io/parquet_encrypted.rs b/datafusion-examples/examples/data_io/parquet_encrypted.rs index d3cc6a121f8ea..26361e9b52be0 100644 --- a/datafusion-examples/examples/data_io/parquet_encrypted.rs +++ b/datafusion-examples/examples/data_io/parquet_encrypted.rs @@ -55,7 +55,7 @@ pub async fn parquet_encrypted() -> datafusion::common::Result<()> { // Create a temporary file location for the encrypted parquet file let tmp_source = TempDir::new()?; - let tempfile = tmp_source.path().join("cars_encrypted"); + let tempfile = tmp_source.path().join("cars_encrypted.parquet"); // Write encrypted parquet let mut options = TableParquetOptions::default(); diff --git a/datafusion-examples/examples/dataframe/main.rs b/datafusion-examples/examples/dataframe/main.rs index 8c294e2f4e9e7..25b5377d38239 100644 --- a/datafusion-examples/examples/dataframe/main.rs +++ b/datafusion-examples/examples/dataframe/main.rs @@ -26,8 +26,15 @@ //! //! Each subcommand runs a corresponding example: //! - `all` — run all examples included in this module -//! - `dataframe` — run a query using a DataFrame API against parquet files, csv files, and in-memory data, including multiple subqueries -//! - `deserialize_to_struct` — convert query results (Arrow ArrayRefs) into Rust structs +//! +//! - `cache_factory` +//! (file: cache_factory.rs, desc: Custom lazy caching for DataFrames using `CacheFactory`) +// +//! - `dataframe` +//! (file: dataframe.rs, desc: Query DataFrames from various sources and write output) +//! +//! - `deserialize_to_struct` +//! (file: deserialize_to_struct.rs, desc: Convert Arrow arrays into Rust structs) mod cache_factory; mod dataframe; diff --git a/datafusion-examples/examples/execution_monitoring/main.rs b/datafusion-examples/examples/execution_monitoring/main.rs index 07de57f6b80e2..8f80c36929ca2 100644 --- a/datafusion-examples/examples/execution_monitoring/main.rs +++ b/datafusion-examples/examples/execution_monitoring/main.rs @@ -26,9 +26,15 @@ //! //! Each subcommand runs a corresponding example: //! - `all` — run all examples included in this module -//! - `mem_pool_exec_plan` — shows how to implement memory-aware ExecutionPlan with memory reservation and spilling -//! - `mem_pool_tracking` — demonstrates TrackConsumersPool for memory tracking and debugging with enhanced error messages -//! - `tracing` — demonstrates the tracing injection feature for the DataFusion runtime +//! +//! - `mem_pool_exec_plan` +//! (file: memory_pool_execution_plan.rs, desc: Memory-aware ExecutionPlan with spilling) +//! +//! - `mem_pool_tracking` +//! (file: memory_pool_tracking.rs, desc: Demonstrates memory tracking) +//! +//! - `tracing` +//! (file: tracing.rs, desc: Demonstrates tracing integration) mod memory_pool_execution_plan; mod memory_pool_tracking; diff --git a/datafusion-examples/examples/execution_monitoring/memory_pool_execution_plan.rs b/datafusion-examples/examples/execution_monitoring/memory_pool_execution_plan.rs index 48475acbb1542..e51ba46a33135 100644 --- a/datafusion-examples/examples/execution_monitoring/memory_pool_execution_plan.rs +++ b/datafusion-examples/examples/execution_monitoring/memory_pool_execution_plan.rs @@ -38,7 +38,7 @@ use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::logical_expr::LogicalPlanBuilder; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, Statistics, + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, }; use datafusion::prelude::*; use futures::stream::{StreamExt, TryStreamExt}; @@ -296,8 +296,4 @@ impl ExecutionPlan for BufferingExecutionPlan { }), ))) } - - fn statistics(&self) -> Result { - Ok(Statistics::new_unknown(&self.schema)) - } } diff --git a/datafusion-examples/examples/execution_monitoring/memory_pool_tracking.rs b/datafusion-examples/examples/execution_monitoring/memory_pool_tracking.rs index 8d6e5dd7e444d..af3031c690fa3 100644 --- a/datafusion-examples/examples/execution_monitoring/memory_pool_tracking.rs +++ b/datafusion-examples/examples/execution_monitoring/memory_pool_tracking.rs @@ -110,7 +110,8 @@ async fn automatic_usage_example() -> Result<()> { println!("✓ Expected memory limit error during data processing:"); println!("Error: {e}"); /* Example error message: - Error: Not enough memory to continue external sort. Consider increasing the memory limit, or decreasing sort_spill_reservation_bytes + Error: Not enough memory to continue external sort. Consider increasing the memory limit config: 'datafusion.runtime.memory_limit', + or decreasing the config: 'datafusion.execution.sort_spill_reservation_bytes'. caused by Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: ExternalSorterMerge[3]#112(can spill: false) consumed 10.0 MB, peak 10.0 MB, diff --git a/datafusion-examples/examples/external_dependency/main.rs b/datafusion-examples/examples/external_dependency/main.rs index 0a9a2cd2372d9..447e7d38bdd5b 100644 --- a/datafusion-examples/examples/external_dependency/main.rs +++ b/datafusion-examples/examples/external_dependency/main.rs @@ -26,8 +26,12 @@ //! //! Each subcommand runs a corresponding example: //! - `all` — run all examples included in this module -//! - `dataframe_to_s3` — run a query using a DataFrame against a parquet file from AWS S3 and writing back to AWS S3 -//! - `query_aws_s3` — configure `object_store` and run a query against files stored in AWS S3 +//! +//! - `dataframe_to_s3` +//! (file: dataframe_to_s3.rs, desc: Query DataFrames and write results to S3) +//! +//! - `query_aws_s3` +//! (file: query_aws_s3.rs, desc: Query S3-backed data using object_store) mod dataframe_to_s3; mod query_aws_s3; diff --git a/datafusion-examples/examples/flight/main.rs b/datafusion-examples/examples/flight/main.rs index 6f20f576d3a7b..426e806486f70 100644 --- a/datafusion-examples/examples/flight/main.rs +++ b/datafusion-examples/examples/flight/main.rs @@ -29,9 +29,15 @@ //! Note: The Flight server must be started in a separate process //! before running the `client` example. Therefore, running `all` will //! not produce a full server+client workflow automatically. -//! - `client` — run DataFusion as a standalone process and execute SQL queries from a client using the Flight protocol -//! - `server` — run DataFusion as a standalone process and execute SQL queries from a client using the Flight protocol -//! - `sql_server` — run DataFusion as a standalone process and execute SQL queries from JDBC clients +//! +//! - `client` +//! (file: client.rs, desc: Execute SQL queries via Arrow Flight protocol) +//! +//! - `server` +//! (file: server.rs, desc: Run DataFusion server accepting FlightSQL/JDBC queries) +//! +//! - `sql_server` +//! (file: sql_server.rs, desc: Standalone SQL server for JDBC clients) mod client; mod server; diff --git a/datafusion-examples/examples/flight/sql_server.rs b/datafusion-examples/examples/flight/sql_server.rs index 78b3aaa05a188..e55aaa7250ea7 100644 --- a/datafusion-examples/examples/flight/sql_server.rs +++ b/datafusion-examples/examples/flight/sql_server.rs @@ -120,7 +120,6 @@ impl FlightSqlServiceImpl { Ok(uuid) } - #[allow(clippy::result_large_err)] fn get_ctx(&self, req: &Request) -> Result, Status> { // get the token from the authorization header on Request let auth = req @@ -146,7 +145,6 @@ impl FlightSqlServiceImpl { } } - #[allow(clippy::result_large_err)] fn get_plan(&self, handle: &str) -> Result { if let Some(plan) = self.statements.get(handle) { Ok(plan.clone()) @@ -155,7 +153,6 @@ impl FlightSqlServiceImpl { } } - #[allow(clippy::result_large_err)] fn get_result(&self, handle: &str) -> Result, Status> { if let Some(result) = self.results.get(handle) { Ok(result.clone()) @@ -203,13 +200,11 @@ impl FlightSqlServiceImpl { .unwrap() } - #[allow(clippy::result_large_err)] fn remove_plan(&self, handle: &str) -> Result<(), Status> { self.statements.remove(&handle.to_string()); Ok(()) } - #[allow(clippy::result_large_err)] fn remove_result(&self, handle: &str) -> Result<(), Status> { self.results.remove(&handle.to_string()); Ok(()) diff --git a/datafusion-examples/examples/proto/expression_deduplication.rs b/datafusion-examples/examples/proto/expression_deduplication.rs new file mode 100644 index 0000000000000..0dec807f8043a --- /dev/null +++ b/datafusion-examples/examples/proto/expression_deduplication.rs @@ -0,0 +1,275 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. +//! +//! This example demonstrates how to use the `PhysicalExtensionCodec` trait's +//! interception methods to implement expression deduplication during deserialization. +//! +//! This pattern is inspired by PR #18192, which introduces expression caching +//! to reduce memory usage when deserializing plans with duplicate expressions. +//! +//! The key insight is that identical expressions serialize to identical protobuf bytes. +//! By caching deserialized expressions keyed by their protobuf bytes, we can: +//! 1. Return the same Arc for duplicate expressions +//! 2. Reduce memory allocation during deserialization +//! 3. Enable downstream optimizations that rely on Arc pointer equality +//! +//! This demonstrates the decorator pattern enabled by the `PhysicalExtensionCodec` trait, +//! where all expression serialization/deserialization routes through the codec methods. + +use std::collections::HashMap; +use std::fmt::Debug; +use std::sync::{Arc, RwLock}; + +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::common::Result; +use datafusion::execution::TaskContext; +use datafusion::logical_expr::Operator; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::expressions::{BinaryExpr, col}; +use datafusion::physical_plan::filter::FilterExec; +use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; +use datafusion::prelude::SessionContext; +use datafusion_proto::physical_plan::from_proto::parse_physical_expr_with_converter; +use datafusion_proto::physical_plan::to_proto::serialize_physical_expr_with_converter; +use datafusion_proto::physical_plan::{ + DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, + PhysicalProtoConverterExtension, +}; +use datafusion_proto::protobuf::{PhysicalExprNode, PhysicalPlanNode}; +use prost::Message; + +/// Example showing how to implement expression deduplication using the codec decorator pattern. +/// +/// This demonstrates: +/// 1. Creating a CachingCodec that caches expressions by their protobuf bytes +/// 2. Intercepting deserialization to return cached Arcs for duplicate expressions +/// 3. Verifying that duplicate expressions share the same Arc after deserialization +/// +/// Deduplication is keyed by the protobuf bytes representing the expression, +/// in reality deduplication could be done based on e.g. the pointer address of the +/// serialized expression in memory, but this is simpler to demonstrate. +/// +/// In this case our expression is trivial and just for demonstration purposes. +/// In real scenarios, expressions can be much more complex, e.g. a large InList +/// expression could be megabytes in size, so deduplication can save significant memory +/// in addition to more correctly representing the original plan structure. +pub async fn expression_deduplication() -> Result<()> { + println!("=== Expression Deduplication Example ===\n"); + + // Create a schema for our test expressions + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, false)])); + + // Step 1: Create expressions with duplicates + println!("Step 1: Creating expressions with duplicates..."); + + // Create expression: col("a") + let a = col("a", &schema)?; + + // Create a clone to show duplicates + let a_clone = Arc::clone(&a); + + // Combine: a OR a_clone + let combined_expr = + Arc::new(BinaryExpr::new(a, Operator::Or, a_clone)) as Arc; + println!(" Created expression: a OR a with duplicates"); + println!(" Note: a appears twice in the expression tree\n"); + // Step 2: Create a filter plan with this expression + println!("Step 2: Creating physical plan with the expression..."); + + let input = Arc::new(PlaceholderRowExec::new(Arc::clone(&schema))); + let filter_plan: Arc = + Arc::new(FilterExec::try_new(combined_expr, input)?); + + println!(" Created FilterExec with duplicate sub-expressions\n"); + + // Step 3: Serialize with the caching codec + println!("Step 3: Serializing plan..."); + + let extension_codec = DefaultPhysicalExtensionCodec {}; + let caching_converter = CachingCodec::new(); + let proto = + caching_converter.execution_plan_to_proto(&filter_plan, &extension_codec)?; + + // Serialize to bytes + let mut bytes = Vec::new(); + proto.encode(&mut bytes).unwrap(); + println!(" Serialized plan to {} bytes\n", bytes.len()); + + // Step 4: Deserialize with the caching codec + println!("Step 4: Deserializing plan with CachingCodec..."); + + let ctx = SessionContext::new(); + let deserialized_plan = proto.try_into_physical_plan_with_converter( + &ctx.task_ctx(), + &extension_codec, + &caching_converter, + )?; + + // Step 5: check that we deduplicated expressions + println!("Step 5: Checking for deduplicated expressions..."); + let Some(filter_exec) = deserialized_plan.as_any().downcast_ref::() + else { + panic!("Deserialized plan is not a FilterExec"); + }; + let predicate = Arc::clone(filter_exec.predicate()); + let binary_expr = predicate + .as_any() + .downcast_ref::() + .expect("Predicate is not a BinaryExpr"); + let left = &binary_expr.left(); + let right = &binary_expr.right(); + // Check if left and right point to the same Arc + let deduplicated = Arc::ptr_eq(left, right); + if deduplicated { + println!(" Success: Duplicate expressions were deduplicated!"); + println!( + " Cache Stats: hits={}, misses={}", + caching_converter.stats.read().unwrap().cache_hits, + caching_converter.stats.read().unwrap().cache_misses, + ); + } else { + println!(" Failure: Duplicate expressions were NOT deduplicated."); + } + + Ok(()) +} + +// ============================================================================ +// CachingCodec - Implements expression deduplication +// ============================================================================ + +/// Statistics for cache performance monitoring +#[derive(Debug, Default)] +struct CacheStats { + cache_hits: usize, + cache_misses: usize, +} + +/// A codec that caches deserialized expressions to enable deduplication. +/// +/// When deserializing, if we've already seen the same protobuf bytes, +/// we return the cached Arc instead of creating a new allocation. +#[derive(Debug, Default)] +struct CachingCodec { + /// Cache mapping protobuf bytes -> deserialized expression + expr_cache: RwLock, Arc>>, + /// Statistics for demonstration + stats: RwLock, +} + +impl CachingCodec { + fn new() -> Self { + Self::default() + } +} + +impl PhysicalExtensionCodec for CachingCodec { + // Required: decode custom extension nodes + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[Arc], + _ctx: &TaskContext, + ) -> Result> { + datafusion::common::not_impl_err!("No custom extension nodes") + } + + // Required: encode custom execution plans + fn try_encode( + &self, + _node: Arc, + _buf: &mut Vec, + ) -> Result<()> { + datafusion::common::not_impl_err!("No custom extension nodes") + } +} + +impl PhysicalProtoConverterExtension for CachingCodec { + fn proto_to_execution_plan( + &self, + ctx: &TaskContext, + extension_codec: &dyn PhysicalExtensionCodec, + proto: &PhysicalPlanNode, + ) -> Result> { + proto.try_into_physical_plan_with_converter(ctx, extension_codec, self) + } + + fn execution_plan_to_proto( + &self, + plan: &Arc, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(plan), + extension_codec, + self, + ) + } + + // CACHING IMPLEMENTATION: Intercept expression deserialization + fn proto_to_physical_expr( + &self, + proto: &PhysicalExprNode, + ctx: &TaskContext, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + // Create cache key from protobuf bytes + let mut key = Vec::new(); + proto.encode(&mut key).map_err(|e| { + datafusion::error::DataFusionError::Internal(format!( + "Failed to encode proto for cache key: {e}" + )) + })?; + + // Check cache first + { + let cache = self.expr_cache.read().unwrap(); + if let Some(cached) = cache.get(&key) { + // Cache hit! Update stats and return cached Arc + let mut stats = self.stats.write().unwrap(); + stats.cache_hits += 1; + return Ok(Arc::clone(cached)); + } + } + + // Cache miss - deserialize and store + let expr = + parse_physical_expr_with_converter(proto, ctx, input_schema, codec, self)?; + + // Store in cache + { + let mut cache = self.expr_cache.write().unwrap(); + cache.insert(key, Arc::clone(&expr)); + let mut stats = self.stats.write().unwrap(); + stats.cache_misses += 1; + } + + Ok(expr) + } + + fn physical_expr_to_proto( + &self, + expr: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result { + serialize_physical_expr_with_converter(expr, codec, self) + } +} diff --git a/datafusion-examples/examples/proto/main.rs b/datafusion-examples/examples/proto/main.rs index f56078b31997d..3f525b5d46afa 100644 --- a/datafusion-examples/examples/proto/main.rs +++ b/datafusion-examples/examples/proto/main.rs @@ -21,14 +21,20 @@ //! //! ## Usage //! ```bash -//! cargo run --example proto -- [all|composed_extension_codec] +//! cargo run --example proto -- [all|composed_extension_codec|expression_deduplication] //! ``` //! //! Each subcommand runs a corresponding example: //! - `all` — run all examples included in this module -//! - `composed_extension_codec` — example of using multiple extension codecs for serialization / deserialization +//! +//! - `composed_extension_codec` +//! (file: composed_extension_codec.rs, desc: Use multiple extension codecs for serialization/deserialization) +//! +//! - `expression_deduplication` +//! (file: expression_deduplication.rs, desc: Example of expression caching/deduplication using the codec decorator pattern) mod composed_extension_codec; +mod expression_deduplication; use datafusion::error::{DataFusionError, Result}; use strum::{IntoEnumIterator, VariantNames}; @@ -39,6 +45,7 @@ use strum_macros::{Display, EnumIter, EnumString, VariantNames}; enum ExampleKind { All, ComposedExtensionCodec, + ExpressionDeduplication, } impl ExampleKind { @@ -59,6 +66,9 @@ impl ExampleKind { ExampleKind::ComposedExtensionCodec => { composed_extension_codec::composed_extension_codec().await? } + ExampleKind::ExpressionDeduplication => { + expression_deduplication::expression_deduplication().await? + } } Ok(()) } diff --git a/datafusion-examples/examples/query_planning/main.rs b/datafusion-examples/examples/query_planning/main.rs index ec21c3ea5a76a..d3f99aedceb3d 100644 --- a/datafusion-examples/examples/query_planning/main.rs +++ b/datafusion-examples/examples/query_planning/main.rs @@ -26,14 +26,30 @@ //! //! Each subcommand runs a corresponding example: //! - `all` — run all examples included in this module -//! - `analyzer_rule` — use a custom AnalyzerRule to change a query's semantics (row level access control) -//! - `expr_api` — create, execute, simplify, analyze and coerce `Expr`s -//! - `optimizer_rule` — use a custom OptimizerRule to replace certain predicates -//! - `parse_sql_expr` — parse SQL text into DataFusion `Expr` -//! - `plan_to_sql` — generate SQL from DataFusion `Expr` and `LogicalPlan` -//! - `planner_api` — APIs to manipulate logical and physical plans -//! - `pruning` — APIs to manipulate logical and physical plans -//! - `thread_pools` — demonstrate TrackConsumersPool for memory tracking and debugging with enhanced error messages and shows how to implement memory-aware ExecutionPlan with memory reservation and spilling +//! +//! - `analyzer_rule` +//! (file: analyzer_rule.rs, desc: Custom AnalyzerRule to change query semantics) +//! +//! - `expr_api` +//! (file: expr_api.rs, desc: Create, execute, analyze, and coerce Exprs) +//! +//! - `optimizer_rule` +//! (file: optimizer_rule.rs, desc: Replace predicates via a custom OptimizerRule) +//! +//! - `parse_sql_expr` +//! (file: parse_sql_expr.rs, desc: Parse SQL into DataFusion Expr) +//! +//! - `plan_to_sql` +//! (file: plan_to_sql.rs, desc: Generate SQL from expressions or plans) +//! +//! - `planner_api` +//! (file: planner_api.rs, desc: APIs for logical and physical plan manipulation) +//! +//! - `pruning` +//! (file: pruning.rs, desc: Use pruning to skip irrelevant files) +//! +//! - `thread_pools` +//! (file: thread_pools.rs, desc: Configure custom thread pools for DataFusion execution) mod analyzer_rule; mod expr_api; diff --git a/datafusion-examples/examples/relation_planner/main.rs b/datafusion-examples/examples/relation_planner/main.rs index 15079f644612d..babc0d3714f72 100644 --- a/datafusion-examples/examples/relation_planner/main.rs +++ b/datafusion-examples/examples/relation_planner/main.rs @@ -27,9 +27,15 @@ //! //! Each subcommand runs a corresponding example: //! - `all` — run all examples included in this module -//! - `match_recognize` — MATCH_RECOGNIZE pattern matching on event streams -//! - `pivot_unpivot` — PIVOT and UNPIVOT operations for reshaping data -//! - `table_sample` — TABLESAMPLE clause for sampling rows from tables +//! +//! - `match_recognize` +//! (file: match_recognize.rs, desc: Implement MATCH_RECOGNIZE pattern matching) +//! +//! - `pivot_unpivot` +//! (file: pivot_unpivot.rs, desc: Implement PIVOT / UNPIVOT) +//! +//! - `table_sample` +//! (file: table_sample.rs, desc: Implement TABLESAMPLE) //! //! ## Snapshot Testing //! diff --git a/datafusion-examples/examples/relation_planner/match_recognize.rs b/datafusion-examples/examples/relation_planner/match_recognize.rs index 60baf9bd61a62..c4b3d522efc17 100644 --- a/datafusion-examples/examples/relation_planner/match_recognize.rs +++ b/datafusion-examples/examples/relation_planner/match_recognize.rs @@ -362,7 +362,7 @@ impl RelationPlanner for MatchRecognizePlanner { .. } = relation else { - return Ok(RelationPlanning::Original(relation)); + return Ok(RelationPlanning::Original(Box::new(relation))); }; // Plan the input table @@ -401,6 +401,8 @@ impl RelationPlanner for MatchRecognizePlanner { node: Arc::new(node), }); - Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) } } diff --git a/datafusion-examples/examples/relation_planner/pivot_unpivot.rs b/datafusion-examples/examples/relation_planner/pivot_unpivot.rs index 86a6cb955500e..2e1696956bf62 100644 --- a/datafusion-examples/examples/relation_planner/pivot_unpivot.rs +++ b/datafusion-examples/examples/relation_planner/pivot_unpivot.rs @@ -339,7 +339,7 @@ impl RelationPlanner for PivotUnpivotPlanner { alias, ), - other => Ok(RelationPlanning::Original(other)), + other => Ok(RelationPlanning::Original(Box::new(other))), } } } @@ -459,7 +459,9 @@ fn plan_pivot( .aggregate(group_by_cols, pivot_exprs)? .build()?; - Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) } // ============================================================================ @@ -540,7 +542,9 @@ fn plan_unpivot( .build()?; } - Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) } // ============================================================================ diff --git a/datafusion-examples/examples/relation_planner/table_sample.rs b/datafusion-examples/examples/relation_planner/table_sample.rs index 362d35dcf4cac..657432ef31362 100644 --- a/datafusion-examples/examples/relation_planner/table_sample.rs +++ b/datafusion-examples/examples/relation_planner/table_sample.rs @@ -331,7 +331,7 @@ impl RelationPlanner for TableSamplePlanner { index_hints, } = relation else { - return Ok(RelationPlanning::Original(relation)); + return Ok(RelationPlanning::Original(Box::new(relation))); }; // Extract sample spec (handles both before/after alias positions) @@ -401,7 +401,9 @@ impl RelationPlanner for TableSamplePlanner { let fraction = bucket_num as f64 / total as f64; let plan = TableSamplePlanNode::new(input, fraction, seed).into_plan(); - return Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))); + return Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))); } // Handle quantity-based sampling @@ -422,7 +424,9 @@ impl RelationPlanner for TableSamplePlanner { let plan = LogicalPlanBuilder::from(input) .limit(0, Some(rows as usize))? .build()?; - Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) } // TABLESAMPLE (N PERCENT) - percentage sampling @@ -430,7 +434,9 @@ impl RelationPlanner for TableSamplePlanner { let percent: f64 = parse_literal::(&quantity_value_expr)?; let fraction = percent / 100.0; let plan = TableSamplePlanNode::new(input, fraction, seed).into_plan(); - Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) } // TABLESAMPLE (N) - fraction if <1.0, row limit if >=1.0 @@ -448,7 +454,9 @@ impl RelationPlanner for TableSamplePlanner { // Interpret as fraction TableSamplePlanNode::new(input, value, seed).into_plan() }; - Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) } } } diff --git a/datafusion-examples/examples/sql_ops/main.rs b/datafusion-examples/examples/sql_ops/main.rs index aaab7778be0e4..ce7be8fa2bada 100644 --- a/datafusion-examples/examples/sql_ops/main.rs +++ b/datafusion-examples/examples/sql_ops/main.rs @@ -26,10 +26,18 @@ //! //! Each subcommand runs a corresponding example: //! - `all` — run all examples included in this module -//! - `analysis` — analyse SQL queries with DataFusion structures -//! - `custom_sql_parser` — implementing a custom SQL parser to extend DataFusion -//! - `frontend` — create LogicalPlans (only) from sql strings -//! - `query` — query data using SQL (in memory RecordBatches, local Parquet files) +//! +//! - `analysis` +//! (file: analysis.rs, desc: Analyze SQL queries) +//! +//! - `custom_sql_parser` +//! (file: custom_sql_parser.rs, desc: Implement a custom SQL parser to extend DataFusion) +//! +//! - `frontend` +//! (file: frontend.rs, desc: Build LogicalPlans from SQL) +//! +//! - `query` +//! (file: query.rs, desc: Query data using SQL) mod analysis; mod custom_sql_parser; diff --git a/datafusion-examples/examples/udf/main.rs b/datafusion-examples/examples/udf/main.rs index aff20e7754296..e024e466ab07e 100644 --- a/datafusion-examples/examples/udf/main.rs +++ b/datafusion-examples/examples/udf/main.rs @@ -26,14 +26,30 @@ //! //! Each subcommand runs a corresponding example: //! - `all` — run all examples included in this module -//! - `adv_udaf` — user defined aggregate function example -//! - `adv_udf` — user defined scalar function example -//! - `adv_udwf` — user defined window function example -//! - `async_udf` — asynchronous user defined function example -//! - `udaf` — simple user defined aggregate function example -//! - `udf` — simple user defined scalar function example -//! - `udtf` — simple user defined table function example -//! - `udwf` — simple user defined window function example +//! +//! - `adv_udaf` +//! (file: advanced_udaf.rs, desc: Advanced User Defined Aggregate Function (UDAF)) +//! +//! - `adv_udf` +//! (file: advanced_udf.rs, desc: Advanced User Defined Scalar Function (UDF)) +//! +//! - `adv_udwf` +//! (file: advanced_udwf.rs, desc: Advanced User Defined Window Function (UDWF)) +//! +//! - `async_udf` +//! (file: async_udf.rs, desc: Asynchronous User Defined Scalar Function) +//! +//! - `udaf` +//! (file: simple_udaf.rs, desc: Simple UDAF example) +//! +//! - `udf` +//! (file: simple_udf.rs, desc: Simple UDF example) +//! +//! - `udtf` +//! (file: simple_udtf.rs, desc: Simple UDTF example) +//! +//! - `udwf` +//! (file: simple_udwf.rs, desc: Simple UDWF example) mod advanced_udaf; mod advanced_udf; diff --git a/datafusion-examples/src/bin/examples-docs.rs b/datafusion-examples/src/bin/examples-docs.rs new file mode 100644 index 0000000000000..7efcf4da15d20 --- /dev/null +++ b/datafusion-examples/src/bin/examples-docs.rs @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Generates Markdown documentation for DataFusion example groups. +//! +//! This binary scans `datafusion-examples/examples`, extracts structured +//! documentation from each group's `main.rs` file, and renders a README-style +//! Markdown document. +//! +//! By default, documentation is generated for all example groups. If a group +//! name is provided as the first CLI argument, only that group is rendered. +//! +//! ## Usage +//! +//! ```bash +//! # Generate docs for all example groups +//! cargo run --bin examples-docs +//! +//! # Generate docs for a single group +//! cargo run --bin examples-docs -- dataframe +//! ``` + +use datafusion_examples::utils::example_metadata::{ + RepoLayout, generate_examples_readme, +}; + +fn main() -> Result<(), Box> { + let layout = RepoLayout::detect()?; + let group = std::env::args().nth(1); + let markdown = generate_examples_readme(&layout, group.as_deref())?; + print!("{markdown}"); + Ok(()) +} diff --git a/datafusion-examples/src/utils/csv_to_parquet.rs b/datafusion-examples/src/utils/csv_to_parquet.rs index 16541b13ae9a9..1fbf2930e9043 100644 --- a/datafusion-examples/src/utils/csv_to_parquet.rs +++ b/datafusion-examples/src/utils/csv_to_parquet.rs @@ -18,9 +18,8 @@ use std::path::{Path, PathBuf}; use datafusion::dataframe::DataFrameWriteOptions; -use datafusion::error::Result; +use datafusion::error::{DataFusionError, Result}; use datafusion::prelude::{CsvReadOptions, SessionContext}; -use datafusion_common::DataFusionError; use tempfile::TempDir; use tokio::fs::create_dir_all; diff --git a/datafusion-examples/src/utils/datasets/mod.rs b/datafusion-examples/src/utils/datasets/mod.rs index 47f946f7d89ee..1857e6af9b559 100644 --- a/datafusion-examples/src/utils/datasets/mod.rs +++ b/datafusion-examples/src/utils/datasets/mod.rs @@ -18,8 +18,7 @@ use std::path::PathBuf; use arrow_schema::SchemaRef; -use datafusion::error::Result; -use datafusion_common::DataFusionError; +use datafusion::error::{DataFusionError, Result}; pub mod cars; pub mod regex; @@ -50,10 +49,11 @@ impl ExampleDataset { } pub fn path_str(&self) -> Result { - self.path().to_str().map(String::from).ok_or_else(|| { + let path = self.path(); + path.to_str().map(String::from).ok_or_else(|| { DataFusionError::Execution(format!( "CSV directory path is not valid UTF-8: {}", - self.path().display() + path.display() )) }) } diff --git a/datafusion-examples/src/utils/example_metadata/discover.rs b/datafusion-examples/src/utils/example_metadata/discover.rs new file mode 100644 index 0000000000000..1ba5f6d29a14e --- /dev/null +++ b/datafusion-examples/src/utils/example_metadata/discover.rs @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utilities for discovering example groups in the repository filesystem. +//! +//! An example group is defined as a directory containing a `main.rs` file +//! under the examples root. This module is intentionally filesystem-focused +//! and does not perform any parsing or rendering. +//! Discovery fails if no valid example groups are found. + +use std::fs; +use std::path::{Path, PathBuf}; + +use datafusion::common::exec_err; +use datafusion::error::Result; + +/// Discovers all example group directories under the given root. +/// +/// A directory is considered an example group if it contains a `main.rs` file. +pub fn discover_example_groups(root: &Path) -> Result> { + let mut groups = Vec::new(); + for entry in fs::read_dir(root)? { + let entry = entry?; + let path = entry.path(); + + if path.is_dir() && path.join("main.rs").is_file() { + groups.push(path); + } + } + + if groups.is_empty() { + return exec_err!("No example groups found under: {}", root.display()); + } + + groups.sort(); + Ok(groups) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::utils::example_metadata::test_utils::assert_exec_err_contains; + + use std::fs::{self, File}; + + use tempfile::TempDir; + + #[test] + fn discover_example_groups_finds_dirs_with_main_rs() -> Result<()> { + let tmp = TempDir::new()?; + let root = tmp.path(); + + // valid example group + let group1 = root.join("group1"); + fs::create_dir(&group1)?; + File::create(group1.join("main.rs"))?; + + // not an example group + let group2 = root.join("group2"); + fs::create_dir(&group2)?; + + let groups = discover_example_groups(root)?; + assert_eq!(groups.len(), 1); + assert_eq!(groups[0], group1); + Ok(()) + } + + #[test] + fn discover_example_groups_errors_if_main_rs_is_a_directory() -> Result<()> { + let tmp = TempDir::new()?; + let root = tmp.path(); + let group = root.join("group"); + fs::create_dir(&group)?; + fs::create_dir(group.join("main.rs"))?; + + let err = discover_example_groups(root).unwrap_err(); + assert_exec_err_contains(err, "No example groups found"); + Ok(()) + } + + #[test] + fn discover_example_groups_errors_if_none_found() -> Result<()> { + let tmp = TempDir::new()?; + let err = discover_example_groups(tmp.path()).unwrap_err(); + assert_exec_err_contains(err, "No example groups found"); + Ok(()) + } +} diff --git a/datafusion-examples/src/utils/example_metadata/layout.rs b/datafusion-examples/src/utils/example_metadata/layout.rs new file mode 100644 index 0000000000000..ee6fad89855f9 --- /dev/null +++ b/datafusion-examples/src/utils/example_metadata/layout.rs @@ -0,0 +1,113 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Repository layout utilities. +//! +//! This module provides a small helper (`RepoLayout`) that encapsulates +//! knowledge about the DataFusion repository structure, in particular +//! where example groups are located relative to the repository root. + +use std::path::{Path, PathBuf}; + +use datafusion::error::{DataFusionError, Result}; + +/// Describes the layout of a DataFusion repository. +/// +/// This type centralizes knowledge about where example-related +/// directories live relative to the repository root. +#[derive(Debug, Clone)] +pub struct RepoLayout { + root: PathBuf, +} + +impl From<&Path> for RepoLayout { + fn from(path: &Path) -> Self { + Self { + root: path.to_path_buf(), + } + } +} + +impl RepoLayout { + /// Creates a layout from an explicit repository root. + pub fn from_root(root: PathBuf) -> Self { + Self { root } + } + + /// Detects the repository root based on `CARGO_MANIFEST_DIR`. + /// + /// This is intended for use from binaries inside the workspace. + pub fn detect() -> Result { + let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + + let root = manifest_dir.parent().ok_or_else(|| { + DataFusionError::Execution( + "CARGO_MANIFEST_DIR does not have a parent".to_string(), + ) + })?; + + Ok(Self { + root: root.to_path_buf(), + }) + } + + /// Returns the repository root directory. + pub fn root(&self) -> &Path { + &self.root + } + + /// Returns the `datafusion-examples/examples` directory. + pub fn examples_root(&self) -> PathBuf { + self.root.join("datafusion-examples").join("examples") + } + + /// Returns the directory for a single example group. + /// + /// Example: `examples/udf` + pub fn example_group_dir(&self, group: &str) -> PathBuf { + self.examples_root().join(group) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn detect_sets_non_empty_root() -> Result<()> { + let layout = RepoLayout::detect()?; + assert!(!layout.root().as_os_str().is_empty()); + Ok(()) + } + + #[test] + fn examples_root_is_under_repo_root() -> Result<()> { + let layout = RepoLayout::detect()?; + let examples_root = layout.examples_root(); + assert!(examples_root.starts_with(layout.root())); + assert!(examples_root.ends_with("datafusion-examples/examples")); + Ok(()) + } + + #[test] + fn example_group_dir_appends_group_name() -> Result<()> { + let layout = RepoLayout::detect()?; + let group_dir = layout.example_group_dir("foo"); + assert!(group_dir.ends_with("datafusion-examples/examples/foo")); + Ok(()) + } +} diff --git a/datafusion-examples/src/utils/example_metadata/mod.rs b/datafusion-examples/src/utils/example_metadata/mod.rs new file mode 100644 index 0000000000000..ab4c8e4a8e4c2 --- /dev/null +++ b/datafusion-examples/src/utils/example_metadata/mod.rs @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Documentation generator for DataFusion examples. +//! +//! # Design goals +//! +//! - Keep README.md in sync with runnable examples +//! - Fail fast on malformed documentation +//! +//! # Overview +//! +//! Each example group corresponds to a directory under +//! `datafusion-examples/examples/` containing a `main.rs` file. +//! Documentation is extracted from structured `//!` comments in that file. +//! +//! For each example group, the generator produces: +//! +//! ```text +//! ## Examples +//! ### Group: `` +//! #### Category: Single Process | Distributed +//! +//! | Subcommand | File Path | Description | +//! ``` +//! +//! # Usage +//! +//! Generate documentation for a single group only: +//! +//! ```bash +//! cargo run --bin examples-docs -- dataframe +//! ``` +//! +//! Generate documentation for all examples: +//! +//! ```bash +//! cargo run --bin examples-docs +//! ``` + +pub mod discover; +pub mod layout; +pub mod model; +pub mod parser; +pub mod render; + +#[cfg(test)] +pub mod test_utils; + +pub use layout::RepoLayout; +pub use model::{Category, ExampleEntry, ExampleGroup, GroupName}; +pub use parser::parse_main_rs_docs; +pub use render::generate_examples_readme; diff --git a/datafusion-examples/src/utils/example_metadata/model.rs b/datafusion-examples/src/utils/example_metadata/model.rs new file mode 100644 index 0000000000000..11416d141eb74 --- /dev/null +++ b/datafusion-examples/src/utils/example_metadata/model.rs @@ -0,0 +1,418 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Domain model for DataFusion example documentation. +//! +//! This module defines the core data structures used to represent +//! example groups, individual examples, and their categorization +//! as parsed from `main.rs` documentation comments. + +use std::path::Path; + +use datafusion::error::{DataFusionError, Result}; + +use crate::utils::example_metadata::parse_main_rs_docs; + +/// Well-known abbreviations used to preserve correct capitalization +/// when generating human-readable documentation titles. +const ABBREVIATIONS: &[(&str, &str)] = &[ + ("dataframe", "DataFrame"), + ("io", "IO"), + ("sql", "SQL"), + ("udf", "UDF"), +]; + +/// A group of related examples (e.g. `builtin_functions`, `udf`). +/// +/// Each group corresponds to a directory containing a `main.rs` file +/// with structured documentation comments. +#[derive(Debug)] +pub struct ExampleGroup { + pub name: GroupName, + pub examples: Vec, + pub category: Category, +} + +impl ExampleGroup { + /// Parses an example group from its directory. + /// + /// The group name is derived from the directory name, and example + /// entries are extracted from `main.rs`. + pub fn from_dir(dir: &Path, category: Category) -> Result { + let raw_name = dir + .file_name() + .and_then(|s| s.to_str()) + .ok_or_else(|| { + DataFusionError::Execution("Invalid example group dir".to_string()) + })? + .to_string(); + + let name = GroupName::from_dir_name(raw_name); + let main_rs = dir.join("main.rs"); + let examples = parse_main_rs_docs(&main_rs)?; + + Ok(Self { + name, + examples, + category, + }) + } +} + +/// Represents an example group name in both raw and human-readable forms. +/// +/// For example: +/// - raw: `builtin_functions` +/// - title: `Builtin Functions` +#[derive(Debug)] +pub struct GroupName { + raw: String, + title: String, +} + +impl GroupName { + /// Creates a group name from a directory name. + pub fn from_dir_name(raw: String) -> Self { + let title = raw + .split('_') + .map(format_part) + .collect::>() + .join(" "); + + Self { raw, title } + } + + /// Returns the raw group name (directory name). + pub fn raw(&self) -> &str { + &self.raw + } + + /// Returns a title-cased name for documentation. + pub fn title(&self) -> &str { + &self.title + } +} + +/// A single runnable example within a group. +/// +/// Each entry corresponds to a subcommand documented in `main.rs`. +#[derive(Debug)] +pub struct ExampleEntry { + /// CLI subcommand name. + pub subcommand: String, + /// Rust source file name. + pub file: String, + /// Human-readable description. + pub desc: String, +} + +/// Execution category of an example group. +#[derive(Debug, Default)] +pub enum Category { + /// Runs in a single process. + #[default] + SingleProcess, + /// Requires a distributed setup. + Distributed, +} + +impl Category { + /// Returns the display name used in documentation. + pub fn name(&self) -> &str { + match self { + Self::SingleProcess => "Single Process", + Self::Distributed => "Distributed", + } + } + + /// Determines the category for a group by name. + pub fn for_group(name: &str) -> Self { + match name { + "flight" => Category::Distributed, + _ => Category::SingleProcess, + } + } +} + +/// Formats a single group-name segment for display. +/// +/// This function applies DataFusion-specific capitalization rules: +/// - Known abbreviations (e.g. `sql`, `io`, `udf`) are rendered in all caps +/// - All other segments fall back to standard Title Case +fn format_part(part: &str) -> String { + let lower = part.to_ascii_lowercase(); + + if let Some((_, replacement)) = ABBREVIATIONS.iter().find(|(k, _)| *k == lower) { + return replacement.to_string(); + } + + let mut chars = part.chars(); + match chars.next() { + Some(first) => first.to_uppercase().collect::() + chars.as_str(), + None => String::new(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::utils::example_metadata::test_utils::{ + assert_exec_err_contains, example_group_from_docs, + }; + + use std::fs; + + use tempfile::TempDir; + + #[test] + fn category_for_group_works() { + assert!(matches!( + Category::for_group("flight"), + Category::Distributed + )); + assert!(matches!( + Category::for_group("anything_else"), + Category::SingleProcess + )); + } + + #[test] + fn all_subcommand_is_ignored() -> Result<()> { + let group = example_group_from_docs( + r#" + //! - `all` — run all examples included in this module + //! + //! - `foo` + //! (file: foo.rs, desc: foo example) + "#, + )?; + assert_eq!(group.examples.len(), 1); + assert_eq!(group.examples[0].subcommand, "foo"); + Ok(()) + } + + #[test] + fn metadata_without_subcommand_fails() { + let err = example_group_from_docs("//! (file: foo.rs, desc: missing subcommand)") + .unwrap_err(); + assert_exec_err_contains(err, "Metadata without preceding subcommand"); + } + + #[test] + fn group_name_handles_abbreviations() { + assert_eq!( + GroupName::from_dir_name("dataframe".to_string()).title(), + "DataFrame" + ); + assert_eq!( + GroupName::from_dir_name("data_io".to_string()).title(), + "Data IO" + ); + assert_eq!( + GroupName::from_dir_name("sql_ops".to_string()).title(), + "SQL Ops" + ); + assert_eq!(GroupName::from_dir_name("udf".to_string()).title(), "UDF"); + } + + #[test] + fn group_name_title_cases() { + let cases = [ + ("very_long_group_name", "Very Long Group Name"), + ("foo", "Foo"), + ("dataframe", "DataFrame"), + ("data_io", "Data IO"), + ("sql_ops", "SQL Ops"), + ("udf", "UDF"), + ]; + for (input, expected) in cases { + let name = GroupName::from_dir_name(input.to_string()); + assert_eq!(name.title(), expected); + } + } + + #[test] + fn parse_group_example_works() -> Result<()> { + let tmp = TempDir::new().unwrap(); + + // Simulate: examples/builtin_functions/ + let group_dir = tmp.path().join("builtin_functions"); + fs::create_dir(&group_dir)?; + + // Write a fake main.rs with docs + let main_rs = group_dir.join("main.rs"); + fs::write( + &main_rs, + r#" + // Licensed to the Apache Software Foundation (ASF) under one + // or more contributor license agreements. See the NOTICE file + // distributed with this work for additional information + // regarding copyright ownership. The ASF licenses this file + // to you under the Apache License, Version 2.0 (the + // "License"); you may not use this file except in compliance + // with the License. You may obtain a copy of the License at + // + // http://www.apache.org/licenses/LICENSE-2.0 + // + // Unless required by applicable law or agreed to in writing, + // software distributed under the License is distributed on an + // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + // KIND, either express or implied. See the License for the + // specific language governing permissions and limitations + // under the License. + // + //! # These are miscellaneous function-related examples + //! + //! These examples demonstrate miscellaneous function-related features. + //! + //! ## Usage + //! ```bash + //! cargo run --example builtin_functions -- [all|date_time|function_factory|regexp] + //! ``` + //! + //! Each subcommand runs a corresponding example: + //! - `all` — run all examples included in this module + //! + //! - `date_time` + //! (file: date_time.rs, desc: Examples of date-time related functions and queries) + //! + //! - `function_factory` + //! (file: function_factory.rs, desc: Register `CREATE FUNCTION` handler to implement SQL macros) + //! + //! - `regexp` + //! (file: regexp.rs, desc: Examples of using regular expression functions) + "#, + )?; + + let group = ExampleGroup::from_dir(&group_dir, Category::SingleProcess)?; + + // Assert group-level data + assert_eq!(group.name.title(), "Builtin Functions"); + assert_eq!(group.examples.len(), 3); + + // Assert 1 example + assert_eq!(group.examples[0].subcommand, "date_time"); + assert_eq!(group.examples[0].file, "date_time.rs"); + assert_eq!( + group.examples[0].desc, + "Examples of date-time related functions and queries" + ); + + // Assert 2 example + assert_eq!(group.examples[1].subcommand, "function_factory"); + assert_eq!(group.examples[1].file, "function_factory.rs"); + assert_eq!( + group.examples[1].desc, + "Register `CREATE FUNCTION` handler to implement SQL macros" + ); + + // Assert 3 example + assert_eq!(group.examples[2].subcommand, "regexp"); + assert_eq!(group.examples[2].file, "regexp.rs"); + assert_eq!( + group.examples[2].desc, + "Examples of using regular expression functions" + ); + + Ok(()) + } + + #[test] + fn duplicate_metadata_without_repeating_subcommand_fails() { + let err = example_group_from_docs( + r#" + //! - `foo` + //! (file: a.rs, desc: first) + //! (file: b.rs, desc: second) + "#, + ) + .unwrap_err(); + assert_exec_err_contains(err, "Metadata without preceding subcommand"); + } + + #[test] + fn duplicate_metadata_for_same_subcommand_fails() { + let err = example_group_from_docs( + r#" + //! - `foo` + //! (file: a.rs, desc: first) + //! + //! - `foo` + //! (file: b.rs, desc: second) + "#, + ) + .unwrap_err(); + assert_exec_err_contains(err, "Duplicate metadata for subcommand `foo`"); + } + + #[test] + fn metadata_must_follow_subcommand() { + let err = example_group_from_docs( + r#" + //! - `foo` + //! some unrelated comment + //! (file: foo.rs, desc: test) + "#, + ) + .unwrap_err(); + assert_exec_err_contains(err, "Metadata without preceding subcommand"); + } + + #[test] + fn preserves_example_order_from_main_rs() -> Result<()> { + let group = example_group_from_docs( + r#" + //! - `second` + //! (file: second.rs, desc: second example) + //! + //! - `first` + //! (file: first.rs, desc: first example) + //! + //! - `third` + //! (file: third.rs, desc: third example) + "#, + )?; + + let subcommands: Vec<&str> = group + .examples + .iter() + .map(|e| e.subcommand.as_str()) + .collect(); + + assert_eq!( + subcommands, + vec!["second", "first", "third"], + "examples must preserve the order defined in main.rs" + ); + + Ok(()) + } + + #[test] + fn metadata_can_follow_blank_doc_line() -> Result<()> { + let group = example_group_from_docs( + r#" + //! - `foo` + //! + //! (file: foo.rs, desc: test) + "#, + )?; + assert_eq!(group.examples.len(), 1); + Ok(()) + } +} diff --git a/datafusion-examples/src/utils/example_metadata/parser.rs b/datafusion-examples/src/utils/example_metadata/parser.rs new file mode 100644 index 0000000000000..4ead3e5a2ae9f --- /dev/null +++ b/datafusion-examples/src/utils/example_metadata/parser.rs @@ -0,0 +1,267 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Parser for example metadata embedded in `main.rs` documentation comments. +//! +//! This module scans `//!` doc comments to extract example subcommands +//! and their associated metadata (file name and description), enforcing +//! a strict ordering and structure to avoid ambiguous documentation. + +use std::{collections::HashSet, fs, path::Path}; + +use datafusion::common::exec_err; +use datafusion::error::Result; +use nom::{ + Err, IResult, Parser, + bytes::complete::{tag, take_until, take_while}, + character::complete::multispace0, + combinator::all_consuming, + error::{Error, ErrorKind}, + sequence::{delimited, preceded}, +}; + +use crate::utils::example_metadata::ExampleEntry; + +/// Parsing state machine used while scanning `main.rs` docs. +/// +/// This makes the "subcommand - metadata" relationship explicit: +/// metadata is only valid immediately after a subcommand has been seen. +enum ParserState<'a> { + /// Not currently expecting metadata. + Idle, + /// A subcommand was just parsed; the next valid metadata (if any) + /// must belong to this subcommand. + SeenSubcommand(&'a str), +} + +/// Parses a subcommand declaration line from `main.rs` docs. +/// +/// Expected format: +/// ```text +/// //! - `` +/// ``` +fn parse_subcommand_line(input: &str) -> IResult<&str, &str> { + let parser = preceded( + multispace0, + delimited(tag("//! - `"), take_until("`"), tag("`")), + ); + all_consuming(parser).parse(input) +} + +/// Parses example metadata (file name and description) from `main.rs` docs. +/// +/// Expected format: +/// ```text +/// //! (file: .rs, desc: ) +/// ``` +fn parse_metadata_line(input: &str) -> IResult<&str, (&str, &str)> { + let parser = preceded( + multispace0, + preceded(tag("//!"), preceded(multispace0, take_while(|_| true))), + ); + let (rest, payload) = all_consuming(parser).parse(input)?; + + let content = payload + .strip_prefix("(") + .and_then(|s| s.strip_suffix(")")) + .ok_or_else(|| Err::Error(Error::new(payload, ErrorKind::Tag)))?; + + let (file, desc) = content + .strip_prefix("file:") + .ok_or_else(|| Err::Error(Error::new(payload, ErrorKind::Tag)))? + .split_once(", desc:") + .ok_or_else(|| Err::Error(Error::new(payload, ErrorKind::Tag)))?; + + Ok((rest, (file.trim(), desc.trim()))) +} + +/// Parses example entries from a group's `main.rs` file. +pub fn parse_main_rs_docs(path: &Path) -> Result> { + let content = fs::read_to_string(path)?; + let mut entries = vec![]; + let mut state = ParserState::Idle; + let mut seen_subcommands = HashSet::new(); + + for (line_no, raw_line) in content.lines().enumerate() { + let line = raw_line.trim(); + + // Try parsing subcommand, excluding `all` because it's not used in README + if let Ok((_, sub)) = parse_subcommand_line(line) { + state = if sub == "all" { + ParserState::Idle + } else { + ParserState::SeenSubcommand(sub) + }; + continue; + } + + // Try parsing metadata + if let Ok((_, (file, desc))) = parse_metadata_line(line) { + let subcommand = match state { + ParserState::SeenSubcommand(s) => s, + ParserState::Idle => { + return exec_err!( + "Metadata without preceding subcommand at {}:{}", + path.display(), + line_no + 1 + ); + } + }; + + if !seen_subcommands.insert(subcommand) { + return exec_err!("Duplicate metadata for subcommand `{subcommand}`"); + } + + entries.push(ExampleEntry { + subcommand: subcommand.to_string(), + file: file.to_string(), + desc: desc.to_string(), + }); + + state = ParserState::Idle; + continue; + } + + // If a non-blank doc line interrupts a pending subcommand, reset the state + if let ParserState::SeenSubcommand(_) = state + && is_non_blank_doc_line(line) + { + state = ParserState::Idle; + } + } + + Ok(entries) +} + +/// Returns `true` for non-blank Rust doc comment lines (`//!`). +/// +/// Used to detect when a subcommand is interrupted by unrelated documentation, +/// so metadata is only accepted immediately after a subcommand (blank doc lines +/// are allowed in between). +fn is_non_blank_doc_line(line: &str) -> bool { + line.starts_with("//!") && !line.trim_start_matches("//!").trim().is_empty() +} + +#[cfg(test)] +mod tests { + use super::*; + + use tempfile::TempDir; + + #[test] + fn parse_subcommand_line_accepts_valid_input() { + let line = "//! - `date_time`"; + let sub = parse_subcommand_line(line); + assert_eq!(sub, Ok(("", "date_time"))); + } + + #[test] + fn parse_subcommand_line_invalid_inputs() { + let err_lines = [ + "//! - ", + "//! - foo", + "//! - `foo` bar", + "//! --", + "//!-", + "//!--", + "//!", + "//", + "/", + "", + ]; + for line in err_lines { + assert!( + parse_subcommand_line(line).is_err(), + "expected error for input: {line}" + ); + } + } + + #[test] + fn parse_metadata_line_accepts_valid_input() { + let line = + "//! (file: date_time.rs, desc: Examples of date-time related functions)"; + let res = parse_metadata_line(line); + assert_eq!( + res, + Ok(( + "", + ("date_time.rs", "Examples of date-time related functions") + )) + ); + + let line = "//! (file: foo.rs, desc: Foo, bar, baz)"; + let res = parse_metadata_line(line); + assert_eq!(res, Ok(("", ("foo.rs", "Foo, bar, baz")))); + + let line = "//! (file: foo.rs, desc: Foo(FOO))"; + let res = parse_metadata_line(line); + assert_eq!(res, Ok(("", ("foo.rs", "Foo(FOO)")))); + } + + #[test] + fn parse_metadata_line_invalid_inputs() { + let bad_lines = [ + "//! (file: foo.rs)", + "//! (desc: missing file)", + "//! file: foo.rs, desc: test", + "//! file: foo.rs,desc: test", + "//! (file: foo.rs desc: test)", + "//! (file: foo.rs,desc: test)", + "//! (desc: test, file: foo.rs)", + "//! ()", + "//! (file: foo.rs, desc: test) extra", + "", + ]; + for line in bad_lines { + assert!( + parse_metadata_line(line).is_err(), + "expected error for input: {line}" + ); + } + } + + #[test] + fn parse_main_rs_docs_extracts_entries() -> Result<()> { + let tmp = TempDir::new().unwrap(); + let main_rs = tmp.path().join("main.rs"); + + fs::write( + &main_rs, + r#" + //! - `foo` + //! (file: foo.rs, desc: first example) + //! + //! - `bar` + //! (file: bar.rs, desc: second example) + "#, + )?; + + let entries = parse_main_rs_docs(&main_rs)?; + + assert_eq!(entries.len(), 2); + + assert_eq!(entries[0].subcommand, "foo"); + assert_eq!(entries[0].file, "foo.rs"); + assert_eq!(entries[0].desc, "first example"); + + assert_eq!(entries[1].subcommand, "bar"); + assert_eq!(entries[1].file, "bar.rs"); + assert_eq!(entries[1].desc, "second example"); + Ok(()) + } +} diff --git a/datafusion-examples/src/utils/example_metadata/render.rs b/datafusion-examples/src/utils/example_metadata/render.rs new file mode 100644 index 0000000000000..a4ea620e78352 --- /dev/null +++ b/datafusion-examples/src/utils/example_metadata/render.rs @@ -0,0 +1,203 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Markdown renderer for DataFusion example documentation. +//! +//! This module takes parsed example metadata and generates the +//! `README.md` content for `datafusion-examples`, including group +//! sections and example tables. + +use std::path::PathBuf; + +use datafusion::error::{DataFusionError, Result}; + +use crate::utils::example_metadata::discover::discover_example_groups; +use crate::utils::example_metadata::model::ExampleGroup; +use crate::utils::example_metadata::{Category, RepoLayout}; + +const STATIC_HEADER: &str = r#" + +# DataFusion Examples + +This crate includes end to end, highly commented examples of how to use +various DataFusion APIs to help you get started. + +## Prerequisites + +Run `git submodule update --init` to init test files. + +## Running Examples + +To run an example, use the `cargo run` command, such as: + +```bash +git clone https://github.com/apache/datafusion +cd datafusion +# Download test data +git submodule update --init + +# Change to the examples directory +cd datafusion-examples/examples + +# Run all examples in a group +cargo run --example -- all + +# Run a specific example within a group +cargo run --example -- + +# Run all examples in the `dataframe` group +cargo run --example dataframe -- all + +# Run a single example from the `dataframe` group +# (apply the same pattern for any other group) +cargo run --example dataframe -- dataframe +``` +"#; + +/// Generates Markdown documentation for DataFusion examples. +/// +/// If `group` is `None`, documentation is generated for all example groups. +/// If `group` is `Some`, only that group is rendered. +/// +/// # Errors +/// +/// Returns an error if: +/// - the requested group does not exist +/// - a `main.rs` file is missing +/// - documentation comments are malformed +pub fn generate_examples_readme( + layout: &RepoLayout, + group: Option<&str>, +) -> Result { + let examples_root = layout.examples_root(); + + let mut out = String::new(); + out.push_str(STATIC_HEADER); + + let group_dirs: Vec = match group { + Some(name) => { + let dir = examples_root.join(name); + if !dir.is_dir() { + return Err(DataFusionError::Execution(format!( + "Example group `{name}` does not exist" + ))); + } + vec![dir] + } + None => discover_example_groups(&examples_root)?, + }; + + for group_dir in group_dirs { + let raw_name = + group_dir + .file_name() + .and_then(|s| s.to_str()) + .ok_or_else(|| { + DataFusionError::Execution("Invalid example group dir".to_string()) + })?; + + let category = Category::for_group(raw_name); + let group = ExampleGroup::from_dir(&group_dir, category)?; + + out.push_str(&group.render_markdown()); + } + + Ok(out) +} + +impl ExampleGroup { + /// Renders this example group as a Markdown section for the README. + pub fn render_markdown(&self) -> String { + let mut out = String::new(); + out.push_str(&format!("\n## {} Examples\n\n", self.name.title())); + out.push_str(&format!("### Group: `{}`\n\n", self.name.raw())); + out.push_str(&format!("#### Category: {}\n\n", self.category.name())); + out.push_str("| Subcommand | File Path | Description |\n"); + out.push_str("| --- | --- | --- |\n"); + + for example in &self.examples { + out.push_str(&format!( + "| {} | [`{}/{}`](examples/{}/{}) | {} |\n", + example.subcommand, + self.name.raw(), + example.file, + self.name.raw(), + example.file, + example.desc + )); + } + + out + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::utils::example_metadata::test_utils::assert_exec_err_contains; + + use std::fs; + + use tempfile::TempDir; + + #[test] + fn single_group_generation_works() { + let tmp = TempDir::new().unwrap(); + // Fake repo root + let layout = RepoLayout::from_root(tmp.path().to_path_buf()); + + // Create: datafusion-examples/examples/builtin_functions + let examples_dir = layout.example_group_dir("builtin_functions"); + fs::create_dir_all(&examples_dir).unwrap(); + + fs::write( + examples_dir.join("main.rs"), + "//! - `x`\n//! (file: foo.rs, desc: test)", + ) + .unwrap(); + + let out = generate_examples_readme(&layout, Some("builtin_functions")).unwrap(); + assert!(out.contains("Builtin Functions")); + assert!(out.contains("| x | [`builtin_functions/foo.rs`]")); + } + + #[test] + fn single_group_generation_fails_if_group_missing() { + let tmp = TempDir::new().unwrap(); + let layout = RepoLayout::from_root(tmp.path().to_path_buf()); + let err = generate_examples_readme(&layout, Some("missing_group")).unwrap_err(); + assert_exec_err_contains(err, "Example group `missing_group` does not exist"); + } +} diff --git a/datafusion-examples/src/utils/example_metadata/test_utils.rs b/datafusion-examples/src/utils/example_metadata/test_utils.rs new file mode 100644 index 0000000000000..d6ab3b06ba06d --- /dev/null +++ b/datafusion-examples/src/utils/example_metadata/test_utils.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Test helpers for example metadata parsing and validation. +//! +//! This module provides small, focused utilities to reduce duplication +//! and keep tests readable across the example metadata submodules. + +use std::fs; + +use datafusion::error::{DataFusionError, Result}; +use tempfile::TempDir; + +use crate::utils::example_metadata::{Category, ExampleGroup}; + +/// Asserts that an `Execution` error contains the expected message fragment. +/// +/// Keeps tests focused on semantic error causes without coupling them +/// to full error string formatting. +pub fn assert_exec_err_contains(err: DataFusionError, needle: &str) { + match err { + DataFusionError::Execution(msg) => { + assert!( + msg.contains(needle), + "expected '{needle}' in error message, got: {msg}" + ); + } + other => panic!("expected Execution error, got: {other:?}"), + } +} + +/// Helper for grammar-focused tests. +/// +/// Creates a minimal temporary example group with a single `main.rs` +/// containing the provided docs. Intended for testing parsing and +/// validation rules, not full integration behavior. +pub fn example_group_from_docs(docs: &str) -> Result { + let tmp = TempDir::new().map_err(|e| { + DataFusionError::Execution(format!("Failed initializing temp dir: {e}")) + })?; + let dir = tmp.path().join("group"); + fs::create_dir(&dir).map_err(|e| { + DataFusionError::Execution(format!("Failed creating temp dir: {e}")) + })?; + fs::write(dir.join("main.rs"), docs).map_err(|e| { + DataFusionError::Execution(format!("Failed writing to temp file: {e}")) + })?; + ExampleGroup::from_dir(&dir, Category::SingleProcess) +} diff --git a/datafusion-examples/src/utils/mod.rs b/datafusion-examples/src/utils/mod.rs index b9e5b487db3ac..da96724a49cb3 100644 --- a/datafusion-examples/src/utils/mod.rs +++ b/datafusion-examples/src/utils/mod.rs @@ -17,5 +17,6 @@ mod csv_to_parquet; pub mod datasets; +pub mod example_metadata; pub use csv_to_parquet::write_csv_to_parquet; diff --git a/datafusion/catalog-listing/src/helpers.rs b/datafusion/catalog-listing/src/helpers.rs index ea016015cebd3..031b2ebfb8109 100644 --- a/datafusion/catalog-listing/src/helpers.rs +++ b/datafusion/catalog-listing/src/helpers.rs @@ -83,6 +83,7 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { | Expr::Exists(_) | Expr::InSubquery(_) | Expr::ScalarSubquery(_) + | Expr::SetComparison(_) | Expr::GroupingSet(_) | Expr::Case(_) => Ok(TreeNodeRecursion::Continue), diff --git a/datafusion/catalog-listing/src/mod.rs b/datafusion/catalog-listing/src/mod.rs index 28bd880ea01fb..9efb5aa96267e 100644 --- a/datafusion/catalog-listing/src/mod.rs +++ b/datafusion/catalog-listing/src/mod.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -#![deny(clippy::allow_attributes)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] #![doc( html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", diff --git a/datafusion/catalog-listing/src/table.rs b/datafusion/catalog-listing/src/table.rs index 38456944075fc..a5de79b052a4e 100644 --- a/datafusion/catalog-listing/src/table.rs +++ b/datafusion/catalog-listing/src/table.rs @@ -28,7 +28,7 @@ use datafusion_common::{ use datafusion_datasource::file::FileSource; use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; -use datafusion_datasource::file_sink_config::FileSinkConfig; +use datafusion_datasource::file_sink_config::{FileOutputMode, FileSinkConfig}; #[expect(deprecated)] use datafusion_datasource::schema_adapter::SchemaAdapterFactory; use datafusion_datasource::{ @@ -674,6 +674,7 @@ impl TableProvider for ListingTable { insert_op, keep_partition_by_columns, file_extension: self.options().format.get_ext(), + file_output_mode: FileOutputMode::Automatic, }; // For writes, we only use user-specified ordering (no file groups to derive from) diff --git a/datafusion/catalog/src/information_schema.rs b/datafusion/catalog/src/information_schema.rs index 52bfeca3d4282..ea93dc21a3f5b 100644 --- a/datafusion/catalog/src/information_schema.rs +++ b/datafusion/catalog/src/information_schema.rs @@ -24,7 +24,7 @@ use crate::{CatalogProviderList, SchemaProvider, TableProvider}; use arrow::array::builder::{BooleanBuilder, UInt8Builder}; use arrow::{ array::{StringBuilder, UInt64Builder}, - datatypes::{DataType, Field, Schema, SchemaRef}, + datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}, record_batch::RecordBatch, }; use async_trait::async_trait; @@ -34,7 +34,10 @@ use datafusion_common::error::Result; use datafusion_common::types::NativeType; use datafusion_execution::TaskContext; use datafusion_execution::runtime_env::RuntimeEnv; -use datafusion_expr::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF}; +use datafusion_expr::function::WindowUDFFieldArgs; +use datafusion_expr::{ + AggregateUDF, ReturnFieldArgs, ScalarUDF, Signature, TypeSignature, WindowUDF, +}; use datafusion_expr::{TableType, Volatility}; use datafusion_physical_plan::SendableRecordBatchStream; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; @@ -421,10 +424,24 @@ fn get_udf_args_and_return_types( Ok(arg_types .into_iter() .map(|arg_types| { - // only handle the function which implemented [`ScalarUDFImpl::return_type`] method + let arg_fields: Vec = arg_types + .iter() + .enumerate() + .map(|(i, t)| { + Arc::new(Field::new(format!("arg_{i}"), t.clone(), true)) + }) + .collect(); + let scalar_arguments = vec![None; arg_fields.len()]; let return_type = udf - .return_type(&arg_types) - .map(|t| remove_native_type_prefix(&NativeType::from(t))) + .return_field_from_args(ReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &scalar_arguments, + }) + .map(|f| { + remove_native_type_prefix(&NativeType::from( + f.data_type().clone(), + )) + }) .ok(); let arg_types = arg_types .into_iter() @@ -447,11 +464,21 @@ fn get_udaf_args_and_return_types( Ok(arg_types .into_iter() .map(|arg_types| { - // only handle the function which implemented [`ScalarUDFImpl::return_type`] method + let arg_fields: Vec = arg_types + .iter() + .enumerate() + .map(|(i, t)| { + Arc::new(Field::new(format!("arg_{i}"), t.clone(), true)) + }) + .collect(); let return_type = udaf - .return_type(&arg_types) - .ok() - .map(|t| remove_native_type_prefix(&NativeType::from(t))); + .return_field(&arg_fields) + .map(|f| { + remove_native_type_prefix(&NativeType::from( + f.data_type().clone(), + )) + }) + .ok(); let arg_types = arg_types .into_iter() .map(|t| remove_native_type_prefix(&NativeType::from(t))) @@ -473,12 +500,26 @@ fn get_udwf_args_and_return_types( Ok(arg_types .into_iter() .map(|arg_types| { - // only handle the function which implemented [`ScalarUDFImpl::return_type`] method + let arg_fields: Vec = arg_types + .iter() + .enumerate() + .map(|(i, t)| { + Arc::new(Field::new(format!("arg_{i}"), t.clone(), true)) + }) + .collect(); + let return_type = udwf + .field(WindowUDFFieldArgs::new(&arg_fields, udwf.name())) + .map(|f| { + remove_native_type_prefix(&NativeType::from( + f.data_type().clone(), + )) + }) + .ok(); let arg_types = arg_types .into_iter() .map(|t| remove_native_type_prefix(&NativeType::from(t))) .collect::>(); - (arg_types, None) + (arg_types, return_type) }) .collect::>()) } diff --git a/datafusion/catalog/src/lib.rs b/datafusion/catalog/src/lib.rs index d1cd3998fecf1..931941e8fdfad 100644 --- a/datafusion/catalog/src/lib.rs +++ b/datafusion/catalog/src/lib.rs @@ -24,7 +24,6 @@ // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -#![deny(clippy::allow_attributes)] //! Interfaces and default implementations of catalogs and schemas. //! diff --git a/datafusion/catalog/src/streaming.rs b/datafusion/catalog/src/streaming.rs index 31669171b291a..db9596b420b7b 100644 --- a/datafusion/catalog/src/streaming.rs +++ b/datafusion/catalog/src/streaming.rs @@ -20,19 +20,18 @@ use std::any::Any; use std::sync::Arc; -use crate::Session; -use crate::TableProvider; - use arrow::datatypes::SchemaRef; +use async_trait::async_trait; use datafusion_common::{DFSchema, Result, plan_err}; use datafusion_expr::{Expr, SortExpr, TableType}; +use datafusion_physical_expr::equivalence::project_ordering; use datafusion_physical_expr::{LexOrdering, create_physical_sort_exprs}; use datafusion_physical_plan::ExecutionPlan; use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; - -use async_trait::async_trait; use log::debug; +use crate::{Session, TableProvider}; + /// A [`TableProvider`] that streams a set of [`PartitionStream`] #[derive(Debug)] pub struct StreamingTable { @@ -105,7 +104,22 @@ impl TableProvider for StreamingTable { let df_schema = DFSchema::try_from(Arc::clone(&self.schema))?; let eqp = state.execution_props(); - create_physical_sort_exprs(&self.sort_order, &df_schema, eqp)? + let original_sort_exprs = + create_physical_sort_exprs(&self.sort_order, &df_schema, eqp)?; + + if let Some(p) = projection { + // When performing a projection, the output columns will not match + // the original physical sort expression indices. Also the sort columns + // may not be in the output projection. To correct for these issues + // we need to project the ordering based on the output schema. + let schema = Arc::new(self.schema.project(p)?); + LexOrdering::new(original_sort_exprs) + .and_then(|lex_ordering| project_ordering(&lex_ordering, &schema)) + .map(|lex_ordering| lex_ordering.to_vec()) + .unwrap_or_default() + } else { + original_sort_exprs + } } else { vec![] }; diff --git a/datafusion/catalog/src/table.rs b/datafusion/catalog/src/table.rs index 1f223852c2b9d..f31d4d52ce88b 100644 --- a/datafusion/catalog/src/table.rs +++ b/datafusion/catalog/src/table.rs @@ -353,6 +353,14 @@ pub trait TableProvider: Debug + Sync + Send { ) -> Result> { not_impl_err!("UPDATE not supported for {} table", self.table_type()) } + + /// Remove all rows from the table. + /// + /// Should return an [ExecutionPlan] producing a single row with count (UInt64), + /// representing the number of rows removed. + async fn truncate(&self, _state: &dyn Session) -> Result> { + not_impl_err!("TRUNCATE not supported for {} table", self.table_type()) + } } /// Arguments for scanning a table with [`TableProvider::scan_with_args`]. diff --git a/datafusion/common-runtime/src/lib.rs b/datafusion/common-runtime/src/lib.rs index fdbfe7f2390ca..cf45ccf3ef63a 100644 --- a/datafusion/common-runtime/src/lib.rs +++ b/datafusion/common-runtime/src/lib.rs @@ -16,7 +16,6 @@ // under the License. #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -#![deny(clippy::allow_attributes)] #![doc( html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index 710996707a64c..82e7aafcee2b1 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -72,6 +72,7 @@ half = { workspace = true } hashbrown = { workspace = true } hex = { workspace = true, optional = true } indexmap = { workspace = true } +itertools = { workspace = true } libc = "0.2.180" log = { workspace = true } object_store = { workspace = true, optional = true } diff --git a/datafusion/common/benches/with_hashes.rs b/datafusion/common/benches/with_hashes.rs index 8154c20df88f3..9ee31d9c4bef6 100644 --- a/datafusion/common/benches/with_hashes.rs +++ b/datafusion/common/benches/with_hashes.rs @@ -19,11 +19,14 @@ use ahash::RandomState; use arrow::array::{ - Array, ArrayRef, ArrowPrimitiveType, DictionaryArray, GenericStringArray, - NullBufferBuilder, OffsetSizeTrait, PrimitiveArray, StringViewArray, make_array, + Array, ArrayRef, ArrowPrimitiveType, DictionaryArray, GenericStringArray, Int32Array, + Int64Array, ListArray, MapArray, NullBufferBuilder, OffsetSizeTrait, PrimitiveArray, + RunArray, StringViewArray, StructArray, UnionArray, make_array, +}; +use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; +use arrow::datatypes::{ + ArrowDictionaryKeyType, DataType, Field, Fields, Int32Type, Int64Type, UnionFields, }; -use arrow::buffer::NullBuffer; -use arrow::datatypes::{ArrowDictionaryKeyType, Int32Type, Int64Type}; use criterion::{Bencher, Criterion, criterion_group, criterion_main}; use datafusion_common::hash_utils::with_hashes; use rand::Rng; @@ -37,6 +40,8 @@ const BATCH_SIZE: usize = 8192; struct BenchData { name: &'static str, array: ArrayRef, + /// Union arrays can't have null bitmasks added + supports_nulls: bool, } fn criterion_benchmark(c: &mut Criterion) { @@ -47,50 +52,93 @@ fn criterion_benchmark(c: &mut Criterion) { BenchData { name: "int64", array: primitive_array::(BATCH_SIZE), + supports_nulls: true, }, BenchData { name: "utf8", array: pool.string_array::(BATCH_SIZE), + supports_nulls: true, }, BenchData { name: "large_utf8", array: pool.string_array::(BATCH_SIZE), + supports_nulls: true, }, BenchData { name: "utf8_view", array: pool.string_view_array(BATCH_SIZE), + supports_nulls: true, }, BenchData { name: "utf8_view (small)", array: small_pool.string_view_array(BATCH_SIZE), + supports_nulls: true, }, BenchData { name: "dictionary_utf8_int32", array: pool.dictionary_array::(BATCH_SIZE), + supports_nulls: true, + }, + BenchData { + name: "list_array", + array: list_array(BATCH_SIZE), + supports_nulls: true, + }, + BenchData { + name: "map_array", + array: map_array(BATCH_SIZE), + supports_nulls: true, + }, + BenchData { + name: "sparse_union", + array: sparse_union_array(BATCH_SIZE), + supports_nulls: false, + }, + BenchData { + name: "dense_union", + array: dense_union_array(BATCH_SIZE), + supports_nulls: false, + }, + BenchData { + name: "struct_array", + array: create_struct_array(&pool, BATCH_SIZE), + supports_nulls: true, + }, + BenchData { + name: "run_array_int32", + array: create_run_array::(BATCH_SIZE), + supports_nulls: true, }, ]; - for BenchData { name, array } in cases { - // with_hash has different code paths for single vs multiple arrays and nulls vs no nulls - let nullable_array = add_nulls(&array); + for BenchData { + name, + array, + supports_nulls, + } in cases + { c.bench_function(&format!("{name}: single, no nulls"), |b| { do_hash_test(b, std::slice::from_ref(&array)); }); - c.bench_function(&format!("{name}: single, nulls"), |b| { - do_hash_test(b, std::slice::from_ref(&nullable_array)); - }); c.bench_function(&format!("{name}: multiple, no nulls"), |b| { let arrays = vec![array.clone(), array.clone(), array.clone()]; do_hash_test(b, &arrays); }); - c.bench_function(&format!("{name}: multiple, nulls"), |b| { - let arrays = vec![ - nullable_array.clone(), - nullable_array.clone(), - nullable_array.clone(), - ]; - do_hash_test(b, &arrays); - }); + // Union arrays can't have null bitmasks + if supports_nulls { + let nullable_array = add_nulls(&array); + c.bench_function(&format!("{name}: single, nulls"), |b| { + do_hash_test(b, std::slice::from_ref(&nullable_array)); + }); + c.bench_function(&format!("{name}: multiple, nulls"), |b| { + let arrays = vec![ + nullable_array.clone(), + nullable_array.clone(), + nullable_array.clone(), + ]; + do_hash_test(b, &arrays); + }); + } } } @@ -122,16 +170,51 @@ where builder.finish().expect("should be nulls in buffer") } -// Returns an new array that is the same as array, but with nulls +// Returns a new array that is the same as array, but with nulls +// Handles the special case of RunArray where nulls must be in the values array fn add_nulls(array: &ArrayRef) -> ArrayRef { - let array_data = array - .clone() - .into_data() - .into_builder() - .nulls(Some(create_null_mask(array.len()))) - .build() - .unwrap(); - make_array(array_data) + use arrow::datatypes::DataType; + + match array.data_type() { + DataType::RunEndEncoded(_, _) => { + // RunArray can't have top-level nulls, so apply nulls to the values array + let run_array = array + .as_any() + .downcast_ref::>() + .expect("Expected RunArray"); + + let run_ends_buffer = run_array.run_ends().inner().clone(); + let run_ends_array = PrimitiveArray::::new(run_ends_buffer, None); + let values = run_array.values().clone(); + + // Add nulls to the values array + let values_with_nulls = { + let array_data = values + .clone() + .into_data() + .into_builder() + .nulls(Some(create_null_mask(values.len()))) + .build() + .unwrap(); + make_array(array_data) + }; + + Arc::new( + RunArray::try_new(&run_ends_array, values_with_nulls.as_ref()) + .expect("Failed to create RunArray with null values"), + ) + } + _ => { + let array_data = array + .clone() + .into_data() + .into_builder() + .nulls(Some(create_null_mask(array.len()))) + .build() + .unwrap(); + make_array(array_data) + } + } } pub fn make_rng() -> StdRng { @@ -205,5 +288,282 @@ where Arc::new(array) } -criterion_group!(benches, criterion_benchmark); +/// Benchmark sliced arrays to demonstrate the optimization for when an array is +/// sliced, the underlying buffer may be much larger than what's referenced by +/// the slice. The optimization avoids hashing unreferenced elements. +fn sliced_array_benchmark(c: &mut Criterion) { + // Test with different slice ratios: slice_size / total_size + // Smaller ratio = more potential savings from the optimization + let slice_ratios = [10, 5, 2]; // 1/10, 1/5, 1/2 of total + + for ratio in slice_ratios { + let total_rows = BATCH_SIZE * ratio; + let slice_offset = BATCH_SIZE * (ratio / 2); // Take from middle + let slice_len = BATCH_SIZE; + + // Sliced ListArray + { + let full_array = list_array(total_rows); + let sliced: ArrayRef = Arc::new( + full_array + .as_any() + .downcast_ref::() + .unwrap() + .slice(slice_offset, slice_len), + ); + c.bench_function( + &format!("list_array_sliced: 1/{ratio} of {total_rows} rows"), + |b| { + do_hash_test_with_len(b, std::slice::from_ref(&sliced), slice_len); + }, + ); + } + + // Sliced MapArray + { + let full_array = map_array(total_rows); + let sliced: ArrayRef = Arc::new( + full_array + .as_any() + .downcast_ref::() + .unwrap() + .slice(slice_offset, slice_len), + ); + c.bench_function( + &format!("map_array_sliced: 1/{ratio} of {total_rows} rows"), + |b| { + do_hash_test_with_len(b, std::slice::from_ref(&sliced), slice_len); + }, + ); + } + + // Sliced Sparse UnionArray + { + let full_array = sparse_union_array(total_rows); + let sliced: ArrayRef = Arc::new( + full_array + .as_any() + .downcast_ref::() + .unwrap() + .slice(slice_offset, slice_len), + ); + c.bench_function( + &format!("sparse_union_sliced: 1/{ratio} of {total_rows} rows"), + |b| { + do_hash_test_with_len(b, std::slice::from_ref(&sliced), slice_len); + }, + ); + } + } +} + +fn do_hash_test_with_len(b: &mut Bencher, arrays: &[ArrayRef], expected_len: usize) { + let state = RandomState::new(); + b.iter(|| { + with_hashes(arrays, &state, |hashes| { + assert_eq!(hashes.len(), expected_len); + Ok(()) + }) + .unwrap(); + }); +} + +fn list_array(num_rows: usize) -> ArrayRef { + let mut rng = make_rng(); + let elements_per_row = 5; + let total_elements = num_rows * elements_per_row; + + let values: Int64Array = (0..total_elements) + .map(|_| Some(rng.random::())) + .collect(); + let offsets: Vec = (0..=num_rows) + .map(|i| (i * elements_per_row) as i32) + .collect(); + + Arc::new(ListArray::new( + Arc::new(Field::new("item", DataType::Int64, true)), + OffsetBuffer::new(ScalarBuffer::from(offsets)), + Arc::new(values), + None, + )) +} + +fn map_array(num_rows: usize) -> ArrayRef { + let mut rng = make_rng(); + let entries_per_row = 5; + let total_entries = num_rows * entries_per_row; + + let keys: Int32Array = (0..total_entries) + .map(|_| Some(rng.random::())) + .collect(); + let values: Int64Array = (0..total_entries) + .map(|_| Some(rng.random::())) + .collect(); + let offsets: Vec = (0..=num_rows) + .map(|i| (i * entries_per_row) as i32) + .collect(); + + let entries = StructArray::try_new( + Fields::from(vec![ + Field::new("keys", DataType::Int32, false), + Field::new("values", DataType::Int64, true), + ]), + vec![Arc::new(keys), Arc::new(values)], + None, + ) + .unwrap(); + + Arc::new(MapArray::new( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("keys", DataType::Int32, false), + Field::new("values", DataType::Int64, true), + ])), + false, + )), + OffsetBuffer::new(ScalarBuffer::from(offsets)), + entries, + None, + false, + )) +} + +fn sparse_union_array(num_rows: usize) -> ArrayRef { + let mut rng = make_rng(); + let num_types = 5; + + let type_ids: Vec = (0..num_rows) + .map(|_| rng.random_range(0..num_types) as i8) + .collect(); + let (fields, children): (Vec<_>, Vec<_>) = (0..num_types) + .map(|i| { + ( + ( + i as i8, + Arc::new(Field::new(format!("f{i}"), DataType::Int64, true)), + ), + primitive_array::(num_rows), + ) + }) + .unzip(); + + Arc::new( + UnionArray::try_new( + UnionFields::from_iter(fields), + ScalarBuffer::from(type_ids), + None, + children, + ) + .unwrap(), + ) +} + +fn dense_union_array(num_rows: usize) -> ArrayRef { + let mut rng = make_rng(); + let num_types = 5; + let type_ids: Vec = (0..num_rows) + .map(|_| rng.random_range(0..num_types) as i8) + .collect(); + + let mut type_counts = vec![0i32; num_types]; + for &tid in &type_ids { + type_counts[tid as usize] += 1; + } + + let mut current_offsets = vec![0i32; num_types]; + let offsets: Vec = type_ids + .iter() + .map(|&tid| { + let offset = current_offsets[tid as usize]; + current_offsets[tid as usize] += 1; + offset + }) + .collect(); + + let (fields, children): (Vec<_>, Vec<_>) = (0..num_types) + .map(|i| { + ( + ( + i as i8, + Arc::new(Field::new(format!("f{i}"), DataType::Int64, true)), + ), + primitive_array::(type_counts[i] as usize), + ) + }) + .unzip(); + + Arc::new( + UnionArray::try_new( + UnionFields::from_iter(fields), + ScalarBuffer::from(type_ids), + Some(ScalarBuffer::from(offsets)), + children, + ) + .unwrap(), + ) +} + +fn boolean_array(array_len: usize) -> ArrayRef { + let mut rng = make_rng(); + Arc::new( + (0..array_len) + .map(|_| Some(rng.random::())) + .collect::(), + ) +} + +/// Create a StructArray with multiple columns +fn create_struct_array(pool: &StringPool, array_len: usize) -> ArrayRef { + let bool_array = boolean_array(array_len); + let int32_array = primitive_array::(array_len); + let int64_array = primitive_array::(array_len); + let str_array = pool.string_array::(array_len); + + let fields = Fields::from(vec![ + Field::new("bool_col", DataType::Boolean, false), + Field::new("int32_col", DataType::Int32, false), + Field::new("int64_col", DataType::Int64, false), + Field::new("string_col", DataType::Utf8, false), + ]); + + Arc::new(StructArray::new( + fields, + vec![bool_array, int32_array, int64_array, str_array], + None, + )) +} + +/// Create a RunArray to test run array hashing. +fn create_run_array(array_len: usize) -> ArrayRef +where + T: ArrowPrimitiveType, + StandardUniform: Distribution, +{ + let mut rng = make_rng(); + + // Create runs of varying lengths + let mut run_ends = Vec::new(); + let mut values = Vec::new(); + let mut current_end = 0; + + while current_end < array_len { + // Random run length between 1 and 50 + let run_length = rng.random_range(1..=50).min(array_len - current_end); + current_end += run_length; + run_ends.push(current_end as i32); + values.push(Some(rng.random::())); + } + + let run_ends_array = Arc::new(PrimitiveArray::::from(run_ends)); + let values_array: Arc = + Arc::new(values.into_iter().collect::>()); + + Arc::new( + RunArray::try_new(&run_ends_array, values_array.as_ref()) + .expect("Failed to create RunArray"), + ) +} + +criterion_group!(benches, criterion_benchmark, sliced_array_benchmark); criterion_main!(benches); diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs index 29082cc303a70..bc4313ed95665 100644 --- a/datafusion/common/src/cast.rs +++ b/datafusion/common/src/cast.rs @@ -25,8 +25,9 @@ use arrow::array::{ BinaryViewArray, Decimal32Array, Decimal64Array, DurationMicrosecondArray, DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, Float16Array, Int8Array, Int16Array, LargeBinaryArray, LargeListViewArray, LargeStringArray, - ListViewArray, StringViewArray, UInt16Array, + ListViewArray, RunArray, StringViewArray, UInt16Array, }; +use arrow::datatypes::RunEndIndexType; use arrow::{ array::{ Array, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, @@ -334,3 +335,8 @@ pub fn as_list_view_array(array: &dyn Array) -> Result<&ListViewArray> { pub fn as_large_list_view_array(array: &dyn Array) -> Result<&LargeListViewArray> { Ok(downcast_value!(array, LargeListViewArray)) } + +// Downcast Array to RunArray +pub fn as_run_array(array: &dyn Array) -> Result<&RunArray> { + Ok(downcast_value!(array, RunArray, T)) +} diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 87344914d2f7e..dad12c1c6bc91 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -757,7 +757,7 @@ config_namespace! { /// (writing) Sets best effort maximum size of data page in bytes pub data_pagesize_limit: usize, default = 1024 * 1024 - /// (writing) Sets write_batch_size in bytes + /// (writing) Sets write_batch_size in rows pub write_batch_size: usize, default = 1024 /// (writing) Sets parquet writer version @@ -1142,6 +1142,12 @@ config_namespace! { /// /// Default: true pub enable_sort_pushdown: bool, default = true + + /// When set to true, the optimizer will extract leaf expressions + /// (such as `get_field`) from filter/sort/join nodes into projections + /// closer to the leaf table scans, and push those projections down + /// towards the leaf nodes. + pub enable_leaf_expression_pushdown: bool, default = true } } @@ -2248,7 +2254,7 @@ impl TableOptions { /// Options that control how Parquet files are read, including global options /// that apply to all columns and optional column-specific overrides /// -/// Closely tied to [`ParquetWriterOptions`](crate::file_options::parquet_writer::ParquetWriterOptions). +/// Closely tied to `ParquetWriterOptions` (see `crate::file_options::parquet_writer::ParquetWriterOptions` when the "parquet" feature is enabled). /// Properties not included in [`TableParquetOptions`] may not be configurable at the external API /// (e.g. sorting_columns). #[derive(Clone, Default, Debug, PartialEq)] @@ -3065,6 +3071,22 @@ config_namespace! { /// If not specified, the default level for the compression algorithm is used. pub compression_level: Option, default = None pub schema_infer_max_rec: Option, default = None + /// The JSON format to use when reading files. + /// + /// When `true` (default), expects newline-delimited JSON (NDJSON): + /// ```text + /// {"key1": 1, "key2": "val"} + /// {"key1": 2, "key2": "vals"} + /// ``` + /// + /// When `false`, expects JSON array format: + /// ```text + /// [ + /// {"key1": 1, "key2": "val"}, + /// {"key1": 2, "key2": "vals"} + /// ] + /// ``` + pub newline_delimited: bool, default = true } } diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index f67e7e4517d2b..de0aacf9e8bcd 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -698,10 +698,12 @@ impl DFSchema { // check nested fields match (dt1, dt2) { (DataType::Dictionary(_, v1), DataType::Dictionary(_, v2)) => { - v1.as_ref() == v2.as_ref() + Self::datatype_is_logically_equal(v1.as_ref(), v2.as_ref()) + } + (DataType::Dictionary(_, v1), othertype) + | (othertype, DataType::Dictionary(_, v1)) => { + Self::datatype_is_logically_equal(v1.as_ref(), othertype) } - (DataType::Dictionary(_, v1), othertype) => v1.as_ref() == othertype, - (othertype, DataType::Dictionary(_, v1)) => v1.as_ref() == othertype, (DataType::List(f1), DataType::List(f2)) | (DataType::LargeList(f1), DataType::LargeList(f2)) | (DataType::FixedSizeList(f1, _), DataType::FixedSizeList(f2, _)) => { @@ -1798,6 +1800,27 @@ mod tests { &DataType::Utf8, &DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)) )); + + // Dictionary is logically equal to the logically equivalent value type + assert!(DFSchema::datatype_is_logically_equal( + &DataType::Utf8View, + &DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)) + )); + + assert!(DFSchema::datatype_is_logically_equal( + &DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::List( + Field::new("element", DataType::Utf8, false).into() + )) + ), + &DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::List( + Field::new("element", DataType::Utf8View, false).into() + )) + ) + )); } #[test] diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index 98dd1f235aee7..3be6118c55ff2 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -20,15 +20,19 @@ use ahash::RandomState; use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; use arrow::array::*; +use arrow::compute::take; use arrow::datatypes::*; #[cfg(not(feature = "force_hash_collisions"))] use arrow::{downcast_dictionary_array, downcast_primitive_array}; +use itertools::Itertools; +use std::collections::HashMap; #[cfg(not(feature = "force_hash_collisions"))] use crate::cast::{ as_binary_view_array, as_boolean_array, as_fixed_size_list_array, - as_generic_binary_array, as_large_list_array, as_list_array, as_map_array, - as_string_array, as_string_view_array, as_struct_array, as_union_array, + as_generic_binary_array, as_large_list_array, as_large_list_view_array, + as_list_array, as_list_view_array, as_map_array, as_string_array, + as_string_view_array, as_struct_array, as_union_array, }; use crate::error::Result; use crate::error::{_internal_datafusion_err, _internal_err}; @@ -390,33 +394,22 @@ fn hash_generic_byte_view_array( } } -/// Helper function to update hash for a dictionary key if the value is valid -#[cfg(not(feature = "force_hash_collisions"))] -#[inline] -fn update_hash_for_dict_key( - hash: &mut u64, - dict_hashes: &[u64], - dict_values: &dyn Array, - idx: usize, - multi_col: bool, -) { - if dict_values.is_valid(idx) { - if multi_col { - *hash = combine_hashes(dict_hashes[idx], *hash); - } else { - *hash = dict_hashes[idx]; - } - } - // no update for invalid dictionary value -} - -/// Hash the values in a dictionary array -#[cfg(not(feature = "force_hash_collisions"))] -fn hash_dictionary( +/// Hash dictionary array with compile-time specialization for null handling. +/// +/// Uses const generics to eliminate runtim branching in the hot loop: +/// - `HAS_NULL_KEYS`: Whether to check for null dictionary keys +/// - `HAS_NULL_VALUES`: Whether to check for null dictionary values +/// - `MULTI_COL`: Whether to combine with existing hash (true) or initialize (false) +#[inline(never)] +fn hash_dictionary_inner< + K: ArrowDictionaryKeyType, + const HAS_NULL_KEYS: bool, + const HAS_NULL_VALUES: bool, + const MULTI_COL: bool, +>( array: &DictionaryArray, random_state: &RandomState, hashes_buffer: &mut [u64], - multi_col: bool, ) -> Result<()> { // Hash each dictionary value once, and then use that computed // hash for each key value to avoid a potentially expensive @@ -425,22 +418,91 @@ fn hash_dictionary( let mut dict_hashes = vec![0; dict_values.len()]; create_hashes([dict_values], random_state, &mut dict_hashes)?; - // combine hash for each index in values - for (hash, key) in hashes_buffer.iter_mut().zip(array.keys().iter()) { - if let Some(key) = key { + if HAS_NULL_KEYS { + for (hash, key) in hashes_buffer.iter_mut().zip(array.keys().iter()) { + if let Some(key) = key { + let idx = key.as_usize(); + if !HAS_NULL_VALUES || dict_values.is_valid(idx) { + if MULTI_COL { + *hash = combine_hashes(dict_hashes[idx], *hash); + } else { + *hash = dict_hashes[idx]; + } + } + } + } + } else { + for (hash, key) in hashes_buffer.iter_mut().zip(array.keys().values()) { let idx = key.as_usize(); - update_hash_for_dict_key( - hash, - &dict_hashes, - dict_values.as_ref(), - idx, - multi_col, - ); - } // no update for Null key + if !HAS_NULL_VALUES || dict_values.is_valid(idx) { + if MULTI_COL { + *hash = combine_hashes(dict_hashes[idx], *hash); + } else { + *hash = dict_hashes[idx]; + } + } + } } Ok(()) } +/// Hash the values in a dictionary array +#[cfg(not(feature = "force_hash_collisions"))] +fn hash_dictionary( + array: &DictionaryArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], + multi_col: bool, +) -> Result<()> { + let has_null_keys = array.keys().null_count() != 0; + let has_null_values = array.values().null_count() != 0; + + // Dispatcher based on null presence and multi-column mode + // Should reduce branching within hot loops + match (has_null_keys, has_null_values, multi_col) { + (false, false, false) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + (false, false, true) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + (false, true, false) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + (false, true, true) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + (true, false, false) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + (true, false, true) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + (true, true, false) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + (true, true, true) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + } +} + #[cfg(not(feature = "force_hash_collisions"))] fn hash_struct_array( array: &StructArray, @@ -450,19 +512,21 @@ fn hash_struct_array( let nulls = array.nulls(); let row_len = array.len(); - let valid_row_indices: Vec = if let Some(nulls) = nulls { - nulls.valid_indices().collect() - } else { - (0..row_len).collect() - }; - // Create hashes for each row that combines the hashes over all the column at that row. let mut values_hashes = vec![0u64; row_len]; create_hashes(array.columns(), random_state, &mut values_hashes)?; - for i in valid_row_indices { - let hash = &mut hashes_buffer[i]; - *hash = combine_hashes(*hash, values_hashes[i]); + // Separate paths to avoid allocating Vec when there are no nulls + if let Some(nulls) = nulls { + for i in nulls.valid_indices() { + let hash = &mut hashes_buffer[i]; + *hash = combine_hashes(*hash, values_hashes[i]); + } + } else { + for i in 0..row_len { + let hash = &mut hashes_buffer[i]; + *hash = combine_hashes(*hash, values_hashes[i]); + } } Ok(()) @@ -479,15 +543,29 @@ fn hash_map_array( let offsets = array.offsets(); // Create hashes for each entry in each row - let mut values_hashes = vec![0u64; array.entries().len()]; - create_hashes(array.entries().columns(), random_state, &mut values_hashes)?; + let first_offset = offsets.first().copied().unwrap_or_default() as usize; + let last_offset = offsets.last().copied().unwrap_or_default() as usize; + let entries_len = last_offset - first_offset; + + // Only hash the entries that are actually referenced + let mut values_hashes = vec![0u64; entries_len]; + let entries = array.entries(); + let sliced_columns: Vec = entries + .columns() + .iter() + .map(|col| col.slice(first_offset, entries_len)) + .collect(); + create_hashes(&sliced_columns, random_state, &mut values_hashes)?; // Combine the hashes for entries on each row with each other and previous hash for that row + // Adjust indices by first_offset since values_hashes is sliced starting from first_offset if let Some(nulls) = nulls { for (i, (start, stop)) in offsets.iter().zip(offsets.iter().skip(1)).enumerate() { if nulls.is_valid(i) { let hash = &mut hashes_buffer[i]; - for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] { + for values_hash in &values_hashes + [start.as_usize() - first_offset..stop.as_usize() - first_offset] + { *hash = combine_hashes(*hash, *values_hash); } } @@ -495,7 +573,9 @@ fn hash_map_array( } else { for (i, (start, stop)) in offsets.iter().zip(offsets.iter().skip(1)).enumerate() { let hash = &mut hashes_buffer[i]; - for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] { + for values_hash in &values_hashes + [start.as_usize() - first_offset..stop.as_usize() - first_offset] + { *hash = combine_hashes(*hash, *values_hash); } } @@ -510,27 +590,83 @@ fn hash_list_array( random_state: &RandomState, hashes_buffer: &mut [u64], ) -> Result<()> +where + OffsetSize: OffsetSizeTrait, +{ + // In case values is sliced, hash only the bytes used by the offsets of this ListArray + let first_offset = array.value_offsets().first().cloned().unwrap_or_default(); + let last_offset = array.value_offsets().last().cloned().unwrap_or_default(); + let value_bytes_len = (last_offset - first_offset).as_usize(); + let mut values_hashes = vec![0u64; value_bytes_len]; + create_hashes( + [array + .values() + .slice(first_offset.as_usize(), value_bytes_len)], + random_state, + &mut values_hashes, + )?; + + if array.null_count() > 0 { + for (i, (start, stop)) in array.value_offsets().iter().tuple_windows().enumerate() + { + if array.is_valid(i) { + let hash = &mut hashes_buffer[i]; + for values_hash in &values_hashes[(*start - first_offset).as_usize() + ..(*stop - first_offset).as_usize()] + { + *hash = combine_hashes(*hash, *values_hash); + } + } + } + } else { + for ((start, stop), hash) in array + .value_offsets() + .iter() + .tuple_windows() + .zip(hashes_buffer.iter_mut()) + { + for values_hash in &values_hashes + [(*start - first_offset).as_usize()..(*stop - first_offset).as_usize()] + { + *hash = combine_hashes(*hash, *values_hash); + } + } + } + Ok(()) +} + +#[cfg(not(feature = "force_hash_collisions"))] +fn hash_list_view_array( + array: &GenericListViewArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) -> Result<()> where OffsetSize: OffsetSizeTrait, { let values = array.values(); let offsets = array.value_offsets(); + let sizes = array.value_sizes(); let nulls = array.nulls(); let mut values_hashes = vec![0u64; values.len()]; create_hashes([values], random_state, &mut values_hashes)?; if let Some(nulls) = nulls { - for (i, (start, stop)) in offsets.iter().zip(offsets.iter().skip(1)).enumerate() { + for (i, (offset, size)) in offsets.iter().zip(sizes.iter()).enumerate() { if nulls.is_valid(i) { let hash = &mut hashes_buffer[i]; - for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] { + let start = offset.as_usize(); + let end = start + size.as_usize(); + for values_hash in &values_hashes[start..end] { *hash = combine_hashes(*hash, *values_hash); } } } } else { - for (i, (start, stop)) in offsets.iter().zip(offsets.iter().skip(1)).enumerate() { + for (i, (offset, size)) in offsets.iter().zip(sizes.iter()).enumerate() { let hash = &mut hashes_buffer[i]; - for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] { + let start = offset.as_usize(); + let end = start + size.as_usize(); + for values_hash in &values_hashes[start..end] { *hash = combine_hashes(*hash, *values_hash); } } @@ -544,14 +680,42 @@ fn hash_union_array( random_state: &RandomState, hashes_buffer: &mut [u64], ) -> Result<()> { - use std::collections::HashMap; - let DataType::Union(union_fields, _mode) = array.data_type() else { unreachable!() }; - let mut child_hashes = HashMap::with_capacity(union_fields.len()); + if array.is_dense() { + // Dense union: children only contain values of their type, so they're already compact. + // Use the default hashing approach which is efficient for dense unions. + hash_union_array_default(array, union_fields, random_state, hashes_buffer) + } else { + // Sparse union: each child has the same length as the union array. + // Optimization: only hash the elements that are actually referenced by type_ids, + // instead of hashing all K*N elements (where K = num types, N = array length). + hash_sparse_union_array(array, union_fields, random_state, hashes_buffer) + } +} + +/// Default hashing for union arrays - hashes all elements of each child array fully. +/// +/// This approach works for both dense and sparse union arrays: +/// - Dense unions: children are compact (each child only contains values of that type) +/// - Sparse unions: children have the same length as the union array +/// +/// For sparse unions with 3+ types, the optimized take/scatter approach in +/// `hash_sparse_union_array` is more efficient, but for 1-2 types or dense unions, +/// this simpler approach is preferred. +#[cfg(not(feature = "force_hash_collisions"))] +fn hash_union_array_default( + array: &UnionArray, + union_fields: &UnionFields, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) -> Result<()> { + let mut child_hashes: HashMap> = + HashMap::with_capacity(union_fields.len()); + // Hash each child array fully for (type_id, _field) in union_fields.iter() { let child = array.child(type_id); let mut child_hash_buffer = vec![0; child.len()]; @@ -560,6 +724,9 @@ fn hash_union_array( child_hashes.insert(type_id, child_hash_buffer); } + // Combine hashes for each row using the appropriate child offset + // For dense unions: value_offset points to the actual position in the child + // For sparse unions: value_offset equals the row index #[expect(clippy::needless_range_loop)] for i in 0..array.len() { let type_id = array.type_id(i); @@ -572,6 +739,69 @@ fn hash_union_array( Ok(()) } +/// Hash a sparse union array. +/// Sparse unions have child arrays with the same length as the union array. +/// For 3+ types, we optimize by only hashing the N elements that are actually used +/// (via take/scatter), instead of hashing all K*N elements. +/// +/// For 1-2 types, the overhead of take/scatter outweighs the benefit, so we use +/// the default approach of hashing all children (same as dense unions). +#[cfg(not(feature = "force_hash_collisions"))] +fn hash_sparse_union_array( + array: &UnionArray, + union_fields: &UnionFields, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) -> Result<()> { + use std::collections::HashMap; + + // For 1-2 types, the take/scatter overhead isn't worth it. + // Fall back to the default approach (same as dense union). + if union_fields.len() <= 2 { + return hash_union_array_default( + array, + union_fields, + random_state, + hashes_buffer, + ); + } + + let type_ids = array.type_ids(); + + // Group indices by type_id + let mut indices_by_type: HashMap> = HashMap::new(); + for (i, &type_id) in type_ids.iter().enumerate() { + indices_by_type.entry(type_id).or_default().push(i as u32); + } + + // For each type, extract only the needed elements, hash them, and scatter back + for (type_id, _field) in union_fields.iter() { + if let Some(indices) = indices_by_type.get(&type_id) { + if indices.is_empty() { + continue; + } + + let child = array.child(type_id); + let indices_array = UInt32Array::from(indices.clone()); + + // Extract only the elements we need using take() + let filtered = take(child.as_ref(), &indices_array, None)?; + + // Hash the filtered array + let mut filtered_hashes = vec![0u64; filtered.len()]; + create_hashes([&filtered], random_state, &mut filtered_hashes)?; + + // Scatter hashes back to correct positions + for (hash, &idx) in filtered_hashes.iter().zip(indices.iter()) { + hashes_buffer[idx as usize] = + combine_hashes(hashes_buffer[idx as usize], *hash); + } + } + } + + Ok(()) +} + #[cfg(not(feature = "force_hash_collisions"))] fn hash_fixed_list_array( array: &FixedSizeListArray, @@ -605,12 +835,17 @@ fn hash_fixed_list_array( Ok(()) } +/// Inner hash function for RunArray +#[inline(never)] #[cfg(not(feature = "force_hash_collisions"))] -fn hash_run_array( +fn hash_run_array_inner< + R: RunEndIndexType, + const HAS_NULL_VALUES: bool, + const REHASH: bool, +>( array: &RunArray, random_state: &RandomState, hashes_buffer: &mut [u64], - rehash: bool, ) -> Result<()> { // We find the relevant runs that cover potentially sliced arrays, so we can only hash those // values. Then we find the runs that refer to the original runs and ensure that we apply @@ -648,25 +883,23 @@ fn hash_run_array( .iter() .enumerate() { - let is_null_value = sliced_values.is_null(adjusted_physical_index); let absolute_run_end = absolute_run_end.as_usize(); - let end_in_slice = (absolute_run_end - array_offset).min(array_len); - if rehash { - if !is_null_value { - let value_hash = values_hashes[adjusted_physical_index]; - for hash in hashes_buffer - .iter_mut() - .take(end_in_slice) - .skip(start_in_slice) - { - *hash = combine_hashes(value_hash, *hash); - } + if HAS_NULL_VALUES && sliced_values.is_null(adjusted_physical_index) { + start_in_slice = end_in_slice; + continue; + } + + let value_hash = values_hashes[adjusted_physical_index]; + let run_slice = &mut hashes_buffer[start_in_slice..end_in_slice]; + + if REHASH { + for hash in run_slice.iter_mut() { + *hash = combine_hashes(value_hash, *hash); } } else { - let value_hash = values_hashes[adjusted_physical_index]; - hashes_buffer[start_in_slice..end_in_slice].fill(value_hash); + run_slice.fill(value_hash); } start_in_slice = end_in_slice; @@ -675,6 +908,31 @@ fn hash_run_array( Ok(()) } +#[cfg(not(feature = "force_hash_collisions"))] +fn hash_run_array( + array: &RunArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], + rehash: bool, +) -> Result<()> { + let has_null_values = array.values().null_count() != 0; + + match (has_null_values, rehash) { + (false, false) => { + hash_run_array_inner::(array, random_state, hashes_buffer) + } + (false, true) => { + hash_run_array_inner::(array, random_state, hashes_buffer) + } + (true, false) => { + hash_run_array_inner::(array, random_state, hashes_buffer) + } + (true, true) => { + hash_run_array_inner::(array, random_state, hashes_buffer) + } + } +} + /// Internal helper function that hashes a single array and either initializes or combines /// the hash values in the buffer. #[cfg(not(feature = "force_hash_collisions"))] @@ -714,6 +972,14 @@ fn hash_single_array( let array = as_large_list_array(array)?; hash_list_array(array, random_state, hashes_buffer)?; } + DataType::ListView(_) => { + let array = as_list_view_array(array)?; + hash_list_view_array(array, random_state, hashes_buffer)?; + } + DataType::LargeListView(_) => { + let array = as_large_list_view_array(array)?; + hash_list_view_array(array, random_state, hashes_buffer)?; + } DataType::Map(_, _) => { let array = as_map_array(array)?; hash_map_array(array, random_state, hashes_buffer)?; @@ -1128,6 +1394,130 @@ mod tests { assert_eq!(hashes[1], hashes[6]); // null vs empty list } + #[test] + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_sliced_list_arrays() { + let data = vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + // Slice from here + Some(vec![Some(3), None, Some(5)]), + Some(vec![Some(3), None, Some(5)]), + None, + // To here + Some(vec![Some(0), Some(1), Some(2)]), + Some(vec![]), + ]; + let list_array = + Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; + let list_array = list_array.slice(2, 3); + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut hashes = vec![0; list_array.len()]; + create_hashes(&[list_array], &random_state, &mut hashes).unwrap(); + assert_eq!(hashes[0], hashes[1]); + assert_ne!(hashes[1], hashes[2]); + } + + #[test] + // Tests actual values of hashes, which are different if forcing collisions + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_list_view_arrays() { + use arrow::buffer::{NullBuffer, ScalarBuffer}; + + // Create values array: [0, 1, 2, 3, null, 5] + let values = Arc::new(Int32Array::from(vec![ + Some(0), + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef; + let field = Arc::new(Field::new("item", DataType::Int32, true)); + + // Create ListView with the following logical structure: + // Row 0: [0, 1, 2] (offset=0, size=3) + // Row 1: null (null bit set) + // Row 2: [3, null, 5] (offset=3, size=3) + // Row 3: [3, null, 5] (offset=3, size=3) - same as row 2 + // Row 4: null (null bit set) + // Row 5: [0, 1, 2] (offset=0, size=3) - same as row 0 + // Row 6: [] (offset=0, size=0) - empty list + let offsets = ScalarBuffer::from(vec![0i32, 0, 3, 3, 0, 0, 0]); + let sizes = ScalarBuffer::from(vec![3i32, 0, 3, 3, 0, 3, 0]); + let nulls = Some(NullBuffer::from(vec![ + true, false, true, true, false, true, true, + ])); + + let list_view_array = + Arc::new(ListViewArray::new(field, offsets, sizes, values, nulls)) + as ArrayRef; + + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut hashes = vec![0; list_view_array.len()]; + create_hashes(&[list_view_array], &random_state, &mut hashes).unwrap(); + + assert_eq!(hashes[0], hashes[5]); // same content [0, 1, 2] + assert_eq!(hashes[1], hashes[4]); // both null + assert_eq!(hashes[2], hashes[3]); // same content [3, null, 5] + assert_eq!(hashes[1], hashes[6]); // null vs empty list + + // Negative tests: different content should produce different hashes + assert_ne!(hashes[0], hashes[2]); // [0, 1, 2] vs [3, null, 5] + assert_ne!(hashes[0], hashes[6]); // [0, 1, 2] vs [] + assert_ne!(hashes[2], hashes[6]); // [3, null, 5] vs [] + } + + #[test] + // Tests actual values of hashes, which are different if forcing collisions + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_large_list_view_arrays() { + use arrow::buffer::{NullBuffer, ScalarBuffer}; + + // Create values array: [0, 1, 2, 3, null, 5] + let values = Arc::new(Int32Array::from(vec![ + Some(0), + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef; + let field = Arc::new(Field::new("item", DataType::Int32, true)); + + // Create LargeListView with the following logical structure: + // Row 0: [0, 1, 2] (offset=0, size=3) + // Row 1: null (null bit set) + // Row 2: [3, null, 5] (offset=3, size=3) + // Row 3: [3, null, 5] (offset=3, size=3) - same as row 2 + // Row 4: null (null bit set) + // Row 5: [0, 1, 2] (offset=0, size=3) - same as row 0 + // Row 6: [] (offset=0, size=0) - empty list + let offsets = ScalarBuffer::from(vec![0i64, 0, 3, 3, 0, 0, 0]); + let sizes = ScalarBuffer::from(vec![3i64, 0, 3, 3, 0, 3, 0]); + let nulls = Some(NullBuffer::from(vec![ + true, false, true, true, false, true, true, + ])); + + let large_list_view_array = Arc::new(LargeListViewArray::new( + field, offsets, sizes, values, nulls, + )) as ArrayRef; + + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut hashes = vec![0; large_list_view_array.len()]; + create_hashes(&[large_list_view_array], &random_state, &mut hashes).unwrap(); + + assert_eq!(hashes[0], hashes[5]); // same content [0, 1, 2] + assert_eq!(hashes[1], hashes[4]); // both null + assert_eq!(hashes[2], hashes[3]); // same content [3, null, 5] + assert_eq!(hashes[1], hashes[6]); // null vs empty list + + // Negative tests: different content should produce different hashes + assert_ne!(hashes[0], hashes[2]); // [0, 1, 2] vs [3, null, 5] + assert_ne!(hashes[0], hashes[6]); // [0, 1, 2] vs [] + assert_ne!(hashes[2], hashes[6]); // [3, null, 5] vs [] + } + #[test] // Tests actual values of hashes, which are different if forcing collisions #[cfg(not(feature = "force_hash_collisions"))] diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index df6659c6f843c..fdd04f752455e 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -24,7 +24,6 @@ // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -#![deny(clippy::allow_attributes)] mod column; mod dfschema; diff --git a/datafusion/common/src/metadata.rs b/datafusion/common/src/metadata.rs index eb687bde07d0b..d6d8fb7b0ed0c 100644 --- a/datafusion/common/src/metadata.rs +++ b/datafusion/common/src/metadata.rs @@ -171,6 +171,10 @@ pub fn format_type_and_metadata( /// // Add any metadata from `FieldMetadata` to `Field` /// let updated_field = metadata.add_to_field(field); /// ``` +/// +/// For more background, please also see the [Implementing User Defined Types and Custom Metadata in DataFusion blog] +/// +/// [Implementing User Defined Types and Custom Metadata in DataFusion blog]: https://datafusion.apache.org/blog/2025/09/21/custom-types-using-metadata #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct FieldMetadata { /// The inner metadata of a literal expression, which is a map of string diff --git a/datafusion/common/src/nested_struct.rs b/datafusion/common/src/nested_struct.rs index 086d96e85230d..bf2558f313069 100644 --- a/datafusion/common/src/nested_struct.rs +++ b/datafusion/common/src/nested_struct.rs @@ -19,9 +19,9 @@ use crate::error::{_plan_err, Result}; use arrow::{ array::{Array, ArrayRef, StructArray, new_null_array}, compute::{CastOptions, cast_with_options}, - datatypes::{DataType::Struct, Field, FieldRef}, + datatypes::{DataType, DataType::Struct, Field, FieldRef}, }; -use std::sync::Arc; +use std::{collections::HashSet, sync::Arc}; /// Cast a struct column to match target struct fields, handling nested structs recursively. /// @@ -31,6 +31,7 @@ use std::sync::Arc; /// /// ## Field Matching Strategy /// - **By Name**: Source struct fields are matched to target fields by name (case-sensitive) +/// - **No Positional Mapping**: Structs with no overlapping field names are rejected /// - **Type Adaptation**: When a matching field is found, it is recursively cast to the target field's type /// - **Missing Fields**: Target fields not present in the source are filled with null values /// - **Extra Fields**: Source fields not present in the target are ignored @@ -54,16 +55,30 @@ fn cast_struct_column( target_fields: &[Arc], cast_options: &CastOptions, ) -> Result { - if let Some(source_struct) = source_col.as_any().downcast_ref::() { - validate_struct_compatibility(source_struct.fields(), target_fields)?; + if source_col.data_type() == &DataType::Null + || (!source_col.is_empty() && source_col.null_count() == source_col.len()) + { + return Ok(new_null_array( + &Struct(target_fields.to_vec().into()), + source_col.len(), + )); + } + if let Some(source_struct) = source_col.as_any().downcast_ref::() { + let source_fields = source_struct.fields(); + validate_struct_compatibility(source_fields, target_fields)?; let mut fields: Vec> = Vec::with_capacity(target_fields.len()); let mut arrays: Vec = Vec::with_capacity(target_fields.len()); let num_rows = source_col.len(); - for target_child_field in target_fields { + // Iterate target fields and pick source child by name when present. + for target_child_field in target_fields.iter() { fields.push(Arc::clone(target_child_field)); - match source_struct.column_by_name(target_child_field.name()) { + + let source_child_opt = + source_struct.column_by_name(target_child_field.name()); + + match source_child_opt { Some(source_child_col) => { let adapted_child = cast_column(source_child_col, target_child_field, cast_options) @@ -200,10 +215,20 @@ pub fn cast_column( /// // Target: {a: binary} /// // Result: Err(...) - string cannot cast to binary /// ``` +/// pub fn validate_struct_compatibility( source_fields: &[FieldRef], target_fields: &[FieldRef], ) -> Result<()> { + let has_overlap = has_one_of_more_common_fields(source_fields, target_fields); + if !has_overlap { + return _plan_err!( + "Cannot cast struct with {} fields to {} fields because there is no field name overlap", + source_fields.len(), + target_fields.len() + ); + } + // Check compatibility for each target field for target_field in target_fields { // Look for matching field in source by name @@ -211,53 +236,102 @@ pub fn validate_struct_compatibility( .iter() .find(|f| f.name() == target_field.name()) { - // Ensure nullability is compatible. It is invalid to cast a nullable - // source field to a non-nullable target field as this may discard - // null values. - if source_field.is_nullable() && !target_field.is_nullable() { + validate_field_compatibility(source_field, target_field)?; + } else { + // Target field is missing from source + // If it's non-nullable, we cannot fill it with NULL + if !target_field.is_nullable() { return _plan_err!( - "Cannot cast nullable struct field '{}' to non-nullable field", + "Cannot cast struct: target field '{}' is non-nullable but missing from source. \ + Cannot fill with NULL.", target_field.name() ); } - // Check if the matching field types are compatible - match (source_field.data_type(), target_field.data_type()) { - // Recursively validate nested structs - (Struct(source_nested), Struct(target_nested)) => { - validate_struct_compatibility(source_nested, target_nested)?; - } - // For non-struct types, use the existing castability check - _ => { - if !arrow::compute::can_cast_types( - source_field.data_type(), - target_field.data_type(), - ) { - return _plan_err!( - "Cannot cast struct field '{}' from type {} to type {}", - target_field.name(), - source_field.data_type(), - target_field.data_type() - ); - } - } - } } - // Missing fields in source are OK - they'll be filled with nulls } // Extra fields in source are OK - they'll be ignored Ok(()) } +fn validate_field_compatibility( + source_field: &Field, + target_field: &Field, +) -> Result<()> { + if source_field.data_type() == &DataType::Null { + // Validate that target allows nulls before returning early. + // It is invalid to cast a NULL source field to a non-nullable target field. + if !target_field.is_nullable() { + return _plan_err!( + "Cannot cast NULL struct field '{}' to non-nullable field '{}'", + source_field.name(), + target_field.name() + ); + } + return Ok(()); + } + + // Ensure nullability is compatible. It is invalid to cast a nullable + // source field to a non-nullable target field as this may discard + // null values. + if source_field.is_nullable() && !target_field.is_nullable() { + return _plan_err!( + "Cannot cast nullable struct field '{}' to non-nullable field", + target_field.name() + ); + } + + // Check if the matching field types are compatible + match (source_field.data_type(), target_field.data_type()) { + // Recursively validate nested structs + (Struct(source_nested), Struct(target_nested)) => { + validate_struct_compatibility(source_nested, target_nested)?; + } + // For non-struct types, use the existing castability check + _ => { + if !arrow::compute::can_cast_types( + source_field.data_type(), + target_field.data_type(), + ) { + return _plan_err!( + "Cannot cast struct field '{}' from type {} to type {}", + target_field.name(), + source_field.data_type(), + target_field.data_type() + ); + } + } + } + + Ok(()) +} + +/// Check if two field lists have at least one common field by name. +/// +/// This is useful for validating struct compatibility when casting between structs, +/// ensuring that source and target fields have overlapping names. +pub fn has_one_of_more_common_fields( + source_fields: &[FieldRef], + target_fields: &[FieldRef], +) -> bool { + let source_names: HashSet<&str> = source_fields + .iter() + .map(|field| field.name().as_str()) + .collect(); + target_fields + .iter() + .any(|field| source_names.contains(field.name().as_str())) +} + #[cfg(test)] mod tests { use super::*; - use crate::format::DEFAULT_CAST_OPTIONS; + use crate::{assert_contains, format::DEFAULT_CAST_OPTIONS}; use arrow::{ array::{ BinaryArray, Int32Array, Int32Builder, Int64Array, ListArray, MapArray, - MapBuilder, StringArray, StringBuilder, + MapBuilder, NullArray, StringArray, StringBuilder, }, buffer::NullBuffer, datatypes::{DataType, Field, FieldRef, Int32Type}, @@ -428,11 +502,14 @@ mod tests { #[test] fn test_validate_struct_compatibility_missing_field_in_source() { - // Source struct: {field2: String} (missing field1) - let source_fields = vec![arc_field("field2", DataType::Utf8)]; + // Source struct: {field1: Int32} (missing field2) + let source_fields = vec![arc_field("field1", DataType::Int32)]; - // Target struct: {field1: Int32} - let target_fields = vec![arc_field("field1", DataType::Int32)]; + // Target struct: {field1: Int32, field2: Utf8} + let target_fields = vec![ + arc_field("field1", DataType::Int32), + arc_field("field2", DataType::Utf8), + ]; // Should be OK - missing fields will be filled with nulls let result = validate_struct_compatibility(&source_fields, &target_fields); @@ -455,6 +532,20 @@ mod tests { assert!(result.is_ok()); } + #[test] + fn test_validate_struct_compatibility_no_overlap_mismatch_len() { + let source_fields = vec![ + arc_field("left", DataType::Int32), + arc_field("right", DataType::Int32), + ]; + let target_fields = vec![arc_field("alpha", DataType::Int32)]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert_contains!(error_msg, "no field name overlap"); + } + #[test] fn test_cast_struct_parent_nulls_retained() { let a_array = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef; @@ -525,6 +616,117 @@ mod tests { assert!(error_msg.contains("non-nullable")); } + #[test] + fn test_validate_struct_compatibility_by_name() { + // Source struct: {field1: Int32, field2: String} + let source_fields = vec![ + arc_field("field1", DataType::Int32), + arc_field("field2", DataType::Utf8), + ]; + + // Target struct: {field2: String, field1: Int64} + let target_fields = vec![ + arc_field("field2", DataType::Utf8), + arc_field("field1", DataType::Int64), + ]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_ok()); + } + + #[test] + fn test_validate_struct_compatibility_by_name_with_type_mismatch() { + // Source struct: {field1: Binary} + let source_fields = vec![arc_field("field1", DataType::Binary)]; + + // Target struct: {field1: Int32} (incompatible type) + let target_fields = vec![arc_field("field1", DataType::Int32)]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert_contains!( + error_msg, + "Cannot cast struct field 'field1' from type Binary to type Int32" + ); + } + + #[test] + fn test_validate_struct_compatibility_no_overlap_equal_len() { + let source_fields = vec![ + arc_field("left", DataType::Int32), + arc_field("right", DataType::Utf8), + ]; + + let target_fields = vec![ + arc_field("alpha", DataType::Int32), + arc_field("beta", DataType::Utf8), + ]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert_contains!(error_msg, "no field name overlap"); + } + + #[test] + fn test_validate_struct_compatibility_mixed_name_overlap() { + // Source struct: {a: Int32, b: String, extra: Boolean} + let source_fields = vec![ + arc_field("a", DataType::Int32), + arc_field("b", DataType::Utf8), + arc_field("extra", DataType::Boolean), + ]; + + // Target struct: {b: String, a: Int64, c: Float32} + // Name overlap with a and b, missing c (nullable) + let target_fields = vec![ + arc_field("b", DataType::Utf8), + arc_field("a", DataType::Int64), + arc_field("c", DataType::Float32), + ]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_ok()); + } + + #[test] + fn test_validate_struct_compatibility_by_name_missing_required_field() { + // Source struct: {field1: Int32} (missing field2) + let source_fields = vec![arc_field("field1", DataType::Int32)]; + + // Target struct: {field1: Int32, field2: Int32 non-nullable} + let target_fields = vec![ + arc_field("field1", DataType::Int32), + Arc::new(non_null_field("field2", DataType::Int32)), + ]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert_contains!( + error_msg, + "Cannot cast struct: target field 'field2' is non-nullable but missing from source. Cannot fill with NULL." + ); + } + + #[test] + fn test_validate_struct_compatibility_partial_name_overlap_with_count_mismatch() { + // Source struct: {a: Int32} (only one field) + let source_fields = vec![arc_field("a", DataType::Int32)]; + + // Target struct: {a: Int32, b: String} (two fields, but 'a' overlaps) + let target_fields = vec![ + arc_field("a", DataType::Int32), + arc_field("b", DataType::Utf8), + ]; + + // This should succeed - partial overlap means by-name mapping + // and missing field 'b' is nullable + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_ok()); + } + #[test] fn test_cast_nested_struct_with_extra_and_missing_fields() { // Source inner struct has fields a, b, extra @@ -585,6 +787,33 @@ mod tests { assert!(missing.is_null(1)); } + #[test] + fn test_cast_null_struct_field_to_nested_struct() { + let null_inner = Arc::new(NullArray::new(2)) as ArrayRef; + let source_struct = StructArray::from(vec![( + arc_field("inner", DataType::Null), + Arc::clone(&null_inner), + )]); + let source_col = Arc::new(source_struct) as ArrayRef; + + let target_field = struct_field( + "outer", + vec![struct_field("inner", vec![field("a", DataType::Int32)])], + ); + + let result = + cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + let outer = result.as_any().downcast_ref::().unwrap(); + let inner = get_column_as!(&outer, "inner", StructArray); + assert_eq!(inner.len(), 2); + assert!(inner.is_null(0)); + assert!(inner.is_null(1)); + + let inner_a = get_column_as!(inner, "a", Int32Array); + assert!(inner_a.is_null(0)); + assert!(inner_a.is_null(1)); + } + #[test] fn test_cast_struct_with_array_and_map_fields() { // Array field with second row null @@ -704,4 +933,81 @@ mod tests { assert_eq!(a_col.value(0), 1); assert_eq!(a_col.value(1), 2); } + + #[test] + fn test_cast_struct_no_overlap_rejected() { + let first = Arc::new(Int32Array::from(vec![Some(10), Some(20)])) as ArrayRef; + let second = + Arc::new(StringArray::from(vec![Some("alpha"), Some("beta")])) as ArrayRef; + + let source_struct = StructArray::from(vec![ + (arc_field("left", DataType::Int32), first), + (arc_field("right", DataType::Utf8), second), + ]); + let source_col = Arc::new(source_struct) as ArrayRef; + + let target_field = struct_field( + "s", + vec![field("a", DataType::Int64), field("b", DataType::Utf8)], + ); + + let result = cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert_contains!(error_msg, "no field name overlap"); + } + + #[test] + fn test_cast_struct_missing_non_nullable_field_fails() { + // Source has only field 'a' + let a = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef; + let source_struct = StructArray::from(vec![(arc_field("a", DataType::Int32), a)]); + let source_col = Arc::new(source_struct) as ArrayRef; + + // Target has fields 'a' (nullable) and 'b' (non-nullable) + let target_field = struct_field( + "s", + vec![ + field("a", DataType::Int32), + non_null_field("b", DataType::Int32), + ], + ); + + // Should fail because 'b' is non-nullable but missing from source + let result = cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + err.to_string() + .contains("target field 'b' is non-nullable but missing from source"), + "Unexpected error: {err}" + ); + } + + #[test] + fn test_cast_struct_missing_nullable_field_succeeds() { + // Source has only field 'a' + let a = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef; + let source_struct = StructArray::from(vec![(arc_field("a", DataType::Int32), a)]); + let source_col = Arc::new(source_struct) as ArrayRef; + + // Target has fields 'a' and 'b' (both nullable) + let target_field = struct_field( + "s", + vec![field("a", DataType::Int32), field("b", DataType::Int32)], + ); + + // Should succeed - 'b' is nullable so can be filled with NULL + let result = + cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + let struct_array = result.as_any().downcast_ref::().unwrap(); + + let a_col = get_column_as!(&struct_array, "a", Int32Array); + assert_eq!(a_col.value(0), 1); + assert_eq!(a_col.value(1), 2); + + let b_col = get_column_as!(&struct_array, "b", Int32Array); + assert!(b_col.is_null(0)); + assert!(b_col.is_null(1)); + } } diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index e4e048ad3c0d8..644916d7891c4 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -43,7 +43,7 @@ use crate::cast::{ as_float16_array, as_float32_array, as_float64_array, as_int8_array, as_int16_array, as_int32_array, as_int64_array, as_interval_dt_array, as_interval_mdn_array, as_interval_ym_array, as_large_binary_array, as_large_list_array, - as_large_string_array, as_string_array, as_string_view_array, + as_large_string_array, as_run_array, as_string_array, as_string_view_array, as_time32_millisecond_array, as_time32_second_array, as_time64_microsecond_array, as_time64_nanosecond_array, as_timestamp_microsecond_array, as_timestamp_millisecond_array, as_timestamp_nanosecond_array, @@ -56,21 +56,20 @@ use crate::hash_utils::create_hashes; use crate::utils::SingleRowListArrayBuilder; use crate::{_internal_datafusion_err, arrow_datafusion_err}; use arrow::array::{ - Array, ArrayData, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, AsArray, - BinaryArray, BinaryViewArray, BinaryViewBuilder, BooleanArray, Date32Array, + Array, ArrayData, ArrayDataBuilder, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, + AsArray, BinaryArray, BinaryViewArray, BinaryViewBuilder, BooleanArray, Date32Array, Date64Array, Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array, DictionaryArray, DurationMicrosecondArray, DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, FixedSizeBinaryArray, - FixedSizeBinaryBuilder, FixedSizeListArray, Float16Array, Float32Array, Float64Array, - GenericListArray, Int8Array, Int16Array, Int32Array, Int64Array, - IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, - LargeBinaryArray, LargeListArray, LargeStringArray, ListArray, MapArray, - MutableArrayData, OffsetSizeTrait, PrimitiveArray, Scalar, StringArray, - StringViewArray, StringViewBuilder, StructArray, Time32MillisecondArray, - Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, UInt8Array, UInt16Array, UInt32Array, UInt64Array, UnionArray, - new_empty_array, new_null_array, + FixedSizeListArray, Float16Array, Float32Array, Float64Array, GenericListArray, + Int8Array, Int16Array, Int32Array, Int64Array, IntervalDayTimeArray, + IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray, LargeListArray, + LargeStringArray, ListArray, MapArray, MutableArrayData, OffsetSizeTrait, + PrimitiveArray, RunArray, Scalar, StringArray, StringViewArray, StringViewBuilder, + StructArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, + Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt8Array, UInt16Array, UInt32Array, + UInt64Array, UnionArray, downcast_run_array, new_empty_array, new_null_array, }; use arrow::buffer::{BooleanBuffer, ScalarBuffer}; use arrow::compute::kernels::cast::{CastOptions, cast_with_options}; @@ -80,11 +79,12 @@ use arrow::compute::kernels::numeric::{ use arrow::datatypes::{ ArrowDictionaryKeyType, ArrowNativeType, ArrowTimestampType, DataType, Date32Type, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, DecimalType, Field, - Float32Type, Int8Type, Int16Type, Int32Type, Int64Type, IntervalDayTime, + FieldRef, Float32Type, Int8Type, Int16Type, Int32Type, Int64Type, IntervalDayTime, IntervalDayTimeType, IntervalMonthDayNano, IntervalMonthDayNanoType, IntervalUnit, - IntervalYearMonthType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, - TimestampNanosecondType, TimestampSecondType, UInt8Type, UInt16Type, UInt32Type, - UInt64Type, UnionFields, UnionMode, i256, validate_decimal_precision_and_scale, + IntervalYearMonthType, RunEndIndexType, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt8Type, + UInt16Type, UInt32Type, UInt64Type, UnionFields, UnionMode, i256, + validate_decimal_precision_and_scale, }; use arrow::util::display::{ArrayFormatter, FormatOptions, array_value_to_string}; use cache::{get_or_create_cached_key_array, get_or_create_cached_null_array}; @@ -429,6 +429,8 @@ pub enum ScalarValue { Union(Option<(i8, Box)>, UnionFields, UnionMode), /// Dictionary type: index type and value Dictionary(Box, Box), + /// (run-ends field, value field, value) + RunEndEncoded(FieldRef, FieldRef, Box), } impl Hash for Fl { @@ -558,6 +560,10 @@ impl PartialEq for ScalarValue { (Union(_, _, _), _) => false, (Dictionary(k1, v1), Dictionary(k2, v2)) => k1.eq(k2) && v1.eq(v2), (Dictionary(_, _), _) => false, + (RunEndEncoded(rf1, vf1, v1), RunEndEncoded(rf2, vf2, v2)) => { + rf1.eq(rf2) && vf1.eq(vf2) && v1.eq(v2) + } + (RunEndEncoded(_, _, _), _) => false, (Null, Null) => true, (Null, _) => false, } @@ -723,6 +729,15 @@ impl PartialOrd for ScalarValue { if k1 == k2 { v1.partial_cmp(v2) } else { None } } (Dictionary(_, _), _) => None, + (RunEndEncoded(rf1, vf1, v1), RunEndEncoded(rf2, vf2, v2)) => { + // Don't compare if the run ends fields don't match (it is effectively a different datatype) + if rf1 == rf2 && vf1 == vf2 { + v1.partial_cmp(v2) + } else { + None + } + } + (RunEndEncoded(_, _, _), _) => None, (Null, Null) => Some(Ordering::Equal), (Null, _) => None, } @@ -966,6 +981,11 @@ impl Hash for ScalarValue { k.hash(state); v.hash(state); } + RunEndEncoded(rf, vf, v) => { + rf.hash(state); + vf.hash(state); + v.hash(state); + } // stable hash for Null value Null => 1.hash(state), } @@ -1244,6 +1264,13 @@ impl ScalarValue { index_type.clone(), Box::new(value_type.as_ref().try_into()?), ), + DataType::RunEndEncoded(run_ends_field, value_field) => { + ScalarValue::RunEndEncoded( + Arc::clone(run_ends_field), + Arc::clone(value_field), + Box::new(value_field.data_type().try_into()?), + ) + } // `ScalarValue::List` contains single element `ListArray`. DataType::List(field_ref) => ScalarValue::List(Arc::new( GenericListArray::new_null(Arc::clone(field_ref), 1), @@ -1574,6 +1601,8 @@ impl ScalarValue { | DataType::Float16 | DataType::Float32 | DataType::Float64 + | DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) | DataType::Decimal128(_, _) | DataType::Decimal256(_, _) | DataType::Timestamp(_, _) @@ -1642,6 +1671,14 @@ impl ScalarValue { Box::new(ScalarValue::new_default(value_type)?), )), + DataType::RunEndEncoded(run_ends_field, value_field) => { + Ok(ScalarValue::RunEndEncoded( + Arc::clone(run_ends_field), + Arc::clone(value_field), + Box::new(ScalarValue::new_default(value_field.data_type())?), + )) + } + // Map types DataType::Map(field, _) => Ok(ScalarValue::Map(Arc::new(MapArray::from( ArrayData::new_empty(field.data_type()), @@ -1661,8 +1698,7 @@ impl ScalarValue { } } - // Unsupported types for now - _ => { + DataType::ListView(_) | DataType::LargeListView(_) => { _not_impl_err!( "Default value for data_type \"{datatype}\" is not implemented yet" ) @@ -1953,6 +1989,12 @@ impl ScalarValue { ScalarValue::Dictionary(k, v) => { DataType::Dictionary(k.clone(), Box::new(v.data_type())) } + ScalarValue::RunEndEncoded(run_ends_field, value_field, _) => { + DataType::RunEndEncoded( + Arc::clone(run_ends_field), + Arc::clone(value_field), + ) + } ScalarValue::Null => DataType::Null, } } @@ -2231,6 +2273,7 @@ impl ScalarValue { None => true, }, ScalarValue::Dictionary(_, v) => v.is_null(), + ScalarValue::RunEndEncoded(_, _, v) => v.is_null(), } } @@ -2598,6 +2641,94 @@ impl ScalarValue { _ => unreachable!("Invalid dictionary keys type: {}", key_type), } } + DataType::RunEndEncoded(run_ends_field, value_field) => { + fn make_run_array( + scalars: impl IntoIterator, + run_ends_field: &FieldRef, + values_field: &FieldRef, + ) -> Result { + let mut scalars = scalars.into_iter(); + + let mut run_ends = vec![]; + let mut value_scalars = vec![]; + + let mut len = R::Native::ONE; + let mut current = + if let Some(ScalarValue::RunEndEncoded(_, _, scalar)) = + scalars.next() + { + *scalar + } else { + // We are guaranteed to have one element of correct + // type because we peeked above + unreachable!() + }; + for scalar in scalars { + let scalar = match scalar { + ScalarValue::RunEndEncoded( + inner_run_ends_field, + inner_value_field, + scalar, + ) if &inner_run_ends_field == run_ends_field + && &inner_value_field == values_field => + { + *scalar + } + _ => { + return _exec_err!( + "Expected RunEndEncoded scalar with run-ends field {run_ends_field} but got: {scalar:?}" + ); + } + }; + + // new run + if scalar != current { + run_ends.push(len); + value_scalars.push(current); + current = scalar; + } + + len = len.add_checked(R::Native::ONE).map_err(|_| { + DataFusionError::Execution(format!( + "Cannot construct RunArray: Overflows run-ends type {}", + run_ends_field.data_type() + )) + })?; + } + + run_ends.push(len); + value_scalars.push(current); + + let run_ends = PrimitiveArray::::from_iter_values(run_ends); + let values = ScalarValue::iter_to_array(value_scalars)?; + + // Using ArrayDataBuilder so we can maintain the fields + let dt = DataType::RunEndEncoded( + Arc::clone(run_ends_field), + Arc::clone(values_field), + ); + let builder = ArrayDataBuilder::new(dt) + .len(RunArray::logical_len(&run_ends)) + .add_child_data(run_ends.to_data()) + .add_child_data(values.to_data()); + let run_array = RunArray::::from(builder.build()?); + + Ok(Arc::new(run_array)) + } + + match run_ends_field.data_type() { + DataType::Int16 => { + make_run_array::(scalars, run_ends_field, value_field)? + } + DataType::Int32 => { + make_run_array::(scalars, run_ends_field, value_field)? + } + DataType::Int64 => { + make_run_array::(scalars, run_ends_field, value_field)? + } + dt => unreachable!("Invalid run-ends type: {dt}"), + } + } DataType::FixedSizeBinary(size) => { let array = scalars .map(|sv| { @@ -2626,7 +2757,6 @@ impl ScalarValue { | DataType::Time32(TimeUnit::Nanosecond) | DataType::Time64(TimeUnit::Second) | DataType::Time64(TimeUnit::Millisecond) - | DataType::RunEndEncoded(_, _) | DataType::ListView(_) | DataType::LargeListView(_) => { return _not_impl_err!( @@ -2989,13 +3119,8 @@ impl ScalarValue { }, ScalarValue::Utf8View(e) => match e { Some(value) => { - let mut builder = - StringViewBuilder::with_capacity(size).with_deduplicate_strings(); - // Replace with upstream arrow-rs code when available: - // https://github.com/apache/arrow-rs/issues/9034 - for _ in 0..size { - builder.append_value(value); - } + let mut builder = StringViewBuilder::with_capacity(size); + builder.try_append_value_n(value, size)?; let array = builder.finish(); Arc::new(array) } @@ -3013,11 +3138,8 @@ impl ScalarValue { }, ScalarValue::BinaryView(e) => match e { Some(value) => { - let mut builder = - BinaryViewBuilder::with_capacity(size).with_deduplicate_strings(); - for _ in 0..size { - builder.append_value(value); - } + let mut builder = BinaryViewBuilder::with_capacity(size); + builder.try_append_value_n(value, size)?; let array = builder.finish(); Arc::new(array) } @@ -3031,14 +3153,7 @@ impl ScalarValue { ) .unwrap(), ), - None => { - // TODO: Replace with FixedSizeBinaryArray::new_null once a fix for - // https://github.com/apache/arrow-rs/issues/8900 is in the used arrow-rs - // version. - let mut builder = FixedSizeBinaryBuilder::new(*s); - builder.append_nulls(size); - Arc::new(builder.finish()) - } + None => Arc::new(FixedSizeBinaryArray::new_null(*s, size)), }, ScalarValue::LargeBinary(e) => match e { Some(value) => { @@ -3218,6 +3333,54 @@ impl ScalarValue { _ => unreachable!("Invalid dictionary keys type: {}", key_type), } } + ScalarValue::RunEndEncoded(run_ends_field, values_field, value) => { + fn make_run_array( + run_ends_field: &Arc, + values_field: &Arc, + value: &ScalarValue, + size: usize, + ) -> Result { + let size_native = R::Native::from_usize(size) + .ok_or_else(|| DataFusionError::Execution(format!("Cannot construct RunArray of size {size}: Overflows run-ends type {}", R::DATA_TYPE)))?; + let values = value.to_array_of_size(1)?; + let run_ends = + PrimitiveArray::::new(vec![size_native].into(), None); + + // Using ArrayDataBuilder so we can maintain the fields + let dt = DataType::RunEndEncoded( + Arc::clone(run_ends_field), + Arc::clone(values_field), + ); + let builder = ArrayDataBuilder::new(dt) + .len(size) + .add_child_data(run_ends.to_data()) + .add_child_data(values.to_data()); + let run_array = RunArray::::from(builder.build()?); + + Ok(Arc::new(run_array)) + } + match run_ends_field.data_type() { + DataType::Int16 => make_run_array::( + run_ends_field, + values_field, + value, + size, + )?, + DataType::Int32 => make_run_array::( + run_ends_field, + values_field, + value, + size, + )?, + DataType::Int64 => make_run_array::( + run_ends_field, + values_field, + value, + size, + )?, + dt => unreachable!("Invalid run-ends type: {dt}"), + } + } ScalarValue::Null => get_or_create_cached_null_array(size), }) } @@ -3568,6 +3731,28 @@ impl ScalarValue { Self::Dictionary(key_type.clone(), Box::new(value)) } + DataType::RunEndEncoded(run_ends_field, value_field) => { + // Explicitly check length here since get_physical_index() doesn't + // bound check for us + if index > array.len() { + return _exec_err!( + "Index {index} out of bounds for array of length {}", + array.len() + ); + } + let scalar = downcast_run_array!( + array => { + let index = array.get_physical_index(index); + ScalarValue::try_from_array(array.values(), index)? + }, + dt => unreachable!("Invalid run-ends type: {dt}") + ); + Self::RunEndEncoded( + Arc::clone(run_ends_field), + Arc::clone(value_field), + Box::new(scalar), + ) + } DataType::Struct(_) => { let a = array.slice(index, 1); Self::Struct(Arc::new(a.as_struct().to_owned())) @@ -3680,6 +3865,7 @@ impl ScalarValue { ScalarValue::LargeUtf8(v) => v, ScalarValue::Utf8View(v) => v, ScalarValue::Dictionary(_, v) => return v.try_as_str(), + ScalarValue::RunEndEncoded(_, _, v) => return v.try_as_str(), _ => return None, }; Some(v.as_ref().map(|v| v.as_str())) @@ -3704,7 +3890,23 @@ impl ScalarValue { } let scalar_array = self.to_array()?; - let cast_arr = cast_with_options(&scalar_array, target_type, cast_options)?; + + // For struct types, use name-based casting logic that matches fields by name + // and recursively casts nested structs. The field name wrapper is arbitrary + // since cast_column only uses the DataType::Struct field definitions inside. + let cast_arr = match target_type { + DataType::Struct(_) => { + // Field name is unused; only the struct's inner field names matter + let target_field = Field::new("_", target_type.clone(), true); + crate::nested_struct::cast_column( + &scalar_array, + &target_field, + cast_options, + )? + } + _ => cast_with_options(&scalar_array, target_type, cast_options)?, + }; + ScalarValue::try_from_array(&cast_arr, 0) } @@ -4008,6 +4210,34 @@ impl ScalarValue { None => v.is_null(), } } + ScalarValue::RunEndEncoded(run_ends_field, _, value) => { + // Explicitly check length here since get_physical_index() doesn't + // bound check for us + if index > array.len() { + return _exec_err!( + "Index {index} out of bounds for array of length {}", + array.len() + ); + } + match run_ends_field.data_type() { + DataType::Int16 => { + let array = as_run_array::(array)?; + let index = array.get_physical_index(index); + value.eq_array(array.values(), index)? + } + DataType::Int32 => { + let array = as_run_array::(array)?; + let index = array.get_physical_index(index); + value.eq_array(array.values(), index)? + } + DataType::Int64 => { + let array = as_run_array::(array)?; + let index = array.get_physical_index(index); + value.eq_array(array.values(), index)? + } + dt => unreachable!("Invalid run-ends type: {dt}"), + } + } ScalarValue::Null => array.is_null(index), }) } @@ -4097,6 +4327,7 @@ impl ScalarValue { // `dt` and `sv` are boxed, so they are NOT already included in `self` dt.size() + sv.size() } + ScalarValue::RunEndEncoded(rf, vf, v) => rf.size() + vf.size() + v.size(), } } @@ -4212,6 +4443,9 @@ impl ScalarValue { ScalarValue::Dictionary(_, value) => { value.compact(); } + ScalarValue::RunEndEncoded(_, _, value) => { + value.compact(); + } } } @@ -4843,6 +5077,7 @@ impl fmt::Display for ScalarValue { None => write!(f, "NULL")?, }, ScalarValue::Dictionary(_k, v) => write!(f, "{v}")?, + ScalarValue::RunEndEncoded(_, _, v) => write!(f, "{v}")?, ScalarValue::Null => write!(f, "NULL")?, }; Ok(()) @@ -5021,6 +5256,9 @@ impl fmt::Debug for ScalarValue { None => write!(f, "Union(NULL)"), }, ScalarValue::Dictionary(k, v) => write!(f, "Dictionary({k:?}, {v:?})"), + ScalarValue::RunEndEncoded(rf, vf, v) => { + write!(f, "RunEndEncoded({rf:?}, {vf:?}, {v:?})") + } ScalarValue::Null => write!(f, "NULL"), } } @@ -7256,6 +7494,31 @@ mod tests { } } + #[test] + fn roundtrip_run_array() { + // Comparison logic in round_trip_through_scalar doesn't work for RunArrays + // so we have a custom test for them + // TODO: https://github.com/apache/arrow-rs/pull/9213 might fix this ^ + let run_ends = Int16Array::from(vec![2, 3]); + let values = Int64Array::from(vec![Some(1), None]); + let run_array = RunArray::try_new(&run_ends, &values).unwrap(); + let run_array = run_array.downcast::().unwrap(); + + let expected_values = run_array.into_iter().collect::>(); + + for i in 0..run_array.len() { + let scalar = ScalarValue::try_from_array(&run_array, i).unwrap(); + let array = scalar.to_array_of_size(1).unwrap(); + assert_eq!(array.data_type(), run_array.data_type()); + let array = array.as_run::(); + let array = array.downcast::().unwrap(); + assert_eq!( + array.into_iter().collect::>(), + expected_values[i..i + 1] + ); + } + } + #[test] fn test_scalar_union_sparse() { let field_a = Arc::new(Field::new("A", DataType::Int32, true)); @@ -8868,7 +9131,7 @@ mod tests { .unwrap(), ScalarValue::try_new_null(&DataType::Map(map_field_ref, false)).unwrap(), ScalarValue::try_new_null(&DataType::Union( - UnionFields::new(vec![42], vec![field_ref]), + UnionFields::try_new(vec![42], vec![field_ref]).unwrap(), UnionMode::Dense, )) .unwrap(), @@ -8971,13 +9234,14 @@ mod tests { } // Test union type - let union_fields = UnionFields::new( + let union_fields = UnionFields::try_new( vec![0, 1], vec![ Field::new("i32", DataType::Int32, false), Field::new("f64", DataType::Float64, false), ], - ); + ) + .unwrap(); let union_result = ScalarValue::new_default(&DataType::Union( union_fields.clone(), UnionMode::Sparse, @@ -9227,6 +9491,175 @@ mod tests { assert_eq!(value.len(), buffers[0].len()); } + #[test] + fn test_to_array_of_size_run_end_encoded() { + fn run_test() { + let value = Box::new(ScalarValue::Float32(Some(1.0))); + let size = 5; + let scalar = ScalarValue::RunEndEncoded( + Field::new("run_ends", R::DATA_TYPE, false).into(), + Field::new("values", DataType::Float32, true).into(), + value.clone(), + ); + let array = scalar.to_array_of_size(size).unwrap(); + let array = array.as_run::(); + let array = array.downcast::().unwrap(); + assert_eq!(vec![Some(1.0); size], array.into_iter().collect::>()); + assert_eq!(1, array.values().len()); + } + + run_test::(); + run_test::(); + run_test::(); + + let scalar = ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int16, false).into(), + Field::new("values", DataType::Float32, true).into(), + Box::new(ScalarValue::Float32(Some(1.0))), + ); + let err = scalar.to_array_of_size(i16::MAX as usize + 10).unwrap_err(); + assert_eq!( + "Execution error: Cannot construct RunArray of size 32777: Overflows run-ends type Int16", + err.to_string() + ) + } + + #[test] + fn test_eq_array_run_end_encoded() { + let run_ends = Int16Array::from(vec![1, 3]); + let values = Float32Array::from(vec![None, Some(1.0)]); + let run_array = + Arc::new(RunArray::try_new(&run_ends, &values).unwrap()) as ArrayRef; + + let scalar = ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int16, false).into(), + Field::new("values", DataType::Float32, true).into(), + Box::new(ScalarValue::Float32(None)), + ); + assert!(scalar.eq_array(&run_array, 0).unwrap()); + + let scalar = ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int16, false).into(), + Field::new("values", DataType::Float32, true).into(), + Box::new(ScalarValue::Float32(Some(1.0))), + ); + assert!(scalar.eq_array(&run_array, 1).unwrap()); + assert!(scalar.eq_array(&run_array, 2).unwrap()); + + // value types must match + let scalar = ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int16, false).into(), + Field::new("values", DataType::Float64, true).into(), + Box::new(ScalarValue::Float64(Some(1.0))), + ); + let err = scalar.eq_array(&run_array, 1).unwrap_err(); + let expected = "Internal error: could not cast array of type Float32 to arrow_array::array::primitive_array::PrimitiveArray"; + assert!(err.to_string().starts_with(expected)); + + // run ends type must match + let scalar = ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int32, false).into(), + Field::new("values", DataType::Float32, true).into(), + Box::new(ScalarValue::Float32(None)), + ); + let err = scalar.eq_array(&run_array, 0).unwrap_err(); + let expected = "Internal error: could not cast array of type RunEndEncoded(\"run_ends\": non-null Int16, \"values\": Float32) to arrow_array::array::run_array::RunArray"; + assert!(err.to_string().starts_with(expected)); + } + + #[test] + fn test_iter_to_array_run_end_encoded() { + let run_ends_field = Arc::new(Field::new("run_ends", DataType::Int16, false)); + let values_field = Arc::new(Field::new("values", DataType::Int64, true)); + let scalars = vec![ + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(1))), + ), + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(1))), + ), + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(None)), + ), + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(2))), + ), + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(2))), + ), + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(2))), + ), + ]; + + let run_array = ScalarValue::iter_to_array(scalars).unwrap(); + let expected = RunArray::try_new( + &Int16Array::from(vec![2, 3, 6]), + &Int64Array::from(vec![Some(1), None, Some(2)]), + ) + .unwrap(); + assert_eq!(&expected as &dyn Array, run_array.as_ref()); + + // inconsistent run-ends type + let scalars = vec![ + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(1))), + ), + ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int32, false).into(), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(1))), + ), + ]; + let err = ScalarValue::iter_to_array(scalars).unwrap_err(); + let expected = "Execution error: Expected RunEndEncoded scalar with run-ends field Field { \"run_ends\": Int16 } but got: RunEndEncoded(Field { name: \"run_ends\", data_type: Int32 }, Field { name: \"values\", data_type: Int64, nullable: true }, Int64(1))"; + assert!(err.to_string().starts_with(expected)); + + // inconsistent value type + let scalars = vec![ + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(1))), + ), + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Field::new("values", DataType::Int32, true).into(), + Box::new(ScalarValue::Int32(Some(1))), + ), + ]; + let err = ScalarValue::iter_to_array(scalars).unwrap_err(); + let expected = "Execution error: Expected RunEndEncoded scalar with run-ends field Field { \"run_ends\": Int16 } but got: RunEndEncoded(Field { name: \"run_ends\", data_type: Int16 }, Field { name: \"values\", data_type: Int32, nullable: true }, Int32(1))"; + assert!(err.to_string().starts_with(expected)); + + // inconsistent scalars type + let scalars = vec![ + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(1))), + ), + ScalarValue::Int64(Some(1)), + ]; + let err = ScalarValue::iter_to_array(scalars).unwrap_err(); + let expected = "Execution error: Expected RunEndEncoded scalar with run-ends field Field { \"run_ends\": Int16 } but got: Int64(1)"; + assert!(err.to_string().starts_with(expected)); + } + #[test] fn test_convert_array_to_scalar_vec() { // 1: Regular ListArray diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index ba13ef392d912..cecf1d03418d7 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -391,8 +391,13 @@ impl Statistics { /// For example, if we had statistics for columns `{"a", "b", "c"}`, /// projecting to `vec![2, 1]` would return statistics for columns `{"c", /// "b"}`. - pub fn project(mut self, projection: Option<&Vec>) -> Self { - let Some(projection) = projection else { + pub fn project(self, projection: Option<&impl AsRef<[usize]>>) -> Self { + let projection = projection.map(AsRef::as_ref); + self.project_impl(projection) + } + + fn project_impl(mut self, projection: Option<&[usize]>) -> Self { + let Some(projection) = projection.map(AsRef::as_ref) else { return self; }; @@ -410,7 +415,7 @@ impl Statistics { .map(Slot::Present) .collect(); - for idx in projection { + for idx in projection.iter() { let next_idx = self.column_statistics.len(); let slot = std::mem::replace( columns.get_mut(*idx).expect("projection out of bounds"), @@ -1066,7 +1071,7 @@ mod tests { #[test] fn test_project_none() { - let projection = None; + let projection: Option> = None; let stats = make_stats(vec![10, 20, 30]).project(projection.as_ref()); assert_eq!(stats, make_stats(vec![10, 20, 30])); } diff --git a/datafusion/common/src/types/native.rs b/datafusion/common/src/types/native.rs index 766c50441613b..5ef90b7209854 100644 --- a/datafusion/common/src/types/native.rs +++ b/datafusion/common/src/types/native.rs @@ -186,7 +186,57 @@ pub enum NativeType { impl Display for NativeType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{self:?}") // TODO: nicer formatting + // Match the format used by arrow::datatypes::DataType's Display impl + match self { + Self::Null => write!(f, "Null"), + Self::Boolean => write!(f, "Boolean"), + Self::Int8 => write!(f, "Int8"), + Self::Int16 => write!(f, "Int16"), + Self::Int32 => write!(f, "Int32"), + Self::Int64 => write!(f, "Int64"), + Self::UInt8 => write!(f, "UInt8"), + Self::UInt16 => write!(f, "UInt16"), + Self::UInt32 => write!(f, "UInt32"), + Self::UInt64 => write!(f, "UInt64"), + Self::Float16 => write!(f, "Float16"), + Self::Float32 => write!(f, "Float32"), + Self::Float64 => write!(f, "Float64"), + Self::Timestamp(unit, Some(tz)) => write!(f, "Timestamp({unit}, {tz:?})"), + Self::Timestamp(unit, None) => write!(f, "Timestamp({unit})"), + Self::Date => write!(f, "Date"), + Self::Time(unit) => write!(f, "Time({unit})"), + Self::Duration(unit) => write!(f, "Duration({unit})"), + Self::Interval(unit) => write!(f, "Interval({unit:?})"), + Self::Binary => write!(f, "Binary"), + Self::FixedSizeBinary(size) => write!(f, "FixedSizeBinary({size})"), + Self::String => write!(f, "String"), + Self::List(field) => write!(f, "List({})", field.logical_type), + Self::FixedSizeList(field, size) => { + write!(f, "FixedSizeList({size} x {})", field.logical_type) + } + Self::Struct(fields) => { + write!(f, "Struct(")?; + for (i, field) in fields.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{:?}: {}", field.name, field.logical_type)?; + } + write!(f, ")") + } + Self::Union(fields) => { + write!(f, "Union(")?; + for (i, (type_id, field)) in fields.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{type_id}: ({:?}: {})", field.name, field.logical_type)?; + } + write!(f, ")") + } + Self::Decimal(precision, scale) => write!(f, "Decimal({precision}, {scale})"), + Self::Map(field) => write!(f, "Map({})", field.logical_type), + } } } diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 03310a7bde193..7f2d78d57970e 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -70,10 +70,10 @@ use std::thread::available_parallelism; /// ``` pub fn project_schema( schema: &SchemaRef, - projection: Option<&Vec>, + projection: Option<&impl AsRef<[usize]>>, ) -> Result { let schema = match projection { - Some(columns) => Arc::new(schema.project(columns)?), + Some(columns) => Arc::new(schema.project(columns.as_ref())?), None => Arc::clone(schema), }; Ok(schema) @@ -516,6 +516,7 @@ impl SingleRowListArrayBuilder { /// ); /// /// assert_eq!(list_arr, expected); +/// ``` pub fn arrays_into_list_array( arr: impl IntoIterator, ) -> Result { @@ -587,6 +588,7 @@ pub enum ListCoercion { /// let base_type = DataType::Float64; /// let coerced_type = coerced_type_with_base_type_only(&data_type, &base_type, None); /// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new_list_field(DataType::Float64, true)))); +/// ``` pub fn coerced_type_with_base_type_only( data_type: &DataType, base_type: &DataType, diff --git a/datafusion/common/src/utils/proxy.rs b/datafusion/common/src/utils/proxy.rs index fddf834912544..846c928515d60 100644 --- a/datafusion/common/src/utils/proxy.rs +++ b/datafusion/common/src/utils/proxy.rs @@ -121,6 +121,8 @@ pub trait HashTableAllocExt { /// /// Returns the bucket where the element was inserted. /// Note that allocation counts capacity, not size. + /// Panics: + /// Assumes the element is not already present, and may panic if it does /// /// # Example: /// ``` @@ -134,7 +136,7 @@ pub trait HashTableAllocExt { /// assert_eq!(allocated, 64); /// /// // insert more values - /// for i in 0..100 { + /// for i in 2..100 { /// table.insert_accounted(i, hash_fn, &mut allocated); /// } /// assert_eq!(allocated, 400); @@ -161,22 +163,24 @@ where ) { let hash = hasher(&x); - // NOTE: `find_entry` does NOT grow! - match self.find_entry(hash, |y| y == &x) { - Ok(_occupied) => {} - Err(_absent) => { - if self.len() == self.capacity() { - // need to request more memory - let bump_elements = self.capacity().max(16); - let bump_size = bump_elements * size_of::(); - *accounting = (*accounting).checked_add(bump_size).expect("overflow"); + if cfg!(debug_assertions) { + // In debug mode, check that the element is not already present + debug_assert!( + self.find_entry(hash, |y| y == &x).is_err(), + "attempted to insert duplicate element into HashTableAllocExt::insert_accounted" + ); + } - self.reserve(bump_elements, &hasher); - } + if self.len() == self.capacity() { + // need to request more memory + let bump_elements = self.capacity().max(16); + let bump_size = bump_elements * size_of::(); + *accounting = (*accounting).checked_add(bump_size).expect("overflow"); - // still need to insert the element since first try failed - self.entry(hash, |y| y == &x, hasher).insert(x); - } + self.reserve(bump_elements, &hasher); } + + // We assume the element is not already present + self.insert_unique(hash, x, hasher); } } diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index bd88ed3b9ca1e..3d0a76a182697 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -158,7 +158,7 @@ sqlparser = { workspace = true, optional = true } tempfile = { workspace = true } tokio = { workspace = true } url = { workspace = true } -uuid = { version = "1.19", features = ["v4", "js"] } +uuid = { workspace = true, features = ["v4", "js"] } zstd = { workspace = true, optional = true } [dev-dependencies] @@ -175,12 +175,14 @@ env_logger = { workspace = true } glob = { workspace = true } insta = { workspace = true } paste = { workspace = true } +pretty_assertions = "1.0" rand = { workspace = true, features = ["small_rng"] } rand_distr = "0.5" +recursive = { workspace = true } regex = { workspace = true } rstest = { workspace = true } serde_json = { workspace = true } -sysinfo = "0.37.2" +sysinfo = "0.38.2" test-utils = { path = "../../test-utils" } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot", "fs"] } @@ -188,7 +190,7 @@ tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot", "fs"] ignored = ["datafusion-doc", "datafusion-macros", "dashmap"] [target.'cfg(not(target_os = "windows"))'.dev-dependencies] -nix = { version = "0.30.1", features = ["fs"] } +nix = { version = "0.31.1", features = ["fs"] } [[bench]] harness = false @@ -239,6 +241,11 @@ harness = false name = "parquet_query_sql" required-features = ["parquet"] +[[bench]] +harness = false +name = "parquet_struct_query" +required-features = ["parquet"] + [[bench]] harness = false name = "range_and_generate_series" @@ -280,3 +287,7 @@ name = "spm" harness = false name = "preserve_file_partitioning" required-features = ["parquet"] + +[[bench]] +harness = false +name = "reset_plan_states" diff --git a/datafusion/core/benches/aggregate_query_sql.rs b/datafusion/core/benches/aggregate_query_sql.rs index 4aa667504e459..f785c9458003f 100644 --- a/datafusion/core/benches/aggregate_query_sql.rs +++ b/datafusion/core/benches/aggregate_query_sql.rs @@ -15,14 +15,9 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -extern crate arrow; -extern crate datafusion; - mod data_utils; -use crate::criterion::Criterion; +use criterion::{Criterion, criterion_group, criterion_main}; use data_utils::create_table_provider; use datafusion::error::Result; use datafusion::execution::context::SessionContext; diff --git a/datafusion/core/benches/csv_load.rs b/datafusion/core/benches/csv_load.rs index 228457947fd5a..13843dadddd0c 100644 --- a/datafusion/core/benches/csv_load.rs +++ b/datafusion/core/benches/csv_load.rs @@ -15,14 +15,9 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -extern crate arrow; -extern crate datafusion; - mod data_utils; -use crate::criterion::Criterion; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::error::Result; use datafusion::execution::context::SessionContext; use datafusion::prelude::CsvReadOptions; diff --git a/datafusion/core/benches/data_utils/mod.rs b/datafusion/core/benches/data_utils/mod.rs index 630bc056600b4..accd51ae5861c 100644 --- a/datafusion/core/benches/data_utils/mod.rs +++ b/datafusion/core/benches/data_utils/mod.rs @@ -36,6 +36,7 @@ use std::sync::Arc; /// create an in-memory table given the partition len, array len, and batch size, /// and the result table will be of array_len in total, and then partitioned, and batched. +#[expect(clippy::allow_attributes)] // some issue where expect(dead_code) doesn't fire properly #[allow(dead_code)] pub fn create_table_provider( partitions_len: usize, @@ -183,6 +184,7 @@ impl TraceIdBuilder { /// Create time series data with `partition_cnt` partitions and `sample_cnt` rows per partition /// in ascending order, if `asc` is true, otherwise randomly sampled using a Pareto distribution +#[expect(clippy::allow_attributes)] // some issue where expect(dead_code) doesn't fire properly #[allow(dead_code)] pub(crate) fn make_data( partition_cnt: i32, diff --git a/datafusion/core/benches/dataframe.rs b/datafusion/core/benches/dataframe.rs index 726187ab5e922..5aeade315cc7b 100644 --- a/datafusion/core/benches/dataframe.rs +++ b/datafusion/core/benches/dataframe.rs @@ -15,13 +15,8 @@ // specific language governing permissions and limitations // under the License. -extern crate arrow; -#[macro_use] -extern crate criterion; -extern crate datafusion; - use arrow_schema::{DataType, Field, Schema}; -use criterion::Criterion; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::datasource::MemTable; use datafusion::prelude::SessionContext; use datafusion_expr::col; diff --git a/datafusion/core/benches/distinct_query_sql.rs b/datafusion/core/benches/distinct_query_sql.rs index 0e638e293d8cf..d389b1b3d6a22 100644 --- a/datafusion/core/benches/distinct_query_sql.rs +++ b/datafusion/core/benches/distinct_query_sql.rs @@ -15,13 +15,9 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -extern crate arrow; -extern crate datafusion; - mod data_utils; -use crate::criterion::Criterion; + +use criterion::{Criterion, criterion_group, criterion_main}; use data_utils::{create_table_provider, make_data}; use datafusion::execution::context::SessionContext; use datafusion::physical_plan::{ExecutionPlan, collect}; diff --git a/datafusion/core/benches/math_query_sql.rs b/datafusion/core/benches/math_query_sql.rs index 4d1d4abb6783c..f5df56e95a2d8 100644 --- a/datafusion/core/benches/math_query_sql.rs +++ b/datafusion/core/benches/math_query_sql.rs @@ -15,18 +15,13 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -use criterion::Criterion; +use criterion::{Criterion, criterion_group, criterion_main}; use parking_lot::Mutex; use std::sync::Arc; use tokio::runtime::Runtime; -extern crate arrow; -extern crate datafusion; - use arrow::{ array::{Float32Array, Float64Array}, datatypes::{DataType, Field, Schema}, diff --git a/datafusion/core/benches/parquet_struct_query.rs b/datafusion/core/benches/parquet_struct_query.rs new file mode 100644 index 0000000000000..17ba17e02ba80 --- /dev/null +++ b/datafusion/core/benches/parquet_struct_query.rs @@ -0,0 +1,312 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmarks of SQL queries on struct columns in parquet data + +use arrow::array::{ArrayRef, Int32Array, StringArray, StructArray}; +use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion::prelude::SessionContext; +use datafusion_common::instant::Instant; +use parquet::arrow::ArrowWriter; +use parquet::file::properties::{WriterProperties, WriterVersion}; +use rand::distr::Alphanumeric; +use rand::prelude::*; +use rand::rng; +use std::hint::black_box; +use std::ops::Range; +use std::path::Path; +use std::sync::Arc; +use tempfile::NamedTempFile; +use tokio::runtime::Runtime; + +/// The number of batches to write +const NUM_BATCHES: usize = 128; +/// The number of rows in each record batch to write +const WRITE_RECORD_BATCH_SIZE: usize = 4096; +/// The number of rows in a row group +const ROW_GROUP_SIZE: usize = 65536; +/// The number of row groups expected +const EXPECTED_ROW_GROUPS: usize = 8; +/// The range for random string lengths +const STRING_LENGTH_RANGE: Range = 50..200; + +fn schema() -> SchemaRef { + let struct_fields = Fields::from(vec![ + Field::new("id", DataType::Int32, false), + Field::new("value", DataType::Utf8, false), + ]); + let struct_type = DataType::Struct(struct_fields); + + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("s", struct_type, false), + ])) +} + +fn generate_strings(len: usize) -> ArrayRef { + let mut rng = rng(); + Arc::new(StringArray::from_iter((0..len).map(|_| { + let string_len = rng.random_range(STRING_LENGTH_RANGE.clone()); + Some( + (0..string_len) + .map(|_| char::from(rng.sample(Alphanumeric))) + .collect::(), + ) + }))) +} + +fn generate_batch(batch_id: usize) -> RecordBatch { + let schema = schema(); + let len = WRITE_RECORD_BATCH_SIZE; + + // Generate sequential IDs based on batch_id for uniqueness + let base_id = (batch_id * len) as i32; + let id_values: Vec = (0..len).map(|i| base_id + i as i32).collect(); + let id_array = Arc::new(Int32Array::from(id_values.clone())); + + // Create struct id array (matching top-level id) + let struct_id_array = Arc::new(Int32Array::from(id_values)); + + // Generate random strings for struct value field + let value_array = generate_strings(len); + + // Construct StructArray + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("id", DataType::Int32, false)), + struct_id_array as ArrayRef, + ), + ( + Arc::new(Field::new("value", DataType::Utf8, false)), + value_array, + ), + ]); + + RecordBatch::try_new(schema, vec![id_array, Arc::new(struct_array)]).unwrap() +} + +fn generate_file() -> NamedTempFile { + let now = Instant::now(); + let mut named_file = tempfile::Builder::new() + .prefix("parquet_struct_query") + .suffix(".parquet") + .tempfile() + .unwrap(); + + println!("Generating parquet file - {}", named_file.path().display()); + let schema = schema(); + + let properties = WriterProperties::builder() + .set_writer_version(WriterVersion::PARQUET_2_0) + .set_max_row_group_size(ROW_GROUP_SIZE) + .build(); + + let mut writer = + ArrowWriter::try_new(&mut named_file, schema, Some(properties)).unwrap(); + + for batch_id in 0..NUM_BATCHES { + let batch = generate_batch(batch_id); + writer.write(&batch).unwrap(); + } + + let metadata = writer.close().unwrap(); + let file_metadata = metadata.file_metadata(); + let expected_rows = WRITE_RECORD_BATCH_SIZE * NUM_BATCHES; + assert_eq!( + file_metadata.num_rows() as usize, + expected_rows, + "Expected {} rows but got {}", + expected_rows, + file_metadata.num_rows() + ); + assert_eq!( + metadata.row_groups().len(), + EXPECTED_ROW_GROUPS, + "Expected {} row groups but got {}", + EXPECTED_ROW_GROUPS, + metadata.row_groups().len() + ); + + println!( + "Generated parquet file with {} rows and {} row groups in {} seconds", + file_metadata.num_rows(), + metadata.row_groups().len(), + now.elapsed().as_secs_f32() + ); + + named_file +} + +fn create_context(file_path: &str) -> SessionContext { + let ctx = SessionContext::new(); + let rt = Runtime::new().unwrap(); + rt.block_on(ctx.register_parquet("t", file_path, Default::default())) + .unwrap(); + ctx +} + +fn query(ctx: &SessionContext, rt: &Runtime, sql: &str) { + let ctx = ctx.clone(); + let sql = sql.to_string(); + let df = rt.block_on(ctx.sql(&sql)).unwrap(); + black_box(rt.block_on(df.collect()).unwrap()); +} + +fn criterion_benchmark(c: &mut Criterion) { + let (file_path, temp_file) = match std::env::var("PARQUET_FILE") { + Ok(file) => (file, None), + Err(_) => { + let temp_file = generate_file(); + (temp_file.path().display().to_string(), Some(temp_file)) + } + }; + + assert!(Path::new(&file_path).exists(), "path not found"); + println!("Using parquet file {file_path}"); + + let ctx = create_context(&file_path); + let rt = Runtime::new().unwrap(); + + // Basic struct access + c.bench_function("struct_access", |b| { + b.iter(|| query(&ctx, &rt, "select id, s['id'] from t")) + }); + + // Filter queries + c.bench_function("filter_struct_field_eq", |b| { + b.iter(|| query(&ctx, &rt, "select id from t where s['id'] = 5")) + }); + + c.bench_function("filter_struct_field_with_select", |b| { + b.iter(|| query(&ctx, &rt, "select id, s['id'] from t where s['id'] = 5")) + }); + + c.bench_function("filter_top_level_with_struct_select", |b| { + b.iter(|| query(&ctx, &rt, "select s['id'] from t where id = 5")) + }); + + c.bench_function("filter_struct_string_length", |b| { + b.iter(|| query(&ctx, &rt, "select id from t where length(s['value']) > 100")) + }); + + c.bench_function("filter_struct_range", |b| { + b.iter(|| { + query( + &ctx, + &rt, + "select id from t where s['id'] > 100 and s['id'] < 200", + ) + }) + }); + + // Join queries (limited with WHERE id < 1000 for performance) + c.bench_function("join_struct_to_struct", |b| { + b.iter(|| query( + &ctx, + &rt, + "select t1.id from t t1 join t t2 on t1.s['id'] = t2.s['id'] where t1.id < 1000" + )) + }); + + c.bench_function("join_struct_to_toplevel", |b| { + b.iter(|| query( + &ctx, + &rt, + "select t1.id from t t1 join t t2 on t1.s['id'] = t2.id where t1.id < 1000" + )) + }); + + c.bench_function("join_toplevel_to_struct", |b| { + b.iter(|| query( + &ctx, + &rt, + "select t1.id from t t1 join t t2 on t1.id = t2.s['id'] where t1.id < 1000" + )) + }); + + c.bench_function("join_struct_to_struct_with_top_level", |b| { + b.iter(|| query( + &ctx, + &rt, + "select t1.id from t t1 join t t2 on t1.s['id'] = t2.s['id'] and t1.id = t2.id where t1.id < 1000" + )) + }); + + c.bench_function("join_struct_and_struct_value", |b| { + b.iter(|| query( + &ctx, + &rt, + "select t1.s['id'], t2.s['value'] from t t1 join t t2 on t1.id = t2.id where t1.id < 1000" + )) + }); + + // Group by queries + c.bench_function("group_by_struct_field", |b| { + b.iter(|| query(&ctx, &rt, "select s['id'] from t group by s['id']")) + }); + + c.bench_function("group_by_struct_select_toplevel", |b| { + b.iter(|| query(&ctx, &rt, "select max(id) from t group by s['id']")) + }); + + c.bench_function("group_by_toplevel_select_struct", |b| { + b.iter(|| query(&ctx, &rt, "select max(s['id']) from t group by id")) + }); + + c.bench_function("group_by_struct_with_count", |b| { + b.iter(|| { + query( + &ctx, + &rt, + "select s['id'], count(*) from t group by s['id']", + ) + }) + }); + + c.bench_function("group_by_multiple_with_count", |b| { + b.iter(|| { + query( + &ctx, + &rt, + "select id, s['id'], count(*) from t group by id, s['id']", + ) + }) + }); + + // Additional queries + c.bench_function("order_by_struct_limit", |b| { + b.iter(|| { + query( + &ctx, + &rt, + "select id, s['id'] from t order by s['id'] limit 1000", + ) + }) + }); + + c.bench_function("distinct_struct_field", |b| { + b.iter(|| query(&ctx, &rt, "select distinct s['id'] from t")) + }); + + // Temporary file must outlive the benchmarks, it is deleted when dropped + drop(temp_file); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/core/benches/physical_plan.rs b/datafusion/core/benches/physical_plan.rs index e6763b4761c2a..7b66996b05929 100644 --- a/datafusion/core/benches/physical_plan.rs +++ b/datafusion/core/benches/physical_plan.rs @@ -15,11 +15,7 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -use criterion::{BatchSize, Criterion}; -extern crate arrow; -extern crate datafusion; +use criterion::{BatchSize, Criterion, criterion_group, criterion_main}; use std::sync::Arc; diff --git a/datafusion/core/benches/preserve_file_partitioning.rs b/datafusion/core/benches/preserve_file_partitioning.rs index 17ebca52cd1d2..9b1f59adc6823 100644 --- a/datafusion/core/benches/preserve_file_partitioning.rs +++ b/datafusion/core/benches/preserve_file_partitioning.rs @@ -322,7 +322,7 @@ async fn save_plans( } } -#[allow(clippy::too_many_arguments)] +#[expect(clippy::too_many_arguments)] fn run_benchmark( c: &mut Criterion, rt: &Runtime, diff --git a/datafusion/core/benches/range_and_generate_series.rs b/datafusion/core/benches/range_and_generate_series.rs index 2b1463a21062a..10d560df0813e 100644 --- a/datafusion/core/benches/range_and_generate_series.rs +++ b/datafusion/core/benches/range_and_generate_series.rs @@ -15,13 +15,9 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -extern crate datafusion; - mod data_utils; -use crate::criterion::Criterion; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::execution::context::SessionContext; use parking_lot::Mutex; use std::hint::black_box; diff --git a/datafusion/core/benches/reset_plan_states.rs b/datafusion/core/benches/reset_plan_states.rs new file mode 100644 index 0000000000000..f2f81f755b96e --- /dev/null +++ b/datafusion/core/benches/reset_plan_states.rs @@ -0,0 +1,198 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::{Arc, LazyLock}; + +use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion::prelude::SessionContext; +use datafusion_catalog::MemTable; +use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::displayable; +use datafusion_physical_plan::execution_plan::reset_plan_states; +use tokio::runtime::Runtime; + +const NUM_FIELDS: usize = 1000; +const PREDICATE_LEN: usize = 50; + +static SCHEMA: LazyLock = LazyLock::new(|| { + Arc::new(Schema::new( + (0..NUM_FIELDS) + .map(|i| Arc::new(Field::new(format!("x_{i}"), DataType::Int64, false))) + .collect::(), + )) +}); + +fn col_name(i: usize) -> String { + format!("x_{i}") +} + +fn aggr_name(i: usize) -> String { + format!("aggr_{i}") +} + +fn physical_plan( + ctx: &SessionContext, + rt: &Runtime, + sql: &str, +) -> Arc { + rt.block_on(async { + ctx.sql(sql) + .await + .unwrap() + .create_physical_plan() + .await + .unwrap() + }) +} + +fn predicate(col_name: impl Fn(usize) -> String, len: usize) -> String { + let mut predicate = String::new(); + for i in 0..len { + if i > 0 { + predicate.push_str(" AND "); + } + predicate.push_str(&col_name(i)); + predicate.push_str(" = "); + predicate.push_str(&i.to_string()); + } + predicate +} + +/// Returns a typical plan for the query like: +/// +/// ```sql +/// SELECT aggr1(col1) as aggr1, aggr2(col2) as aggr2 FROM t +/// WHERE p1 +/// HAVING p2 +/// ``` +/// +/// Where `p1` and `p2` some long predicates. +/// +fn query1() -> String { + let mut query = String::new(); + query.push_str("SELECT "); + for i in 0..NUM_FIELDS { + if i > 0 { + query.push_str(", "); + } + query.push_str("AVG("); + query.push_str(&col_name(i)); + query.push_str(") AS "); + query.push_str(&aggr_name(i)); + } + query.push_str(" FROM t WHERE "); + query.push_str(&predicate(col_name, PREDICATE_LEN)); + query.push_str(" HAVING "); + query.push_str(&predicate(aggr_name, PREDICATE_LEN)); + query +} + +/// Returns a typical plan for the query like: +/// +/// ```sql +/// SELECT projection FROM t JOIN v ON t.a = v.a +/// WHERE p1 +/// ``` +/// +fn query2() -> String { + let mut query = String::new(); + query.push_str("SELECT "); + for i in (0..NUM_FIELDS).step_by(2) { + if i > 0 { + query.push_str(", "); + } + if (i / 2) % 2 == 0 { + query.push_str(&format!("t.{}", col_name(i))); + } else { + query.push_str(&format!("v.{}", col_name(i))); + } + } + query.push_str(" FROM t JOIN v ON t.x_0 = v.x_0 WHERE "); + + fn qualified_name(i: usize) -> String { + format!("t.{}", col_name(i)) + } + + query.push_str(&predicate(qualified_name, PREDICATE_LEN)); + query +} + +/// Returns a typical plan for the query like: +/// +/// ```sql +/// SELECT projection FROM t +/// WHERE p +/// ``` +/// +fn query3() -> String { + let mut query = String::new(); + query.push_str("SELECT "); + + // Create non-trivial projection. + for i in 0..NUM_FIELDS / 2 { + if i > 0 { + query.push_str(", "); + } + query.push_str(&col_name(i * 2)); + query.push_str(" + "); + query.push_str(&col_name(i * 2 + 1)); + } + + query.push_str(" FROM t WHERE "); + query.push_str(&predicate(col_name, PREDICATE_LEN)); + query +} + +fn run_reset_states(b: &mut criterion::Bencher, plan: &Arc) { + b.iter(|| std::hint::black_box(reset_plan_states(Arc::clone(plan)).unwrap())); +} + +/// Benchmark is intended to measure overhead of actions, required to perform +/// making an independent instance of the execution plan to re-execute it, avoiding +/// re-planning stage. +fn bench_reset_plan_states(c: &mut Criterion) { + let rt = Runtime::new().unwrap(); + let ctx = SessionContext::new(); + ctx.register_table( + "t", + Arc::new(MemTable::try_new(Arc::clone(&SCHEMA), vec![vec![], vec![]]).unwrap()), + ) + .unwrap(); + + ctx.register_table( + "v", + Arc::new(MemTable::try_new(Arc::clone(&SCHEMA), vec![vec![], vec![]]).unwrap()), + ) + .unwrap(); + + macro_rules! bench_query { + ($query_producer: expr) => {{ + let sql = $query_producer(); + let plan = physical_plan(&ctx, &rt, &sql); + log::debug!("plan:\n{}", displayable(plan.as_ref()).indent(true)); + move |b| run_reset_states(b, &plan) + }}; + } + + c.bench_function("query1", bench_query!(query1)); + c.bench_function("query2", bench_query!(query2)); + c.bench_function("query3", bench_query!(query3)); +} + +criterion_group!(benches, bench_reset_plan_states); +criterion_main!(benches); diff --git a/datafusion/core/benches/sort_limit_query_sql.rs b/datafusion/core/benches/sort_limit_query_sql.rs index c18070fb7725e..54cd9a0bcd547 100644 --- a/datafusion/core/benches/sort_limit_query_sql.rs +++ b/datafusion/core/benches/sort_limit_query_sql.rs @@ -15,9 +15,7 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -use criterion::Criterion; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::datasource::file_format::csv::CsvFormat; use datafusion::datasource::listing::{ ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, @@ -27,9 +25,6 @@ use datafusion::prelude::SessionConfig; use parking_lot::Mutex; use std::sync::Arc; -extern crate arrow; -extern crate datafusion; - use arrow::datatypes::{DataType, Field, Schema}; use datafusion::datasource::MemTable; diff --git a/datafusion/core/benches/sql_planner.rs b/datafusion/core/benches/sql_planner.rs index 7cce7e0bd7db7..59502da987904 100644 --- a/datafusion/core/benches/sql_planner.rs +++ b/datafusion/core/benches/sql_planner.rs @@ -15,20 +15,15 @@ // specific language governing permissions and limitations // under the License. -extern crate arrow; -#[macro_use] -extern crate criterion; -extern crate datafusion; - mod data_utils; -use crate::criterion::Criterion; use arrow::array::PrimitiveArray; use arrow::array::{ArrayRef, RecordBatch}; use arrow::datatypes::ArrowNativeTypeOp; use arrow::datatypes::ArrowPrimitiveType; use arrow::datatypes::{DataType, Field, Fields, Schema}; use criterion::Bencher; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::datasource::MemTable; use datafusion::execution::context::SessionContext; use datafusion_common::{ScalarValue, config::Dialect}; @@ -78,6 +73,21 @@ fn create_table_provider(column_prefix: &str, num_columns: usize) -> Arc Arc { + let struct_fields = Fields::from(vec![ + Field::new("value", DataType::Int32, true), + Field::new("label", DataType::Utf8, true), + ]); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("props", DataType::Struct(struct_fields), true), + ])); + MemTable::try_new(schema, vec![vec![]]) + .map(Arc::new) + .unwrap() +} + fn create_context() -> SessionContext { let ctx = SessionContext::new(); ctx.register_table("t1", create_table_provider("a", 200)) @@ -88,6 +98,10 @@ fn create_context() -> SessionContext { .unwrap(); ctx.register_table("t1000", create_table_provider("d", 1000)) .unwrap(); + ctx.register_table("struct_t1", create_struct_table_provider()) + .unwrap(); + ctx.register_table("struct_t2", create_struct_table_provider()) + .unwrap(); ctx } @@ -118,6 +132,11 @@ fn register_clickbench_hits_table(rt: &Runtime) -> SessionContext { let sql = format!("CREATE EXTERNAL TABLE hits STORED AS PARQUET LOCATION '{path}'"); + // ClickBench partitioned dataset was written by an ancient version of pyarrow that + // that wrote strings with the wrong logical type. To read it correctly, we must + // automatically convert binary to string. + rt.block_on(ctx.sql("SET datafusion.execution.parquet.binary_as_string = true;")) + .unwrap(); rt.block_on(ctx.sql(&sql)).unwrap(); let count = @@ -419,6 +438,25 @@ fn criterion_benchmark(c: &mut Criterion) { }); }); + let struct_agg_sort_query = "SELECT \ + struct_t1.props['label'], \ + SUM(struct_t1.props['value']), \ + MAX(struct_t2.props['value']), \ + COUNT(*) \ + FROM struct_t1 \ + JOIN struct_t2 ON struct_t1.id = struct_t2.id \ + WHERE struct_t1.props['value'] > 50 \ + GROUP BY struct_t1.props['label'] \ + ORDER BY SUM(struct_t1.props['value']) DESC"; + + // -- Struct column benchmarks -- + c.bench_function("logical_plan_struct_join_agg_sort", |b| { + b.iter(|| logical_plan(&ctx, &rt, struct_agg_sort_query)) + }); + c.bench_function("physical_plan_struct_join_agg_sort", |b| { + b.iter(|| physical_plan(&ctx, &rt, struct_agg_sort_query)) + }); + // -- Sorted Queries -- // 100, 200 && 300 is taking too long - https://github.com/apache/datafusion/issues/18366 // Logical Plan for datatype Int64 and UInt64 differs, UInt64 Logical Plan's Union are wrapped diff --git a/datafusion/core/benches/topk_aggregate.rs b/datafusion/core/benches/topk_aggregate.rs index 7979efdec605e..f71cf1087be7d 100644 --- a/datafusion/core/benches/topk_aggregate.rs +++ b/datafusion/core/benches/topk_aggregate.rs @@ -17,6 +17,9 @@ mod data_utils; +use arrow::array::Int64Builder; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; use criterion::{Criterion, criterion_group, criterion_main}; use data_utils::make_data; @@ -24,12 +27,53 @@ use datafusion::physical_plan::{collect, displayable}; use datafusion::prelude::SessionContext; use datafusion::{datasource::MemTable, error::Result}; use datafusion_execution::config::SessionConfig; +use rand::SeedableRng; +use rand::seq::SliceRandom; use std::hint::black_box; use std::sync::Arc; use tokio::runtime::Runtime; const LIMIT: usize = 10; +/// Create deterministic data for DISTINCT benchmarks with predictable trace_ids +/// This ensures consistent results across benchmark runs +fn make_distinct_data( + partition_cnt: i32, + sample_cnt: i32, +) -> Result<(Arc, Vec>)> { + let mut rng = rand::rngs::SmallRng::from_seed([42; 32]); + let total_samples = partition_cnt as usize * sample_cnt as usize; + let mut ids = Vec::new(); + for i in 0..total_samples { + ids.push(i as i64); + } + ids.shuffle(&mut rng); + + let mut global_idx = 0; + let schema = test_distinct_schema(); + let mut partitions = vec![]; + for _ in 0..partition_cnt { + let mut id_builder = Int64Builder::new(); + + for _ in 0..sample_cnt { + let id = ids[global_idx]; + id_builder.append_value(id); + global_idx += 1; + } + + let id_col = Arc::new(id_builder.finish()); + let batch = RecordBatch::try_new(schema.clone(), vec![id_col])?; + partitions.push(vec![batch]); + } + + Ok((schema, partitions)) +} + +/// Returns a Schema for distinct benchmarks with i64 trace_id +fn test_distinct_schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)])) +} + async fn create_context( partition_cnt: i32, sample_cnt: i32, @@ -50,6 +94,25 @@ async fn create_context( Ok(ctx) } +async fn create_context_distinct( + partition_cnt: i32, + sample_cnt: i32, + use_topk: bool, +) -> Result { + // Use deterministic data generation for DISTINCT queries to ensure consistent results + let (schema, parts) = make_distinct_data(partition_cnt, sample_cnt).unwrap(); + let mem_table = Arc::new(MemTable::try_new(schema, parts).unwrap()); + + // Create the DataFrame + let mut cfg = SessionConfig::new(); + let opts = cfg.options_mut(); + opts.optimizer.enable_topk_aggregation = use_topk; + let ctx = SessionContext::new_with_config(cfg); + let _ = ctx.register_table("traces", mem_table)?; + + Ok(ctx) +} + fn run(rt: &Runtime, ctx: SessionContext, limit: usize, use_topk: bool, asc: bool) { black_box(rt.block_on(async { aggregate(ctx, limit, use_topk, asc).await })).unwrap(); } @@ -59,6 +122,17 @@ fn run_string(rt: &Runtime, ctx: SessionContext, limit: usize, use_topk: bool) { .unwrap(); } +fn run_distinct( + rt: &Runtime, + ctx: SessionContext, + limit: usize, + use_topk: bool, + asc: bool, +) { + black_box(rt.block_on(async { aggregate_distinct(ctx, limit, use_topk, asc).await })) + .unwrap(); +} + async fn aggregate( ctx: SessionContext, limit: usize, @@ -133,6 +207,84 @@ async fn aggregate_string( Ok(()) } +async fn aggregate_distinct( + ctx: SessionContext, + limit: usize, + use_topk: bool, + asc: bool, +) -> Result<()> { + let order_direction = if asc { "asc" } else { "desc" }; + let sql = format!( + "select id from traces group by id order by id {order_direction} limit {limit};" + ); + let df = ctx.sql(sql.as_str()).await?; + let plan = df.create_physical_plan().await?; + let actual_phys_plan = displayable(plan.as_ref()).indent(true).to_string(); + assert_eq!( + actual_phys_plan.contains(&format!("lim=[{limit}]")), + use_topk + ); + let batches = collect(plan, ctx.task_ctx()).await?; + assert_eq!(batches.len(), 1); + let batch = batches.first().unwrap(); + assert_eq!(batch.num_rows(), LIMIT); + + let actual = format!("{}", pretty_format_batches(&batches)?).to_lowercase(); + + let expected_asc = r#" ++----+ +| id | ++----+ +| 0 | +| 1 | +| 2 | +| 3 | +| 4 | +| 5 | +| 6 | +| 7 | +| 8 | +| 9 | ++----+ +"# + .trim(); + + let expected_desc = r#" ++---------+ +| id | ++---------+ +| 9999999 | +| 9999998 | +| 9999997 | +| 9999996 | +| 9999995 | +| 9999994 | +| 9999993 | +| 9999992 | +| 9999991 | +| 9999990 | ++---------+ +"# + .trim(); + + // Verify exact results match expected values + if asc { + assert_eq!( + actual.trim(), + expected_asc, + "Ascending DISTINCT results do not match expected values" + ); + } else { + assert_eq!( + actual.trim(), + expected_desc, + "Descending DISTINCT results do not match expected values" + ); + } + + Ok(()) +} + fn criterion_benchmark(c: &mut Criterion) { let rt = Runtime::new().unwrap(); let limit = LIMIT; @@ -253,6 +405,37 @@ fn criterion_benchmark(c: &mut Criterion) { .as_str(), |b| b.iter(|| run_string(&rt, ctx.clone(), limit, true)), ); + + // DISTINCT benchmarks + let ctx = rt.block_on(async { + create_context_distinct(partitions, samples, false) + .await + .unwrap() + }); + c.bench_function( + format!("distinct {} rows desc [no TopK]", partitions * samples).as_str(), + |b| b.iter(|| run_distinct(&rt, ctx.clone(), limit, false, false)), + ); + + c.bench_function( + format!("distinct {} rows asc [no TopK]", partitions * samples).as_str(), + |b| b.iter(|| run_distinct(&rt, ctx.clone(), limit, false, true)), + ); + + let ctx_topk = rt.block_on(async { + create_context_distinct(partitions, samples, true) + .await + .unwrap() + }); + c.bench_function( + format!("distinct {} rows desc [TopK]", partitions * samples).as_str(), + |b| b.iter(|| run_distinct(&rt, ctx_topk.clone(), limit, true, false)), + ); + + c.bench_function( + format!("distinct {} rows asc [TopK]", partitions * samples).as_str(), + |b| b.iter(|| run_distinct(&rt, ctx_topk.clone(), limit, true, true)), + ); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/core/benches/window_query_sql.rs b/datafusion/core/benches/window_query_sql.rs index e4643567a0f0c..1657cae913fef 100644 --- a/datafusion/core/benches/window_query_sql.rs +++ b/datafusion/core/benches/window_query_sql.rs @@ -15,14 +15,9 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -extern crate arrow; -extern crate datafusion; - mod data_utils; -use crate::criterion::Criterion; +use criterion::{Criterion, criterion_group, criterion_main}; use data_utils::create_table_provider; use datafusion::error::Result; use datafusion::execution::context::SessionContext; diff --git a/datafusion/core/src/bin/print_functions_docs.rs b/datafusion/core/src/bin/print_functions_docs.rs index 74a10bf079e61..2466d42692192 100644 --- a/datafusion/core/src/bin/print_functions_docs.rs +++ b/datafusion/core/src/bin/print_functions_docs.rs @@ -84,30 +84,6 @@ fn print_window_docs() -> Result { print_docs(providers, window_doc_sections::doc_sections()) } -// Temporary method useful to semi automate -// the migration of UDF documentation generation from code based -// to attribute based -// To be removed -#[allow(dead_code)] -fn save_doc_code_text(documentation: &Documentation, name: &str) { - let attr_text = documentation.to_doc_attribute(); - - let file_path = format!("{name}.txt"); - if std::path::Path::new(&file_path).exists() { - std::fs::remove_file(&file_path).unwrap(); - } - - // Open the file in append mode, create it if it doesn't exist - let mut file = std::fs::OpenOptions::new() - .append(true) // Open in append mode - .create(true) // Create the file if it doesn't exist - .open(file_path) - .unwrap(); - - use std::io::Write; - file.write_all(attr_text.as_bytes()).unwrap(); -} - #[expect(clippy::needless_pass_by_value)] fn print_docs( providers: Vec>, @@ -306,8 +282,7 @@ impl DocProvider for WindowUDF { } } -#[allow(clippy::borrowed_box)] -#[allow(clippy::ptr_arg)] +#[expect(clippy::borrowed_box)] fn get_names_and_aliases(functions: &Vec<&Box>) -> Vec { functions .iter() diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index fe760760eef3f..2292f5855bfde 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -78,9 +78,11 @@ pub struct DataFrameWriteOptions { /// Controls how new data should be written to the table, determining whether /// to append, overwrite, or replace existing data. insert_op: InsertOp, - /// Controls if all partitions should be coalesced into a single output file - /// Generally will have slower performance when set to true. - single_file_output: bool, + /// Controls if all partitions should be coalesced into a single output file. + /// - `None`: Use automatic mode (extension-based heuristic) + /// - `Some(true)`: Force single file output at exact path + /// - `Some(false)`: Force directory output with generated filenames + single_file_output: Option, /// Sets which columns should be used for hive-style partitioned writes by name. /// Can be set to empty vec![] for non-partitioned writes. partition_by: Vec, @@ -94,7 +96,7 @@ impl DataFrameWriteOptions { pub fn new() -> Self { DataFrameWriteOptions { insert_op: InsertOp::Append, - single_file_output: false, + single_file_output: None, partition_by: vec![], sort_by: vec![], } @@ -108,9 +110,13 @@ impl DataFrameWriteOptions { /// Set the single_file_output value to true or false /// - /// When set to true, an output file will always be created even if the DataFrame is empty + /// - `true`: Force single file output at the exact path specified + /// - `false`: Force directory output with generated filenames + /// + /// When not called, automatic mode is used (extension-based heuristic). + /// When set to true, an output file will always be created even if the DataFrame is empty. pub fn with_single_file_output(mut self, single_file_output: bool) -> Self { - self.single_file_output = single_file_output; + self.single_file_output = Some(single_file_output); self } @@ -125,6 +131,15 @@ impl DataFrameWriteOptions { self.sort_by = sort_by; self } + + /// Build the options HashMap to pass to CopyTo for sink configuration. + fn build_sink_options(&self) -> HashMap { + let mut options = HashMap::new(); + if let Some(single_file) = self.single_file_output { + options.insert("single_file_output".to_string(), single_file.to_string()); + } + options + } } impl Default for DataFrameWriteOptions { @@ -447,15 +462,31 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` - pub fn drop_columns(self, columns: &[&str]) -> Result { + pub fn drop_columns(self, columns: &[T]) -> Result + where + T: Into + Clone, + { let fields_to_drop = columns .iter() - .flat_map(|name| { - self.plan - .schema() - .qualified_fields_with_unqualified_name(name) + .flat_map(|col| { + let column: Column = col.clone().into(); + match column.relation.as_ref() { + Some(_) => { + // qualified_field_from_column returns Result<(Option<&TableReference>, &FieldRef)> + vec![self.plan.schema().qualified_field_from_column(&column)] + } + None => { + // qualified_fields_with_unqualified_name returns Vec<(Option<&TableReference>, &FieldRef)> + self.plan + .schema() + .qualified_fields_with_unqualified_name(&column.name) + .into_iter() + .map(Ok) + .collect::>() + } + } }) - .collect::>(); + .collect::, _>>()?; let expr: Vec = self .plan .schema() @@ -481,7 +512,7 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_json("tests/data/unnest.json", NdJsonReadOptions::default()).await?; + /// let df = ctx.read_json("tests/data/unnest.json", JsonReadOptions::default()).await?; /// // expand into multiple columns if it's json array, flatten field name if it's nested structure /// let df = df.unnest_columns(&["b","c","d"])?; /// let expected = vec![ @@ -2024,6 +2055,8 @@ impl DataFrame { let file_type = format_as_file_type(format); + let copy_options = options.build_sink_options(); + let plan = if options.sort_by.is_empty() { self.plan } else { @@ -2036,7 +2069,7 @@ impl DataFrame { plan, path.into(), file_type, - HashMap::new(), + copy_options, options.partition_by, )? .build()?; @@ -2092,6 +2125,8 @@ impl DataFrame { let file_type = format_as_file_type(format); + let copy_options = options.build_sink_options(); + let plan = if options.sort_by.is_empty() { self.plan } else { @@ -2104,7 +2139,7 @@ impl DataFrame { plan, path.into(), file_type, - Default::default(), + copy_options, options.partition_by, )? .build()?; @@ -2465,6 +2500,48 @@ impl DataFrame { .collect() } + /// Find qualified columns for this dataframe from names + /// + /// # Arguments + /// * `names` - Unqualified names to find. + /// + /// # Example + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # use datafusion_common::ScalarValue; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let ctx = SessionContext::new(); + /// ctx.register_csv("first_table", "tests/data/example.csv", CsvReadOptions::new()) + /// .await?; + /// let df = ctx.table("first_table").await?; + /// ctx.register_csv("second_table", "tests/data/example.csv", CsvReadOptions::new()) + /// .await?; + /// let df2 = ctx.table("second_table").await?; + /// let join_expr = df.find_qualified_columns(&["a"])?.iter() + /// .zip(df2.find_qualified_columns(&["a"])?.iter()) + /// .map(|(col1, col2)| col(*col1).eq(col(*col2))) + /// .collect::>(); + /// let df3 = df.join_on(df2, JoinType::Inner, join_expr)?; + /// # Ok(()) + /// # } + /// ``` + pub fn find_qualified_columns( + &self, + names: &[&str], + ) -> Result, &FieldRef)>> { + let schema = self.logical_plan().schema(); + names + .iter() + .map(|name| { + schema + .qualified_field_from_column(&Column::from_name(*name)) + .map_err(|_| plan_datafusion_err!("Column '{}' not found", name)) + }) + .collect() + } + /// Helper for creating DataFrame. /// # Example /// ``` diff --git a/datafusion/core/src/dataframe/parquet.rs b/datafusion/core/src/dataframe/parquet.rs index 6edf628e2d6d6..54dadfd78cbc2 100644 --- a/datafusion/core/src/dataframe/parquet.rs +++ b/datafusion/core/src/dataframe/parquet.rs @@ -76,6 +76,8 @@ impl DataFrame { let file_type = format_as_file_type(format); + let copy_options = options.build_sink_options(); + let plan = if options.sort_by.is_empty() { self.plan } else { @@ -88,7 +90,7 @@ impl DataFrame { plan, path.into(), file_type, - Default::default(), + copy_options, options.partition_by, )? .build()?; @@ -324,4 +326,156 @@ mod tests { Ok(()) } + + /// Test FileOutputMode::SingleFile - explicitly request single file output + /// for paths WITHOUT file extensions. This verifies the fix for the regression + /// where extension heuristics ignored the explicit with_single_file_output(true). + #[tokio::test] + async fn test_file_output_mode_single_file() -> Result<()> { + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + + let ctx = SessionContext::new(); + let tmp_dir = TempDir::new()?; + + // Path WITHOUT .parquet extension - this is the key scenario + let output_path = tmp_dir.path().join("data_no_ext"); + let output_path_str = output_path.to_str().unwrap(); + + let df = ctx.read_batch(RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + )?)?; + + // Explicitly request single file output + df.write_parquet( + output_path_str, + DataFrameWriteOptions::new().with_single_file_output(true), + None, + ) + .await?; + + // Verify: output should be a FILE, not a directory + assert!( + output_path.is_file(), + "Expected single file at {:?}, but got is_file={}, is_dir={}", + output_path, + output_path.is_file(), + output_path.is_dir() + ); + + // Verify the file is readable as parquet + let file = std::fs::File::open(&output_path)?; + let reader = parquet::file::reader::SerializedFileReader::new(file)?; + let metadata = reader.metadata(); + assert_eq!(metadata.num_row_groups(), 1); + assert_eq!(metadata.file_metadata().num_rows(), 3); + + Ok(()) + } + + /// Test FileOutputMode::Automatic - uses extension heuristic. + /// Path WITH extension -> single file; path WITHOUT extension -> directory. + #[tokio::test] + async fn test_file_output_mode_automatic() -> Result<()> { + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + + let ctx = SessionContext::new(); + let tmp_dir = TempDir::new()?; + + let schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + )?; + + // Case 1: Path WITH extension -> should create single file (Automatic mode) + let output_with_ext = tmp_dir.path().join("data.parquet"); + let df = ctx.read_batch(batch.clone())?; + df.write_parquet( + output_with_ext.to_str().unwrap(), + DataFrameWriteOptions::new(), // Automatic mode (default) + None, + ) + .await?; + + assert!( + output_with_ext.is_file(), + "Path with extension should be a single file, got is_file={}, is_dir={}", + output_with_ext.is_file(), + output_with_ext.is_dir() + ); + + // Case 2: Path WITHOUT extension -> should create directory (Automatic mode) + let output_no_ext = tmp_dir.path().join("data_dir"); + let df = ctx.read_batch(batch)?; + df.write_parquet( + output_no_ext.to_str().unwrap(), + DataFrameWriteOptions::new(), // Automatic mode (default) + None, + ) + .await?; + + assert!( + output_no_ext.is_dir(), + "Path without extension should be a directory, got is_file={}, is_dir={}", + output_no_ext.is_file(), + output_no_ext.is_dir() + ); + + Ok(()) + } + + /// Test FileOutputMode::Directory - explicitly request directory output + /// even for paths WITH file extensions. + #[tokio::test] + async fn test_file_output_mode_directory() -> Result<()> { + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + + let ctx = SessionContext::new(); + let tmp_dir = TempDir::new()?; + + // Path WITH .parquet extension but explicitly requesting directory output + let output_path = tmp_dir.path().join("output.parquet"); + let output_path_str = output_path.to_str().unwrap(); + + let df = ctx.read_batch(RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + )?)?; + + // Explicitly request directory output (single_file_output = false) + df.write_parquet( + output_path_str, + DataFrameWriteOptions::new().with_single_file_output(false), + None, + ) + .await?; + + // Verify: output should be a DIRECTORY, not a single file + assert!( + output_path.is_dir(), + "Expected directory at {:?}, but got is_file={}, is_dir={}", + output_path, + output_path.is_file(), + output_path.is_dir() + ); + + // Verify the directory contains parquet file(s) + let entries: Vec<_> = std::fs::read_dir(&output_path)? + .filter_map(|e| e.ok()) + .collect(); + assert!( + !entries.is_empty(), + "Directory should contain at least one file" + ); + + Ok(()) + } } diff --git a/datafusion/core/src/datasource/file_format/avro.rs b/datafusion/core/src/datasource/file_format/avro.rs index cad35d43db486..7cf23ee294d86 100644 --- a/datafusion/core/src/datasource/file_format/avro.rs +++ b/datafusion/core/src/datasource/file_format/avro.rs @@ -95,7 +95,7 @@ mod tests { .schema() .fields() .iter() - .map(|f| format!("{}: {:?}", f.name(), f.data_type())) + .map(|f| format!("{}: {}", f.name(), f.data_type())) .collect(); assert_eq!( vec![ @@ -109,7 +109,7 @@ mod tests { "double_col: Float64", "date_string_col: Binary", "string_col: Binary", - "timestamp_col: Timestamp(Microsecond, None)", + "timestamp_col: Timestamp(µs)", ], x ); diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index cb2e9d787ee92..5b3e22705620e 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -25,7 +25,7 @@ mod tests { use super::*; use crate::datasource::file_format::test_util::scan_format; - use crate::prelude::{NdJsonReadOptions, SessionConfig, SessionContext}; + use crate::prelude::{SessionConfig, SessionContext}; use crate::test::object_store::local_unpartitioned_file; use arrow::array::RecordBatch; use arrow_schema::Schema; @@ -46,12 +46,54 @@ mod tests { use datafusion_common::internal_err; use datafusion_common::stats::Precision; + use crate::execution::options::JsonReadOptions; use datafusion_common::Result; + use datafusion_datasource::file_compression_type::FileCompressionType; use futures::StreamExt; use insta::assert_snapshot; use object_store::local::LocalFileSystem; use regex::Regex; use rstest::rstest; + // ==================== Test Helpers ==================== + + /// Create a temporary JSON file and return (TempDir, path) + fn create_temp_json(content: &str) -> (tempfile::TempDir, String) { + let tmp_dir = tempfile::TempDir::new().unwrap(); + let path = tmp_dir.path().join("test.json"); + std::fs::write(&path, content).unwrap(); + (tmp_dir, path.to_string_lossy().to_string()) + } + + /// Infer schema from JSON array format file + async fn infer_json_array_schema( + content: &str, + ) -> Result { + let (_tmp_dir, path) = create_temp_json(content); + let session = SessionContext::new(); + let ctx = session.state(); + let store = Arc::new(LocalFileSystem::new()) as _; + let format = JsonFormat::default().with_newline_delimited(false); + format + .infer_schema(&ctx, &store, &[local_unpartitioned_file(&path)]) + .await + } + + /// Register a JSON array table and run a query + async fn query_json_array(content: &str, query: &str) -> Result> { + let (_tmp_dir, path) = create_temp_json(content); + let ctx = SessionContext::new(); + let options = JsonReadOptions::default().newline_delimited(false); + ctx.register_json("test_table", &path, options).await?; + ctx.sql(query).await?.collect().await + } + + /// Register a JSON array table and run a query, return formatted string + async fn query_json_array_str(content: &str, query: &str) -> Result { + let result = query_json_array(content, query).await?; + Ok(batches_to_string(&result)) + } + + // ==================== Existing Tests ==================== #[tokio::test] async fn read_small_batches() -> Result<()> { @@ -208,7 +250,7 @@ mod tests { let ctx = SessionContext::new_with_config(config); let table_path = "tests/data/1.json"; - let options = NdJsonReadOptions::default(); + let options = JsonReadOptions::default(); ctx.register_json("json_parallel", table_path, options) .await?; @@ -240,7 +282,7 @@ mod tests { let ctx = SessionContext::new_with_config(config); let table_path = "tests/data/empty.json"; - let options = NdJsonReadOptions::default(); + let options = JsonReadOptions::default(); ctx.register_json("json_parallel_empty", table_path, options) .await?; @@ -314,7 +356,6 @@ mod tests { .digest(r#"{ "c1": 11, "c2": 12, "c3": 13, "c4": 14, "c5": 15 }"#.into()); let mut all_batches = RecordBatch::new_empty(schema.clone()); - // We get RequiresMoreData after 2 batches because of how json::Decoder works for _ in 0..2 { let output = deserializer.next()?; let DeserializerOutput::RecordBatch(batch) = output else { @@ -354,11 +395,11 @@ mod tests { async fn test_write_empty_json_from_sql() -> Result<()> { let ctx = SessionContext::new(); let tmp_dir = tempfile::TempDir::new()?; - let path = format!("{}/empty_sql.json", tmp_dir.path().to_string_lossy()); + let path = tmp_dir.path().join("empty_sql.json"); + let path = path.to_string_lossy().to_string(); let df = ctx.sql("SELECT CAST(1 AS BIGINT) AS id LIMIT 0").await?; df.write_json(&path, crate::dataframe::DataFrameWriteOptions::new(), None) .await?; - // Expected the file to exist and be empty assert!(std::path::Path::new(&path).exists()); let metadata = std::fs::metadata(&path)?; assert_eq!(metadata.len(), 0); @@ -381,14 +422,216 @@ mod tests { )?; let tmp_dir = tempfile::TempDir::new()?; - let path = format!("{}/empty_batch.json", tmp_dir.path().to_string_lossy()); + let path = tmp_dir.path().join("empty_batch.json"); + let path = path.to_string_lossy().to_string(); let df = ctx.read_batch(empty_batch.clone())?; df.write_json(&path, crate::dataframe::DataFrameWriteOptions::new(), None) .await?; - // Expected the file to exist and be empty assert!(std::path::Path::new(&path).exists()); let metadata = std::fs::metadata(&path)?; assert_eq!(metadata.len(), 0); Ok(()) } + + // ==================== JSON Array Format Tests ==================== + + #[tokio::test] + async fn test_json_array_schema_inference() -> Result<()> { + let schema = infer_json_array_schema( + r#"[{"a": 1, "b": 2.0, "c": true}, {"a": 2, "b": 3.5, "c": false}]"#, + ) + .await?; + + let fields: Vec<_> = schema + .fields() + .iter() + .map(|f| format!("{}: {:?}", f.name(), f.data_type())) + .collect(); + assert_eq!(vec!["a: Int64", "b: Float64", "c: Boolean"], fields); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_empty() -> Result<()> { + let schema = infer_json_array_schema("[]").await?; + assert_eq!(schema.fields().len(), 0); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_nested_struct() -> Result<()> { + let schema = infer_json_array_schema( + r#"[{"id": 1, "info": {"name": "Alice", "age": 30}}]"#, + ) + .await?; + + let info_field = schema.field_with_name("info").unwrap(); + assert!(matches!(info_field.data_type(), DataType::Struct(_))); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_list_type() -> Result<()> { + let schema = + infer_json_array_schema(r#"[{"id": 1, "tags": ["a", "b", "c"]}]"#).await?; + + let tags_field = schema.field_with_name("tags").unwrap(); + assert!(matches!(tags_field.data_type(), DataType::List(_))); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_basic_query() -> Result<()> { + let result = query_json_array_str( + r#"[{"a": 1, "b": "hello"}, {"a": 2, "b": "world"}, {"a": 3, "b": "test"}]"#, + "SELECT a, b FROM test_table ORDER BY a", + ) + .await?; + + assert_snapshot!(result, @r" + +---+-------+ + | a | b | + +---+-------+ + | 1 | hello | + | 2 | world | + | 3 | test | + +---+-------+ + "); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_with_nulls() -> Result<()> { + let result = query_json_array_str( + r#"[{"id": 1, "name": "Alice"}, {"id": 2, "name": null}, {"id": 3, "name": "Charlie"}]"#, + "SELECT id, name FROM test_table ORDER BY id", + ) + .await?; + + assert_snapshot!(result, @r" + +----+---------+ + | id | name | + +----+---------+ + | 1 | Alice | + | 2 | | + | 3 | Charlie | + +----+---------+ + "); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_unnest() -> Result<()> { + let result = query_json_array_str( + r#"[{"id": 1, "values": [10, 20, 30]}, {"id": 2, "values": [40, 50]}]"#, + "SELECT id, unnest(values) as value FROM test_table ORDER BY id, value", + ) + .await?; + + assert_snapshot!(result, @r" + +----+-------+ + | id | value | + +----+-------+ + | 1 | 10 | + | 1 | 20 | + | 1 | 30 | + | 2 | 40 | + | 2 | 50 | + +----+-------+ + "); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_unnest_struct() -> Result<()> { + let result = query_json_array_str( + r#"[{"id": 1, "orders": [{"product": "A", "qty": 2}, {"product": "B", "qty": 3}]}, {"id": 2, "orders": [{"product": "C", "qty": 1}]}]"#, + "SELECT id, unnest(orders)['product'] as product, unnest(orders)['qty'] as qty FROM test_table ORDER BY id, product", + ) + .await?; + + assert_snapshot!(result, @r" + +----+---------+-----+ + | id | product | qty | + +----+---------+-----+ + | 1 | A | 2 | + | 1 | B | 3 | + | 2 | C | 1 | + +----+---------+-----+ + "); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_nested_struct_access() -> Result<()> { + let result = query_json_array_str( + r#"[{"id": 1, "dept": {"name": "Engineering", "head": "Alice"}}, {"id": 2, "dept": {"name": "Sales", "head": "Bob"}}]"#, + "SELECT id, dept['name'] as dept_name, dept['head'] as head FROM test_table ORDER BY id", + ) + .await?; + + assert_snapshot!(result, @r" + +----+-------------+-------+ + | id | dept_name | head | + +----+-------------+-------+ + | 1 | Engineering | Alice | + | 2 | Sales | Bob | + +----+-------------+-------+ + "); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_with_compression() -> Result<()> { + use flate2::Compression; + use flate2::write::GzEncoder; + use std::io::Write; + + let tmp_dir = tempfile::TempDir::new()?; + let path = tmp_dir.path().join("array.json.gz"); + let path = path.to_string_lossy().to_string(); + + let file = std::fs::File::create(&path)?; + let mut encoder = GzEncoder::new(file, Compression::default()); + encoder.write_all( + r#"[{"a": 1, "b": "hello"}, {"a": 2, "b": "world"}]"#.as_bytes(), + )?; + encoder.finish()?; + + let ctx = SessionContext::new(); + let options = JsonReadOptions::default() + .newline_delimited(false) + .file_compression_type(FileCompressionType::GZIP) + .file_extension(".json.gz"); + + ctx.register_json("test_table", &path, options).await?; + let result = ctx + .sql("SELECT a, b FROM test_table ORDER BY a") + .await? + .collect() + .await?; + + assert_snapshot!(batches_to_string(&result), @r" + +---+-------+ + | a | b | + +---+-------+ + | 1 | hello | + | 2 | world | + +---+-------+ + "); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_list_of_structs() -> Result<()> { + let batches = query_json_array( + r#"[{"id": 1, "items": [{"name": "x", "price": 10.5}]}, {"id": 2, "items": []}]"#, + "SELECT id, items FROM test_table ORDER BY id", + ) + .await?; + + assert_eq!(1, batches.len()); + assert_eq!(2, batches[0].num_rows()); + Ok(()) + } } diff --git a/datafusion/core/src/datasource/file_format/options.rs b/datafusion/core/src/datasource/file_format/options.rs index 146c5f6f5fd0f..bd0ac36087381 100644 --- a/datafusion/core/src/datasource/file_format/options.rs +++ b/datafusion/core/src/datasource/file_format/options.rs @@ -442,14 +442,23 @@ impl<'a> AvroReadOptions<'a> { } } -/// Options that control the reading of Line-delimited JSON files (NDJson) +#[deprecated( + since = "53.0.0", + note = "Use `JsonReadOptions` instead. This alias will be removed in a future version." +)] +#[doc = "Deprecated: Use [`JsonReadOptions`] instead."] +pub type NdJsonReadOptions<'a> = JsonReadOptions<'a>; + +/// Options that control the reading of JSON files. +/// +/// Supports both newline-delimited JSON (NDJSON) and JSON array formats. /// /// Note this structure is supplied when a datasource is created and -/// can not not vary from statement to statement. For settings that +/// can not vary from statement to statement. For settings that /// can vary statement to statement see /// [`ConfigOptions`](crate::config::ConfigOptions). #[derive(Clone)] -pub struct NdJsonReadOptions<'a> { +pub struct JsonReadOptions<'a> { /// The data source schema. pub schema: Option<&'a Schema>, /// Max number of rows to read from JSON files for schema inference if needed. Defaults to `DEFAULT_SCHEMA_INFER_MAX_RECORD`. @@ -465,9 +474,25 @@ pub struct NdJsonReadOptions<'a> { pub infinite: bool, /// Indicates how the file is sorted pub file_sort_order: Vec>, + /// Whether to read as newline-delimited JSON (default: true). + /// + /// When `true` (default), expects newline-delimited JSON (NDJSON): + /// ```text + /// {"key1": 1, "key2": "val"} + /// {"key1": 2, "key2": "vals"} + /// ``` + /// + /// When `false`, expects JSON array format: + /// ```text + /// [ + /// {"key1": 1, "key2": "val"}, + /// {"key1": 2, "key2": "vals"} + /// ] + /// ``` + pub newline_delimited: bool, } -impl Default for NdJsonReadOptions<'_> { +impl Default for JsonReadOptions<'_> { fn default() -> Self { Self { schema: None, @@ -477,11 +502,12 @@ impl Default for NdJsonReadOptions<'_> { file_compression_type: FileCompressionType::UNCOMPRESSED, infinite: false, file_sort_order: vec![], + newline_delimited: true, } } } -impl<'a> NdJsonReadOptions<'a> { +impl<'a> JsonReadOptions<'a> { /// Specify table_partition_cols for partition pruning pub fn table_partition_cols( mut self, @@ -529,6 +555,26 @@ impl<'a> NdJsonReadOptions<'a> { self.schema_infer_max_records = schema_infer_max_records; self } + + /// Set whether to read as newline-delimited JSON. + /// + /// When `true` (default), expects newline-delimited JSON (NDJSON): + /// ```text + /// {"key1": 1, "key2": "val"} + /// {"key1": 2, "key2": "vals"} + /// ``` + /// + /// When `false`, expects JSON array format: + /// ```text + /// [ + /// {"key1": 1, "key2": "val"}, + /// {"key1": 2, "key2": "vals"} + /// ] + /// ``` + pub fn newline_delimited(mut self, newline_delimited: bool) -> Self { + self.newline_delimited = newline_delimited; + self + } } #[async_trait] @@ -654,7 +700,7 @@ impl ReadOptions<'_> for ParquetReadOptions<'_> { } #[async_trait] -impl ReadOptions<'_> for NdJsonReadOptions<'_> { +impl ReadOptions<'_> for JsonReadOptions<'_> { fn to_listing_options( &self, config: &SessionConfig, @@ -663,7 +709,8 @@ impl ReadOptions<'_> for NdJsonReadOptions<'_> { let file_format = JsonFormat::default() .with_options(table_options.json) .with_schema_infer_max_rec(self.schema_infer_max_records) - .with_file_compression_type(self.file_compression_type.to_owned()); + .with_file_compression_type(self.file_compression_type.to_owned()) + .with_newline_delimited(self.newline_delimited); ListingOptions::new(Arc::new(file_format)) .with_file_extension(self.file_extension) diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 47ce519f01289..def3c0f35f9b3 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -130,7 +130,9 @@ mod tests { use datafusion_common::test_util::batches_to_string; use datafusion_common::{Result, ScalarValue}; use datafusion_datasource::file_format::FileFormat; - use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig}; + use datafusion_datasource::file_sink_config::{ + FileOutputMode, FileSink, FileSinkConfig, + }; use datafusion_datasource::{ListingTableUrl, PartitionedFile}; use datafusion_datasource_parquet::{ ParquetFormat, ParquetFormatFactory, ParquetSink, @@ -815,7 +817,7 @@ mod tests { .schema() .fields() .iter() - .map(|f| format!("{}: {:?}", f.name(), f.data_type())) + .map(|f| format!("{}: {}", f.name(), f.data_type())) .collect(); let y = x.join("\n"); assert_eq!(expected, y); @@ -841,7 +843,7 @@ mod tests { double_col: Float64\n\ date_string_col: Binary\n\ string_col: Binary\n\ - timestamp_col: Timestamp(Nanosecond, None)"; + timestamp_col: Timestamp(ns)"; _run_read_alltypes_plain_parquet(ForceViews::No, no_views).await?; let with_views = "id: Int32\n\ @@ -854,7 +856,7 @@ mod tests { double_col: Float64\n\ date_string_col: BinaryView\n\ string_col: BinaryView\n\ - timestamp_col: Timestamp(Nanosecond, None)"; + timestamp_col: Timestamp(ns)"; _run_read_alltypes_plain_parquet(ForceViews::Yes, with_views).await?; Ok(()) @@ -1547,6 +1549,7 @@ mod tests { insert_op: InsertOp::Overwrite, keep_partition_by_columns: false, file_extension: "parquet".into(), + file_output_mode: FileOutputMode::Automatic, }; let parquet_sink = Arc::new(ParquetSink::new( file_sink_config, @@ -1638,6 +1641,7 @@ mod tests { insert_op: InsertOp::Overwrite, keep_partition_by_columns: false, file_extension: "parquet".into(), + file_output_mode: FileOutputMode::Automatic, }; let parquet_sink = Arc::new(ParquetSink::new( file_sink_config, @@ -1728,6 +1732,7 @@ mod tests { insert_op: InsertOp::Overwrite, keep_partition_by_columns: false, file_extension: "parquet".into(), + file_output_mode: FileOutputMode::Automatic, }; let parquet_sink = Arc::new(ParquetSink::new( file_sink_config, diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 4e33f3cad51a4..5dd11739c1f57 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -110,6 +110,7 @@ mod tests { #[cfg(feature = "parquet")] use crate::datasource::file_format::parquet::ParquetFormat; use crate::datasource::listing::table::ListingTableConfigExt; + use crate::execution::options::JsonReadOptions; use crate::prelude::*; use crate::{ datasource::{ @@ -808,7 +809,7 @@ mod tests { .register_json( "t", tmp_dir.path().to_str().unwrap(), - NdJsonReadOptions::default() + JsonReadOptions::default() .schema(schema.as_ref()) .file_compression_type(file_compression_type), ) diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index 86af691fd7248..f85f15a6d8c63 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -54,7 +54,15 @@ impl TableProviderFactory for ListingTableFactory { cmd: &CreateExternalTable, ) -> Result> { // TODO (https://github.com/apache/datafusion/issues/11600) remove downcast_ref from here. Should file format factory be an extension to session state? - let session_state = state.as_any().downcast_ref::().unwrap(); + let session_state = + state + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion_common::internal_datafusion_err!( + "ListingTableFactory requires SessionState" + ) + })?; let file_format = session_state .get_file_format_factory(cmd.file_type.as_str()) .ok_or(config_datafusion_err!( @@ -161,9 +169,7 @@ impl TableProviderFactory for ListingTableFactory { } None => format!("*.{}", cmd.file_type.to_lowercase()), }; - table_path = table_path - .with_glob(glob.as_ref())? - .with_table_ref(cmd.name.clone()); + table_path = table_path.with_glob(glob.as_ref())?; } let schema = options.infer_schema(session_state, &table_path).await?; let df_schema = Arc::clone(&schema).to_dfschema()?; @@ -548,4 +554,103 @@ mod tests { "Statistics cache should not be pre-warmed when collect_statistics is disabled" ); } + + #[tokio::test] + async fn test_create_with_invalid_session() { + use async_trait::async_trait; + use datafusion_catalog::Session; + use datafusion_common::Result; + use datafusion_common::config::TableOptions; + use datafusion_execution::TaskContext; + use datafusion_execution::config::SessionConfig; + use datafusion_physical_expr::PhysicalExpr; + use datafusion_physical_plan::ExecutionPlan; + use std::any::Any; + use std::collections::HashMap; + use std::sync::Arc; + + // A mock Session that is NOT SessionState + #[derive(Debug)] + struct MockSession; + + #[async_trait] + impl Session for MockSession { + fn session_id(&self) -> &str { + "mock_session" + } + fn config(&self) -> &SessionConfig { + unimplemented!() + } + async fn create_physical_plan( + &self, + _logical_plan: &datafusion_expr::LogicalPlan, + ) -> Result> { + unimplemented!() + } + fn create_physical_expr( + &self, + _expr: datafusion_expr::Expr, + _df_schema: &DFSchema, + ) -> Result> { + unimplemented!() + } + fn scalar_functions( + &self, + ) -> &HashMap> { + unimplemented!() + } + fn aggregate_functions( + &self, + ) -> &HashMap> { + unimplemented!() + } + fn window_functions( + &self, + ) -> &HashMap> { + unimplemented!() + } + fn runtime_env(&self) -> &Arc { + unimplemented!() + } + fn execution_props( + &self, + ) -> &datafusion_expr::execution_props::ExecutionProps { + unimplemented!() + } + fn as_any(&self) -> &dyn Any { + self + } + fn table_options(&self) -> &TableOptions { + unimplemented!() + } + fn table_options_mut(&mut self) -> &mut TableOptions { + unimplemented!() + } + fn task_ctx(&self) -> Arc { + unimplemented!() + } + } + + let factory = ListingTableFactory::new(); + let mock_session = MockSession; + + let name = TableReference::bare("foo"); + let cmd = CreateExternalTable::builder( + name, + "foo.csv".to_string(), + "csv", + Arc::new(DFSchema::empty()), + ) + .build(); + + // This should return an error, not panic + let result = factory.create(&mock_session, &cmd).await; + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .strip_backtrace() + .contains("Internal error: ListingTableFactory requires SessionState") + ); + } } diff --git a/datafusion/core/src/datasource/mod.rs b/datafusion/core/src/datasource/mod.rs index 1f21d6a7e603a..32b3b0799dd85 100644 --- a/datafusion/core/src/datasource/mod.rs +++ b/datafusion/core/src/datasource/mod.rs @@ -149,10 +149,10 @@ mod tests { &self, _logical_file_schema: SchemaRef, physical_file_schema: SchemaRef, - ) -> Arc { - Arc::new(TestPhysicalExprAdapter { + ) -> Result> { + Ok(Arc::new(TestPhysicalExprAdapter { physical_file_schema, - }) + })) } } diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index 8de6a60258f08..b70791c7b2390 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -32,7 +32,7 @@ mod tests { use crate::dataframe::DataFrameWriteOptions; use crate::execution::SessionState; - use crate::prelude::{CsvReadOptions, NdJsonReadOptions, SessionContext}; + use crate::prelude::{CsvReadOptions, JsonReadOptions, SessionContext}; use crate::test::partitioned_file_groups; use datafusion_common::Result; use datafusion_common::cast::{as_int32_array, as_int64_array, as_string_array}; @@ -136,7 +136,7 @@ mod tests { .get_ext_with_compression(&file_compression_type) .unwrap(); - let read_options = NdJsonReadOptions::default() + let read_options = JsonReadOptions::default() .file_extension(ext.as_str()) .file_compression_type(file_compression_type.to_owned()); let frame = ctx.read_json(path, read_options).await.unwrap(); @@ -389,7 +389,7 @@ mod tests { let path = format!("{TEST_DATA_BASE}/1.json"); // register json file with the execution context - ctx.register_json("test", path.as_str(), NdJsonReadOptions::default()) + ctx.register_json("test", path.as_str(), JsonReadOptions::default()) .await?; // register a local file system object store for /tmp directory @@ -431,7 +431,7 @@ mod tests { } // register each partition as well as the top level dir - let json_read_option = NdJsonReadOptions::default(); + let json_read_option = JsonReadOptions::default(); ctx.register_json( "part0", &format!("{out_dir}/{part_0_name}"), @@ -511,7 +511,7 @@ mod tests { async fn read_test_data(schema_infer_max_records: usize) -> Result { let ctx = SessionContext::new(); - let options = NdJsonReadOptions { + let options = JsonReadOptions { schema_infer_max_records, ..Default::default() }; @@ -587,7 +587,7 @@ mod tests { .get_ext_with_compression(&file_compression_type) .unwrap(); - let read_option = NdJsonReadOptions::default() + let read_option = JsonReadOptions::default() .file_compression_type(file_compression_type) .file_extension(ext.as_str()); diff --git a/datafusion/core/src/datasource/physical_plan/parquet.rs b/datafusion/core/src/datasource/physical_plan/parquet.rs index ce2b05e6d3b61..4c6d915d5bcaa 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet.rs @@ -38,10 +38,10 @@ mod tests { use crate::prelude::{ParquetReadOptions, SessionConfig, SessionContext}; use crate::test::object_store::local_unpartitioned_file; use arrow::array::{ - ArrayRef, AsArray, Date64Array, Int8Array, Int32Array, Int64Array, StringArray, - StringViewArray, StructArray, TimestampNanosecondArray, + ArrayRef, AsArray, Date64Array, DictionaryArray, Int8Array, Int32Array, + Int64Array, StringArray, StringViewArray, StructArray, TimestampNanosecondArray, }; - use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaBuilder}; + use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaBuilder, UInt16Type}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; use arrow_schema::{SchemaRef, TimeUnit}; @@ -995,6 +995,7 @@ mod tests { assert_eq!(read, 1, "Expected 1 rows to match the predicate"); assert_eq!(get_value(&metrics, "row_groups_pruned_statistics"), 0); assert_eq!(get_value(&metrics, "page_index_rows_pruned"), 2); + assert_eq!(get_value(&metrics, "page_index_pages_pruned"), 1); assert_eq!(get_value(&metrics, "pushdown_rows_pruned"), 1); // If we filter with a value that is completely out of the range of the data // we prune at the row group level. @@ -1168,10 +1169,16 @@ mod tests { // There are 4 rows pruned in each of batch2, batch3, and // batch4 for a total of 12. batch1 had no pruning as c2 was // filled in as null - let (page_index_pruned, page_index_matched) = + let (page_index_rows_pruned, page_index_rows_matched) = get_pruning_metric(&metrics, "page_index_rows_pruned"); - assert_eq!(page_index_pruned, 12); - assert_eq!(page_index_matched, 6); + assert_eq!(page_index_rows_pruned, 12); + assert_eq!(page_index_rows_matched, 6); + + // each page has 2 rows, so the num of pages is 1/2 the number of rows + let (page_index_pages_pruned, page_index_pages_matched) = + get_pruning_metric(&metrics, "page_index_pages_pruned"); + assert_eq!(page_index_pages_pruned, 6); + assert_eq!(page_index_pages_matched, 3); } #[tokio::test] @@ -1734,6 +1741,7 @@ mod tests { Some(3), Some(4), Some(5), + Some(6), // last page with only one row ])); let batch1 = create_batch(vec![("int", c1.clone())]); @@ -1742,7 +1750,7 @@ mod tests { let rt = RoundTrip::new() .with_predicate(filter) .with_page_index_predicate() - .round_trip(vec![batch1]) + .round_trip(vec![batch1.clone()]) .await; let metrics = rt.parquet_exec.metrics().unwrap(); @@ -1755,14 +1763,40 @@ mod tests { | 5 | +-----+ "); - let (page_index_pruned, page_index_matched) = + let (page_index_rows_pruned, page_index_rows_matched) = get_pruning_metric(&metrics, "page_index_rows_pruned"); - assert_eq!(page_index_pruned, 4); - assert_eq!(page_index_matched, 2); + assert_eq!(page_index_rows_pruned, 5); + assert_eq!(page_index_rows_matched, 2); assert!( get_value(&metrics, "page_index_eval_time") > 0, "no eval time in metrics: {metrics:#?}" ); + + // each page has 2 rows, so the num of pages is 1/2 the number of rows + let (page_index_pages_pruned, page_index_pages_matched) = + get_pruning_metric(&metrics, "page_index_pages_pruned"); + assert_eq!(page_index_pages_pruned, 3); + assert_eq!(page_index_pages_matched, 1); + + // test with a filter that matches the page with one row + let filter = col("int").eq(lit(6_i32)); + let rt = RoundTrip::new() + .with_predicate(filter) + .with_page_index_predicate() + .round_trip(vec![batch1]) + .await; + + let metrics = rt.parquet_exec.metrics().unwrap(); + + let (page_index_rows_pruned, page_index_rows_matched) = + get_pruning_metric(&metrics, "page_index_rows_pruned"); + assert_eq!(page_index_rows_pruned, 6); + assert_eq!(page_index_rows_matched, 1); + + let (page_index_pages_pruned, page_index_pages_matched) = + get_pruning_metric(&metrics, "page_index_pages_pruned"); + assert_eq!(page_index_pages_pruned, 3); + assert_eq!(page_index_pages_matched, 1); } /// Returns a string array with contents: @@ -2229,6 +2263,48 @@ mod tests { Ok(()) } + /// Tests that constant dictionary columns (where min == max in statistics) + /// are correctly handled. This reproduced a bug where the constant value + /// from statistics had type Utf8 but the schema expected Dictionary. + #[tokio::test] + async fn test_constant_dictionary_column_parquet() -> Result<()> { + let tmp_dir = TempDir::new()?; + let path = tmp_dir.path().to_str().unwrap().to_string() + "/test.parquet"; + + // Write parquet with dictionary column where all values are the same + let schema = Arc::new(Schema::new(vec![Field::new( + "status", + DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)), + false, + )])); + let status: DictionaryArray = + vec!["active", "active"].into_iter().collect(); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(status)])?; + let file = File::create(&path)?; + let props = WriterProperties::builder() + .set_statistics_enabled(parquet::file::properties::EnabledStatistics::Page) + .build(); + let mut writer = ArrowWriter::try_new(file, schema, Some(props))?; + writer.write(&batch)?; + writer.close()?; + + // Query the constant dictionary column + let ctx = SessionContext::new(); + ctx.register_parquet("t", &path, ParquetReadOptions::default()) + .await?; + let result = ctx.sql("SELECT status FROM t").await?.collect().await?; + + insta::assert_snapshot!(batches_to_string(&result),@r" + +--------+ + | status | + +--------+ + | active | + | active | + +--------+ + "); + Ok(()) + } + fn write_file(file: &String) { let struct_fields = Fields::from(vec![ Field::new("id", DataType::Int64, false), diff --git a/datafusion/core/src/execution/context/json.rs b/datafusion/core/src/execution/context/json.rs index e9d799400863d..f7df2ad7a1cd6 100644 --- a/datafusion/core/src/execution/context/json.rs +++ b/datafusion/core/src/execution/context/json.rs @@ -15,13 +15,13 @@ // specific language governing permissions and limitations // under the License. +use super::super::options::ReadOptions; +use super::{DataFilePaths, DataFrame, ExecutionPlan, Result, SessionContext}; +use crate::execution::options::JsonReadOptions; use datafusion_common::TableReference; use datafusion_datasource_json::source::plan_to_json; use std::sync::Arc; -use super::super::options::{NdJsonReadOptions, ReadOptions}; -use super::{DataFilePaths, DataFrame, ExecutionPlan, Result, SessionContext}; - impl SessionContext { /// Creates a [`DataFrame`] for reading an JSON data source. /// @@ -32,7 +32,7 @@ impl SessionContext { pub async fn read_json( &self, table_paths: P, - options: NdJsonReadOptions<'_>, + options: JsonReadOptions<'_>, ) -> Result { self._read_type(table_paths, options).await } @@ -43,7 +43,7 @@ impl SessionContext { &self, table_ref: impl Into, table_path: impl AsRef, - options: NdJsonReadOptions<'_>, + options: JsonReadOptions<'_>, ) -> Result<()> { let listing_options = options .to_listing_options(&self.copied_config(), self.copied_table_options()); diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 712f4389f5852..b6c606ff467f9 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -2186,7 +2186,7 @@ mod tests { // configure with same memory / disk manager let memory_pool = ctx1.runtime_env().memory_pool.clone(); - let mut reservation = MemoryConsumer::new("test").register(&memory_pool); + let reservation = MemoryConsumer::new("test").register(&memory_pool); reservation.grow(100); let disk_manager = ctx1.runtime_env().disk_manager.clone(); diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 7cdbc77ae90c3..9560616c1b6da 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -878,11 +878,8 @@ impl SessionState { &self.catalog_list } - /// set the catalog list - pub(crate) fn register_catalog_list( - &mut self, - catalog_list: Arc, - ) { + /// Set the catalog list + pub fn register_catalog_list(&mut self, catalog_list: Arc) { self.catalog_list = catalog_list; } @@ -972,6 +969,7 @@ impl SessionState { /// be used for all values unless explicitly provided. /// /// See example on [`SessionState`] +#[derive(Clone)] pub struct SessionStateBuilder { session_id: Option, analyzer: Option, @@ -1843,9 +1841,14 @@ impl ContextProvider for SessionContextProvider<'_> { self.state.execution_props().query_execution_start_time, ); let simplifier = ExprSimplifier::new(simplify_context); + let schema = DFSchema::empty(); let args = args .into_iter() - .map(|arg| simplifier.simplify(arg)) + .map(|arg| { + simplifier + .coerce(arg, &schema) + .and_then(|e| simplifier.simplify(e)) + }) .collect::>>()?; let provider = tbl_func.create_table_provider(&args)?; diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index e83934a8e281d..349eee5592abe 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -#![deny(clippy::allow_attributes)] #![doc( html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" @@ -1181,8 +1180,56 @@ doc_comment::doctest!( #[cfg(doctest)] doc_comment::doctest!( - "../../../docs/source/library-user-guide/upgrading.md", - library_user_guide_upgrading + "../../../docs/source/library-user-guide/upgrading/46.0.0.md", + library_user_guide_upgrading_46_0_0 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/47.0.0.md", + library_user_guide_upgrading_47_0_0 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/48.0.0.md", + library_user_guide_upgrading_48_0_0 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/48.0.1.md", + library_user_guide_upgrading_48_0_1 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/49.0.0.md", + library_user_guide_upgrading_49_0_0 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/50.0.0.md", + library_user_guide_upgrading_50_0_0 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/51.0.0.md", + library_user_guide_upgrading_51_0_0 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/52.0.0.md", + library_user_guide_upgrading_52_0_0 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/53.0.0.md", + library_user_guide_upgrading_53_0_0 ); #[cfg(doctest)] diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index cc7d534776d7e..6765b7f79fdd2 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use crate::datasource::file_format::file_type_to_format; use crate::datasource::listing::ListingTableUrl; -use crate::datasource::physical_plan::FileSinkConfig; +use crate::datasource::physical_plan::{FileOutputMode, FileSinkConfig}; use crate::datasource::{DefaultTableSource, source_as_provider}; use crate::error::{DataFusionError, Result}; use crate::execution::context::{ExecutionProps, SessionState}; @@ -39,7 +39,7 @@ use crate::physical_expr::{create_physical_expr, create_physical_exprs}; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::physical_plan::analyze::AnalyzeExec; use crate::physical_plan::explain::ExplainExec; -use crate::physical_plan::filter::FilterExec; +use crate::physical_plan::filter::FilterExecBuilder; use crate::physical_plan::joins::utils as join_utils; use crate::physical_plan::joins::{ CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec, @@ -549,8 +549,30 @@ impl DefaultPhysicalPlanner { } }; + // Parse single_file_output option if explicitly set + let file_output_mode = match source_option_tuples + .get("single_file_output") + .map(|v| v.trim()) + { + None => FileOutputMode::Automatic, + Some("true") => FileOutputMode::SingleFile, + Some("false") => FileOutputMode::Directory, + Some(value) => { + return Err(DataFusionError::Configuration(format!( + "provided value for 'single_file_output' was not recognized: \"{value}\"" + ))); + } + }; + + // Filter out sink-related options that are not format options + let format_options: HashMap = source_option_tuples + .iter() + .filter(|(k, _)| k.as_str() != "single_file_output") + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); + let sink_format = file_type_to_format(file_type)? - .create(session_state, source_option_tuples)?; + .create(session_state, &format_options)?; // Determine extension based on format extension and compression let file_extension = match sink_format.compression_type() { @@ -571,6 +593,7 @@ impl DefaultPhysicalPlanner { insert_op: InsertOp::Append, keep_partition_by_columns, file_extension, + file_output_mode, }; let ordering = input_exec.properties().output_ordering().cloned(); @@ -655,6 +678,30 @@ impl DefaultPhysicalPlanner { ); } } + LogicalPlan::Dml(DmlStatement { + table_name, + target, + op: WriteOp::Truncate, + .. + }) => { + if let Some(provider) = + target.as_any().downcast_ref::() + { + provider + .table_provider + .truncate(session_state) + .await + .map_err(|e| { + e.context(format!( + "TRUNCATE operation on table '{table_name}'" + )) + })? + } else { + return exec_err!( + "Table source can't be downcasted to DefaultTableSource" + ); + } + } LogicalPlan::Window(Window { window_expr, .. }) => { assert_or_internal_err!( !window_expr.is_empty(), @@ -938,8 +985,12 @@ impl DefaultPhysicalPlanner { input_schema.as_arrow(), )? { PlanAsyncExpr::Sync(PlannedExprResult::Expr(runtime_expr)) => { - FilterExec::try_new(Arc::clone(&runtime_expr[0]), physical_input)? - .with_batch_size(session_state.config().batch_size())? + FilterExecBuilder::new( + Arc::clone(&runtime_expr[0]), + physical_input, + ) + .with_batch_size(session_state.config().batch_size()) + .build()? } PlanAsyncExpr::Async( async_map, @@ -949,16 +1000,17 @@ impl DefaultPhysicalPlanner { async_map.async_exprs, physical_input, )?; - FilterExec::try_new( + FilterExecBuilder::new( Arc::clone(&runtime_expr[0]), Arc::new(async_exec), - )? + ) // project the output columns excluding the async functions // The async functions are always appended to the end of the schema. - .with_projection(Some( - (0..input.schema().fields().len()).collect(), + .apply_projection(Some( + (0..input.schema().fields().len()).collect::>(), ))? - .with_batch_size(session_state.config().batch_size())? + .with_batch_size(session_state.config().batch_size()) + .build()? } _ => { return internal_err!( @@ -1091,6 +1143,7 @@ impl DefaultPhysicalPlanner { filter, join_type, null_equality, + null_aware, schema: join_schema, .. }) => { @@ -1487,6 +1540,8 @@ impl DefaultPhysicalPlanner { } else if session_state.config().target_partitions() > 1 && session_state.config().repartition_joins() && prefer_hash_join + && !*null_aware + // Null-aware joins must use CollectLeft { Arc::new(HashJoinExec::try_new( physical_left, @@ -1497,6 +1552,7 @@ impl DefaultPhysicalPlanner { None, PartitionMode::Auto, *null_equality, + *null_aware, )?) } else { Arc::new(HashJoinExec::try_new( @@ -1508,6 +1564,7 @@ impl DefaultPhysicalPlanner { None, PartitionMode::CollectLeft, *null_equality, + *null_aware, )?) }; @@ -2719,7 +2776,7 @@ impl<'a> OptimizationInvariantChecker<'a> { && !is_allowed_schema_change(previous_schema.as_ref(), plan.schema().as_ref()) { internal_err!( - "PhysicalOptimizer rule '{}' failed. Schema mismatch. Expected original schema: {:?}, got new schema: {:?}", + "PhysicalOptimizer rule '{}' failed. Schema mismatch. Expected original schema: {}, got new schema: {}", self.rule.name(), previous_schema, plan.schema() diff --git a/datafusion/core/src/prelude.rs b/datafusion/core/src/prelude.rs index 50e4a2649c923..31d9d7eb471f0 100644 --- a/datafusion/core/src/prelude.rs +++ b/datafusion/core/src/prelude.rs @@ -29,7 +29,7 @@ pub use crate::dataframe; pub use crate::dataframe::DataFrame; pub use crate::execution::context::{SQLOptions, SessionConfig, SessionContext}; pub use crate::execution::options::{ - AvroReadOptions, CsvReadOptions, NdJsonReadOptions, ParquetReadOptions, + AvroReadOptions, CsvReadOptions, JsonReadOptions, ParquetReadOptions, }; pub use datafusion_common::Column; diff --git a/datafusion/core/tests/custom_sources_cases/dml_planning.rs b/datafusion/core/tests/custom_sources_cases/dml_planning.rs index 84cf97710a902..c53819ffcca58 100644 --- a/datafusion/core/tests/custom_sources_cases/dml_planning.rs +++ b/datafusion/core/tests/custom_sources_cases/dml_planning.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Tests for DELETE and UPDATE planning to verify filter and assignment extraction. +//! Tests for DELETE, UPDATE, and TRUNCATE planning to verify filter and assignment extraction. use std::any::Any; use std::sync::{Arc, Mutex}; @@ -24,9 +24,10 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; use datafusion::datasource::{TableProvider, TableType}; use datafusion::error::Result; -use datafusion::execution::context::SessionContext; +use datafusion::execution::context::{SessionConfig, SessionContext}; use datafusion::logical_expr::Expr; use datafusion_catalog::Session; +use datafusion_common::ScalarValue; use datafusion_physical_plan::ExecutionPlan; use datafusion_physical_plan::empty::EmptyExec; @@ -94,7 +95,7 @@ impl TableProvider for CaptureDeleteProvider { } /// A TableProvider that captures filters and assignments passed to update(). -#[allow(clippy::type_complexity)] +#[expect(clippy::type_complexity)] struct CaptureUpdateProvider { schema: SchemaRef, received_filters: Arc>>>, @@ -165,6 +166,66 @@ impl TableProvider for CaptureUpdateProvider { } } +/// A TableProvider that captures whether truncate() was called. +struct CaptureTruncateProvider { + schema: SchemaRef, + truncate_called: Arc>, +} + +impl CaptureTruncateProvider { + fn new(schema: SchemaRef) -> Self { + Self { + schema, + truncate_called: Arc::new(Mutex::new(false)), + } + } + + fn was_truncated(&self) -> bool { + *self.truncate_called.lock().unwrap() + } +} + +impl std::fmt::Debug for CaptureTruncateProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CaptureTruncateProvider") + .field("schema", &self.schema) + .finish() + } +} + +#[async_trait] +impl TableProvider for CaptureTruncateProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + _projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + Ok(Arc::new(EmptyExec::new(Arc::clone(&self.schema)))) + } + + async fn truncate(&self, _state: &dyn Session) -> Result> { + *self.truncate_called.lock().unwrap() = true; + + Ok(Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![ + Field::new("count", DataType::UInt64, false), + ]))))) + } +} + fn test_schema() -> SchemaRef { Arc::new(Schema::new(vec![ Field::new("id", DataType::Int32, false), @@ -269,6 +330,28 @@ async fn test_update_assignments() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_truncate_calls_provider() -> Result<()> { + let provider = Arc::new(CaptureTruncateProvider::new(test_schema())); + let config = SessionConfig::new().set( + "datafusion.optimizer.max_passes", + &ScalarValue::UInt64(Some(0)), + ); + + let ctx = SessionContext::new_with_config(config); + + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + ctx.sql("TRUNCATE TABLE t").await?.collect().await?; + + assert!( + provider.was_truncated(), + "truncate() should be called on the TableProvider" + ); + + Ok(()) +} + #[tokio::test] async fn test_unsupported_table_delete() -> Result<()> { let schema = test_schema(); @@ -295,3 +378,18 @@ async fn test_unsupported_table_update() -> Result<()> { assert!(result.is_err() || result.unwrap().collect().await.is_err()); Ok(()) } + +#[tokio::test] +async fn test_unsupported_table_truncate() -> Result<()> { + let schema = test_schema(); + let ctx = SessionContext::new(); + + let empty_table = datafusion::datasource::empty::EmptyTable::new(schema); + ctx.register_table("empty_t", Arc::new(empty_table))?; + + let result = ctx.sql("TRUNCATE TABLE empty_t").await; + + assert!(result.is_err() || result.unwrap().collect().await.is_err()); + + Ok(()) +} diff --git a/datafusion/core/tests/custom_sources_cases/mod.rs b/datafusion/core/tests/custom_sources_cases/mod.rs index 8453615c2886b..ec0b9e253d2ab 100644 --- a/datafusion/core/tests/custom_sources_cases/mod.rs +++ b/datafusion/core/tests/custom_sources_cases/mod.rs @@ -180,10 +180,6 @@ impl ExecutionPlan for CustomExecutionPlan { Ok(Box::pin(TestCustomRecordBatchStream { nb_batch: 1 })) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { if partition.is_some() { return Ok(Statistics::new_unknown(&self.schema())); diff --git a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs index ca1eaa1f958ea..b54a57b033591 100644 --- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs +++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs @@ -29,7 +29,7 @@ use datafusion::logical_expr::TableProviderFilterPushDown; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, - SendableRecordBatchStream, Statistics, + SendableRecordBatchStream, }; use datafusion::prelude::*; use datafusion::scalar::ScalarValue; @@ -149,12 +149,6 @@ impl ExecutionPlan for CustomPlan { })), ))) } - - fn statistics(&self) -> Result { - // here we could provide more accurate statistics - // but we want to test the filter pushdown not the CBOs - Ok(Statistics::new_unknown(&self.schema())) - } } #[derive(Clone, Debug)] diff --git a/datafusion/core/tests/custom_sources_cases/statistics.rs b/datafusion/core/tests/custom_sources_cases/statistics.rs index 820c2a470b376..e81cd9f6b81b1 100644 --- a/datafusion/core/tests/custom_sources_cases/statistics.rs +++ b/datafusion/core/tests/custom_sources_cases/statistics.rs @@ -181,10 +181,6 @@ impl ExecutionPlan for StatisticsValidation { unimplemented!("This plan only serves for testing statistics") } - fn statistics(&self) -> Result { - Ok(self.stats.clone()) - } - fn partition_statistics(&self, partition: Option) -> Result { if partition.is_some() { Ok(Statistics::new_unknown(&self.schema)) diff --git a/datafusion/core/tests/data/json_array.json b/datafusion/core/tests/data/json_array.json new file mode 100644 index 0000000000000..1a8716dbf4beb --- /dev/null +++ b/datafusion/core/tests/data/json_array.json @@ -0,0 +1,5 @@ +[ + {"a": 1, "b": "hello"}, + {"a": 2, "b": "world"}, + {"a": 3, "b": "test"} +] diff --git a/datafusion/core/tests/data/json_empty_array.json b/datafusion/core/tests/data/json_empty_array.json new file mode 100644 index 0000000000000..fe51488c7066f --- /dev/null +++ b/datafusion/core/tests/data/json_empty_array.json @@ -0,0 +1 @@ +[] diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index c09db371912b0..6c0452a99bccc 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -43,6 +43,7 @@ use datafusion_functions_nested::make_array::make_array_udf; use datafusion_functions_window::expr_fn::{first_value, lead, row_number}; use insta::assert_snapshot; use object_store::local::LocalFileSystem; +use rstest::rstest; use std::collections::HashMap; use std::fs; use std::path::Path; @@ -56,9 +57,7 @@ use datafusion::error::Result; use datafusion::execution::context::SessionContext; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::logical_expr::{ColumnarValue, Volatility}; -use datafusion::prelude::{ - CsvReadOptions, JoinType, NdJsonReadOptions, ParquetReadOptions, -}; +use datafusion::prelude::{CsvReadOptions, JoinType, ParquetReadOptions}; use datafusion::test_util::{ parquet_test_data, populate_csv_partitions, register_aggregate_csv, test_table, test_table_with_cache_factory, test_table_with_name, @@ -93,6 +92,7 @@ use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties, displayable}; use datafusion::error::Result as DataFusionResult; +use datafusion::execution::options::JsonReadOptions; use datafusion_functions_window::expr_fn::lag; // Get string representation of the plan @@ -534,7 +534,8 @@ async fn drop_columns_with_nonexistent_columns() -> Result<()> { async fn drop_columns_with_empty_array() -> Result<()> { // build plan using Table API let t = test_table().await?; - let t2 = t.drop_columns(&[])?; + let drop_columns = vec![] as Vec<&str>; + let t2 = t.drop_columns(&drop_columns)?; let plan = t2.logical_plan().clone(); // build query using SQL @@ -549,6 +550,107 @@ async fn drop_columns_with_empty_array() -> Result<()> { Ok(()) } +#[tokio::test] +async fn drop_columns_qualified() -> Result<()> { + // build plan using Table API + let mut t = test_table().await?; + t = t.select_columns(&["c1", "c2", "c11"])?; + let mut t2 = test_table_with_name("another_table").await?; + t2 = t2.select_columns(&["c1", "c2", "c11"])?; + let mut t3 = t.join_on( + t2, + JoinType::Inner, + [col("aggregate_test_100.c1").eq(col("another_table.c1"))], + )?; + t3 = t3.drop_columns(&["another_table.c2", "another_table.c11"])?; + + let plan = t3.logical_plan().clone(); + + let sql = "SELECT aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c11, another_table.c1 FROM (SELECT c1, c2, c11 FROM aggregate_test_100) INNER JOIN (SELECT c1, c2, c11 FROM another_table) ON aggregate_test_100.c1 = another_table.c1"; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx, "aggregate_test_100").await?; + register_aggregate_csv(&ctx, "another_table").await?; + let sql_plan = ctx.sql(sql).await?.into_unoptimized_plan(); + + // the two plans should be identical + assert_same_plan(&plan, &sql_plan); + + Ok(()) +} + +#[tokio::test] +async fn drop_columns_qualified_find_qualified() -> Result<()> { + // build plan using Table API + let mut t = test_table().await?; + t = t.select_columns(&["c1", "c2", "c11"])?; + let mut t2 = test_table_with_name("another_table").await?; + t2 = t2.select_columns(&["c1", "c2", "c11"])?; + let mut t3 = t.join_on( + t2.clone(), + JoinType::Inner, + [col("aggregate_test_100.c1").eq(col("another_table.c1"))], + )?; + t3 = t3.drop_columns(&t2.find_qualified_columns(&["c2", "c11"])?)?; + + let plan = t3.logical_plan().clone(); + + let sql = "SELECT aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c11, another_table.c1 FROM (SELECT c1, c2, c11 FROM aggregate_test_100) INNER JOIN (SELECT c1, c2, c11 FROM another_table) ON aggregate_test_100.c1 = another_table.c1"; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx, "aggregate_test_100").await?; + register_aggregate_csv(&ctx, "another_table").await?; + let sql_plan = ctx.sql(sql).await?.into_unoptimized_plan(); + + // the two plans should be identical + assert_same_plan(&plan, &sql_plan); + + Ok(()) +} + +#[tokio::test] +async fn test_find_qualified_names() -> Result<()> { + let t = test_table().await?; + let column_names = ["c1", "c2", "c3"]; + let columns = t.find_qualified_columns(&column_names)?; + + // Expected results for each column + let binding = TableReference::bare("aggregate_test_100"); + let expected = [ + (Some(&binding), "c1"), + (Some(&binding), "c2"), + (Some(&binding), "c3"), + ]; + + // Verify we got the expected number of results + assert_eq!( + columns.len(), + expected.len(), + "Expected {} columns, got {}", + expected.len(), + columns.len() + ); + + // Iterate over the results and check each one individually + for (i, (actual, expected)) in columns.iter().zip(expected.iter()).enumerate() { + let (actual_table_ref, actual_field_ref) = actual; + let (expected_table_ref, expected_field_name) = expected; + + // Check table reference + assert_eq!( + actual_table_ref, expected_table_ref, + "Column {i}: expected table reference {expected_table_ref:?}, got {actual_table_ref:?}" + ); + + // Check field name + assert_eq!( + actual_field_ref.name(), + *expected_field_name, + "Column {i}: expected field name '{expected_field_name}', got '{actual_field_ref}'" + ); + } + + Ok(()) +} + #[tokio::test] async fn drop_with_quotes() -> Result<()> { // define data with a column name that has a "." in it: @@ -594,7 +696,7 @@ async fn drop_with_periods() -> Result<()> { let ctx = SessionContext::new(); ctx.register_batch("t", batch)?; - let df = ctx.table("t").await?.drop_columns(&["f.c1"])?; + let df = ctx.table("t").await?.drop_columns(&["\"f.c1\""])?; let df_results = df.collect().await?; @@ -2793,7 +2895,7 @@ async fn write_json_with_order() -> Result<()> { ctx.register_json( "data", test_path.to_str().unwrap(), - NdJsonReadOptions::default().schema(&schema), + JsonReadOptions::default().schema(&schema), ) .await?; @@ -4699,7 +4801,7 @@ async fn unnest_with_redundant_columns() -> Result<()> { @r" Projection: shapes.shape_id [shape_id:UInt32] Unnest: lists[shape_id2|depth=1] structs[] [shape_id:UInt32, shape_id2:UInt32;N] - Aggregate: groupBy=[[shapes.shape_id]], aggr=[[array_agg(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { data_type: UInt32, nullable: true });N] + Aggregate: groupBy=[[shapes.shape_id]], aggr=[[array_agg(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(UInt32);N] TableScan: shapes projection=[shape_id] [shape_id:UInt32] " ); @@ -5513,30 +5615,33 @@ async fn test_dataframe_placeholder_like_expression() -> Result<()> { Ok(()) } +#[rstest] +#[case(DataType::Utf8)] +#[case(DataType::LargeUtf8)] +#[case(DataType::Utf8View)] #[tokio::test] -async fn write_partitioned_parquet_results() -> Result<()> { - // create partitioned input file and context - let tmp_dir = TempDir::new()?; - - let ctx = SessionContext::new(); - +async fn write_partitioned_parquet_results(#[case] string_type: DataType) -> Result<()> { // Create an in memory table with schema C1 and C2, both strings let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Utf8, false), - Field::new("c2", DataType::Utf8, false), + Field::new("c1", string_type.clone(), false), + Field::new("c2", string_type.clone(), false), ])); - let record_batch = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(StringArray::from(vec!["abc", "def"])), - Arc::new(StringArray::from(vec!["123", "456"])), - ], - )?; + let columns = [ + Arc::new(StringArray::from(vec!["abc", "def"])) as ArrayRef, + Arc::new(StringArray::from(vec!["123", "456"])) as ArrayRef, + ] + .map(|col| arrow::compute::cast(&col, &string_type).unwrap()) + .to_vec(); + + let record_batch = RecordBatch::try_new(schema.clone(), columns)?; let mem_table = Arc::new(MemTable::try_new(schema, vec![vec![record_batch]])?); // Register the table in the context + // create partitioned input file and context + let tmp_dir = TempDir::new()?; + let ctx = SessionContext::new(); ctx.register_table("test", mem_table)?; let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?); @@ -5563,6 +5668,7 @@ async fn write_partitioned_parquet_results() -> Result<()> { // Check that the c2 column is gone and that c1 is abc. let results = filter_df.collect().await?; + insta::allow_duplicates! { assert_snapshot!( batches_to_string(&results), @r" @@ -5572,7 +5678,7 @@ async fn write_partitioned_parquet_results() -> Result<()> { | abc | +-----+ " - ); + )}; // Read the entire set of parquet files let df = ctx @@ -5585,9 +5691,10 @@ async fn write_partitioned_parquet_results() -> Result<()> { // Check that the df has the entire set of data let results = df.collect().await?; - assert_snapshot!( - batches_to_sort_string(&results), - @r" + insta::allow_duplicates! { + assert_snapshot!( + batches_to_sort_string(&results), + @r" +-----+-----+ | c1 | c2 | +-----+-----+ @@ -5595,7 +5702,8 @@ async fn write_partitioned_parquet_results() -> Result<()> { | def | 456 | +-----+-----+ " - ); + ) + }; Ok(()) } @@ -6213,7 +6321,7 @@ async fn register_non_json_file() { .register_json( "data", "tests/data/test_binary.parquet", - NdJsonReadOptions::default(), + JsonReadOptions::default(), ) .await; assert_contains!( diff --git a/datafusion/core/tests/execution/coop.rs b/datafusion/core/tests/execution/coop.rs index 380a47505ac2d..9818d9d98f6b1 100644 --- a/datafusion/core/tests/execution/coop.rs +++ b/datafusion/core/tests/execution/coop.rs @@ -24,7 +24,7 @@ use datafusion::physical_expr::aggregate::AggregateExprBuilder; use datafusion::physical_plan; use datafusion::physical_plan::ExecutionPlan; use datafusion::physical_plan::aggregates::{ - AggregateExec, AggregateMode, PhysicalGroupBy, + AggregateExec, AggregateMode, LimitOptions, PhysicalGroupBy, }; use datafusion::physical_plan::execution_plan::Boundedness; use datafusion::prelude::SessionContext; @@ -233,6 +233,7 @@ async fn agg_grouped_topk_yields( #[values(false, true)] pretend_infinite: bool, ) -> Result<(), Box> { // build session + let session_ctx = SessionContext::new(); // set up a top-k aggregation @@ -260,7 +261,7 @@ async fn agg_grouped_topk_yields( inf.clone(), inf.schema(), )? - .with_limit(Some(100)), + .with_limit_options(Some(LimitOptions::new(100))), ); query_yields(aggr, session_ctx.task_ctx()).await @@ -606,6 +607,7 @@ async fn join_yields( None, PartitionMode::CollectLeft, NullEquality::NullEqualsNull, + false, )?); query_yields(join, session_ctx.task_ctx()).await @@ -655,6 +657,7 @@ async fn join_agg_yields( None, PartitionMode::CollectLeft, NullEquality::NullEqualsNull, + false, )?); // Project only one column (“value” from the left side) because we just want to sum that @@ -720,6 +723,7 @@ async fn hash_join_yields( None, PartitionMode::CollectLeft, NullEquality::NullEqualsNull, + false, )?); query_yields(join, session_ctx.task_ctx()).await @@ -751,9 +755,10 @@ async fn hash_join_without_repartition_and_no_agg( /* filter */ None, &JoinType::Inner, /* output64 */ None, - // Using CollectLeft is fine—just avoid RepartitionExec’s partitioned channels. + // Using CollectLeft is fine—just avoid RepartitionExec's partitioned channels. PartitionMode::CollectLeft, NullEquality::NullEqualsNull, + false, )?); query_yields(join, session_ctx.task_ctx()).await @@ -762,7 +767,7 @@ async fn hash_join_without_repartition_and_no_agg( #[derive(Debug)] enum Yielded { ReadyOrPending, - Err(#[allow(dead_code)] DataFusionError), + Err(#[expect(dead_code)] DataFusionError), Timeout, } diff --git a/datafusion/core/tests/fifo/mod.rs b/datafusion/core/tests/fifo/mod.rs index 36cc769417dbc..3d99cc72fa590 100644 --- a/datafusion/core/tests/fifo/mod.rs +++ b/datafusion/core/tests/fifo/mod.rs @@ -94,7 +94,6 @@ mod unix_test { /// This function creates a writing task for the FIFO file. To verify /// incremental processing, it waits for a signal to continue writing after /// a certain number of lines are written. - #[allow(clippy::disallowed_methods)] fn create_writing_task( file_path: PathBuf, header: String, @@ -105,6 +104,7 @@ mod unix_test { // Timeout for a long period of BrokenPipe error let broken_pipe_timeout = Duration::from_secs(10); // Spawn a new task to write to the FIFO file + #[expect(clippy::disallowed_methods)] tokio::spawn(async move { let mut file = tokio::fs::OpenOptions::new() .write(true) @@ -357,7 +357,7 @@ mod unix_test { (sink_fifo_path.clone(), sink_fifo_path.display()); // Spawn a new thread to read sink EXTERNAL TABLE. - #[allow(clippy::disallowed_methods)] // spawn allowed only in tests + #[expect(clippy::disallowed_methods)] // spawn allowed only in tests tasks.push(spawn_blocking(move || { let file = File::open(sink_fifo_path_thread).unwrap(); let schema = Arc::new(Schema::new(vec![ diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs index bf71053d6c852..fe31098622c58 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs @@ -214,7 +214,7 @@ impl GeneratedSessionContextBuilder { /// The generated params for [`SessionContext`] #[derive(Debug)] -#[allow(dead_code)] +#[expect(dead_code)] pub struct SessionContextParams { batch_size: usize, target_partitions: usize, diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs index 0d04e98536f2a..7bb6177c31010 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs @@ -182,13 +182,13 @@ impl QueryBuilder { /// Add max columns num in group by(default: 3), for example if it is set to 1, /// the generated sql will group by at most 1 column - #[allow(dead_code)] + #[expect(dead_code)] pub fn with_max_group_by_columns(mut self, max_group_by_columns: usize) -> Self { self.max_group_by_columns = max_group_by_columns; self } - #[allow(dead_code)] + #[expect(dead_code)] pub fn with_min_group_by_columns(mut self, min_group_by_columns: usize) -> Self { self.min_group_by_columns = min_group_by_columns; self @@ -202,7 +202,7 @@ impl QueryBuilder { } /// Add if also test the no grouping aggregation case(default: true) - #[allow(dead_code)] + #[expect(dead_code)] pub fn with_no_grouping(mut self, no_grouping: bool) -> Self { self.no_grouping = no_grouping; self diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index ce422494db101..669b98e39fec1 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -849,6 +849,7 @@ impl JoinFuzzTestCase { None, PartitionMode::Partitioned, NullEquality::NullEqualsNothing, + false, ) .unwrap(), ) @@ -1086,7 +1087,7 @@ impl JoinFuzzTestCase { /// Files can be of different sizes /// The method can be useful to read partitions have been saved by `save_partitioned_batches_as_parquet` /// for test debugging purposes - #[allow(dead_code)] + #[expect(dead_code)] async fn load_partitioned_batches_from_parquet( dir: &str, ) -> std::io::Result> { diff --git a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs index c424a314270c6..8f3b8ea05324c 100644 --- a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs @@ -301,7 +301,7 @@ mod sp_repartition_fuzz_tests { let mut handles = Vec::new(); for seed in seed_start..seed_end { - #[allow(clippy::disallowed_methods)] // spawn allowed only in tests + #[expect(clippy::disallowed_methods)] // spawn allowed only in tests let job = tokio::spawn(run_sort_preserving_repartition_test( make_staggered_batches::(n_row, n_distinct, seed as u64), is_first_roundrobin, diff --git a/datafusion/core/tests/macro_hygiene/mod.rs b/datafusion/core/tests/macro_hygiene/mod.rs index 48f0103113cf6..9fd60cd1f06f3 100644 --- a/datafusion/core/tests/macro_hygiene/mod.rs +++ b/datafusion/core/tests/macro_hygiene/mod.rs @@ -73,7 +73,7 @@ mod config_field { #[test] fn test_macro() { #[derive(Debug)] - #[allow(dead_code)] + #[expect(dead_code)] struct E; impl std::fmt::Display for E { @@ -84,7 +84,7 @@ mod config_field { impl std::error::Error for E {} - #[allow(dead_code)] + #[expect(dead_code)] #[derive(Default)] struct S; diff --git a/datafusion/core/tests/parquet/expr_adapter.rs b/datafusion/core/tests/parquet/expr_adapter.rs index 515422ed750ef..aee37fda1670d 100644 --- a/datafusion/core/tests/parquet/expr_adapter.rs +++ b/datafusion/core/tests/parquet/expr_adapter.rs @@ -63,15 +63,15 @@ impl PhysicalExprAdapterFactory for CustomPhysicalExprAdapterFactory { &self, logical_file_schema: SchemaRef, physical_file_schema: SchemaRef, - ) -> Arc { - Arc::new(CustomPhysicalExprAdapter { + ) -> Result> { + Ok(Arc::new(CustomPhysicalExprAdapter { logical_file_schema: Arc::clone(&logical_file_schema), physical_file_schema: Arc::clone(&physical_file_schema), inner: Arc::new(DefaultPhysicalExprAdapter::new( logical_file_schema, physical_file_schema, )), - }) + })) } } diff --git a/datafusion/core/tests/parquet/filter_pushdown.rs b/datafusion/core/tests/parquet/filter_pushdown.rs index e3a191ee9ade2..1eb8103d3e4d4 100644 --- a/datafusion/core/tests/parquet/filter_pushdown.rs +++ b/datafusion/core/tests/parquet/filter_pushdown.rs @@ -220,7 +220,6 @@ async fn single_file() { } #[tokio::test] -#[allow(dead_code)] async fn single_file_small_data_pages() { let batches = read_parquet_test_data( "tests/data/filter_pushdown/single_file_small_pages.gz.parquet", @@ -644,6 +643,22 @@ async fn predicate_cache_pushdown_default() -> datafusion_common::Result<()> { .await } +#[tokio::test] +async fn predicate_cache_stats_issue_19561() -> datafusion_common::Result<()> { + let mut config = SessionConfig::new(); + config.options_mut().execution.parquet.pushdown_filters = true; + // force to get multiple batches to trigger repeated metric compound bug + config.options_mut().execution.batch_size = 1; + let ctx = SessionContext::new_with_config(config); + // The cache is on by default, and used when filter pushdown is enabled + PredicateCacheTest { + expected_inner_records: 8, + expected_records: 4, + } + .run(&ctx) + .await +} + #[tokio::test] async fn predicate_cache_pushdown_default_selections_only() -> datafusion_common::Result<()> { diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 44c9a2393e3d8..5a05718936509 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -30,6 +30,7 @@ use arrow::{ record_batch::RecordBatch, util::pretty::pretty_format_batches, }; +use arrow_schema::SchemaRef; use chrono::{Datelike, Duration, TimeDelta}; use datafusion::{ datasource::{TableProvider, provider_as_source}, @@ -110,6 +111,26 @@ struct ContextWithParquet { ctx: SessionContext, } +struct PruningMetric { + total_pruned: usize, + total_matched: usize, + total_fully_matched: usize, +} + +impl PruningMetric { + pub fn total_pruned(&self) -> usize { + self.total_pruned + } + + pub fn total_matched(&self) -> usize { + self.total_matched + } + + pub fn total_fully_matched(&self) -> usize { + self.total_fully_matched + } +} + /// The output of running one of the test cases struct TestOutput { /// The input query SQL @@ -127,8 +148,8 @@ struct TestOutput { impl TestOutput { /// retrieve the value of the named metric, if any fn metric_value(&self, metric_name: &str) -> Option { - if let Some((pruned, _matched)) = self.pruning_metric(metric_name) { - return Some(pruned); + if let Some(pm) = self.pruning_metric(metric_name) { + return Some(pm.total_pruned()); } self.parquet_metrics @@ -141,9 +162,10 @@ impl TestOutput { }) } - fn pruning_metric(&self, metric_name: &str) -> Option<(usize, usize)> { + fn pruning_metric(&self, metric_name: &str) -> Option { let mut total_pruned = 0; let mut total_matched = 0; + let mut total_fully_matched = 0; let mut found = false; for metric in self.parquet_metrics.iter() { @@ -155,12 +177,18 @@ impl TestOutput { { total_pruned += pruning_metrics.pruned(); total_matched += pruning_metrics.matched(); + total_fully_matched += pruning_metrics.fully_matched(); + found = true; } } if found { - Some((total_pruned, total_matched)) + Some(PruningMetric { + total_pruned, + total_matched, + total_fully_matched, + }) } else { None } @@ -172,27 +200,33 @@ impl TestOutput { } /// The number of row_groups pruned / matched by bloom filter - fn row_groups_bloom_filter(&self) -> Option<(usize, usize)> { + fn row_groups_bloom_filter(&self) -> Option { self.pruning_metric("row_groups_pruned_bloom_filter") } /// The number of row_groups matched by statistics fn row_groups_matched_statistics(&self) -> Option { self.pruning_metric("row_groups_pruned_statistics") - .map(|(_pruned, matched)| matched) + .map(|pm| pm.total_matched()) + } + + /// The number of row_groups fully matched by statistics + fn row_groups_fully_matched_statistics(&self) -> Option { + self.pruning_metric("row_groups_pruned_statistics") + .map(|pm| pm.total_fully_matched()) } /// The number of row_groups pruned by statistics fn row_groups_pruned_statistics(&self) -> Option { self.pruning_metric("row_groups_pruned_statistics") - .map(|(pruned, _matched)| pruned) + .map(|pm| pm.total_pruned()) } /// Metric `files_ranges_pruned_statistics` tracks both pruned and matched count, /// for testing purpose, here it only aggregate the `pruned` count. fn files_ranges_pruned_statistics(&self) -> Option { self.pruning_metric("files_ranges_pruned_statistics") - .map(|(pruned, _matched)| pruned) + .map(|pm| pm.total_pruned()) } /// The number of row_groups matched by bloom filter or statistics @@ -201,14 +235,13 @@ impl TestOutput { /// filter: 7 total -> 3 matched, this function returns 3 for the final matched /// count. fn row_groups_matched(&self) -> Option { - self.row_groups_bloom_filter() - .map(|(_pruned, matched)| matched) + self.row_groups_bloom_filter().map(|pm| pm.total_matched()) } /// The number of row_groups pruned fn row_groups_pruned(&self) -> Option { self.row_groups_bloom_filter() - .map(|(pruned, _matched)| pruned) + .map(|pm| pm.total_pruned()) .zip(self.row_groups_pruned_statistics()) .map(|(a, b)| a + b) } @@ -216,7 +249,13 @@ impl TestOutput { /// The number of row pages pruned fn row_pages_pruned(&self) -> Option { self.pruning_metric("page_index_rows_pruned") - .map(|(pruned, _matched)| pruned) + .map(|pm| pm.total_pruned()) + } + + /// The number of row groups pruned by limit pruning + fn limit_pruned_row_groups(&self) -> Option { + self.pruning_metric("limit_pruned_row_groups") + .map(|pm| pm.total_pruned()) } fn description(&self) -> String { @@ -232,20 +271,41 @@ impl TestOutput { /// and the appropriate scenario impl ContextWithParquet { async fn new(scenario: Scenario, unit: Unit) -> Self { - Self::with_config(scenario, unit, SessionConfig::new()).await + Self::with_config(scenario, unit, SessionConfig::new(), None, None).await + } + + /// Set custom schema and batches for the test + pub async fn with_custom_data( + scenario: Scenario, + unit: Unit, + schema: Arc, + batches: Vec, + ) -> Self { + Self::with_config( + scenario, + unit, + SessionConfig::new(), + Some(schema), + Some(batches), + ) + .await } async fn with_config( scenario: Scenario, unit: Unit, mut config: SessionConfig, + custom_schema: Option, + custom_batches: Option>, ) -> Self { // Use a single partition for deterministic results no matter how many CPUs the host has config = config.with_target_partitions(1); let file = match unit { Unit::RowGroup(row_per_group) => { config = config.with_parquet_bloom_filter_pruning(true); - make_test_file_rg(scenario, row_per_group).await + config.options_mut().execution.parquet.pushdown_filters = true; + make_test_file_rg(scenario, row_per_group, custom_schema, custom_batches) + .await } Unit::Page(row_per_page) => { config = config.with_parquet_page_index_pruning(true); @@ -516,9 +576,9 @@ fn make_uint_batches(start: u8, end: u8) -> RecordBatch { Field::new("u64", DataType::UInt64, true), ])); let v8: Vec = (start..end).collect(); - let v16: Vec = (start as _..end as _).collect(); - let v32: Vec = (start as _..end as _).collect(); - let v64: Vec = (start as _..end as _).collect(); + let v16: Vec = (start as u16..end as u16).collect(); + let v32: Vec = (start as u32..end as u32).collect(); + let v64: Vec = (start as u64..end as u64).collect(); RecordBatch::try_new( schema, vec![ @@ -1075,7 +1135,12 @@ fn create_data_batch(scenario: Scenario) -> Vec { } /// Create a test parquet file with various data types -async fn make_test_file_rg(scenario: Scenario, row_per_group: usize) -> NamedTempFile { +async fn make_test_file_rg( + scenario: Scenario, + row_per_group: usize, + custom_schema: Option, + custom_batches: Option>, +) -> NamedTempFile { let mut output_file = tempfile::Builder::new() .prefix("parquet_pruning") .suffix(".parquet") @@ -1088,8 +1153,14 @@ async fn make_test_file_rg(scenario: Scenario, row_per_group: usize) -> NamedTem .set_statistics_enabled(EnabledStatistics::Page) .build(); - let batches = create_data_batch(scenario); - let schema = batches[0].schema(); + let (batches, schema) = + if let (Some(schema), Some(batches)) = (custom_schema, custom_batches) { + (batches, schema) + } else { + let batches = create_data_batch(scenario); + let schema = batches[0].schema(); + (batches, schema) + }; let mut writer = ArrowWriter::try_new(&mut output_file, schema, Some(props)).unwrap(); diff --git a/datafusion/core/tests/parquet/page_pruning.rs b/datafusion/core/tests/parquet/page_pruning.rs index 7eb39bfe78305..6d49e0bcc676e 100644 --- a/datafusion/core/tests/parquet/page_pruning.rs +++ b/datafusion/core/tests/parquet/page_pruning.rs @@ -20,7 +20,8 @@ use std::sync::Arc; use crate::parquet::Unit::Page; use crate::parquet::{ContextWithParquet, Scenario}; -use arrow::array::RecordBatch; +use arrow::array::{Int32Array, RecordBatch}; +use arrow::datatypes::{DataType, Field, Schema}; use datafusion::datasource::file_format::FileFormat; use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::PartitionedFile; @@ -30,7 +31,7 @@ use datafusion::datasource::source::DataSourceExec; use datafusion::execution::context::SessionState; use datafusion::physical_plan::ExecutionPlan; use datafusion::physical_plan::metrics::MetricValue; -use datafusion::prelude::SessionContext; +use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::{ScalarValue, ToDFSchema}; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::{Expr, col, lit}; @@ -40,6 +41,8 @@ use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use futures::StreamExt; use object_store::ObjectMeta; use object_store::path::Path; +use parquet::arrow::ArrowWriter; +use parquet::file::properties::WriterProperties; async fn get_parquet_exec( state: &SessionState, @@ -961,3 +964,56 @@ fn cast_count_metric(metric: MetricValue) -> Option { _ => None, } } + +#[tokio::test] +async fn test_parquet_opener_without_page_index() { + // Defines a simple schema and batch + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + + // Create a temp file + let file = tempfile::Builder::new() + .suffix(".parquet") + .tempfile() + .unwrap(); + let path = file.path().to_str().unwrap().to_string(); + + // Write parquet WITHOUT page index + // The default WriterProperties does not write page index, but we set it explicitly + // to be robust against future changes in defaults as requested by reviewers. + let props = WriterProperties::builder() + .set_statistics_enabled(parquet::file::properties::EnabledStatistics::None) + .build(); + + let file_fs = std::fs::File::create(&path).unwrap(); + let mut writer = ArrowWriter::try_new(file_fs, batch.schema(), Some(props)).unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + + // Setup SessionContext with PageIndex enabled + // This triggers the ParquetOpener to try and load page index if available + let config = SessionConfig::new().with_parquet_page_index_pruning(true); + + let ctx = SessionContext::new_with_config(config); + + // Register the table + ctx.register_parquet("t", &path, Default::default()) + .await + .unwrap(); + + // Query the table + // If the bug exists, this might fail because Opener tries to load PageIndex forcefully + let df = ctx.sql("SELECT * FROM t").await.unwrap(); + let batches = df + .collect() + .await + .expect("Failed to read parquet file without page index"); + + // We expect this to succeed, but currently it might fail + assert_eq!(batches.len(), 1); + assert_eq!(batches[0].num_rows(), 3); +} diff --git a/datafusion/core/tests/parquet/row_group_pruning.rs b/datafusion/core/tests/parquet/row_group_pruning.rs index 0411298055f26..445ae7e97f228 100644 --- a/datafusion/core/tests/parquet/row_group_pruning.rs +++ b/datafusion/core/tests/parquet/row_group_pruning.rs @@ -18,8 +18,12 @@ //! This file contains an end to end test of parquet pruning. It writes //! data into a parquet file and then verifies row groups are pruned as //! expected. +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int32Array, RecordBatch}; +use arrow_schema::{DataType, Field, Schema}; use datafusion::prelude::SessionConfig; -use datafusion_common::ScalarValue; +use datafusion_common::{DataFusionError, ScalarValue}; use itertools::Itertools; use crate::parquet::Unit::RowGroup; @@ -30,10 +34,12 @@ struct RowGroupPruningTest { query: String, expected_errors: Option, expected_row_group_matched_by_statistics: Option, + expected_row_group_fully_matched_by_statistics: Option, expected_row_group_pruned_by_statistics: Option, expected_files_pruned_by_statistics: Option, expected_row_group_matched_by_bloom_filter: Option, expected_row_group_pruned_by_bloom_filter: Option, + expected_limit_pruned_row_groups: Option, expected_rows: usize, } impl RowGroupPruningTest { @@ -45,9 +51,11 @@ impl RowGroupPruningTest { expected_errors: None, expected_row_group_matched_by_statistics: None, expected_row_group_pruned_by_statistics: None, + expected_row_group_fully_matched_by_statistics: None, expected_files_pruned_by_statistics: None, expected_row_group_matched_by_bloom_filter: None, expected_row_group_pruned_by_bloom_filter: None, + expected_limit_pruned_row_groups: None, expected_rows: 0, } } @@ -76,6 +84,15 @@ impl RowGroupPruningTest { self } + // Set the expected fully matched row groups by statistics + fn with_fully_matched_by_stats( + mut self, + fully_matched_by_stats: Option, + ) -> Self { + self.expected_row_group_fully_matched_by_statistics = fully_matched_by_stats; + self + } + // Set the expected pruned row groups by statistics fn with_pruned_by_stats(mut self, pruned_by_stats: Option) -> Self { self.expected_row_group_pruned_by_statistics = pruned_by_stats; @@ -99,6 +116,11 @@ impl RowGroupPruningTest { self } + fn with_limit_pruned_row_groups(mut self, pruned_by_limit: Option) -> Self { + self.expected_limit_pruned_row_groups = pruned_by_limit; + self + } + /// Set the number of expected rows from the output of this test fn with_expected_rows(mut self, rows: usize) -> Self { self.expected_rows = rows; @@ -135,15 +157,74 @@ impl RowGroupPruningTest { ); let bloom_filter_metrics = output.row_groups_bloom_filter(); assert_eq!( - bloom_filter_metrics.map(|(_pruned, matched)| matched), + bloom_filter_metrics.as_ref().map(|pm| pm.total_matched()), self.expected_row_group_matched_by_bloom_filter, "mismatched row_groups_matched_bloom_filter", ); assert_eq!( - bloom_filter_metrics.map(|(pruned, _matched)| pruned), + bloom_filter_metrics.map(|pm| pm.total_pruned()), self.expected_row_group_pruned_by_bloom_filter, "mismatched row_groups_pruned_bloom_filter", ); + + assert_eq!( + output.result_rows, + self.expected_rows, + "Expected {} rows, got {}: {}", + output.result_rows, + self.expected_rows, + output.description(), + ); + } + + // Execute the test with the current configuration + async fn test_row_group_prune_with_custom_data( + self, + schema: Arc, + batches: Vec, + max_row_per_group: usize, + ) { + let output = ContextWithParquet::with_custom_data( + self.scenario, + RowGroup(max_row_per_group), + schema, + batches, + ) + .await + .query(&self.query) + .await; + + println!("{}", output.description()); + assert_eq!( + output.predicate_evaluation_errors(), + self.expected_errors, + "mismatched predicate_evaluation error" + ); + assert_eq!( + output.row_groups_matched_statistics(), + self.expected_row_group_matched_by_statistics, + "mismatched row_groups_matched_statistics", + ); + assert_eq!( + output.row_groups_fully_matched_statistics(), + self.expected_row_group_fully_matched_by_statistics, + "mismatched row_groups_fully_matched_statistics", + ); + assert_eq!( + output.row_groups_pruned_statistics(), + self.expected_row_group_pruned_by_statistics, + "mismatched row_groups_pruned_statistics", + ); + assert_eq!( + output.files_ranges_pruned_statistics(), + self.expected_files_pruned_by_statistics, + "mismatched files_ranges_pruned_statistics", + ); + assert_eq!( + output.limit_pruned_row_groups(), + self.expected_limit_pruned_row_groups, + "mismatched limit_pruned_row_groups", + ); assert_eq!( output.result_rows, self.expected_rows, @@ -289,11 +370,16 @@ async fn prune_disabled() { let expected_rows = 10; let config = SessionConfig::new().with_parquet_pruning(false); - let output = - ContextWithParquet::with_config(Scenario::Timestamps, RowGroup(5), config) - .await - .query(query) - .await; + let output = ContextWithParquet::with_config( + Scenario::Timestamps, + RowGroup(5), + config, + None, + None, + ) + .await + .query(query) + .await; println!("{}", output.description()); // This should not prune any @@ -1636,3 +1722,240 @@ async fn test_bloom_filter_decimal_dict() { .test_row_group_prune() .await; } + +// Helper function to create a batch with a single Int32 column. +fn make_i32_batch( + name: &str, + values: Vec, +) -> datafusion_common::error::Result { + let schema = Arc::new(Schema::new(vec![Field::new(name, DataType::Int32, false)])); + let array: ArrayRef = Arc::new(Int32Array::from(values)); + RecordBatch::try_new(schema, vec![array]).map_err(DataFusionError::from) +} + +// Helper function to create a batch with two Int32 columns +fn make_two_col_i32_batch( + name_a: &str, + name_b: &str, + values_a: Vec, + values_b: Vec, +) -> datafusion_common::error::Result { + let schema = Arc::new(Schema::new(vec![ + Field::new(name_a, DataType::Int32, false), + Field::new(name_b, DataType::Int32, false), + ])); + let array_a: ArrayRef = Arc::new(Int32Array::from(values_a)); + let array_b: ArrayRef = Arc::new(Int32Array::from(values_b)); + RecordBatch::try_new(schema, vec![array_a, array_b]).map_err(DataFusionError::from) +} + +#[tokio::test] +async fn test_limit_pruning_basic() -> datafusion_common::error::Result<()> { + // Scenario: Simple integer column, multiple row groups + // Query: SELECT c1 FROM t WHERE c1 = 0 LIMIT 2 + // We expect 2 rows in total. + + // Row Group 0: c1 = [0, -2] -> Partially matched, 1 row + // Row Group 1: c1 = [1, 2] -> Fully matched, 2 rows + // Row Group 2: c1 = [3, 4] -> Fully matched, 2 rows + // Row Group 3: c1 = [5, 6] -> Fully matched, 2 rows + // Row Group 4: c1 = [-1, -2] -> Not matched + + // If limit = 2, and RG1 is fully matched and has 2 rows, we should + // only scan RG1 and prune other row groups + // RG4 is pruned by statistics. RG2 and RG3 are pruned by limit. + // So 2 row groups are effectively pruned due to limit pruning. + + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); + let query = "SELECT c1 FROM t WHERE c1 >= 0 LIMIT 2"; + + let batches = vec![ + make_i32_batch("c1", vec![0, -2])?, + make_i32_batch("c1", vec![0, 0])?, + make_i32_batch("c1", vec![0, 0])?, + make_i32_batch("c1", vec![0, 0])?, + make_i32_batch("c1", vec![-1, -2])?, + ]; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) // Assuming Scenario::Int can handle this data + .with_query(query) + .with_expected_errors(Some(0)) + .with_expected_rows(2) + .with_pruned_files(Some(0)) + .with_matched_by_stats(Some(4)) + .with_fully_matched_by_stats(Some(3)) + .with_pruned_by_stats(Some(1)) + .with_limit_pruned_row_groups(Some(3)) + .test_row_group_prune_with_custom_data(schema, batches, 2) + .await; + + Ok(()) +} + +#[tokio::test] +async fn test_limit_pruning_complex_filter() -> datafusion_common::error::Result<()> { + // Test Case 1: Complex filter with two columns (a = 1 AND b > 1 AND b < 4) + // Row Group 0: a=[1,1,1], b=[0,2,3] -> Partially matched, 2 rows match (b=2,3) + // Row Group 1: a=[1,1,1], b=[2,2,2] -> Fully matched, 3 rows + // Row Group 2: a=[1,1,1], b=[2,3,3] -> Fully matched, 3 rows + // Row Group 3: a=[1,1,1], b=[2,2,3] -> Fully matched, 3 rows + // Row Group 4: a=[2,2,2], b=[2,2,2] -> Not matched (a != 1) + // Row Group 5: a=[1,1,1], b=[5,6,7] -> Not matched (b >= 4) + + // With LIMIT 5, we need RG1 (3 rows) + RG2 (2 rows from 3) = 5 rows + // RG4 and RG5 should be pruned by statistics + // RG3 should be pruned by limit + // RG0 is partially matched, so it depends on the order + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + let query = "SELECT a, b FROM t WHERE a = 1 AND b > 1 AND b < 4 LIMIT 5"; + + let batches = vec![ + make_two_col_i32_batch("a", "b", vec![1, 1, 1], vec![0, 2, 3])?, + make_two_col_i32_batch("a", "b", vec![1, 1, 1], vec![2, 2, 2])?, + make_two_col_i32_batch("a", "b", vec![1, 1, 1], vec![2, 3, 3])?, + make_two_col_i32_batch("a", "b", vec![1, 1, 1], vec![2, 2, 3])?, + make_two_col_i32_batch("a", "b", vec![2, 2, 2], vec![2, 2, 2])?, + make_two_col_i32_batch("a", "b", vec![1, 1, 1], vec![5, 6, 7])?, + ]; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(query) + .with_expected_errors(Some(0)) + .with_expected_rows(5) + .with_pruned_files(Some(0)) + .with_matched_by_stats(Some(4)) // RG0,1,2,3 are matched + .with_fully_matched_by_stats(Some(3)) + .with_pruned_by_stats(Some(2)) // RG4,5 are pruned + .with_limit_pruned_row_groups(Some(2)) // RG0, RG3 is pruned by limit + .test_row_group_prune_with_custom_data(schema, batches, 3) + .await; + + Ok(()) +} + +#[tokio::test] +async fn test_limit_pruning_multiple_fully_matched() +-> datafusion_common::error::Result<()> { + // Test Case 2: Limit requires multiple fully matched row groups + // Row Group 0: a=[5,5,5,5] -> Fully matched, 4 rows + // Row Group 1: a=[5,5,5,5] -> Fully matched, 4 rows + // Row Group 2: a=[5,5,5,5] -> Fully matched, 4 rows + // Row Group 3: a=[5,5,5,5] -> Fully matched, 4 rows + // Row Group 4: a=[1,2,3,4] -> Not matched + + // With LIMIT 8, we need RG0 (4 rows) + RG1 (4 rows) 8 rows + // RG2,3 should be pruned by limit + // RG4 should be pruned by statistics + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let query = "SELECT a FROM t WHERE a = 5 LIMIT 8"; + + let batches = vec![ + make_i32_batch("a", vec![5, 5, 5, 5])?, + make_i32_batch("a", vec![5, 5, 5, 5])?, + make_i32_batch("a", vec![5, 5, 5, 5])?, + make_i32_batch("a", vec![5, 5, 5, 5])?, + make_i32_batch("a", vec![1, 2, 3, 4])?, + ]; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(query) + .with_expected_errors(Some(0)) + .with_expected_rows(8) + .with_pruned_files(Some(0)) + .with_matched_by_stats(Some(4)) // RG0,1,2,3 matched + .with_fully_matched_by_stats(Some(4)) + .with_pruned_by_stats(Some(1)) // RG4 pruned + .with_limit_pruned_row_groups(Some(2)) // RG2,3 pruned by limit + .test_row_group_prune_with_custom_data(schema, batches, 4) + .await; + + Ok(()) +} + +#[tokio::test] +async fn test_limit_pruning_no_fully_matched() -> datafusion_common::error::Result<()> { + // Test Case 3: No fully matched row groups - all are partially matched + // Row Group 0: a=[1,2,3] -> Partially matched, 1 row (a=2) + // Row Group 1: a=[2,3,4] -> Partially matched, 1 row (a=2) + // Row Group 2: a=[2,5,6] -> Partially matched, 1 row (a=2) + // Row Group 3: a=[2,7,8] -> Partially matched, 1 row (a=2) + // Row Group 4: a=[9,10,11] -> Not matched + + // With LIMIT 3, we need to scan RG0,1,2 to get 3 matching rows + // Cannot prune much by limit since all matching RGs are partial + // RG4 should be pruned by statistics + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let query = "SELECT a FROM t WHERE a = 2 LIMIT 3"; + + let batches = vec![ + make_i32_batch("a", vec![1, 2, 3])?, + make_i32_batch("a", vec![2, 3, 4])?, + make_i32_batch("a", vec![2, 5, 6])?, + make_i32_batch("a", vec![2, 7, 8])?, + make_i32_batch("a", vec![9, 10, 11])?, + ]; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(query) + .with_expected_errors(Some(0)) + .with_expected_rows(3) + .with_pruned_files(Some(0)) + .with_matched_by_stats(Some(4)) // RG0,1,2,3 matched + .with_fully_matched_by_stats(Some(0)) + .with_pruned_by_stats(Some(1)) // RG4 pruned + .with_limit_pruned_row_groups(Some(0)) // RG3 pruned by limit + .test_row_group_prune_with_custom_data(schema, batches, 3) + .await; + + Ok(()) +} + +#[tokio::test] +async fn test_limit_pruning_exceeds_fully_matched() -> datafusion_common::error::Result<()> +{ + // Test Case 4: Limit exceeds all fully matched rows, need partially matched + // Row Group 0: a=[10,11,12,12] -> Partially matched, 1 row (a=10) + // Row Group 1: a=[10,10,10,10] -> Fully matched, 4 rows + // Row Group 2: a=[10,10,10,10] -> Fully matched, 4 rows + // Row Group 3: a=[10,13,14,11] -> Partially matched, 1 row (a=10) + // Row Group 4: a=[20,21,22,22] -> Not matched + + // With LIMIT 10, we need RG1 (4) + RG2 (4) = 8 from fully matched + // Still need 2 more, so we need to scan partially matched RG0 and RG3 + // All matching row groups should be scanned, only RG4 pruned by statistics + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let query = "SELECT a FROM t WHERE a = 10 LIMIT 10"; + + let batches = vec![ + make_i32_batch("a", vec![10, 11, 12, 12])?, + make_i32_batch("a", vec![10, 10, 10, 10])?, + make_i32_batch("a", vec![10, 10, 10, 10])?, + make_i32_batch("a", vec![10, 13, 14, 11])?, + make_i32_batch("a", vec![20, 21, 22, 22])?, + ]; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(query) + .with_expected_errors(Some(0)) + .with_expected_rows(10) // Total: 1 + 4 + 4 + 1 = 10 + .with_pruned_files(Some(0)) + .with_matched_by_stats(Some(4)) // RG0,1,2,3 matched + .with_fully_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) // RG4 pruned + .with_limit_pruned_row_groups(Some(0)) // No limit pruning since we need all RGs + .test_row_group_prune_with_custom_data(schema, batches, 4) + .await; + Ok(()) +} diff --git a/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs index 2fdfece2a86e7..9e63c341c92d9 100644 --- a/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs @@ -37,7 +37,7 @@ use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_optimizer::combine_partial_final_agg::CombinePartialFinalAggregate; use datafusion_physical_plan::ExecutionPlan; use datafusion_physical_plan::aggregates::{ - AggregateExec, AggregateMode, PhysicalGroupBy, + AggregateExec, AggregateMode, LimitOptions, PhysicalGroupBy, }; use datafusion_physical_plan::displayable; use datafusion_physical_plan::repartition::RepartitionExec; @@ -260,7 +260,7 @@ fn aggregations_with_limit_combined() -> datafusion_common::Result<()> { schema, ) .unwrap() - .with_limit(Some(5)), + .with_limit_options(Some(LimitOptions::new(5))), ); let plan: Arc = final_agg; // should combine the Partial/Final AggregateExecs to a Single AggregateExec diff --git a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs index 94ae82a9ad755..30edd7196606e 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs @@ -67,8 +67,7 @@ use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::union::UnionExec; use datafusion_physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlanProperties, PlanProperties, Statistics, - displayable, + DisplayAs, DisplayFormatType, ExecutionPlanProperties, PlanProperties, displayable, }; use insta::Settings; @@ -210,10 +209,6 @@ impl ExecutionPlan for SortRequiredExec { ) -> Result { unreachable!(); } - - fn statistics(&self) -> Result { - self.input.partition_statistics(None) - } } fn parquet_exec() -> Arc { diff --git a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs index 4b74aebdf5deb..6349ff1cd109f 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs @@ -29,11 +29,11 @@ use crate::physical_optimizer::test_utils::{ spr_repartition_exec, stream_exec_ordered, union_exec, }; -use arrow::compute::SortOptions; +use arrow::compute::{SortOptions}; use arrow::datatypes::{DataType, SchemaRef}; use datafusion_common::config::{ConfigOptions, CsvOptions}; use datafusion_common::tree_node::{TreeNode, TransformedResult}; -use datafusion_common::{Result, TableReference}; +use datafusion_common::{create_array, Result, TableReference}; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::source::DataSourceExec; use datafusion_expr_common::operator::Operator; @@ -58,7 +58,7 @@ use datafusion_physical_optimizer::enforce_distribution::EnforceDistribution; use datafusion_physical_optimizer::output_requirements::OutputRequirementExec; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion::prelude::*; -use arrow::array::{Int32Array, RecordBatch}; +use arrow::array::{record_batch, ArrayRef, Int32Array, RecordBatch}; use arrow::datatypes::{Field}; use arrow_schema::Schema; use datafusion_execution::TaskContext; @@ -2805,3 +2805,47 @@ async fn test_partial_sort_with_homogeneous_batches() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_sort_with_streaming_table() -> Result<()> { + let batch = record_batch!(("a", Int32, [1, 2, 3]), ("b", Int32, [1, 2, 3]))?; + + let ctx = SessionContext::new(); + + let sort_order = vec![ + SortExpr::new( + Expr::Column(datafusion_common::Column::new( + Option::::None, + "a", + )), + true, + false, + ), + SortExpr::new( + Expr::Column(datafusion_common::Column::new( + Option::::None, + "b", + )), + true, + false, + ), + ]; + let schema = batch.schema(); + let batches = Arc::new(DummyStreamPartition { + schema: schema.clone(), + batches: vec![batch], + }) as _; + let provider = StreamingTable::try_new(schema.clone(), vec![batches])? + .with_sort_order(sort_order); + ctx.register_table("test_table", Arc::new(provider))?; + + let sql = "SELECT a FROM test_table GROUP BY a ORDER BY a"; + let results = ctx.sql(sql).await?.collect().await?; + + assert_eq!(results.len(), 1); + assert_eq!(results[0].num_columns(), 1); + let expected = create_array!(Int32, vec![1, 2, 3]) as ArrayRef; + assert_eq!(results[0].column(0), &expected); + + Ok(()) +} diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown.rs similarity index 83% rename from datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs rename to datafusion/core/tests/physical_optimizer/filter_pushdown.rs index f2d6607e3ca1b..99db81d34d8fa 100644 --- a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs +++ b/datafusion/core/tests/physical_optimizer/filter_pushdown.rs @@ -58,20 +58,19 @@ use datafusion_physical_plan::{ aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}, coalesce_partitions::CoalescePartitionsExec, collect, - filter::FilterExec, + filter::{FilterExec, FilterExecBuilder}, + projection::ProjectionExec, repartition::RepartitionExec, sorts::sort::SortExec, }; +use super::pushdown_utils::{ + OptimizationTest, TestNode, TestScanBuilder, TestSource, format_plan_for_test, +}; use datafusion_physical_plan::union::UnionExec; use futures::StreamExt; use object_store::{ObjectStore, memory::InMemory}; use regex::Regex; -use util::{OptimizationTest, TestNode, TestScanBuilder, format_plan_for_test}; - -use crate::physical_optimizer::filter_pushdown::util::TestSource; - -mod util; #[test] fn test_pushdown_into_scan() { @@ -233,6 +232,7 @@ async fn test_dynamic_filter_pushdown_through_hash_join_with_topk() { None, PartitionMode::Partitioned, datafusion_common::NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); @@ -354,6 +354,7 @@ async fn test_static_filter_pushdown_through_hash_join() { None, PartitionMode::Partitioned, datafusion_common::NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); @@ -400,7 +401,8 @@ async fn test_static_filter_pushdown_through_hash_join() { " ); - // Test left join - filters should NOT be pushed down + // Test left join: filter on preserved (build) side is pushed down, + // filter on non-preserved (probe) side is NOT pushed down. let join = Arc::new( HashJoinExec::try_new( TestScanBuilder::new(Arc::clone(&build_side_schema)) @@ -418,30 +420,36 @@ async fn test_static_filter_pushdown_through_hash_join() { None, PartitionMode::Partitioned, datafusion_common::NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); let join_schema = join.schema(); - let filter = col_lit_predicate("a", "aa", &join_schema); - let plan = - Arc::new(FilterExec::try_new(filter, join).unwrap()) as Arc; + // Filter on build side column (preserved): should be pushed down + let left_filter = col_lit_predicate("a", "aa", &join_schema); + // Filter on probe side column (not preserved): should NOT be pushed down + let right_filter = col_lit_predicate("e", "ba", &join_schema); + let filter = + Arc::new(FilterExec::try_new(left_filter, Arc::clone(&join) as _).unwrap()); + let plan = Arc::new(FilterExec::try_new(right_filter, filter).unwrap()) + as Arc; - // Test that filters are NOT pushed down for left join insta::assert_snapshot!( OptimizationTest::new(plan, FilterPushdown::new(), true), @r" OptimizationTest: input: - - FilterExec: a@0 = aa - - HashJoinExec: mode=Partitioned, join_type=Left, on=[(a@0, d@0)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true + - FilterExec: e@4 = ba + - FilterExec: a@0 = aa + - HashJoinExec: mode=Partitioned, join_type=Left, on=[(a@0, d@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true output: Ok: - - FilterExec: a@0 = aa + - FilterExec: e@4 = ba - HashJoinExec: mode=Partitioned, join_type=Left, on=[(a@0, d@0)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = aa - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true " ); @@ -477,9 +485,10 @@ fn test_filter_with_projection() { let projection = vec![1, 0]; let predicate = col_lit_predicate("a", "foo", &schema()); let plan = Arc::new( - FilterExec::try_new(predicate, Arc::clone(&scan)) + FilterExecBuilder::new(predicate, Arc::clone(&scan)) + .apply_projection(Some(projection)) .unwrap() - .with_projection(Some(projection)) + .build() .unwrap(), ); @@ -502,9 +511,10 @@ fn test_filter_with_projection() { let projection = vec![1]; let predicate = col_lit_predicate("a", "foo", &schema()); let plan = Arc::new( - FilterExec::try_new(predicate, scan) + FilterExecBuilder::new(predicate, scan) + .apply_projection(Some(projection)) .unwrap() - .with_projection(Some(projection)) + .build() .unwrap(), ); insta::assert_snapshot!( @@ -561,9 +571,9 @@ fn test_pushdown_through_aggregates_on_grouping_columns() { let scan = TestScanBuilder::new(schema()).with_support(true).build(); let filter = Arc::new( - FilterExec::try_new(col_lit_predicate("a", "foo", &schema()), scan) - .unwrap() + FilterExecBuilder::new(col_lit_predicate("a", "foo", &schema()), scan) .with_batch_size(10) + .build() .unwrap(), ); @@ -593,9 +603,9 @@ fn test_pushdown_through_aggregates_on_grouping_columns() { let predicate = col_lit_predicate("b", "bar", &schema()); let plan = Arc::new( - FilterExec::try_new(predicate, aggregate) - .unwrap() + FilterExecBuilder::new(predicate, aggregate) .with_batch_size(100) + .build() .unwrap(), ); @@ -981,6 +991,7 @@ async fn test_hashjoin_dynamic_filter_pushdown() { None, PartitionMode::CollectLeft, datafusion_common::NullEquality::NullEqualsNothing, + false, ) .unwrap(), ) as Arc; @@ -1170,6 +1181,7 @@ async fn test_hashjoin_dynamic_filter_pushdown_partitioned() { None, PartitionMode::Partitioned, datafusion_common::NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); @@ -1363,6 +1375,7 @@ async fn test_hashjoin_dynamic_filter_pushdown_collect_left() { None, PartitionMode::CollectLeft, datafusion_common::NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); @@ -1531,6 +1544,7 @@ async fn test_nested_hashjoin_dynamic_filter_pushdown() { None, PartitionMode::Partitioned, datafusion_common::NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); @@ -1550,6 +1564,7 @@ async fn test_nested_hashjoin_dynamic_filter_pushdown() { None, PartitionMode::Partitioned, datafusion_common::NullEquality::NullEqualsNothing, + false, ) .unwrap(), ) as Arc; @@ -1665,6 +1680,7 @@ async fn test_hashjoin_parent_filter_pushdown() { None, PartitionMode::Partitioned, datafusion_common::NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); @@ -1712,6 +1728,218 @@ async fn test_hashjoin_parent_filter_pushdown() { ); } +#[test] +fn test_hashjoin_parent_filter_pushdown_same_column_names() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("build_val", DataType::Utf8, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .build(); + + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("probe_val", DataType::Utf8, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .build(); + + let on = vec![( + col("id", &build_side_schema).unwrap(), + col("id", &probe_side_schema).unwrap(), + )]; + let join = Arc::new( + HashJoinExec::try_new( + build_scan, + probe_scan, + on, + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + let join_schema = join.schema(); + + let build_id_filter = col_lit_predicate("id", "aa", &join_schema); + let probe_val_filter = col_lit_predicate("probe_val", "x", &join_schema); + + let filter = + Arc::new(FilterExec::try_new(build_id_filter, Arc::clone(&join) as _).unwrap()); + let plan = Arc::new(FilterExec::try_new(probe_val_filter, filter).unwrap()) + as Arc; + + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: probe_val@3 = x + - FilterExec: id@0 = aa + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(id@0, id@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id, build_val], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id, probe_val], file_type=test, pushdown_supported=true + output: + Ok: + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(id@0, id@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id, build_val], file_type=test, pushdown_supported=true, predicate=id@0 = aa + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id, probe_val], file_type=test, pushdown_supported=true, predicate=probe_val@1 = x + " + ); +} + +#[test] +fn test_hashjoin_parent_filter_pushdown_mark_join() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + let left_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("val", DataType::Utf8, false), + ])); + let left_scan = TestScanBuilder::new(Arc::clone(&left_schema)) + .with_support(true) + .build(); + + let right_schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Utf8, false)])); + let right_scan = TestScanBuilder::new(Arc::clone(&right_schema)) + .with_support(true) + .build(); + + let on = vec![( + col("id", &left_schema).unwrap(), + col("id", &right_schema).unwrap(), + )]; + let join = Arc::new( + HashJoinExec::try_new( + left_scan, + right_scan, + on, + None, + &JoinType::LeftMark, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + let join_schema = join.schema(); + + let left_filter = col_lit_predicate("val", "x", &join_schema); + let mark_filter = col_lit_predicate("mark", true, &join_schema); + + let filter = + Arc::new(FilterExec::try_new(left_filter, Arc::clone(&join) as _).unwrap()); + let plan = Arc::new(FilterExec::try_new(mark_filter, filter).unwrap()) + as Arc; + + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: mark@2 = true + - FilterExec: val@1 = x + - HashJoinExec: mode=Partitioned, join_type=LeftMark, on=[(id@0, id@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id, val], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id], file_type=test, pushdown_supported=true + output: + Ok: + - FilterExec: mark@2 = true + - HashJoinExec: mode=Partitioned, join_type=LeftMark, on=[(id@0, id@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id, val], file_type=test, pushdown_supported=true, predicate=val@1 = x + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id], file_type=test, pushdown_supported=true + " + ); +} + +/// Test that filters on join key columns are pushed to both sides of semi/anti joins. +/// For LeftSemi/LeftAnti, the output only contains left columns, but filters on +/// join key columns can also be pushed to the right (non-preserved) side because +/// the equijoin condition guarantees the key values match. +#[test] +fn test_hashjoin_parent_filter_pushdown_semi_anti_join() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + let left_schema = Arc::new(Schema::new(vec![ + Field::new("k", DataType::Utf8, false), + Field::new("v", DataType::Utf8, false), + ])); + let left_scan = TestScanBuilder::new(Arc::clone(&left_schema)) + .with_support(true) + .build(); + + let right_schema = Arc::new(Schema::new(vec![ + Field::new("k", DataType::Utf8, false), + Field::new("w", DataType::Utf8, false), + ])); + let right_scan = TestScanBuilder::new(Arc::clone(&right_schema)) + .with_support(true) + .build(); + + let on = vec![( + col("k", &left_schema).unwrap(), + col("k", &right_schema).unwrap(), + )]; + + let join = Arc::new( + HashJoinExec::try_new( + left_scan, + right_scan, + on, + None, + &JoinType::LeftSemi, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + let join_schema = join.schema(); + // Filter on join key column: k = 'x' — should be pushed to BOTH sides + let key_filter = col_lit_predicate("k", "x", &join_schema); + // Filter on non-key column: v = 'y' — should only be pushed to the left side + let val_filter = col_lit_predicate("v", "y", &join_schema); + + let filter = + Arc::new(FilterExec::try_new(key_filter, Arc::clone(&join) as _).unwrap()); + let plan = Arc::new(FilterExec::try_new(val_filter, filter).unwrap()) + as Arc; + + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: v@1 = y + - FilterExec: k@0 = x + - HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(k@0, k@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[k, v], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[k, w], file_type=test, pushdown_supported=true + output: + Ok: + - HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(k@0, k@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[k, v], file_type=test, pushdown_supported=true, predicate=k@0 = x AND v@1 = y + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[k, w], file_type=test, pushdown_supported=true, predicate=k@0 = x + " + ); +} + /// Integration test for dynamic filter pushdown with TopK. /// We use an integration test because there are complex interactions in the optimizer rules /// that the unit tests applying a single optimizer rule do not cover. @@ -1798,6 +2026,67 @@ fn test_filter_pushdown_through_union() { ); } +#[test] +fn test_filter_pushdown_through_union_mixed_support() { + // Test case where one child supports filter pushdown and one doesn't + let scan1 = TestScanBuilder::new(schema()).with_support(true).build(); + let scan2 = TestScanBuilder::new(schema()).with_support(false).build(); + + let union = UnionExec::try_new(vec![scan1, scan2]).unwrap(); + + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, union).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - UnionExec + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + output: + Ok: + - UnionExec + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + " + ); +} + +#[test] +fn test_filter_pushdown_through_union_does_not_support() { + // Test case where one child supports filter pushdown and one doesn't + let scan1 = TestScanBuilder::new(schema()).with_support(false).build(); + let scan2 = TestScanBuilder::new(schema()).with_support(false).build(); + + let union = UnionExec::try_new(vec![scan1, scan2]).unwrap(); + + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, union).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - UnionExec + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + output: + Ok: + - UnionExec + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + " + ); +} + /// Schema: /// a: String /// b: String @@ -1815,6 +2104,234 @@ fn schema() -> SchemaRef { Arc::clone(&TEST_SCHEMA) } +struct ProjectionDynFilterTestCase { + schema: SchemaRef, + batches: Vec, + projection: Vec<(Arc, String)>, + sort_expr: PhysicalSortExpr, + expected_plans: Vec, +} + +async fn run_projection_dyn_filter_case(case: ProjectionDynFilterTestCase) { + let ProjectionDynFilterTestCase { + schema, + batches, + projection, + sort_expr, + expected_plans, + } = case; + + let scan = TestScanBuilder::new(Arc::clone(&schema)) + .with_support(true) + .with_batches(batches) + .build(); + + let projection_exec = Arc::new(ProjectionExec::try_new(projection, scan).unwrap()); + + let sort = Arc::new( + SortExec::new(LexOrdering::new(vec![sort_expr]).unwrap(), projection_exec) + .with_fetch(Some(2)), + ) as Arc; + + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + config.optimizer.enable_dynamic_filter_pushdown = true; + + let optimized_plan = FilterPushdown::new_post_optimization() + .optimize(Arc::clone(&sort), &config) + .unwrap(); + + pretty_assertions::assert_eq!( + format_plan_for_test(&optimized_plan).trim(), + expected_plans[0].trim() + ); + + let config = SessionConfig::new().with_batch_size(2); + let session_ctx = SessionContext::new_with_config(config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let mut stream = optimized_plan.execute(0, Arc::clone(&task_ctx)).unwrap(); + for (idx, expected_plan) in expected_plans.iter().enumerate().skip(1) { + stream.next().await.unwrap().unwrap(); + let formatted_plan = format_plan_for_test(&optimized_plan); + pretty_assertions::assert_eq!( + formatted_plan.trim(), + expected_plan.trim(), + "Mismatch at iteration {}", + idx + ); + } +} + +#[tokio::test] +async fn test_topk_with_projection_transformation_on_dyn_filter() { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ])); + let simple_abc = vec![ + record_batch!( + ("a", Int32, [1, 2, 3]), + ("b", Utf8, ["x", "y", "z"]), + ("c", Float64, [1.0, 2.0, 3.0]) + ) + .unwrap(), + ]; + + // Case 1: Reordering [b, a] + run_projection_dyn_filter_case(ProjectionDynFilterTestCase { + schema: Arc::clone(&schema), + batches: simple_abc.clone(), + projection: vec![ + (col("b", &schema).unwrap(), "b".to_string()), + (col("a", &schema).unwrap(), "a".to_string()), + ], + sort_expr: PhysicalSortExpr::new( + Arc::new(Column::new("a", 1)), + SortOptions::default(), + ), + expected_plans: vec![ +r#" - SortExec: TopK(fetch=2), expr=[a@1 ASC], preserve_partitioning=[false] + - ProjectionExec: expr=[b@1 as b, a@0 as a] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ]"#.to_string(), +r#" - SortExec: TopK(fetch=2), expr=[a@1 ASC], preserve_partitioning=[false], filter=[a@1 IS NULL OR a@1 < 2] + - ProjectionExec: expr=[b@1 as b, a@0 as a] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 IS NULL OR a@0 < 2 ]"#.to_string()] + }) + .await; + + // Case 2: Pruning [a] + run_projection_dyn_filter_case(ProjectionDynFilterTestCase { + schema: Arc::clone(&schema), + batches: simple_abc.clone(), + projection: vec![(col("a", &schema).unwrap(), "a".to_string())], + sort_expr: PhysicalSortExpr::new( + Arc::new(Column::new("a", 0)), + SortOptions::default(), + ), + expected_plans: vec![ + r#" - SortExec: TopK(fetch=2), expr=[a@0 ASC], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as a] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ]"#.to_string(), + r#" - SortExec: TopK(fetch=2), expr=[a@0 ASC], preserve_partitioning=[false], filter=[a@0 IS NULL OR a@0 < 2] + - ProjectionExec: expr=[a@0 as a] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 IS NULL OR a@0 < 2 ]"#.to_string(), + ], + }) + .await; + + // Case 3: Identity [a, b] + run_projection_dyn_filter_case(ProjectionDynFilterTestCase { + schema: Arc::clone(&schema), + batches: simple_abc.clone(), + projection: vec![ + (col("a", &schema).unwrap(), "a".to_string()), + (col("b", &schema).unwrap(), "b".to_string()), + ], + sort_expr: PhysicalSortExpr::new( + Arc::new(Column::new("a", 0)), + SortOptions::default(), + ), + expected_plans: vec![ + r#" - SortExec: TopK(fetch=2), expr=[a@0 ASC], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ]"#.to_string(), + r#" - SortExec: TopK(fetch=2), expr=[a@0 ASC], preserve_partitioning=[false], filter=[a@0 IS NULL OR a@0 < 2] + - ProjectionExec: expr=[a@0 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 IS NULL OR a@0 < 2 ]"#.to_string(), + ], + }) + .await; + + // Case 4: Expressions [a + 1, b] + run_projection_dyn_filter_case(ProjectionDynFilterTestCase { + schema: Arc::clone(&schema), + batches: simple_abc.clone(), + projection: vec![ + ( + Arc::new(BinaryExpr::new( + col("a", &schema).unwrap(), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + )), + "a_plus_1".to_string(), + ), + (col("b", &schema).unwrap(), "b".to_string()), + ], + sort_expr: PhysicalSortExpr::new( + Arc::new(Column::new("a_plus_1", 0)), + SortOptions::default(), + ), + expected_plans: vec![ + r#" - SortExec: TopK(fetch=2), expr=[a_plus_1@0 ASC], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 + 1 as a_plus_1, b@1 as b] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ]"#.to_string(), + r#" - SortExec: TopK(fetch=2), expr=[a_plus_1@0 ASC], preserve_partitioning=[false], filter=[a_plus_1@0 IS NULL OR a_plus_1@0 < 3] + - ProjectionExec: expr=[a@0 + 1 as a_plus_1, b@1 as b] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 + 1 IS NULL OR a@0 + 1 < 3 ]"#.to_string(), + ], + }) + .await; + + // Case 5: [a as b, b as a] (swapped columns) + run_projection_dyn_filter_case(ProjectionDynFilterTestCase { + schema: Arc::clone(&schema), + batches: simple_abc.clone(), + projection: vec![ + (col("a", &schema).unwrap(), "b".to_string()), + (col("b", &schema).unwrap(), "a".to_string()), + ], + sort_expr: PhysicalSortExpr::new( + Arc::new(Column::new("b", 0)), + SortOptions::default(), + ), + expected_plans: vec![ + r#" - SortExec: TopK(fetch=2), expr=[b@0 ASC], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as b, b@1 as a] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ]"#.to_string(), + r#" - SortExec: TopK(fetch=2), expr=[b@0 ASC], preserve_partitioning=[false], filter=[b@0 IS NULL OR b@0 < 2] + - ProjectionExec: expr=[a@0 as b, b@1 as a] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 IS NULL OR a@0 < 2 ]"#.to_string(), + ], + }) + .await; + + // Case 6: Confusing expr [a + 1 as a, b] + run_projection_dyn_filter_case(ProjectionDynFilterTestCase { + schema: Arc::clone(&schema), + batches: simple_abc.clone(), + projection: vec![ + ( + Arc::new(BinaryExpr::new( + col("a", &schema).unwrap(), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + )), + "a".to_string(), + ), + (col("b", &schema).unwrap(), "b".to_string()), + ], + sort_expr: PhysicalSortExpr::new( + Arc::new(Column::new("a", 0)), + SortOptions::default(), + ), + expected_plans: vec![ + r#" - SortExec: TopK(fetch=2), expr=[a@0 ASC], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 + 1 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ]"#.to_string(), + r#" - SortExec: TopK(fetch=2), expr=[a@0 ASC], preserve_partitioning=[false], filter=[a@0 IS NULL OR a@0 < 3] + - ProjectionExec: expr=[a@0 + 1 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 + 1 IS NULL OR a@0 + 1 < 3 ]"#.to_string(), + ], + }) + .await; +} + /// Returns a predicate that is a binary expression col = lit fn col_lit_predicate( column_name: &str, @@ -2773,6 +3290,7 @@ async fn test_hashjoin_dynamic_filter_all_partitions_empty() { None, PartitionMode::Partitioned, datafusion_common::NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); @@ -2901,6 +3419,7 @@ async fn test_hashjoin_dynamic_filter_with_nulls() { None, PartitionMode::CollectLeft, datafusion_common::NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); @@ -3051,6 +3570,7 @@ async fn test_hashjoin_hash_table_pushdown_partitioned() { None, PartitionMode::Partitioned, datafusion_common::NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); @@ -3201,6 +3721,7 @@ async fn test_hashjoin_hash_table_pushdown_collect_left() { None, PartitionMode::CollectLeft, datafusion_common::NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); @@ -3335,6 +3856,7 @@ async fn test_hashjoin_hash_table_pushdown_integer_keys() { None, PartitionMode::CollectLeft, datafusion_common::NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); @@ -3443,6 +3965,7 @@ async fn test_hashjoin_dynamic_filter_pushdown_is_used() { None, PartitionMode::CollectLeft, datafusion_common::NullEquality::NullEqualsNothing, + false, ) .unwrap(), ) as Arc; @@ -3476,3 +3999,90 @@ async fn test_hashjoin_dynamic_filter_pushdown_is_used() { ); } } + +/// Regression test for https://github.com/apache/datafusion/issues/20109 +#[tokio::test] +async fn test_filter_with_projection_pushdown() { + use arrow::array::{Int64Array, RecordBatch, StringArray}; + use datafusion_physical_plan::collect; + use datafusion_physical_plan::filter::FilterExecBuilder; + + // Create schema: [time, event, size] + let schema = Arc::new(Schema::new(vec![ + Field::new("time", DataType::Int64, false), + Field::new("event", DataType::Utf8, false), + Field::new("size", DataType::Int64, false), + ])); + + // Create sample data + let timestamps = vec![100i64, 200, 300, 400, 500]; + let events = vec!["Ingestion", "Ingestion", "Query", "Ingestion", "Query"]; + let sizes = vec![10i64, 20, 30, 40, 50]; + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(timestamps)), + Arc::new(StringArray::from(events)), + Arc::new(Int64Array::from(sizes)), + ], + ) + .unwrap(); + + // Create data source + let memory_exec = datafusion_datasource::memory::MemorySourceConfig::try_new_exec( + &[vec![batch]], + schema.clone(), + None, + ) + .unwrap(); + + // First FilterExec: time < 350 with projection=[event@1, size@2] + let time_col = col("time", &memory_exec.schema()).unwrap(); + let time_filter = Arc::new(BinaryExpr::new( + time_col, + Operator::Lt, + Arc::new(Literal::new(ScalarValue::Int64(Some(350)))), + )); + let filter1 = Arc::new( + FilterExecBuilder::new(time_filter, memory_exec) + .apply_projection(Some(vec![1, 2])) + .unwrap() + .build() + .unwrap(), + ); + + // Second FilterExec: event = 'Ingestion' with projection=[size@1] + let event_col = col("event", &filter1.schema()).unwrap(); + let event_filter = Arc::new(BinaryExpr::new( + event_col, + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Utf8(Some( + "Ingestion".to_string(), + )))), + )); + let filter2 = Arc::new( + FilterExecBuilder::new(event_filter, filter1) + .apply_projection(Some(vec![1])) + .unwrap() + .build() + .unwrap(), + ); + + // Apply filter pushdown optimization + let config = ConfigOptions::default(); + let optimized_plan = FilterPushdown::new() + .optimize(Arc::clone(&filter2) as Arc, &config) + .unwrap(); + + // Execute the optimized plan - this should not error + let ctx = SessionContext::new(); + let result = collect(optimized_plan, ctx.task_ctx()).await.unwrap(); + + // Verify results: should return rows where time < 350 AND event = 'Ingestion' + // That's rows with time=100,200 (both have event='Ingestion'), so sizes 10,20 + let expected = [ + "+------+", "| size |", "+------+", "| 10 |", "| 20 |", "+------+", + ]; + assert_batches_eq!(expected, &result); +} diff --git a/datafusion/core/tests/physical_optimizer/join_selection.rs b/datafusion/core/tests/physical_optimizer/join_selection.rs index 37bcefd418bdb..567af64c6a366 100644 --- a/datafusion/core/tests/physical_optimizer/join_selection.rs +++ b/datafusion/core/tests/physical_optimizer/join_selection.rs @@ -222,6 +222,7 @@ async fn test_join_with_swap() { None, PartitionMode::CollectLeft, NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); @@ -284,6 +285,7 @@ async fn test_left_join_no_swap() { None, PartitionMode::CollectLeft, NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); @@ -333,6 +335,7 @@ async fn test_join_with_swap_semi() { None, PartitionMode::Partitioned, NullEquality::NullEqualsNothing, + false, ) .unwrap(); @@ -388,6 +391,7 @@ async fn test_join_with_swap_mark() { None, PartitionMode::Partitioned, NullEquality::NullEqualsNothing, + false, ) .unwrap(); @@ -461,6 +465,7 @@ async fn test_nested_join_swap() { None, PartitionMode::CollectLeft, NullEquality::NullEqualsNothing, + false, ) .unwrap(); let child_schema = child_join.schema(); @@ -478,6 +483,7 @@ async fn test_nested_join_swap() { None, PartitionMode::CollectLeft, NullEquality::NullEqualsNothing, + false, ) .unwrap(); @@ -518,6 +524,7 @@ async fn test_join_no_swap() { None, PartitionMode::CollectLeft, NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); @@ -745,6 +752,7 @@ async fn test_hash_join_swap_on_joins_with_projections( Some(projection), PartitionMode::Partitioned, NullEquality::NullEqualsNothing, + false, )?); let swapped = join @@ -754,7 +762,7 @@ async fn test_hash_join_swap_on_joins_with_projections( "ProjectionExec won't be added above if HashJoinExec contains embedded projection", ); - assert_eq!(swapped_join.projection, Some(vec![0_usize])); + assert_eq!(swapped_join.projection.as_deref().unwrap(), &[0_usize]); assert_eq!(swapped.schema().fields.len(), 1); assert_eq!(swapped.schema().fields[0].name(), "small_col"); Ok(()) @@ -906,6 +914,7 @@ fn check_join_partition_mode( None, PartitionMode::Auto, NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); @@ -1167,10 +1176,6 @@ impl ExecutionPlan for StatisticsExec { unimplemented!("This plan only serves for testing statistics") } - fn statistics(&self) -> Result { - Ok(self.stats.clone()) - } - fn partition_statistics(&self, partition: Option) -> Result { Ok(if partition.is_some() { Statistics::new_unknown(&self.schema) @@ -1554,6 +1559,7 @@ async fn test_join_with_maybe_swap_unbounded_case(t: TestCase) -> Result<()> { None, t.initial_mode, NullEquality::NullEqualsNothing, + false, )?) as _; let optimized_join_plan = diff --git a/datafusion/core/tests/physical_optimizer/limit_pushdown.rs b/datafusion/core/tests/physical_optimizer/limit_pushdown.rs index e8d06d69df414..b8c4d6d6f0d7a 100644 --- a/datafusion/core/tests/physical_optimizer/limit_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/limit_pushdown.rs @@ -18,22 +18,24 @@ use std::sync::Arc; use crate::physical_optimizer::test_utils::{ - coalesce_partitions_exec, global_limit_exec, local_limit_exec, sort_exec, - sort_preserving_merge_exec, stream_exec, + coalesce_partitions_exec, global_limit_exec, hash_join_exec, local_limit_exec, + sort_exec, sort_preserving_merge_exec, stream_exec, }; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::config::ConfigOptions; use datafusion_common::error::Result; -use datafusion_expr::Operator; +use datafusion_expr::{JoinType, Operator}; use datafusion_physical_expr::Partitioning; use datafusion_physical_expr::expressions::{BinaryExpr, col, lit}; +use datafusion_physical_expr_common::physical_expr::PhysicalExprRef; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_optimizer::limit_pushdown::LimitPushdown; use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::filter::FilterExec; +use datafusion_physical_plan::joins::NestedLoopJoinExec; use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::{ExecutionPlan, get_plan_string}; @@ -87,6 +89,20 @@ fn empty_exec(schema: SchemaRef) -> Arc { Arc::new(EmptyExec::new(schema)) } +fn nested_loop_join_exec( + left: Arc, + right: Arc, + join_type: JoinType, +) -> Result> { + Ok(Arc::new(NestedLoopJoinExec::try_new( + left, right, None, &join_type, None, + )?)) +} + +fn format_plan(plan: &Arc) -> String { + get_plan_string(plan).join("\n") +} + #[test] fn transforms_streaming_table_exec_into_fetching_version_when_skip_is_zero() -> Result<()> { @@ -94,20 +110,23 @@ fn transforms_streaming_table_exec_into_fetching_version_when_skip_is_zero() -> let streaming_table = stream_exec(&schema); let global_limit = global_limit_exec(streaming_table, 0, Some(5)); - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", - ]; - assert_eq!(initial, expected_initial); + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=0, fetch=5 + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - let expected = [ - "StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true, fetch=5", - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @"StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true, fetch=5" + ); Ok(()) } @@ -119,21 +138,188 @@ fn transforms_streaming_table_exec_into_fetching_version_and_keeps_the_global_li let streaming_table = stream_exec(&schema); let global_limit = global_limit_exec(streaming_table, 2, Some(5)); - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=2, fetch=5", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", - ]; - assert_eq!(initial, expected_initial); + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=2, fetch=5 + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); + + let after_optimize = + LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; + + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=2, fetch=5 + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true, fetch=7 + " + ); + + Ok(()) +} + +fn join_on_columns( + left_col: &str, + right_col: &str, +) -> Vec<(PhysicalExprRef, PhysicalExprRef)> { + vec![( + Arc::new(datafusion_physical_expr::expressions::Column::new( + left_col, 0, + )) as _, + Arc::new(datafusion_physical_expr::expressions::Column::new( + right_col, 0, + )) as _, + )] +} + +#[test] +fn absorbs_limit_into_hash_join_inner() -> Result<()> { + // HashJoinExec with Inner join should absorb limit via with_fetch + let schema = create_schema(); + let left = empty_exec(Arc::clone(&schema)); + let right = empty_exec(Arc::clone(&schema)); + let on = join_on_columns("c1", "c1"); + let hash_join = hash_join_exec(left, right, on, None, &JoinType::Inner)?; + let global_limit = global_limit_exec(hash_join, 0, Some(5)); + + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=0, fetch=5 + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c1@0, c1@0)] + EmptyExec + EmptyExec + " + ); + + let after_optimize = + LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; + let optimized = format_plan(&after_optimize); + // The limit should be absorbed by the hash join (not pushed to children) + insta::assert_snapshot!( + optimized, + @r" + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c1@0, c1@0)], fetch=5 + EmptyExec + EmptyExec + " + ); + + Ok(()) +} + +#[test] +fn absorbs_limit_into_hash_join_right() -> Result<()> { + // HashJoinExec with Right join should absorb limit via with_fetch + let schema = create_schema(); + let left = empty_exec(Arc::clone(&schema)); + let right = empty_exec(Arc::clone(&schema)); + let on = join_on_columns("c1", "c1"); + let hash_join = hash_join_exec(left, right, on, None, &JoinType::Right)?; + let global_limit = global_limit_exec(hash_join, 0, Some(10)); + + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=0, fetch=10 + HashJoinExec: mode=Partitioned, join_type=Right, on=[(c1@0, c1@0)] + EmptyExec + EmptyExec + " + ); + + let after_optimize = + LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; + let optimized = format_plan(&after_optimize); + // The limit should be absorbed by the hash join + insta::assert_snapshot!( + optimized, + @r" + HashJoinExec: mode=Partitioned, join_type=Right, on=[(c1@0, c1@0)], fetch=10 + EmptyExec + EmptyExec + " + ); + + Ok(()) +} + +#[test] +fn absorbs_limit_into_hash_join_left() -> Result<()> { + // during probing, then unmatched rows at the end, stopping when limit is reached + let schema = create_schema(); + let left = empty_exec(Arc::clone(&schema)); + let right = empty_exec(Arc::clone(&schema)); + let on = join_on_columns("c1", "c1"); + let hash_join = hash_join_exec(left, right, on, None, &JoinType::Left)?; + let global_limit = global_limit_exec(hash_join, 0, Some(5)); + + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=0, fetch=5 + HashJoinExec: mode=Partitioned, join_type=Left, on=[(c1@0, c1@0)] + EmptyExec + EmptyExec + " + ); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; + let optimized = format_plan(&after_optimize); + // Left join now absorbs the limit + insta::assert_snapshot!( + optimized, + @r" + HashJoinExec: mode=Partitioned, join_type=Left, on=[(c1@0, c1@0)], fetch=5 + EmptyExec + EmptyExec + " + ); - let expected = [ - "GlobalLimitExec: skip=2, fetch=5", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true, fetch=7", - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + Ok(()) +} + +#[test] +fn absorbs_limit_with_skip_into_hash_join() -> Result<()> { + let schema = create_schema(); + let left = empty_exec(Arc::clone(&schema)); + let right = empty_exec(Arc::clone(&schema)); + let on = join_on_columns("c1", "c1"); + let hash_join = hash_join_exec(left, right, on, None, &JoinType::Inner)?; + let global_limit = global_limit_exec(hash_join, 3, Some(5)); + + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=3, fetch=5 + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c1@0, c1@0)] + EmptyExec + EmptyExec + " + ); + + let after_optimize = + LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; + let optimized = format_plan(&after_optimize); + // With skip, GlobalLimit is kept but fetch (skip + limit = 8) is absorbed by the join + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=3, fetch=5 + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c1@0, c1@0)], fetch=8 + EmptyExec + EmptyExec + " + ); Ok(()) } @@ -146,24 +332,29 @@ fn pushes_global_limit_exec_through_projection_exec() -> Result<()> { let projection = projection_exec(schema, filter)?; let global_limit = global_limit_exec(projection, 0, Some(5)); - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " FilterExec: c3@2 > 0", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", - ]; - assert_eq!(initial, expected_initial); + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=0, fetch=5 + ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3] + FilterExec: c3@2 > 0 + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - let expected = [ - "ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " FilterExec: c3@2 > 0, fetch=5", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3] + FilterExec: c3@2 > 0, fetch=5 + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); Ok(()) } @@ -183,29 +374,33 @@ fn pushes_global_limit_into_multiple_fetch_plans() -> Result<()> { let spm = sort_preserving_merge_exec(ordering, sort); let global_limit = global_limit_exec(spm, 0, Some(5)); - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " SortPreservingMergeExec: [c1@0 ASC]", - " SortExec: expr=[c1@0 ASC], preserve_partitioning=[false]", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", - ]; - - assert_eq!(initial, expected_initial); + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=0, fetch=5 + SortPreservingMergeExec: [c1@0 ASC] + SortExec: expr=[c1@0 ASC], preserve_partitioning=[false] + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3] + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - let expected = [ - "SortPreservingMergeExec: [c1@0 ASC], fetch=5", - " SortExec: TopK(fetch=5), expr=[c1@0 ASC], preserve_partitioning=[false]", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + SortPreservingMergeExec: [c1@0 ASC], fetch=5 + SortExec: TopK(fetch=5), expr=[c1@0 ASC], preserve_partitioning=[false] + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3] + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); Ok(()) } @@ -220,26 +415,31 @@ fn keeps_pushed_local_limit_exec_when_there_are_multiple_input_partitions() -> R let coalesce_partitions = coalesce_partitions_exec(filter); let global_limit = global_limit_exec(coalesce_partitions, 0, Some(5)); - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " CoalescePartitionsExec", - " FilterExec: c3@2 > 0", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", - ]; - assert_eq!(initial, expected_initial); + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=0, fetch=5 + CoalescePartitionsExec + FilterExec: c3@2 > 0 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - let expected = [ - "CoalescePartitionsExec: fetch=5", - " FilterExec: c3@2 > 0, fetch=5", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + CoalescePartitionsExec: fetch=5 + FilterExec: c3@2 > 0, fetch=5 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); Ok(()) } @@ -251,20 +451,27 @@ fn merges_local_limit_with_local_limit() -> Result<()> { let child_local_limit = local_limit_exec(empty_exec, 10); let parent_local_limit = local_limit_exec(child_local_limit, 20); - let initial = get_plan_string(&parent_local_limit); - let expected_initial = [ - "LocalLimitExec: fetch=20", - " LocalLimitExec: fetch=10", - " EmptyExec", - ]; - - assert_eq!(initial, expected_initial); + let initial = format_plan(&parent_local_limit); + insta::assert_snapshot!( + initial, + @r" + LocalLimitExec: fetch=20 + LocalLimitExec: fetch=10 + EmptyExec + " + ); let after_optimize = LimitPushdown::new().optimize(parent_local_limit, &ConfigOptions::new())?; - let expected = ["GlobalLimitExec: skip=0, fetch=10", " EmptyExec"]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=0, fetch=10 + EmptyExec + " + ); Ok(()) } @@ -276,20 +483,27 @@ fn merges_global_limit_with_global_limit() -> Result<()> { let child_global_limit = global_limit_exec(empty_exec, 10, Some(30)); let parent_global_limit = global_limit_exec(child_global_limit, 10, Some(20)); - let initial = get_plan_string(&parent_global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=10, fetch=20", - " GlobalLimitExec: skip=10, fetch=30", - " EmptyExec", - ]; - - assert_eq!(initial, expected_initial); + let initial = format_plan(&parent_global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=10, fetch=20 + GlobalLimitExec: skip=10, fetch=30 + EmptyExec + " + ); let after_optimize = LimitPushdown::new().optimize(parent_global_limit, &ConfigOptions::new())?; - let expected = ["GlobalLimitExec: skip=20, fetch=20", " EmptyExec"]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=20, fetch=20 + EmptyExec + " + ); Ok(()) } @@ -301,20 +515,27 @@ fn merges_global_limit_with_local_limit() -> Result<()> { let local_limit = local_limit_exec(empty_exec, 40); let global_limit = global_limit_exec(local_limit, 20, Some(30)); - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=20, fetch=30", - " LocalLimitExec: fetch=40", - " EmptyExec", - ]; - - assert_eq!(initial, expected_initial); + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=20, fetch=30 + LocalLimitExec: fetch=40 + EmptyExec + " + ); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - let expected = ["GlobalLimitExec: skip=20, fetch=20", " EmptyExec"]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=20, fetch=20 + EmptyExec + " + ); Ok(()) } @@ -326,20 +547,138 @@ fn merges_local_limit_with_global_limit() -> Result<()> { let global_limit = global_limit_exec(empty_exec, 20, Some(30)); let local_limit = local_limit_exec(global_limit, 20); - let initial = get_plan_string(&local_limit); - let expected_initial = [ - "LocalLimitExec: fetch=20", - " GlobalLimitExec: skip=20, fetch=30", - " EmptyExec", - ]; - - assert_eq!(initial, expected_initial); + let initial = format_plan(&local_limit); + insta::assert_snapshot!( + initial, + @r" + LocalLimitExec: fetch=20 + GlobalLimitExec: skip=20, fetch=30 + EmptyExec + " + ); let after_optimize = LimitPushdown::new().optimize(local_limit, &ConfigOptions::new())?; - let expected = ["GlobalLimitExec: skip=20, fetch=20", " EmptyExec"]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=20, fetch=20 + EmptyExec + " + ); + + Ok(()) +} + +#[test] +fn preserves_nested_global_limit() -> Result<()> { + // If there are multiple limits in an execution plan, they all need to be + // preserved in the optimized plan. + // + // Plan structure: + // GlobalLimitExec: skip=1, fetch=1 + // NestedLoopJoinExec (Left) + // EmptyExec (left side) + // GlobalLimitExec: skip=2, fetch=1 + // NestedLoopJoinExec (Right) + // EmptyExec (left side) + // EmptyExec (right side) + let schema = create_schema(); + + // Build inner join: NestedLoopJoin(Empty, Empty) + let inner_left = empty_exec(Arc::clone(&schema)); + let inner_right = empty_exec(Arc::clone(&schema)); + let inner_join = nested_loop_join_exec(inner_left, inner_right, JoinType::Right)?; + + // Add inner limit: GlobalLimitExec: skip=2, fetch=1 + let inner_limit = global_limit_exec(inner_join, 2, Some(1)); + + // Build outer join: NestedLoopJoin(Empty, GlobalLimit) + let outer_left = empty_exec(Arc::clone(&schema)); + let outer_join = nested_loop_join_exec(outer_left, inner_limit, JoinType::Left)?; + + // Add outer limit: GlobalLimitExec: skip=1, fetch=1 + let outer_limit = global_limit_exec(outer_join, 1, Some(1)); + + let initial = format_plan(&outer_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=1, fetch=1 + NestedLoopJoinExec: join_type=Left + EmptyExec + GlobalLimitExec: skip=2, fetch=1 + NestedLoopJoinExec: join_type=Right + EmptyExec + EmptyExec + " + ); + + let after_optimize = + LimitPushdown::new().optimize(outer_limit, &ConfigOptions::new())?; + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=1, fetch=1 + NestedLoopJoinExec: join_type=Left + EmptyExec + GlobalLimitExec: skip=2, fetch=1 + NestedLoopJoinExec: join_type=Right + EmptyExec + EmptyExec + " + ); + + Ok(()) +} + +#[test] +fn preserves_skip_before_sort() -> Result<()> { + // If there's a limit with skip before a node that (1) supports fetch but + // (2) does not support limit pushdown, that limit should not be removed. + // + // Plan structure: + // GlobalLimitExec: skip=1, fetch=None + // SortExec: TopK(fetch=4) + // EmptyExec + let schema = create_schema(); + + let empty = empty_exec(Arc::clone(&schema)); + + let ordering = [PhysicalSortExpr { + expr: col("c1", &schema)?, + options: SortOptions::default(), + }]; + let sort = sort_exec(ordering.into(), empty) + .with_fetch(Some(4)) + .unwrap(); + + let outer_limit = global_limit_exec(sort, 1, None); + + let initial = format_plan(&outer_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=1, fetch=None + SortExec: TopK(fetch=4), expr=[c1@0 ASC], preserve_partitioning=[false] + EmptyExec + " + ); + + let after_optimize = + LimitPushdown::new().optimize(outer_limit, &ConfigOptions::new())?; + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=1, fetch=3 + SortExec: TopK(fetch=4), expr=[c1@0 ASC], preserve_partitioning=[false] + EmptyExec + " + ); Ok(()) } diff --git a/datafusion/core/tests/physical_optimizer/mod.rs b/datafusion/core/tests/physical_optimizer/mod.rs index d11322cd26be9..cf179cb727cf1 100644 --- a/datafusion/core/tests/physical_optimizer/mod.rs +++ b/datafusion/core/tests/physical_optimizer/mod.rs @@ -24,7 +24,6 @@ mod combine_partial_final_agg; mod enforce_distribution; mod enforce_sorting; mod enforce_sorting_monotonicity; -#[expect(clippy::needless_pass_by_value)] mod filter_pushdown; mod join_selection; #[expect(clippy::needless_pass_by_value)] @@ -38,3 +37,5 @@ mod sanity_checker; #[expect(clippy::needless_pass_by_value)] mod test_utils; mod window_optimize; + +mod pushdown_utils; diff --git a/datafusion/core/tests/physical_optimizer/partition_statistics.rs b/datafusion/core/tests/physical_optimizer/partition_statistics.rs index b33305c23ede2..fa021ed3dcce3 100644 --- a/datafusion/core/tests/physical_optimizer/partition_statistics.rs +++ b/datafusion/core/tests/physical_optimizer/partition_statistics.rs @@ -826,7 +826,7 @@ mod test { let plan_string = get_plan_string(&aggregate_exec_partial).swap_remove(0); assert_snapshot!( plan_string, - @"AggregateExec: mode=Partial, gby=[id@0 as id, 1 + id@0 as expr], aggr=[COUNT(c)], ordering_mode=Sorted" + @"AggregateExec: mode=Partial, gby=[id@0 as id, 1 + id@0 as expr], aggr=[COUNT(c)]" ); let p0_statistics = aggregate_exec_partial.partition_statistics(Some(0))?; @@ -1294,4 +1294,64 @@ mod test { Ok(()) } + + #[tokio::test] + async fn test_statistics_by_partition_of_empty_exec() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + + // Try to test with single partition + let empty_single = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + let stats = empty_single.partition_statistics(Some(0))?; + assert_eq!(stats.num_rows, Precision::Exact(0)); + assert_eq!(stats.total_byte_size, Precision::Exact(0)); + assert_eq!(stats.column_statistics.len(), 2); + + for col_stat in &stats.column_statistics { + assert_eq!(col_stat.null_count, Precision::Exact(0)); + assert_eq!(col_stat.distinct_count, Precision::Exact(0)); + assert_eq!(col_stat.byte_size, Precision::Exact(0)); + assert_eq!(col_stat.min_value, Precision::::Absent); + assert_eq!(col_stat.max_value, Precision::::Absent); + assert_eq!(col_stat.sum_value, Precision::::Absent); + assert_eq!(col_stat.byte_size, Precision::Exact(0)); + } + + let overall_stats = empty_single.partition_statistics(None)?; + assert_eq!(stats, overall_stats); + + validate_statistics_with_data(empty_single, vec![ExpectedStatistics::Empty], 0) + .await?; + + // Test with multiple partitions + let empty_multi: Arc = + Arc::new(EmptyExec::new(Arc::clone(&schema)).with_partitions(3)); + + let statistics = (0..empty_multi.output_partitioning().partition_count()) + .map(|idx| empty_multi.partition_statistics(Some(idx))) + .collect::>>()?; + + assert_eq!(statistics.len(), 3); + + for stat in &statistics { + assert_eq!(stat.num_rows, Precision::Exact(0)); + assert_eq!(stat.total_byte_size, Precision::Exact(0)); + } + + validate_statistics_with_data( + empty_multi, + vec![ + ExpectedStatistics::Empty, + ExpectedStatistics::Empty, + ExpectedStatistics::Empty, + ], + 0, + ) + .await?; + + Ok(()) + } } diff --git a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs index ff87ad7212967..00e016ae02cad 100644 --- a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs @@ -1284,6 +1284,7 @@ fn test_hash_join_after_projection() -> Result<()> { None, PartitionMode::Auto, NullEquality::NullEqualsNothing, + false, )?); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ @@ -1722,3 +1723,47 @@ fn test_cooperative_exec_after_projection() -> Result<()> { Ok(()) } + +#[test] +fn test_hash_join_empty_projection_embeds() -> Result<()> { + let left_csv = create_simple_csv_exec(); + let right_csv = create_simple_csv_exec(); + + let join = Arc::new(HashJoinExec::try_new( + left_csv, + right_csv, + vec![(Arc::new(Column::new("a", 0)), Arc::new(Column::new("a", 0)))], + None, + &JoinType::Right, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + false, + )?); + + // Empty projection: no columns needed from the join output + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![] as Vec, + join, + )?); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + + // The empty projection should be embedded into the HashJoinExec, + // resulting in projection=[] on the join and no ProjectionExec wrapper. + assert_snapshot!( + actual, + @r" + HashJoinExec: mode=CollectLeft, join_type=Right, on=[(a@0, a@0)], projection=[] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); + + Ok(()) +} diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs b/datafusion/core/tests/physical_optimizer/pushdown_utils.rs similarity index 92% rename from datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs rename to datafusion/core/tests/physical_optimizer/pushdown_utils.rs index 1afdc4823f0a4..524d33ae6edb6 100644 --- a/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs +++ b/datafusion/core/tests/physical_optimizer/pushdown_utils.rs @@ -24,6 +24,7 @@ use datafusion_datasource::{ file_scan_config::FileScanConfigBuilder, file_stream::FileOpenFuture, file_stream::FileOpener, source::DataSourceExec, }; +use datafusion_physical_expr::projection::ProjectionExprs; use datafusion_physical_expr_common::physical_expr::fmt_sql; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::filter::batch_filter; @@ -50,7 +51,7 @@ use std::{ pub struct TestOpener { batches: Vec, batch_size: Option, - projection: Option>, + projection: Option, predicate: Option>, } @@ -60,6 +61,7 @@ impl FileOpener for TestOpener { if self.batches.is_empty() { return Ok((async { Ok(TestStream::new(vec![]).boxed()) }).boxed()); } + let schema = self.batches[0].schema(); if let Some(batch_size) = self.batch_size { let batch = concat_batches(&batches[0].schema(), &batches)?; let mut new_batches = Vec::new(); @@ -83,9 +85,10 @@ impl FileOpener for TestOpener { batches = new_batches; if let Some(projection) = &self.projection { + let projector = projection.make_projector(&schema)?; batches = batches .into_iter() - .map(|batch| batch.project(projection).unwrap()) + .map(|batch| projector.project_batch(&batch).unwrap()) .collect(); } @@ -103,14 +106,13 @@ pub struct TestSource { batch_size: Option, batches: Vec, metrics: ExecutionPlanMetricsSet, - projection: Option>, + projection: Option, table_schema: datafusion_datasource::TableSchema, } impl TestSource { pub fn new(schema: SchemaRef, support: bool, batches: Vec) -> Self { - let table_schema = - datafusion_datasource::TableSchema::new(Arc::clone(&schema), vec![]); + let table_schema = datafusion_datasource::TableSchema::new(schema, vec![]); Self { support, metrics: ExecutionPlanMetricsSet::new(), @@ -210,6 +212,30 @@ impl FileSource for TestSource { } } + fn try_pushdown_projection( + &self, + projection: &ProjectionExprs, + ) -> Result>> { + if let Some(existing_projection) = &self.projection { + // Combine existing projection with new projection + let combined_projection = existing_projection.try_merge(projection)?; + Ok(Some(Arc::new(TestSource { + projection: Some(combined_projection), + table_schema: self.table_schema.clone(), + ..self.clone() + }))) + } else { + Ok(Some(Arc::new(TestSource { + projection: Some(projection.clone()), + ..self.clone() + }))) + } + } + + fn projection(&self) -> Option<&ProjectionExprs> { + self.projection.as_ref() + } + fn table_schema(&self) -> &datafusion_datasource::TableSchema { &self.table_schema } @@ -332,6 +358,7 @@ pub struct OptimizationTest { } impl OptimizationTest { + #[expect(clippy::needless_pass_by_value)] pub fn new( input_plan: Arc, opt: O, diff --git a/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs index 84534b4fd833d..b717f546dc422 100644 --- a/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs @@ -1139,6 +1139,7 @@ fn hash_join_exec( None, PartitionMode::Partitioned, NullEquality::NullEqualsNothing, + false, ) .unwrap(), ) diff --git a/datafusion/core/tests/physical_optimizer/test_utils.rs b/datafusion/core/tests/physical_optimizer/test_utils.rs index 40beb12d48cdb..feac8190ffde4 100644 --- a/datafusion/core/tests/physical_optimizer/test_utils.rs +++ b/datafusion/core/tests/physical_optimizer/test_utils.rs @@ -247,6 +247,7 @@ pub fn hash_join_exec( None, PartitionMode::Partitioned, NullEquality::NullEqualsNothing, + false, )?)) } diff --git a/datafusion/core/tests/set_comparison.rs b/datafusion/core/tests/set_comparison.rs new file mode 100644 index 0000000000000..464d6c937b328 --- /dev/null +++ b/datafusion/core/tests/set_comparison.rs @@ -0,0 +1,193 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{Int32Array, StringArray}; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::record_batch::RecordBatch; +use datafusion::prelude::SessionContext; +use datafusion_common::{Result, assert_batches_eq, assert_contains}; + +fn build_table(values: &[i32]) -> Result { + let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)])); + let array = + Arc::new(Int32Array::from(values.to_vec())) as Arc; + RecordBatch::try_new(schema, vec![array]).map_err(Into::into) +} + +#[tokio::test] +async fn set_comparison_any() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[1, 6, 10])?)?; + // Include a NULL in the subquery input to ensure we propagate UNKNOWN correctly. + ctx.register_batch("s", { + let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)])); + let array = Arc::new(Int32Array::from(vec![Some(5), None])) + as Arc; + RecordBatch::try_new(schema, vec![array])? + })?; + + let df = ctx + .sql("select v from t where v > any(select v from s)") + .await?; + let results = df.collect().await?; + + assert_batches_eq!( + &["+----+", "| v |", "+----+", "| 6 |", "| 10 |", "+----+",], + &results + ); + Ok(()) +} + +#[tokio::test] +async fn set_comparison_any_aggregate_subquery() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[1, 7])?)?; + ctx.register_batch("s", build_table(&[1, 2, 3])?)?; + + let df = ctx + .sql( + "select v from t where v > any(select sum(v) from s group by v % 2) order by v", + ) + .await?; + let results = df.collect().await?; + + assert_batches_eq!(&["+---+", "| v |", "+---+", "| 7 |", "+---+",], &results); + Ok(()) +} + +#[tokio::test] +async fn set_comparison_all_empty() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[1, 6, 10])?)?; + ctx.register_batch( + "e", + RecordBatch::new_empty(Arc::new(Schema::new(vec![Field::new( + "v", + DataType::Int32, + true, + )]))), + )?; + + let df = ctx + .sql("select v from t where v < all(select v from e)") + .await?; + let results = df.collect().await?; + + assert_batches_eq!( + &[ + "+----+", "| v |", "+----+", "| 1 |", "| 6 |", "| 10 |", "+----+", + ], + &results + ); + Ok(()) +} + +#[tokio::test] +async fn set_comparison_type_mismatch() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[1])?)?; + ctx.register_batch("strings", { + let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)])); + let array = Arc::new(StringArray::from(vec![Some("a"), Some("b")])) + as Arc; + RecordBatch::try_new(schema, vec![array])? + })?; + + let df = ctx + .sql("select v from t where v > any(select s from strings)") + .await?; + let err = df.collect().await.unwrap_err(); + assert_contains!( + err.to_string(), + "expr type Int32 can't cast to Utf8 in SetComparison" + ); + Ok(()) +} + +#[tokio::test] +async fn set_comparison_multiple_operators() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[1, 2, 3, 4])?)?; + ctx.register_batch("s", build_table(&[2, 3])?)?; + + let df = ctx + .sql("select v from t where v = any(select v from s) order by v") + .await?; + let results = df.collect().await?; + assert_batches_eq!( + &["+---+", "| v |", "+---+", "| 2 |", "| 3 |", "+---+",], + &results + ); + + let df = ctx + .sql("select v from t where v != all(select v from s) order by v") + .await?; + let results = df.collect().await?; + assert_batches_eq!( + &["+---+", "| v |", "+---+", "| 1 |", "| 4 |", "+---+",], + &results + ); + + let df = ctx + .sql("select v from t where v >= all(select v from s) order by v") + .await?; + let results = df.collect().await?; + assert_batches_eq!( + &["+---+", "| v |", "+---+", "| 3 |", "| 4 |", "+---+",], + &results + ); + + let df = ctx + .sql("select v from t where v <= any(select v from s) order by v") + .await?; + let results = df.collect().await?; + assert_batches_eq!( + &[ + "+---+", "| v |", "+---+", "| 1 |", "| 2 |", "| 3 |", "+---+", + ], + &results + ); + Ok(()) +} + +#[tokio::test] +async fn set_comparison_null_semantics_all() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[5])?)?; + ctx.register_batch("s", { + let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)])); + let array = Arc::new(Int32Array::from(vec![Some(1), None])) + as Arc; + RecordBatch::try_new(schema, vec![array])? + })?; + + let df = ctx + .sql("select v from t where v != all(select v from s)") + .await?; + let results = df.collect().await?; + let row_count: usize = results.iter().map(|batch| batch.num_rows()).sum(); + assert_eq!(0, row_count); + Ok(()) +} diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index fa248c448683b..5f62f7204eff1 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -879,12 +879,13 @@ async fn parquet_explain_analyze() { let i_rowgroup_stat = formatted.find("row_groups_pruned_statistics").unwrap(); let i_rowgroup_bloomfilter = formatted.find("row_groups_pruned_bloom_filter").unwrap(); - let i_page = formatted.find("page_index_rows_pruned").unwrap(); + let i_page_rows = formatted.find("page_index_rows_pruned").unwrap(); + let i_page_pages = formatted.find("page_index_pages_pruned").unwrap(); assert!( (i_file < i_rowgroup_stat) && (i_rowgroup_stat < i_rowgroup_bloomfilter) - && (i_rowgroup_bloomfilter < i_page), + && (i_rowgroup_bloomfilter < i_page_pages && i_page_pages < i_page_rows), "The parquet pruning metrics should be displayed in an order of: file range -> row group statistics -> row group bloom filter -> page index." ); } diff --git a/datafusion/core/tests/sql/unparser.rs b/datafusion/core/tests/sql/unparser.rs index 8b56bf67a261c..ab1015b2d18d9 100644 --- a/datafusion/core/tests/sql/unparser.rs +++ b/datafusion/core/tests/sql/unparser.rs @@ -47,6 +47,7 @@ use datafusion_physical_plan::ExecutionPlanProperties; use datafusion_sql::unparser::Unparser; use datafusion_sql::unparser::dialect::DefaultDialect; use itertools::Itertools; +use recursive::{set_minimum_stack_size, set_stack_allocation_size}; /// Paths to benchmark query files (supports running from repo root or different working directories). const BENCHMARK_PATHS: &[&str] = &["../../benchmarks/", "./benchmarks/"]; @@ -458,5 +459,8 @@ async fn test_clickbench_unparser_roundtrip() { #[tokio::test] async fn test_tpch_unparser_roundtrip() { + // Grow stacker segments earlier to avoid deep unparser recursion overflow in q20. + set_minimum_stack_size(512 * 1024); + set_stack_allocation_size(8 * 1024 * 1024); run_roundtrip_tests("TPC-H", tpch_queries(), tpch_test_context).await; } diff --git a/datafusion/core/tests/user_defined/relation_planner.rs b/datafusion/core/tests/user_defined/relation_planner.rs index bda9b37ebea68..54af53ad858d4 100644 --- a/datafusion/core/tests/user_defined/relation_planner.rs +++ b/datafusion/core/tests/user_defined/relation_planner.rs @@ -68,9 +68,11 @@ fn plan_static_values_table( .project(vec![col("column1").alias(column_name)])? .build()?; - Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) } - other => Ok(RelationPlanning::Original(other)), + other => Ok(RelationPlanning::Original(Box::new(other))), } } @@ -176,9 +178,11 @@ impl RelationPlanner for SamplingJoinPlanner { .cross_join(right_sampled)? .build()?; - Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) } - other => Ok(RelationPlanning::Original(other)), + other => Ok(RelationPlanning::Original(Box::new(other))), } } } @@ -195,7 +199,7 @@ impl RelationPlanner for PassThroughPlanner { _context: &mut dyn RelationPlannerContext, ) -> Result { // Never handles anything - always delegates - Ok(RelationPlanning::Original(relation)) + Ok(RelationPlanning::Original(Box::new(relation))) } } @@ -217,7 +221,7 @@ impl RelationPlanner for PremiumFeaturePlanner { to unlock advanced array operations." .to_string(), )), - other => Ok(RelationPlanning::Original(other)), + other => Ok(RelationPlanning::Original(Box::new(other))), } } } diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index d53e076739608..990b05c49d82b 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -84,7 +84,7 @@ use datafusion::{ physical_expr::EquivalenceProperties, physical_plan::{ DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, - PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, + PlanProperties, RecordBatchStream, SendableRecordBatchStream, }, physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner}, prelude::{SessionConfig, SessionContext}, @@ -742,12 +742,6 @@ impl ExecutionPlan for TopKExec { state: BTreeMap::new(), })) } - - fn statistics(&self) -> Result { - // to improve the optimizability of this plan - // better statistics inference could be provided - Ok(Statistics::new_unknown(&self.schema())) - } } // A very specialized TopK implementation diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 24cade1e80d5a..b4ce3a03dbcbd 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -1306,19 +1306,14 @@ async fn create_scalar_function_from_sql_statement_default_arguments() -> Result "Error during planning: Non-default arguments cannot follow default arguments."; assert!(expected.starts_with(&err.strip_backtrace())); - // FIXME: The `DEFAULT` syntax does not work with positional params - let bad_expression_sql = r#" + let expression_sql = r#" CREATE FUNCTION bad_expression_fun(DOUBLE, DOUBLE DEFAULT 2.0) RETURNS DOUBLE RETURN $1 + $2 "#; - let err = ctx - .sql(bad_expression_sql) - .await - .expect_err("sqlparser error"); - let expected = - "SQL error: ParserError(\"Expected: ), found: 2.0 at Line: 2, Column: 63\")"; - assert!(expected.starts_with(&err.strip_backtrace())); + let result = ctx.sql(expression_sql).await; + + assert!(result.is_ok()); Ok(()) } diff --git a/datafusion/core/tests/user_defined/user_defined_table_functions.rs b/datafusion/core/tests/user_defined/user_defined_table_functions.rs index 8be8609c62480..95694d00a6c30 100644 --- a/datafusion/core/tests/user_defined/user_defined_table_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_table_functions.rs @@ -221,6 +221,31 @@ impl TableFunctionImpl for SimpleCsvTableFunc { } } +/// Test that expressions passed to UDTFs are properly type-coerced +/// This is a regression test for https://github.com/apache/datafusion/issues/19914 +#[tokio::test] +async fn test_udtf_type_coercion() -> Result<()> { + use datafusion::datasource::MemTable; + + #[derive(Debug)] + struct NoOpTableFunc; + + impl TableFunctionImpl for NoOpTableFunc { + fn call(&self, _: &[Expr]) -> Result> { + let schema = Arc::new(arrow::datatypes::Schema::empty()); + Ok(Arc::new(MemTable::try_new(schema, vec![vec![]])?)) + } + } + + let ctx = SessionContext::new(); + ctx.register_udtf("f", Arc::new(NoOpTableFunc)); + + // This should not panic - the array elements should be coerced to Float64 + let _ = ctx.sql("SELECT * FROM f(ARRAY[0.1, 1, 2])").await?; + + Ok(()) +} + fn read_csv_batches(csv_path: impl AsRef) -> Result<(SchemaRef, Vec)> { let mut file = File::open(csv_path)?; let (schema, _) = Format::default() diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index 57baf271c5913..775325a337184 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -536,7 +536,7 @@ impl OddCounter { impl SimpleWindowUDF { fn new(test_state: Arc) -> Self { let signature = - Signature::exact(vec![DataType::Float64], Volatility::Immutable); + Signature::exact(vec![DataType::Int64], Volatility::Immutable); Self { signature, test_state: test_state.into(), diff --git a/datafusion/datasource-arrow/src/mod.rs b/datafusion/datasource-arrow/src/mod.rs index cbfd7887093e7..4816a45942e5a 100644 --- a/datafusion/datasource-arrow/src/mod.rs +++ b/datafusion/datasource-arrow/src/mod.rs @@ -19,7 +19,6 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] -#![deny(clippy::allow_attributes)] //! [`ArrowFormat`]: Apache Arrow file format abstractions diff --git a/datafusion/datasource-avro/src/avro_to_arrow/schema.rs b/datafusion/datasource-avro/src/avro_to_arrow/schema.rs index 0e8f2a4d56088..053be3c9aff94 100644 --- a/datafusion/datasource-avro/src/avro_to_arrow/schema.rs +++ b/datafusion/datasource-avro/src/avro_to_arrow/schema.rs @@ -117,8 +117,8 @@ fn schema_to_field_with_props( .iter() .map(|s| schema_to_field_with_props(s, None, has_nullable, None)) .collect::>>()?; - let type_ids = 0_i8..fields.len() as i8; - DataType::Union(UnionFields::new(type_ids, fields), UnionMode::Dense) + // Assign type_ids based on the order in which they appear + DataType::Union(UnionFields::from_fields(fields), UnionMode::Dense) } } AvroSchema::Record(RecordSchema { fields, .. }) => { diff --git a/datafusion/datasource-avro/src/mod.rs b/datafusion/datasource-avro/src/mod.rs index 22c40e203a014..5ad209591e380 100644 --- a/datafusion/datasource-avro/src/mod.rs +++ b/datafusion/datasource-avro/src/mod.rs @@ -24,7 +24,6 @@ // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -#![deny(clippy::allow_attributes)] //! An [Avro](https://avro.apache.org/) based [`FileSource`](datafusion_datasource::file::FileSource) implementation and related functionality. diff --git a/datafusion/datasource-csv/src/mod.rs b/datafusion/datasource-csv/src/mod.rs index d58ce1188550c..fdfee05d86a79 100644 --- a/datafusion/datasource-csv/src/mod.rs +++ b/datafusion/datasource-csv/src/mod.rs @@ -19,7 +19,6 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] -#![deny(clippy::allow_attributes)] pub mod file_format; pub mod source; diff --git a/datafusion/datasource-json/Cargo.toml b/datafusion/datasource-json/Cargo.toml index 37fa8d43a0816..bd0cead8d2af8 100644 --- a/datafusion/datasource-json/Cargo.toml +++ b/datafusion/datasource-json/Cargo.toml @@ -44,7 +44,9 @@ datafusion-physical-plan = { workspace = true } datafusion-session = { workspace = true } futures = { workspace = true } object_store = { workspace = true } +serde_json = { workspace = true } tokio = { workspace = true } +tokio-stream = { workspace = true, features = ["sync"] } # Note: add additional linter rules in lib.rs. # Rust does not support workspace + new linter rules in subcrates yet diff --git a/datafusion/datasource-json/src/file_format.rs b/datafusion/datasource-json/src/file_format.rs index a14458b5acd36..881e5f3d873e6 100644 --- a/datafusion/datasource-json/src/file_format.rs +++ b/datafusion/datasource-json/src/file_format.rs @@ -15,13 +15,13 @@ // specific language governing permissions and limitations // under the License. -//! [`JsonFormat`]: Line delimited JSON [`FileFormat`] abstractions +//! [`JsonFormat`]: Line delimited and array JSON [`FileFormat`] abstractions use std::any::Any; use std::collections::HashMap; use std::fmt; use std::fmt::Debug; -use std::io::BufReader; +use std::io::{BufReader, Read}; use std::sync::Arc; use crate::source::JsonSource; @@ -31,6 +31,7 @@ use arrow::datatypes::{Schema, SchemaRef}; use arrow::error::ArrowError; use arrow::json; use arrow::json::reader::{ValueIter, infer_json_schema_from_iterator}; +use bytes::{Buf, Bytes}; use datafusion_common::config::{ConfigField, ConfigFileType, JsonOptions}; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::{ @@ -48,6 +49,7 @@ use datafusion_datasource::file_format::{ use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig}; use datafusion_datasource::sink::{DataSink, DataSinkExec}; +use datafusion_datasource::source::DataSourceExec; use datafusion_datasource::write::BatchSerializer; use datafusion_datasource::write::demux::DemuxedStreamReceiver; use datafusion_datasource::write::orchestration::spawn_writer_tasks_and_join; @@ -57,9 +59,8 @@ use datafusion_physical_expr_common::sort_expr::LexRequirement; use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; use datafusion_session::Session; +use crate::utils::JsonArrayToNdjsonReader; use async_trait::async_trait; -use bytes::{Buf, Bytes}; -use datafusion_datasource::source::DataSourceExec; use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; #[derive(Default)] @@ -132,7 +133,26 @@ impl Debug for JsonFormatFactory { } } -/// New line delimited JSON `FileFormat` implementation. +/// JSON `FileFormat` implementation supporting both line-delimited and array formats. +/// +/// # Supported Formats +/// +/// ## Line-Delimited JSON (default, `newline_delimited = true`) +/// ```text +/// {"key1": 1, "key2": "val"} +/// {"key1": 2, "key2": "vals"} +/// ``` +/// +/// ## JSON Array Format (`newline_delimited = false`) +/// ```text +/// [ +/// {"key1": 1, "key2": "val"}, +/// {"key1": 2, "key2": "vals"} +/// ] +/// ``` +/// +/// Note: JSON array format is processed using streaming conversion, +/// which is memory-efficient even for large files. #[derive(Debug, Default)] pub struct JsonFormat { options: JsonOptions, @@ -166,6 +186,57 @@ impl JsonFormat { self.options.compression = file_compression_type.into(); self } + + /// Set whether to read as newline-delimited JSON (NDJSON). + /// + /// When `true` (default), expects newline-delimited format: + /// ```text + /// {"a": 1} + /// {"a": 2} + /// ``` + /// + /// When `false`, expects JSON array format: + /// ```text + /// [{"a": 1}, {"a": 2}] + /// ``` + pub fn with_newline_delimited(mut self, newline_delimited: bool) -> Self { + self.options.newline_delimited = newline_delimited; + self + } + + /// Returns whether this format expects newline-delimited JSON. + pub fn is_newline_delimited(&self) -> bool { + self.options.newline_delimited + } +} + +/// Infer schema from JSON array format using streaming conversion. +/// +/// This function converts JSON array format to NDJSON on-the-fly and uses +/// arrow-json's schema inference. It properly tracks the number of records +/// processed for correct `records_to_read` management. +/// +/// # Returns +/// A tuple of (Schema, records_consumed) where records_consumed is the +/// number of records that were processed for schema inference. +fn infer_schema_from_json_array( + reader: R, + max_records: usize, +) -> Result<(Schema, usize)> { + let ndjson_reader = JsonArrayToNdjsonReader::new(reader); + + let iter = ValueIter::new(ndjson_reader, None); + let mut count = 0; + + let schema = infer_json_schema_from_iterator(iter.take_while(|_| { + let should_take = count < max_records; + if should_take { + count += 1; + } + should_take + }))?; + + Ok((schema, count)) } #[async_trait] @@ -202,37 +273,67 @@ impl FileFormat for JsonFormat { .schema_infer_max_rec .unwrap_or(DEFAULT_SCHEMA_INFER_MAX_RECORD); let file_compression_type = FileCompressionType::from(self.options.compression); + let newline_delimited = self.options.newline_delimited; + for object in objects { - let mut take_while = || { - let should_take = records_to_read > 0; - if should_take { - records_to_read -= 1; - } - should_take - }; + // Early exit if we've read enough records + if records_to_read == 0 { + break; + } let r = store.as_ref().get(&object.location).await?; - let schema = match r.payload { + + let (schema, records_consumed) = match r.payload { #[cfg(not(target_arch = "wasm32"))] GetResultPayload::File(file, _) => { let decoder = file_compression_type.convert_read(file)?; - let mut reader = BufReader::new(decoder); - let iter = ValueIter::new(&mut reader, None); - infer_json_schema_from_iterator(iter.take_while(|_| take_while()))? + let reader = BufReader::new(decoder); + + if newline_delimited { + // NDJSON: use ValueIter directly + let iter = ValueIter::new(reader, None); + let mut count = 0; + let schema = + infer_json_schema_from_iterator(iter.take_while(|_| { + let should_take = count < records_to_read; + if should_take { + count += 1; + } + should_take + }))?; + (schema, count) + } else { + // JSON array format: use streaming converter + infer_schema_from_json_array(reader, records_to_read)? + } } GetResultPayload::Stream(_) => { let data = r.bytes().await?; let decoder = file_compression_type.convert_read(data.reader())?; - let mut reader = BufReader::new(decoder); - let iter = ValueIter::new(&mut reader, None); - infer_json_schema_from_iterator(iter.take_while(|_| take_while()))? + let reader = BufReader::new(decoder); + + if newline_delimited { + let iter = ValueIter::new(reader, None); + let mut count = 0; + let schema = + infer_json_schema_from_iterator(iter.take_while(|_| { + let should_take = count < records_to_read; + if should_take { + count += 1; + } + should_take + }))?; + (schema, count) + } else { + // JSON array format: use streaming converter + infer_schema_from_json_array(reader, records_to_read)? + } } }; schemas.push(schema); - if records_to_read == 0 { - break; - } + // Correctly decrement records_to_read + records_to_read = records_to_read.saturating_sub(records_consumed); } let schema = Schema::try_merge(schemas)?; @@ -281,7 +382,10 @@ impl FileFormat for JsonFormat { } fn file_source(&self, table_schema: TableSchema) -> Arc { - Arc::new(JsonSource::new(table_schema)) + Arc::new( + JsonSource::new(table_schema) + .with_newline_delimited(self.options.newline_delimited), + ) } } diff --git a/datafusion/datasource-json/src/mod.rs b/datafusion/datasource-json/src/mod.rs index 3d27d4cc5ef5a..7dc0a0c7ba0f9 100644 --- a/datafusion/datasource-json/src/mod.rs +++ b/datafusion/datasource-json/src/mod.rs @@ -19,9 +19,9 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] -#![deny(clippy::allow_attributes)] pub mod file_format; pub mod source; +pub mod utils; pub use file_format::*; diff --git a/datafusion/datasource-json/src/source.rs b/datafusion/datasource-json/src/source.rs index 5797054f11b9c..867cfe0e98fea 100644 --- a/datafusion/datasource-json/src/source.rs +++ b/datafusion/datasource-json/src/source.rs @@ -15,17 +15,19 @@ // specific language governing permissions and limitations // under the License. -//! Execution plan for reading line-delimited JSON files +//! Execution plan for reading JSON files (line-delimited and array formats) use std::any::Any; use std::io::{BufReader, Read, Seek, SeekFrom}; +use std::pin::Pin; use std::sync::Arc; -use std::task::Poll; +use std::task::{Context, Poll}; use crate::file_format::JsonDecoder; +use crate::utils::{ChannelReader, JsonArrayToNdjsonReader}; use datafusion_common::error::{DataFusionError, Result}; -use datafusion_common_runtime::JoinSet; +use datafusion_common_runtime::{JoinSet, SpawnedTask}; use datafusion_datasource::decoder::{DecoderDeserializer, deserialize_stream}; use datafusion_datasource::file_compression_type::FileCompressionType; use datafusion_datasource::file_stream::{FileOpenFuture, FileOpener}; @@ -36,6 +38,7 @@ use datafusion_datasource::{ use datafusion_physical_plan::projection::ProjectionExprs; use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; +use arrow::array::RecordBatch; use arrow::json::ReaderBuilder; use arrow::{datatypes::SchemaRef, json}; use datafusion_datasource::file::FileSource; @@ -43,10 +46,55 @@ use datafusion_datasource::file_scan_config::FileScanConfig; use datafusion_execution::TaskContext; use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; -use futures::{StreamExt, TryStreamExt}; +use futures::{Stream, StreamExt, TryStreamExt}; use object_store::buffered::BufWriter; use object_store::{GetOptions, GetResultPayload, ObjectStore}; use tokio::io::AsyncWriteExt; +use tokio_stream::wrappers::ReceiverStream; + +/// Channel buffer size for streaming JSON array processing. +/// With ~128KB average chunk size, 128 chunks ≈ 16MB buffer. +const CHANNEL_BUFFER_SIZE: usize = 128; + +/// Buffer size for JsonArrayToNdjsonReader (2MB each, 4MB total for input+output) +const JSON_CONVERTER_BUFFER_SIZE: usize = 2 * 1024 * 1024; + +// ============================================================================ +// JsonArrayStream - Custom stream wrapper to hold SpawnedTask handles +// ============================================================================ + +/// A stream wrapper that holds SpawnedTask handles to keep them alive +/// until the stream is fully consumed or dropped. +/// +/// This ensures cancel-safety: when the stream is dropped, the tasks +/// are properly aborted via SpawnedTask's Drop implementation. +struct JsonArrayStream { + inner: ReceiverStream>, + /// Task that reads from object store and sends bytes to channel. + /// Kept alive until stream is consumed or dropped. + _read_task: SpawnedTask<()>, + /// Task that parses JSON and sends RecordBatches. + /// Kept alive until stream is consumed or dropped. + _parse_task: SpawnedTask<()>, +} + +impl Stream for JsonArrayStream { + type Item = std::result::Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_next(cx) + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} +// ============================================================================ +// JsonOpener and JsonSource +// ============================================================================ /// A [`FileOpener`] that opens a JSON file and yields a [`FileOpenFuture`] pub struct JsonOpener { @@ -54,21 +102,26 @@ pub struct JsonOpener { projected_schema: SchemaRef, file_compression_type: FileCompressionType, object_store: Arc, + /// When `true` (default), expects newline-delimited JSON (NDJSON). + /// When `false`, expects JSON array format `[{...}, {...}]`. + newline_delimited: bool, } impl JsonOpener { - /// Returns a [`JsonOpener`] + /// Returns a [`JsonOpener`] pub fn new( batch_size: usize, projected_schema: SchemaRef, file_compression_type: FileCompressionType, object_store: Arc, + newline_delimited: bool, ) -> Self { Self { batch_size, projected_schema, file_compression_type, object_store, + newline_delimited, } } } @@ -80,6 +133,9 @@ pub struct JsonSource { batch_size: Option, metrics: ExecutionPlanMetricsSet, projection: SplitProjection, + /// When `true` (default), expects newline-delimited JSON (NDJSON). + /// When `false`, expects JSON array format `[{...}, {...}]`. + newline_delimited: bool, } impl JsonSource { @@ -91,8 +147,18 @@ impl JsonSource { table_schema, batch_size: None, metrics: ExecutionPlanMetricsSet::new(), + newline_delimited: true, } } + + /// Set whether to read as newline-delimited JSON. + /// + /// When `true` (default), expects newline-delimited format. + /// When `false`, expects JSON array format `[{...}, {...}]`. + pub fn with_newline_delimited(mut self, newline_delimited: bool) -> Self { + self.newline_delimited = newline_delimited; + self + } } impl From for Arc { @@ -120,6 +186,7 @@ impl FileSource for JsonSource { projected_schema, file_compression_type: base_config.file_compression_type, object_store, + newline_delimited: self.newline_delimited, }) as Arc; // Wrap with ProjectionOpener @@ -172,7 +239,7 @@ impl FileSource for JsonSource { } impl FileOpener for JsonOpener { - /// Open a partitioned NDJSON file. + /// Open a partitioned JSON file. /// /// If `file_meta.range` is `None`, the entire file is opened. /// Else `file_meta.range` is `Some(FileRange{start, end})`, which corresponds to the byte range [start, end) within the file. @@ -181,11 +248,23 @@ impl FileOpener for JsonOpener { /// are applied to determine which lines to read: /// 1. The first line of the partition is the line in which the index of the first character >= `start`. /// 2. The last line of the partition is the line in which the byte at position `end - 1` resides. + /// + /// Note: JSON array format does not support range-based scanning. fn open(&self, partitioned_file: PartitionedFile) -> Result { let store = Arc::clone(&self.object_store); let schema = Arc::clone(&self.projected_schema); let batch_size = self.batch_size; let file_compression_type = self.file_compression_type.to_owned(); + let newline_delimited = self.newline_delimited; + + // JSON array format requires reading the complete file + if !newline_delimited && partitioned_file.range.is_some() { + return Err(DataFusionError::NotImplemented( + "JSON array format does not support range-based file scanning. \ + Disable repartition_file_scans or use newline-delimited JSON format." + .to_string(), + )); + } Ok(Box::pin(async move { let calculated_range = @@ -218,31 +297,150 @@ impl FileOpener for JsonOpener { Some(_) => { file.seek(SeekFrom::Start(result.range.start as _))?; let limit = result.range.end - result.range.start; - file_compression_type.convert_read(file.take(limit as u64))? + file_compression_type.convert_read(file.take(limit))? } }; - let reader = ReaderBuilder::new(schema) - .with_batch_size(batch_size) - .build(BufReader::new(bytes))?; - - Ok(futures::stream::iter(reader) - .map(|r| r.map_err(Into::into)) - .boxed()) + if newline_delimited { + // NDJSON: use BufReader directly + let reader = BufReader::new(bytes); + let arrow_reader = ReaderBuilder::new(schema) + .with_batch_size(batch_size) + .build(reader)?; + + Ok(futures::stream::iter(arrow_reader) + .map(|r| r.map_err(Into::into)) + .boxed()) + } else { + // JSON array format: wrap with streaming converter + let ndjson_reader = JsonArrayToNdjsonReader::with_capacity( + bytes, + JSON_CONVERTER_BUFFER_SIZE, + ); + let arrow_reader = ReaderBuilder::new(schema) + .with_batch_size(batch_size) + .build(ndjson_reader)?; + + Ok(futures::stream::iter(arrow_reader) + .map(|r| r.map_err(Into::into)) + .boxed()) + } } GetResultPayload::Stream(s) => { - let s = s.map_err(DataFusionError::from); - - let decoder = ReaderBuilder::new(schema) - .with_batch_size(batch_size) - .build_decoder()?; - let input = file_compression_type.convert_stream(s.boxed())?.fuse(); - - let stream = deserialize_stream( - input, - DecoderDeserializer::new(JsonDecoder::new(decoder)), - ); - Ok(stream.map_err(Into::into).boxed()) + if newline_delimited { + // Newline-delimited JSON (NDJSON) streaming reader + let s = s.map_err(DataFusionError::from); + let decoder = ReaderBuilder::new(schema) + .with_batch_size(batch_size) + .build_decoder()?; + let input = + file_compression_type.convert_stream(s.boxed())?.fuse(); + let stream = deserialize_stream( + input, + DecoderDeserializer::new(JsonDecoder::new(decoder)), + ); + Ok(stream.map_err(Into::into).boxed()) + } else { + // JSON array format: streaming conversion with channel-based byte transfer + // + // Architecture: + // 1. Async task reads from object store stream, decompresses, sends to channel + // 2. Blocking task receives bytes, converts JSON array to NDJSON, parses to Arrow + // 3. RecordBatches are sent back via another channel + // + // Memory budget (~32MB): + // - sync_channel: CHANNEL_BUFFER_SIZE chunks (~16MB) + // - JsonArrayToNdjsonReader: 2 × JSON_CONVERTER_BUFFER_SIZE (~4MB) + // - Arrow JsonReader internal buffer (~8MB) + // - Miscellaneous (~4MB) + + let s = s.map_err(DataFusionError::from); + let decompressed_stream = + file_compression_type.convert_stream(s.boxed())?; + + // Channel for bytes: async producer -> blocking consumer + // Uses tokio::sync::mpsc so the async send never blocks a + // tokio worker thread; the consumer calls blocking_recv() + // inside spawn_blocking. + let (byte_tx, byte_rx) = tokio::sync::mpsc::channel::( + CHANNEL_BUFFER_SIZE, + ); + + // Channel for results: sync producer -> async consumer + let (result_tx, result_rx) = tokio::sync::mpsc::channel(2); + let error_tx = result_tx.clone(); + + // Async task: read from object store stream and send bytes to channel + // Store the SpawnedTask to keep it alive until stream is dropped + let read_task = SpawnedTask::spawn(async move { + tokio::pin!(decompressed_stream); + while let Some(chunk) = decompressed_stream.next().await { + match chunk { + Ok(bytes) => { + if byte_tx.send(bytes).await.is_err() { + break; // Consumer dropped + } + } + Err(e) => { + let _ = error_tx + .send(Err( + arrow::error::ArrowError::ExternalError( + Box::new(e), + ), + )) + .await; + break; + } + } + } + // byte_tx dropped here, signals EOF to ChannelReader + }); + + // Blocking task: receive bytes from channel and parse JSON + // Store the SpawnedTask to keep it alive until stream is dropped + let parse_task = SpawnedTask::spawn_blocking(move || { + let channel_reader = ChannelReader::new(byte_rx); + let mut ndjson_reader = + JsonArrayToNdjsonReader::with_capacity( + channel_reader, + JSON_CONVERTER_BUFFER_SIZE, + ); + + match ReaderBuilder::new(schema) + .with_batch_size(batch_size) + .build(&mut ndjson_reader) + { + Ok(arrow_reader) => { + for batch_result in arrow_reader { + if result_tx.blocking_send(batch_result).is_err() + { + break; // Receiver dropped + } + } + } + Err(e) => { + let _ = result_tx.blocking_send(Err(e)); + } + } + + // Validate the JSON array was properly formed + if let Err(e) = ndjson_reader.validate_complete() { + let _ = result_tx.blocking_send(Err( + arrow::error::ArrowError::JsonError(e.to_string()), + )); + } + // result_tx dropped here, closes the stream + }); + + // Wrap in JsonArrayStream to keep tasks alive until stream is consumed + let stream = JsonArrayStream { + inner: ReceiverStream::new(result_rx), + _read_task: read_task, + _parse_task: parse_task, + }; + + Ok(stream.map(|r| r.map_err(Into::into)).boxed()) + } } } })) @@ -303,3 +501,307 @@ pub async fn plan_to_json( Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::{DataType, Field, Schema}; + use bytes::Bytes; + use datafusion_datasource::FileRange; + use futures::TryStreamExt; + use object_store::PutPayload; + use object_store::memory::InMemory; + use object_store::path::Path; + + /// Helper to create a test schema + fn test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, true), + Field::new("name", DataType::Utf8, true), + ])) + } + + #[tokio::test] + async fn test_json_array_from_file() -> Result<()> { + // Test reading JSON array format from a file + let json_data = r#"[{"id": 1, "name": "alice"}, {"id": 2, "name": "bob"}]"#; + + let store = Arc::new(InMemory::new()); + let path = Path::from("test.json"); + store + .put(&path, PutPayload::from_static(json_data.as_bytes())) + .await?; + + let opener = JsonOpener::new( + 1024, + test_schema(), + FileCompressionType::UNCOMPRESSED, + store.clone(), + false, // JSON array format + ); + + let meta = store.head(&path).await?; + let file = PartitionedFile::new(path.to_string(), meta.size); + + let stream = opener.open(file)?.await?; + let batches: Vec<_> = stream.try_collect().await?; + + assert_eq!(batches.len(), 1); + assert_eq!(batches[0].num_rows(), 2); + + Ok(()) + } + + #[tokio::test] + async fn test_json_array_from_stream() -> Result<()> { + // Test reading JSON array format from object store stream (simulates S3) + let json_data = r#"[{"id": 1, "name": "alice"}, {"id": 2, "name": "bob"}, {"id": 3, "name": "charlie"}]"#; + + // Use InMemory store which returns Stream payload + let store = Arc::new(InMemory::new()); + let path = Path::from("test_stream.json"); + store + .put(&path, PutPayload::from_static(json_data.as_bytes())) + .await?; + + let opener = JsonOpener::new( + 2, // small batch size to test multiple batches + test_schema(), + FileCompressionType::UNCOMPRESSED, + store.clone(), + false, // JSON array format + ); + + let meta = store.head(&path).await?; + let file = PartitionedFile::new(path.to_string(), meta.size); + + let stream = opener.open(file)?.await?; + let batches: Vec<_> = stream.try_collect().await?; + + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 3); + + Ok(()) + } + + #[tokio::test] + async fn test_json_array_nested_objects() -> Result<()> { + // Test JSON array with nested objects and arrays + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, true), + Field::new("data", DataType::Utf8, true), + ])); + + let json_data = r#"[ + {"id": 1, "data": "{\"nested\": true}"}, + {"id": 2, "data": "[1, 2, 3]"} + ]"#; + + let store = Arc::new(InMemory::new()); + let path = Path::from("nested.json"); + store + .put(&path, PutPayload::from_static(json_data.as_bytes())) + .await?; + + let opener = JsonOpener::new( + 1024, + schema, + FileCompressionType::UNCOMPRESSED, + store.clone(), + false, + ); + + let meta = store.head(&path).await?; + let file = PartitionedFile::new(path.to_string(), meta.size); + + let stream = opener.open(file)?.await?; + let batches: Vec<_> = stream.try_collect().await?; + + assert_eq!(batches[0].num_rows(), 2); + + Ok(()) + } + + #[tokio::test] + async fn test_json_array_empty() -> Result<()> { + // Test empty JSON array + let json_data = "[]"; + + let store = Arc::new(InMemory::new()); + let path = Path::from("empty.json"); + store + .put(&path, PutPayload::from_static(json_data.as_bytes())) + .await?; + + let opener = JsonOpener::new( + 1024, + test_schema(), + FileCompressionType::UNCOMPRESSED, + store.clone(), + false, + ); + + let meta = store.head(&path).await?; + let file = PartitionedFile::new(path.to_string(), meta.size); + + let stream = opener.open(file)?.await?; + let batches: Vec<_> = stream.try_collect().await?; + + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 0); + + Ok(()) + } + + #[tokio::test] + async fn test_json_array_range_not_supported() { + // Test that range-based scanning returns error for JSON array format + let store = Arc::new(InMemory::new()); + let path = Path::from("test.json"); + store + .put(&path, PutPayload::from_static(b"[]")) + .await + .unwrap(); + + let opener = JsonOpener::new( + 1024, + test_schema(), + FileCompressionType::UNCOMPRESSED, + store.clone(), + false, // JSON array format + ); + + let meta = store.head(&path).await.unwrap(); + let mut file = PartitionedFile::new(path.to_string(), meta.size); + file.range = Some(FileRange { start: 0, end: 10 }); + + let result = opener.open(file); + match result { + Ok(_) => panic!("Expected error for range-based JSON array scanning"), + Err(e) => { + assert!( + e.to_string().contains("does not support range-based"), + "Unexpected error message: {e}" + ); + } + } + } + + #[tokio::test] + async fn test_ndjson_still_works() -> Result<()> { + // Ensure NDJSON format still works correctly + let json_data = + "{\"id\": 1, \"name\": \"alice\"}\n{\"id\": 2, \"name\": \"bob\"}\n"; + + let store = Arc::new(InMemory::new()); + let path = Path::from("test.ndjson"); + store + .put(&path, PutPayload::from_static(json_data.as_bytes())) + .await?; + + let opener = JsonOpener::new( + 1024, + test_schema(), + FileCompressionType::UNCOMPRESSED, + store.clone(), + true, // NDJSON format + ); + + let meta = store.head(&path).await?; + let file = PartitionedFile::new(path.to_string(), meta.size); + + let stream = opener.open(file)?.await?; + let batches: Vec<_> = stream.try_collect().await?; + + assert_eq!(batches.len(), 1); + assert_eq!(batches[0].num_rows(), 2); + + Ok(()) + } + + #[tokio::test] + async fn test_json_array_large_file() -> Result<()> { + // Test with a larger JSON array to verify streaming works + let mut json_data = String::from("["); + for i in 0..1000 { + if i > 0 { + json_data.push(','); + } + json_data.push_str(&format!(r#"{{"id": {i}, "name": "user{i}"}}"#)); + } + json_data.push(']'); + + let store = Arc::new(InMemory::new()); + let path = Path::from("large.json"); + store + .put(&path, PutPayload::from(Bytes::from(json_data))) + .await?; + + let opener = JsonOpener::new( + 100, // batch size of 100 + test_schema(), + FileCompressionType::UNCOMPRESSED, + store.clone(), + false, + ); + + let meta = store.head(&path).await?; + let file = PartitionedFile::new(path.to_string(), meta.size); + + let stream = opener.open(file)?.await?; + let batches: Vec<_> = stream.try_collect().await?; + + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 1000); + + // Should have multiple batches due to batch_size=100 + assert!(batches.len() >= 10); + + Ok(()) + } + + #[tokio::test] + async fn test_json_array_stream_cancellation() -> Result<()> { + // Test that cancellation works correctly (tasks are aborted when stream is dropped) + let mut json_data = String::from("["); + for i in 0..10000 { + if i > 0 { + json_data.push(','); + } + json_data.push_str(&format!(r#"{{"id": {i}, "name": "user{i}"}}"#)); + } + json_data.push(']'); + + let store = Arc::new(InMemory::new()); + let path = Path::from("cancel_test.json"); + store + .put(&path, PutPayload::from(Bytes::from(json_data))) + .await?; + + let opener = JsonOpener::new( + 10, // small batch size + test_schema(), + FileCompressionType::UNCOMPRESSED, + store.clone(), + false, + ); + + let meta = store.head(&path).await?; + let file = PartitionedFile::new(path.to_string(), meta.size); + + let mut stream = opener.open(file)?.await?; + + // Read only first batch, then drop the stream (simulating cancellation) + let first_batch = stream.next().await; + assert!(first_batch.is_some()); + + // Drop the stream - this should abort the spawned tasks via SpawnedTask's Drop + drop(stream); + + // Give tasks time to be aborted + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // If we reach here without hanging, cancellation worked + Ok(()) + } +} diff --git a/datafusion/datasource-json/src/utils.rs b/datafusion/datasource-json/src/utils.rs new file mode 100644 index 0000000000000..bc75799edff73 --- /dev/null +++ b/datafusion/datasource-json/src/utils.rs @@ -0,0 +1,778 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utility types for JSON processing + +use std::io::{BufRead, Read}; + +use bytes::Bytes; + +// ============================================================================ +// JsonArrayToNdjsonReader - Streaming JSON Array to NDJSON Converter +// ============================================================================ +// +// Architecture: +// +// ```text +// ┌─────────────────────────────────────────────────────────────┐ +// │ JSON Array File (potentially very large, e.g. 33GB) │ +// │ [{"a":1}, {"a":2}, {"a":3}, ...... {"a":1000000}] │ +// └─────────────────────────────────────────────────────────────┘ +// │ +// ▼ read chunks via ChannelReader +// ┌───────────────────┐ +// │ JsonArrayToNdjson │ ← character substitution only: +// │ Reader │ '[' skip, ',' → '\n', ']' stop +// └───────────────────┘ +// │ +// ▼ outputs NDJSON format +// ┌───────────────────┐ +// │ Arrow Reader │ ← internal buffer, batch parsing +// │ batch_size=8192 │ +// └───────────────────┘ +// │ +// ▼ outputs RecordBatch +// ┌───────────────────┐ +// │ RecordBatch │ +// └───────────────────┘ +// ``` +// +// Memory Efficiency: +// +// | Approach | Memory for 33GB file | Parse count | +// |---------------------------------------|----------------------|-------------| +// | Load entire file + serde_json | ~100GB+ | 3x | +// | Streaming with JsonArrayToNdjsonReader| ~32MB (configurable) | 1x | +// +// Design Note: +// +// This implementation uses `inner: R` directly (not `BufReader`) and manages +// its own input buffer. This is critical for compatibility with `SyncIoBridge` +// and `ChannelReader` in `spawn_blocking` contexts. +// + +/// Default buffer size for JsonArrayToNdjsonReader (2MB for better throughput) +const DEFAULT_BUF_SIZE: usize = 2 * 1024 * 1024; + +/// Parser state for JSON array streaming +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum JsonArrayState { + /// Initial state, looking for opening '[' + Start, + /// Inside the JSON array, processing objects + InArray, + /// Reached the closing ']', finished + Done, +} + +/// A streaming reader that converts JSON array format to NDJSON format. +/// +/// This reader wraps an underlying reader containing JSON array data +/// `[{...}, {...}, ...]` and transforms it on-the-fly to newline-delimited +/// JSON format that Arrow's JSON reader can process. +/// +/// Implements both `Read` and `BufRead` traits for compatibility with Arrow's +/// `ReaderBuilder::build()` which requires `BufRead`. +/// +/// # Transformation Rules +/// +/// - Skip leading `[` and whitespace before it +/// - Convert top-level `,` (between objects) to `\n` +/// - Skip whitespace at top level (between objects) +/// - Stop at trailing `]` +/// - Preserve everything inside objects (including nested `[`, `]`, `,`) +/// - Properly handle strings (ignore special chars inside quotes) +/// +/// # Example +/// +/// ```text +/// Input: [{"a":1}, {"b":[1,2]}, {"c":"x,y"}] +/// Output: {"a":1} +/// {"b":[1,2]} +/// {"c":"x,y"} +/// ``` +pub struct JsonArrayToNdjsonReader { + /// Inner reader - we use R directly (not `BufReader`) for SyncIoBridge compatibility + inner: R, + state: JsonArrayState, + /// Tracks nesting depth of `{` and `[` to identify top-level commas + depth: i32, + /// Whether we're currently inside a JSON string + in_string: bool, + /// Whether the next character is escaped (after `\`) + escape_next: bool, + /// Input buffer - stores raw bytes read from inner reader + input_buffer: Vec, + /// Current read position in input buffer + input_pos: usize, + /// Number of valid bytes in input buffer + input_filled: usize, + /// Output buffer - stores transformed NDJSON bytes + output_buffer: Vec, + /// Current read position in output buffer + output_pos: usize, + /// Number of valid bytes in output buffer + output_filled: usize, + /// Whether trailing non-whitespace content was detected after ']' + has_trailing_content: bool, + /// Whether leading non-whitespace content was detected before '[' + has_leading_content: bool, +} + +impl JsonArrayToNdjsonReader { + /// Create a new streaming reader that converts JSON array to NDJSON. + pub fn new(reader: R) -> Self { + Self::with_capacity(reader, DEFAULT_BUF_SIZE) + } + + /// Create a new streaming reader with custom buffer size. + /// + /// Larger buffers improve throughput but use more memory. + /// Total memory usage is approximately 2 * capacity (input + output buffers). + pub fn with_capacity(reader: R, capacity: usize) -> Self { + Self { + inner: reader, + state: JsonArrayState::Start, + depth: 0, + in_string: false, + escape_next: false, + input_buffer: vec![0; capacity], + input_pos: 0, + input_filled: 0, + output_buffer: vec![0; capacity], + output_pos: 0, + output_filled: 0, + has_trailing_content: false, + has_leading_content: false, + } + } + + /// Check if the JSON array was properly terminated. + /// + /// This should be called after all data has been read. + /// + /// Returns an error if: + /// - Unbalanced braces/brackets (depth != 0) + /// - Unterminated string + /// - Missing closing `]` + /// - Unexpected trailing content after `]` + pub fn validate_complete(&self) -> std::io::Result<()> { + if self.has_leading_content { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Malformed JSON: unexpected leading content before '['", + )); + } + if self.depth != 0 { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Malformed JSON array: unbalanced braces or brackets", + )); + } + if self.in_string { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Malformed JSON array: unterminated string", + )); + } + if self.state != JsonArrayState::Done { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Incomplete JSON array: expected closing bracket ']'", + )); + } + if self.has_trailing_content { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Malformed JSON: unexpected trailing content after ']'", + )); + } + Ok(()) + } + + /// Process a single byte and return the transformed byte (if any) + #[inline] + fn process_byte(&mut self, byte: u8) -> Option { + match self.state { + JsonArrayState::Start => { + // Looking for the opening '[', skip whitespace + if byte == b'[' { + self.state = JsonArrayState::InArray; + } else if !byte.is_ascii_whitespace() { + self.has_leading_content = true; + } + None + } + JsonArrayState::InArray => { + // Handle escape sequences in strings + if self.escape_next { + self.escape_next = false; + return Some(byte); + } + + if self.in_string { + // Inside a string: handle escape and closing quote + match byte { + b'\\' => self.escape_next = true, + b'"' => self.in_string = false, + _ => {} + } + Some(byte) + } else { + // Outside strings: track depth and transform + match byte { + b'"' => { + self.in_string = true; + Some(byte) + } + b'{' | b'[' => { + self.depth += 1; + Some(byte) + } + b'}' => { + self.depth -= 1; + Some(byte) + } + b']' => { + if self.depth == 0 { + // Top-level ']' means end of array + self.state = JsonArrayState::Done; + None + } else { + // Nested ']' inside an object + self.depth -= 1; + Some(byte) + } + } + b',' if self.depth == 0 => { + // Top-level comma between objects → newline + Some(b'\n') + } + _ => { + // At depth 0, skip whitespace between objects + if self.depth == 0 && byte.is_ascii_whitespace() { + None + } else { + Some(byte) + } + } + } + } + } + JsonArrayState::Done => { + // After ']', check for non-whitespace trailing content + if !byte.is_ascii_whitespace() { + self.has_trailing_content = true; + } + None + } + } + } + + /// Refill input buffer from inner reader if needed. + /// Returns true if there's data available, false on EOF. + fn refill_input_if_needed(&mut self) -> std::io::Result { + if self.input_pos >= self.input_filled { + // Input buffer exhausted, read more from inner + let bytes_read = self.inner.read(&mut self.input_buffer)?; + if bytes_read == 0 { + return Ok(false); // EOF + } + self.input_pos = 0; + self.input_filled = bytes_read; + } + Ok(true) + } + + /// Fill the output buffer with transformed data. + /// + /// This method manages its own input buffer, reading from the inner reader + /// as needed. When the output buffer is full, we stop processing but preserve + /// the current position in the input buffer for the next call. + fn fill_output_buffer(&mut self) -> std::io::Result<()> { + let mut write_pos = 0; + + while write_pos < self.output_buffer.len() { + // Refill input buffer if exhausted + if !self.refill_input_if_needed()? { + break; // EOF + } + + // Process bytes from input buffer + while self.input_pos < self.input_filled + && write_pos < self.output_buffer.len() + { + let byte = self.input_buffer[self.input_pos]; + self.input_pos += 1; + + if let Some(transformed) = self.process_byte(byte) { + self.output_buffer[write_pos] = transformed; + write_pos += 1; + } + } + } + + self.output_pos = 0; + self.output_filled = write_pos; + Ok(()) + } +} + +impl Read for JsonArrayToNdjsonReader { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + // If output buffer is empty, fill it + if self.output_pos >= self.output_filled { + self.fill_output_buffer()?; + if self.output_filled == 0 { + return Ok(0); // EOF + } + } + + // Copy from output buffer to caller's buffer + let available = self.output_filled - self.output_pos; + let to_copy = std::cmp::min(available, buf.len()); + buf[..to_copy].copy_from_slice( + &self.output_buffer[self.output_pos..self.output_pos + to_copy], + ); + self.output_pos += to_copy; + Ok(to_copy) + } +} + +impl BufRead for JsonArrayToNdjsonReader { + fn fill_buf(&mut self) -> std::io::Result<&[u8]> { + if self.output_pos >= self.output_filled { + self.fill_output_buffer()?; + } + Ok(&self.output_buffer[self.output_pos..self.output_filled]) + } + + fn consume(&mut self, amt: usize) { + self.output_pos = std::cmp::min(self.output_pos + amt, self.output_filled); + } +} + +// ============================================================================ +// ChannelReader - Sync reader that receives bytes from async channel +// ============================================================================ +// +// Architecture: +// +// ```text +// ┌─────────────────────────────────────────────────────────────────────────┐ +// │ S3 / MinIO (async) │ +// │ (33GB JSON Array File) │ +// └─────────────────────────────────────────────────────────────────────────┘ +// │ +// ▼ async stream (Bytes chunks) +// ┌─────────────────────────────────────────────────────────────────────────┐ +// │ Async Task (tokio runtime) │ +// │ while let Some(chunk) = stream.next().await │ +// │ byte_tx.send(chunk) │ +// └─────────────────────────────────────────────────────────────────────────┘ +// │ +// ▼ tokio::sync::mpsc::channel +// │ (bounded, ~32MB buffer) +// ▼ +// ┌─────────────────────────────────────────────────────────────────────────┐ +// │ Blocking Task (spawn_blocking) │ +// │ ┌──────────────┐ ┌────────────────────────┐ ┌──────────────────┐ │ +// │ │ChannelReader │ → │JsonArrayToNdjsonReader │ → │ Arrow JsonReader │ │ +// │ │ (Read) │ │ [{},...] → {}\n{} │ │ (RecordBatch) │ │ +// │ └──────────────┘ └────────────────────────┘ └──────────────────┘ │ +// └─────────────────────────────────────────────────────────────────────────┘ +// │ +// ▼ tokio::sync::mpsc::channel +// ┌─────────────────────────────────────────────────────────────────────────┐ +// │ ReceiverStream (async) │ +// │ → DataFusion execution engine │ +// └─────────────────────────────────────────────────────────────────────────┘ +// ``` +// +// Memory Budget (~32MB total): +// - sync_channel buffer: 128 chunks × ~128KB = ~16MB +// - JsonArrayToNdjsonReader: 2 × 2MB = 4MB +// - Arrow JsonReader internal: ~8MB +// - Miscellaneous: ~4MB +// + +/// A synchronous `Read` implementation that receives bytes from an async channel. +/// +/// This enables true streaming between async and sync contexts without +/// loading the entire file into memory. Uses `tokio::sync::mpsc::Receiver` +/// with `blocking_recv()` so the async producer never blocks a tokio worker +/// thread, while the sync consumer (running in `spawn_blocking`) safely blocks. +pub struct ChannelReader { + rx: tokio::sync::mpsc::Receiver, + current: Option, + pos: usize, +} + +impl ChannelReader { + /// Create a new ChannelReader from a tokio mpsc receiver. + pub fn new(rx: tokio::sync::mpsc::Receiver) -> Self { + Self { + rx, + current: None, + pos: 0, + } + } +} + +impl Read for ChannelReader { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + loop { + // If we have current chunk with remaining data, read from it + if let Some(ref chunk) = self.current { + let remaining = chunk.len() - self.pos; + if remaining > 0 { + let to_copy = std::cmp::min(remaining, buf.len()); + buf[..to_copy].copy_from_slice(&chunk[self.pos..self.pos + to_copy]); + self.pos += to_copy; + return Ok(to_copy); + } + } + + // Current chunk exhausted, get next from channel + match self.rx.blocking_recv() { + Some(bytes) => { + self.current = Some(bytes); + self.pos = 0; + // Loop back to read from new chunk + } + None => return Ok(0), // Channel closed = EOF + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_json_array_to_ndjson_simple() { + let input = r#"[{"a":1}, {"a":2}, {"a":3}]"#; + let mut reader = JsonArrayToNdjsonReader::new(input.as_bytes()); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + assert_eq!(output, "{\"a\":1}\n{\"a\":2}\n{\"a\":3}"); + } + + #[test] + fn test_json_array_to_ndjson_nested() { + let input = r#"[{"a":{"b":1}}, {"c":[1,2,3]}]"#; + let mut reader = JsonArrayToNdjsonReader::new(input.as_bytes()); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + assert_eq!(output, "{\"a\":{\"b\":1}}\n{\"c\":[1,2,3]}"); + } + + #[test] + fn test_json_array_to_ndjson_strings_with_special_chars() { + let input = r#"[{"a":"[1,2]"}, {"b":"x,y"}]"#; + let mut reader = JsonArrayToNdjsonReader::new(input.as_bytes()); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + assert_eq!(output, "{\"a\":\"[1,2]\"}\n{\"b\":\"x,y\"}"); + } + + #[test] + fn test_json_array_to_ndjson_escaped_quotes() { + let input = r#"[{"a":"say \"hello\""}, {"b":1}]"#; + let mut reader = JsonArrayToNdjsonReader::new(input.as_bytes()); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + assert_eq!(output, "{\"a\":\"say \\\"hello\\\"\"}\n{\"b\":1}"); + } + + #[test] + fn test_json_array_to_ndjson_empty() { + let input = r#"[]"#; + let mut reader = JsonArrayToNdjsonReader::new(input.as_bytes()); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + assert_eq!(output, ""); + } + + #[test] + fn test_json_array_to_ndjson_single_element() { + let input = r#"[{"a":1}]"#; + let mut reader = JsonArrayToNdjsonReader::new(input.as_bytes()); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + assert_eq!(output, "{\"a\":1}"); + } + + #[test] + fn test_json_array_to_ndjson_bufread() { + let input = r#"[{"a":1}, {"a":2}]"#; + let mut reader = JsonArrayToNdjsonReader::new(input.as_bytes()); + + let buf = reader.fill_buf().unwrap(); + assert!(!buf.is_empty()); + + let first_len = buf.len(); + reader.consume(first_len); + + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + } + + #[test] + fn test_json_array_to_ndjson_whitespace() { + let input = r#" [ {"a":1} , {"a":2} ] "#; + let mut reader = JsonArrayToNdjsonReader::new(input.as_bytes()); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + // Top-level whitespace is skipped, internal whitespace preserved + assert_eq!(output, "{\"a\":1}\n{\"a\":2}"); + } + + #[test] + fn test_validate_complete_valid_json() { + let valid_json = r#"[{"a":1},{"a":2}]"#; + let mut reader = JsonArrayToNdjsonReader::new(valid_json.as_bytes()); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + reader.validate_complete().unwrap(); + } + + #[test] + fn test_json_array_with_trailing_junk() { + let input = r#" [ {"a":1} , {"a":2} ] some { junk [ here ] "#; + let mut reader = JsonArrayToNdjsonReader::new(input.as_bytes()); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + + // Should extract the valid array content + assert_eq!(output, "{\"a\":1}\n{\"a\":2}"); + + // But validation should catch the trailing junk + let result = reader.validate_complete(); + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("trailing content") + || err_msg.contains("Unexpected trailing"), + "Expected trailing content error, got: {err_msg}" + ); + } + + #[test] + fn test_validate_complete_incomplete_array() { + let invalid_json = r#"[{"a":1},{"a":2}"#; // Missing closing ] + let mut reader = JsonArrayToNdjsonReader::new(invalid_json.as_bytes()); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + + let result = reader.validate_complete(); + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("expected closing bracket") + || err_msg.contains("missing closing"), + "Expected missing bracket error, got: {err_msg}" + ); + } + + #[test] + fn test_validate_complete_unbalanced_braces() { + let invalid_json = r#"[{"a":1},{"a":2]"#; // Wrong closing bracket + let mut reader = JsonArrayToNdjsonReader::new(invalid_json.as_bytes()); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + + let result = reader.validate_complete(); + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("unbalanced") + || err_msg.contains("expected closing bracket"), + "Expected unbalanced or missing bracket error, got: {err_msg}" + ); + } + + #[test] + fn test_json_array_with_leading_junk() { + let input = r#"junk[{"a":1}, {"a":2}]"#; + let mut reader = JsonArrayToNdjsonReader::new(input.as_bytes()); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + + // Should still extract the valid array content + assert_eq!(output, "{\"a\":1}\n{\"a\":2}"); + + // But validation should catch the leading junk + let result = reader.validate_complete(); + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("leading content"), + "Expected leading content error, got: {err_msg}" + ); + } + + #[test] + fn test_json_array_with_leading_whitespace_ok() { + let input = r#" + [{"a":1}, {"a":2}]"#; + let mut reader = JsonArrayToNdjsonReader::new(input.as_bytes()); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + assert_eq!(output, "{\"a\":1}\n{\"a\":2}"); + + // Leading whitespace should be fine + reader.validate_complete().unwrap(); + } + + #[test] + fn test_validate_complete_valid_with_trailing_whitespace() { + let input = r#"[{"a":1},{"a":2}] + "#; // Trailing whitespace is OK + let mut reader = JsonArrayToNdjsonReader::new(input.as_bytes()); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + + // Whitespace after ] should be allowed + reader.validate_complete().unwrap(); + } + + /// Test that data is not lost at buffer boundaries. + /// + /// This test creates input larger than the internal buffer to verify + /// that newline characters are not dropped when they occur at buffer boundaries. + #[test] + fn test_buffer_boundary_no_data_loss() { + // Create objects ~9KB each, so 10 objects = ~90KB + let large_value = "x".repeat(9000); + + let mut objects = vec![]; + for i in 0..10 { + objects.push(format!(r#"{{"id":{i},"data":"{large_value}"}}"#)); + } + + let input = format!("[{}]", objects.join(",")); + + // Use small buffer to force multiple fill cycles + let mut reader = JsonArrayToNdjsonReader::with_capacity(input.as_bytes(), 8192); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + + // Verify correct number of newlines (9 newlines separate 10 objects) + let newline_count = output.matches('\n').count(); + assert_eq!( + newline_count, 9, + "Expected 9 newlines separating 10 objects, got {newline_count}" + ); + + // Verify each line is valid JSON + for (i, line) in output.lines().enumerate() { + let parsed: Result = serde_json::from_str(line); + assert!( + parsed.is_ok(), + "Line {} is not valid JSON: {}...", + i, + &line[..100.min(line.len())] + ); + + // Verify the id field matches expected value + let value = parsed.unwrap(); + assert_eq!( + value["id"].as_i64(), + Some(i as i64), + "Object {i} has wrong id" + ); + } + } + + /// Test with real-world-like data format (with leading whitespace and newlines) + #[test] + fn test_real_world_format_large() { + let large_value = "x".repeat(8000); + + // Format similar to real files: opening bracket on its own line, + // each object indented with 2 spaces + let mut objects = vec![]; + for i in 0..10 { + objects.push(format!(r#" {{"id":{i},"data":"{large_value}"}}"#)); + } + + let input = format!("[\n{}\n]", objects.join(",\n")); + + let mut reader = JsonArrayToNdjsonReader::with_capacity(input.as_bytes(), 8192); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + + let lines: Vec<&str> = output.lines().collect(); + assert_eq!(lines.len(), 10, "Expected 10 objects"); + + for (i, line) in lines.iter().enumerate() { + assert!( + line.starts_with("{\"id\""), + "Line {} should start with object, got: {}...", + i, + &line[..50.min(line.len())] + ); + } + } + + /// Test ChannelReader + #[test] + fn test_channel_reader() { + let (tx, rx) = tokio::sync::mpsc::channel(4); + + // Send some chunks (try_send is non-async) + tx.try_send(Bytes::from("Hello, ")).unwrap(); + tx.try_send(Bytes::from("World!")).unwrap(); + drop(tx); // Close channel + + let mut reader = ChannelReader::new(rx); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + + assert_eq!(output, "Hello, World!"); + } + + /// Test ChannelReader with small reads + #[test] + fn test_channel_reader_small_reads() { + let (tx, rx) = tokio::sync::mpsc::channel(4); + + tx.try_send(Bytes::from("ABCDEFGHIJ")).unwrap(); + drop(tx); + + let mut reader = ChannelReader::new(rx); + let mut buf = [0u8; 3]; + + // Read in small chunks + assert_eq!(reader.read(&mut buf).unwrap(), 3); + assert_eq!(&buf, b"ABC"); + + assert_eq!(reader.read(&mut buf).unwrap(), 3); + assert_eq!(&buf, b"DEF"); + + assert_eq!(reader.read(&mut buf).unwrap(), 3); + assert_eq!(&buf, b"GHI"); + + assert_eq!(reader.read(&mut buf).unwrap(), 1); + assert_eq!(&buf[..1], b"J"); + + // EOF + assert_eq!(reader.read(&mut buf).unwrap(), 0); + } +} diff --git a/datafusion/datasource-parquet/src/access_plan.rs b/datafusion/datasource-parquet/src/access_plan.rs index 570792d40e5b4..44911fcf2a9ca 100644 --- a/datafusion/datasource-parquet/src/access_plan.rs +++ b/datafusion/datasource-parquet/src/access_plan.rs @@ -82,6 +82,10 @@ use parquet::file::metadata::RowGroupMetaData; /// └───────────────────┘ /// Row Group 3 /// ``` +/// +/// For more background, please also see the [Embedding User-Defined Indexes in Apache Parquet Files blog] +/// +/// [Embedding User-Defined Indexes in Apache Parquet Files blog]: https://datafusion.apache.org/blog/2025/07/14/user-defined-parquet-indexes #[derive(Debug, Clone, PartialEq)] pub struct ParquetAccessPlan { /// How to access the i-th row group diff --git a/datafusion/datasource-parquet/src/file_format.rs b/datafusion/datasource-parquet/src/file_format.rs index 6635c9072dd97..d59b42ed15d15 100644 --- a/datafusion/datasource-parquet/src/file_format.rs +++ b/datafusion/datasource-parquet/src/file_format.rs @@ -1360,7 +1360,7 @@ impl FileSink for ParquetSink { parquet_props.clone(), ) .await?; - let mut reservation = MemoryConsumer::new(format!("ParquetSink[{path}]")) + let reservation = MemoryConsumer::new(format!("ParquetSink[{path}]")) .register(context.memory_pool()); file_write_tasks.spawn(async move { while let Some(batch) = rx.recv().await { @@ -1465,7 +1465,7 @@ impl DataSink for ParquetSink { async fn column_serializer_task( mut rx: Receiver, mut writer: ArrowColumnWriter, - mut reservation: MemoryReservation, + reservation: MemoryReservation, ) -> Result<(ArrowColumnWriter, MemoryReservation)> { while let Some(col) = rx.recv().await { writer.write(&col)?; @@ -1550,7 +1550,7 @@ fn spawn_rg_join_and_finalize_task( rg_rows: usize, pool: &Arc, ) -> SpawnedTask { - let mut rg_reservation = + let rg_reservation = MemoryConsumer::new("ParquetSink(SerializedRowGroupWriter)").register(pool); SpawnedTask::spawn(async move { @@ -1682,12 +1682,12 @@ async fn concatenate_parallel_row_groups( mut object_store_writer: Box, pool: Arc, ) -> Result { - let mut file_reservation = + let file_reservation = MemoryConsumer::new("ParquetSink(SerializedFileWriter)").register(&pool); while let Some(task) = serialize_rx.recv().await { let result = task.join_unwind().await; - let (serialized_columns, mut rg_reservation, _cnt) = + let (serialized_columns, rg_reservation, _cnt) = result.map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??; let mut rg_out = parquet_writer.next_row_group()?; diff --git a/datafusion/datasource-parquet/src/metadata.rs b/datafusion/datasource-parquet/src/metadata.rs index b763f817a0268..5a4c0bcdd514d 100644 --- a/datafusion/datasource-parquet/src/metadata.rs +++ b/datafusion/datasource-parquet/src/metadata.rs @@ -21,7 +21,7 @@ use crate::{ ObjectStoreFetch, apply_file_schema_type_coercions, coerce_int96_to_resolution, }; -use arrow::array::{ArrayRef, BooleanArray}; +use arrow::array::{Array, ArrayRef, BooleanArray}; use arrow::compute::and; use arrow::compute::kernels::cmp::eq; use arrow::compute::sum; @@ -487,22 +487,40 @@ fn summarize_min_max_null_counts( if let Some(max_acc) = &mut accumulators.max_accs[logical_schema_index] { max_acc.update_batch(&[Arc::clone(&max_values)])?; - let mut cur_max_acc = max_acc.clone(); - accumulators.is_max_value_exact[logical_schema_index] = has_any_exact_match( - &cur_max_acc.evaluate()?, - &max_values, - &is_max_value_exact_stat, - ); + + // handle the common special case when all row groups have exact statistics + let exactness = &is_max_value_exact_stat; + if !exactness.is_empty() + && exactness.null_count() == 0 + && exactness.true_count() == exactness.len() + { + accumulators.is_max_value_exact[logical_schema_index] = Some(true); + } else if exactness.true_count() == 0 { + accumulators.is_max_value_exact[logical_schema_index] = Some(false); + } else { + let val = max_acc.evaluate()?; + accumulators.is_max_value_exact[logical_schema_index] = + has_any_exact_match(&val, &max_values, exactness); + } } if let Some(min_acc) = &mut accumulators.min_accs[logical_schema_index] { min_acc.update_batch(&[Arc::clone(&min_values)])?; - let mut cur_min_acc = min_acc.clone(); - accumulators.is_min_value_exact[logical_schema_index] = has_any_exact_match( - &cur_min_acc.evaluate()?, - &min_values, - &is_min_value_exact_stat, - ); + + // handle the common special case when all row groups have exact statistics + let exactness = &is_min_value_exact_stat; + if !exactness.is_empty() + && exactness.null_count() == 0 + && exactness.true_count() == exactness.len() + { + accumulators.is_min_value_exact[logical_schema_index] = Some(true); + } else if exactness.true_count() == 0 { + accumulators.is_min_value_exact[logical_schema_index] = Some(false); + } else { + let val = min_acc.evaluate()?; + accumulators.is_min_value_exact[logical_schema_index] = + has_any_exact_match(&val, &min_values, exactness); + } } accumulators.null_counts_array[logical_schema_index] = match sum(&null_counts) { @@ -582,6 +600,15 @@ fn has_any_exact_match( array: &ArrayRef, exactness: &BooleanArray, ) -> Option { + if value.is_null() { + return Some(false); + } + + // Shortcut for single row group + if array.len() == 1 { + return Some(exactness.is_valid(0) && exactness.value(0)); + } + let scalar_array = value.to_scalar().ok()?; let eq_mask = eq(&scalar_array, &array).ok()?; let combined_mask = and(&eq_mask, exactness).ok()?; diff --git a/datafusion/datasource-parquet/src/metrics.rs b/datafusion/datasource-parquet/src/metrics.rs index 8ce3a081a2e32..2d6fb69270bf3 100644 --- a/datafusion/datasource-parquet/src/metrics.rs +++ b/datafusion/datasource-parquet/src/metrics.rs @@ -16,7 +16,7 @@ // under the License. use datafusion_physical_plan::metrics::{ - Count, ExecutionPlanMetricsSet, MetricBuilder, MetricType, PruningMetrics, + Count, ExecutionPlanMetricsSet, Gauge, MetricBuilder, MetricType, PruningMetrics, RatioMergeStrategy, RatioMetrics, Time, }; @@ -45,9 +45,11 @@ pub struct ParquetFileMetrics { pub files_ranges_pruned_statistics: PruningMetrics, /// Number of times the predicate could not be evaluated pub predicate_evaluation_errors: Count, - /// Number of row groups whose bloom filters were checked, tracked with matched/pruned counts + /// Number of row groups pruned by bloom filters pub row_groups_pruned_bloom_filter: PruningMetrics, - /// Number of row groups whose statistics were checked, tracked with matched/pruned counts + /// Number of row groups pruned due to limit pruning. + pub limit_pruned_row_groups: PruningMetrics, + /// Number of row groups pruned by statistics pub row_groups_pruned_statistics: PruningMetrics, /// Total number of bytes scanned pub bytes_scanned: Count, @@ -63,6 +65,8 @@ pub struct ParquetFileMetrics { pub bloom_filter_eval_time: Time, /// Total rows filtered or matched by parquet page index pub page_index_rows_pruned: PruningMetrics, + /// Total pages filtered or matched by parquet page index + pub page_index_pages_pruned: PruningMetrics, /// Total time spent evaluating parquet page index filters pub page_index_eval_time: Time, /// Total time spent reading and parsing metadata from the footer @@ -77,11 +81,16 @@ pub struct ParquetFileMetrics { /// Parquet. /// /// This is the expensive path (IO + Decompression + Decoding). - pub predicate_cache_inner_records: Count, + /// + /// We use a Gauge here as arrow-rs reports absolute numbers rather + /// than incremental readings, we want a `set` operation here rather + /// than `add`. Earlier it was `Count`, which led to this issue: + /// github.com/apache/datafusion/issues/19334 + pub predicate_cache_inner_records: Gauge, /// Predicate Cache: number of records read from the cache. This is the /// number of rows that were stored in the cache after evaluating predicates /// reused for the output. - pub predicate_cache_records: Count, + pub predicate_cache_records: Gauge, } impl ParquetFileMetrics { @@ -99,15 +108,20 @@ impl ParquetFileMetrics { .with_type(MetricType::SUMMARY) .pruning_metrics("row_groups_pruned_bloom_filter", partition); + let limit_pruned_row_groups = MetricBuilder::new(metrics) + .with_new_label("filename", filename.to_string()) + .with_type(MetricType::SUMMARY) + .pruning_metrics("limit_pruned_row_groups", partition); + let row_groups_pruned_statistics = MetricBuilder::new(metrics) .with_new_label("filename", filename.to_string()) .with_type(MetricType::SUMMARY) .pruning_metrics("row_groups_pruned_statistics", partition); - let page_index_rows_pruned = MetricBuilder::new(metrics) + let page_index_pages_pruned = MetricBuilder::new(metrics) .with_new_label("filename", filename.to_string()) .with_type(MetricType::SUMMARY) - .pruning_metrics("page_index_rows_pruned", partition); + .pruning_metrics("page_index_pages_pruned", partition); let bytes_scanned = MetricBuilder::new(metrics) .with_new_label("filename", filename.to_string()) @@ -160,24 +174,30 @@ impl ParquetFileMetrics { .with_new_label("filename", filename.to_string()) .subset_time("page_index_eval_time", partition); + let page_index_rows_pruned = MetricBuilder::new(metrics) + .with_new_label("filename", filename.to_string()) + .pruning_metrics("page_index_rows_pruned", partition); + let predicate_cache_inner_records = MetricBuilder::new(metrics) .with_new_label("filename", filename.to_string()) - .counter("predicate_cache_inner_records", partition); + .gauge("predicate_cache_inner_records", partition); let predicate_cache_records = MetricBuilder::new(metrics) .with_new_label("filename", filename.to_string()) - .counter("predicate_cache_records", partition); + .gauge("predicate_cache_records", partition); Self { files_ranges_pruned_statistics, predicate_evaluation_errors, row_groups_pruned_bloom_filter, row_groups_pruned_statistics, + limit_pruned_row_groups, bytes_scanned, pushdown_rows_pruned, pushdown_rows_matched, row_pushdown_eval_time, page_index_rows_pruned, + page_index_pages_pruned, statistics_eval_time, bloom_filter_eval_time, page_index_eval_time, diff --git a/datafusion/datasource-parquet/src/mod.rs b/datafusion/datasource-parquet/src/mod.rs index d7e92f70afa99..0e137a706fad7 100644 --- a/datafusion/datasource-parquet/src/mod.rs +++ b/datafusion/datasource-parquet/src/mod.rs @@ -19,7 +19,6 @@ // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -#![deny(clippy::allow_attributes)] pub mod access_plan; pub mod file_format; diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index 570f9b4412840..f87a30265a17b 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -47,7 +47,7 @@ use datafusion_physical_expr_common::physical_expr::{ PhysicalExpr, is_dynamic_physical_expr, }; use datafusion_physical_plan::metrics::{ - Count, ExecutionPlanMetricsSet, MetricBuilder, PruningMetrics, + Count, ExecutionPlanMetricsSet, Gauge, MetricBuilder, PruningMetrics, }; use datafusion_pruning::{FilePruner, PruningPredicate, build_pruning_predicate}; @@ -69,13 +69,15 @@ use parquet::file::metadata::{PageIndexPolicy, ParquetMetaDataReader, RowGroupMe /// Implements [`FileOpener`] for a parquet file pub(super) struct ParquetOpener { /// Execution partition index - pub partition_index: usize, + pub(crate) partition_index: usize, /// Projection to apply on top of the table schema (i.e. can reference partition columns). pub projection: ProjectionExprs, /// Target number of rows in each output RecordBatch pub batch_size: usize, /// Optional limit on the number of rows to read - pub limit: Option, + pub(crate) limit: Option, + /// If should keep the output rows in order + pub preserve_order: bool, /// Optional predicate to apply during the scan pub predicate: Option>, /// Table schema, including partition columns. @@ -277,6 +279,8 @@ impl FileOpener for ParquetOpener { let max_predicate_cache_size = self.max_predicate_cache_size; let reverse_row_groups = self.reverse_row_groups; + let preserve_order = self.preserve_order; + Ok(Box::pin(async move { #[cfg(feature = "parquet_encryption")] let file_decryption_properties = encryption_context @@ -342,7 +346,7 @@ impl FileOpener for ParquetOpener { // Don't load the page index yet. Since it is not stored inline in // the footer, loading the page index if it is not needed will do // unnecessary I/O. We decide later if it is needed to evaluate the - // pruning predicates. Thus default to not requesting if from the + // pruning predicates. Thus default to not requesting it from the // underlying reader. let mut options = ArrowReaderOptions::new().with_page_index(false); #[cfg(feature = "parquet_encryption")] @@ -408,7 +412,7 @@ impl FileOpener for ParquetOpener { let rewriter = expr_adapter_factory.create( Arc::clone(&logical_file_schema), Arc::clone(&physical_file_schema), - ); + )?; let simplifier = PhysicalExprSimplifier::new(&physical_file_schema); predicate = predicate .map(|p| simplifier.simplify(rewriter.rewrite(p)?)) @@ -432,7 +436,7 @@ impl FileOpener for ParquetOpener { reader_metadata, &mut async_file_reader, // Since we're manually loading the page index the option here should not matter but we pass it in for consistency - options.with_page_index(true), + options.with_page_index_policy(PageIndexPolicy::Optional), ) .await?; } @@ -545,11 +549,15 @@ impl FileOpener for ParquetOpener { .add_matched(n_remaining_row_groups); } - let mut access_plan = row_groups.build(); + // Prune by limit if limit is set and limit order is not sensitive + if let (Some(limit), false) = (limit, preserve_order) { + row_groups.prune_by_limit(limit, rg_metadata, &file_metrics); + } // -------------------------------------------------------- // Step: prune pages from the kept row groups // + let mut access_plan = row_groups.build(); // page index pruning: if all data on individual pages can // be ruled using page metadata, rows from other columns // with that range can be skipped as well @@ -573,7 +581,7 @@ impl FileOpener for ParquetOpener { // ---------------------------------------------------------- // Step: potentially reverse the access plan for performance. - // See `ParquetSource::try_reverse_output` for the rationale. + // See `ParquetSource::try_pushdown_sort` for the rationale. // ---------------------------------------------------------- if reverse_row_groups { prepared_plan = prepared_plan.reverse(file_metadata.as_ref())?; @@ -674,15 +682,15 @@ impl FileOpener for ParquetOpener { /// arrow-rs parquet reader) to the parquet file metrics for DataFusion fn copy_arrow_reader_metrics( arrow_reader_metrics: &ArrowReaderMetrics, - predicate_cache_inner_records: &Count, - predicate_cache_records: &Count, + predicate_cache_inner_records: &Gauge, + predicate_cache_records: &Gauge, ) { if let Some(v) = arrow_reader_metrics.records_read_from_inner() { - predicate_cache_inner_records.add(v); + predicate_cache_inner_records.set(v); } if let Some(v) = arrow_reader_metrics.records_read_from_cache() { - predicate_cache_records.add(v); + predicate_cache_records.set(v); } } @@ -731,6 +739,10 @@ fn constant_value_from_stats( && !min.is_null() && matches!(column_stats.null_count, Precision::Exact(0)) { + // Cast to the expected data type if needed (e.g., Utf8 -> Dictionary) + if min.data_type() != *data_type { + return min.cast_to(data_type).ok(); + } return Some(min.clone()); } @@ -1051,6 +1063,7 @@ mod test { coerce_int96: Option, max_predicate_cache_size: Option, reverse_row_groups: bool, + preserve_order: bool, } impl ParquetOpenerBuilder { @@ -1076,6 +1089,7 @@ mod test { coerce_int96: None, max_predicate_cache_size: None, reverse_row_groups: false, + preserve_order: false, } } @@ -1183,6 +1197,7 @@ mod test { encryption_factory: None, max_predicate_cache_size: self.max_predicate_cache_size, reverse_row_groups: self.reverse_row_groups, + preserve_order: self.preserve_order, } } } diff --git a/datafusion/datasource-parquet/src/page_filter.rs b/datafusion/datasource-parquet/src/page_filter.rs index e25e33835f790..194e6e94fba3a 100644 --- a/datafusion/datasource-parquet/src/page_filter.rs +++ b/datafusion/datasource-parquet/src/page_filter.rs @@ -189,6 +189,10 @@ impl PagePruningAccessPlanFilter { let mut total_skip = 0; // track the total number of rows that should not be skipped let mut total_select = 0; + // track the total number of pages that should be skipped + let mut total_pages_skip = 0; + // track the total number of pages that should not be skipped + let mut total_pages_select = 0; // for each row group specified in the access plan let row_group_indexes = access_plan.row_group_indexes(); @@ -226,10 +230,12 @@ impl PagePruningAccessPlanFilter { file_metrics, ); - let Some(selection) = selection else { + let Some((selection, total_pages, matched_pages)) = selection else { trace!("No pages pruned in prune_pages_in_one_row_group"); continue; }; + total_pages_select += matched_pages; + total_pages_skip += total_pages - matched_pages; debug!( "Use filter and page index to create RowSelection {:?} from predicate: {:?}", @@ -278,6 +284,12 @@ impl PagePruningAccessPlanFilter { file_metrics .page_index_rows_pruned .add_matched(total_select); + file_metrics + .page_index_pages_pruned + .add_pruned(total_pages_skip); + file_metrics + .page_index_pages_pruned + .add_matched(total_pages_select); access_plan } @@ -297,7 +309,8 @@ fn update_selection( } } -/// Returns a [`RowSelection`] for the rows in this row group to scan. +/// Returns a [`RowSelection`] for the rows in this row group to scan, in addition to the number of +/// total and matched pages. /// /// This Row Selection is formed from the page index and the predicate skips row /// ranges that can be ruled out based on the predicate. @@ -310,7 +323,7 @@ fn prune_pages_in_one_row_group( converter: StatisticsConverter<'_>, parquet_metadata: &ParquetMetaData, metrics: &ParquetFileMetrics, -) -> Option { +) -> Option<(RowSelection, usize, usize)> { let pruning_stats = PagesPruningStatistics::try_new(row_group_index, converter, parquet_metadata)?; @@ -362,7 +375,11 @@ fn prune_pages_in_one_row_group( RowSelector::skip(sum_row) }; vec.push(selector); - Some(RowSelection::from(vec)) + + let total_pages = values.len(); + let matched_pages = values.iter().filter(|v| **v).count(); + + Some((RowSelection::from(vec), total_pages, matched_pages)) } /// Implement [`PruningStatistics`] for one column's PageIndex (column_index + offset_index) diff --git a/datafusion/datasource-parquet/src/row_filter.rs b/datafusion/datasource-parquet/src/row_filter.rs index 04c11b8875541..2924208c5bd99 100644 --- a/datafusion/datasource-parquet/src/row_filter.rs +++ b/datafusion/datasource-parquet/src/row_filter.rs @@ -276,7 +276,7 @@ struct PushdownChecker<'schema> { /// Does the expression reference any columns not present in the file schema? projected_columns: bool, /// Indices into the file schema of columns required to evaluate the expression. - required_columns: BTreeSet, + required_columns: Vec, /// Tracks the nested column behavior found during traversal. nested_behavior: NestedColumnSupport, /// Whether nested list columns are supported by the predicate semantics. @@ -290,7 +290,7 @@ impl<'schema> PushdownChecker<'schema> { Self { non_primitive_columns: false, projected_columns: false, - required_columns: BTreeSet::default(), + required_columns: Vec::new(), nested_behavior: NestedColumnSupport::PrimitiveOnly, allow_list_columns, file_schema, @@ -307,7 +307,8 @@ impl<'schema> PushdownChecker<'schema> { } }; - self.required_columns.insert(idx); + // Duplicates are handled by dedup() in into_sorted_columns() + self.required_columns.push(idx); let data_type = self.file_schema.field(idx).data_type(); if DataType::is_nested(data_type) { @@ -355,6 +356,21 @@ impl<'schema> PushdownChecker<'schema> { fn prevents_pushdown(&self) -> bool { self.non_primitive_columns || self.projected_columns } + + /// Consumes the checker and returns sorted, deduplicated column indices + /// wrapped in a `PushdownColumns` struct. + /// + /// This method sorts the column indices and removes duplicates. The sort + /// is required because downstream code relies on column indices being in + /// ascending order for correct schema projection. + fn into_sorted_columns(mut self) -> PushdownColumns { + self.required_columns.sort_unstable(); + self.required_columns.dedup(); + PushdownColumns { + required_columns: self.required_columns, + nested: self.nested_behavior, + } + } } impl TreeNodeVisitor<'_> for PushdownChecker<'_> { @@ -390,9 +406,13 @@ enum NestedColumnSupport { Unsupported, } +/// Result of checking which columns are required for filter pushdown. #[derive(Debug)] struct PushdownColumns { - required_columns: BTreeSet, + /// Sorted, unique column indices into the file schema required to evaluate + /// the filter expression. Must be in ascending order for correct schema + /// projection matching. + required_columns: Vec, nested: NestedColumnSupport, } @@ -411,10 +431,7 @@ fn pushdown_columns( let allow_list_columns = supports_list_predicates(expr); let mut checker = PushdownChecker::new(file_schema, allow_list_columns); expr.visit(&mut checker)?; - Ok((!checker.prevents_pushdown()).then_some(PushdownColumns { - required_columns: checker.required_columns, - nested: checker.nested_behavior, - })) + Ok((!checker.prevents_pushdown()).then(|| checker.into_sorted_columns())) } fn leaf_indices_for_roots( @@ -722,6 +739,7 @@ mod test { let expr = logical2physical(&expr, &table_schema); let expr = DefaultPhysicalExprAdapterFactory {} .create(Arc::new(table_schema.clone()), Arc::clone(&file_schema)) + .expect("creating expr adapter") .rewrite(expr) .expect("rewriting expression"); let candidate = FilterCandidateBuilder::new(expr, file_schema.clone()) @@ -761,6 +779,7 @@ mod test { // Rewrite the expression to add CastExpr for type coercion let expr = DefaultPhysicalExprAdapterFactory {} .create(Arc::new(table_schema), Arc::clone(&file_schema)) + .expect("creating expr adapter") .rewrite(expr) .expect("rewriting expression"); let candidate = FilterCandidateBuilder::new(expr, file_schema) diff --git a/datafusion/datasource-parquet/src/row_group_filter.rs b/datafusion/datasource-parquet/src/row_group_filter.rs index 046379cc25e23..7eea8285ad6b5 100644 --- a/datafusion/datasource-parquet/src/row_group_filter.rs +++ b/datafusion/datasource-parquet/src/row_group_filter.rs @@ -24,6 +24,8 @@ use arrow::datatypes::Schema; use datafusion_common::pruning::PruningStatistics; use datafusion_common::{Column, Result, ScalarValue}; use datafusion_datasource::FileRange; +use datafusion_physical_expr::PhysicalExprSimplifier; +use datafusion_physical_expr::expressions::NotExpr; use datafusion_pruning::PruningPredicate; use parquet::arrow::arrow_reader::statistics::StatisticsConverter; use parquet::arrow::parquet_column; @@ -46,13 +48,20 @@ use parquet::{ pub struct RowGroupAccessPlanFilter { /// which row groups should be accessed access_plan: ParquetAccessPlan, + /// Row groups where ALL rows are known to match the pruning predicate + /// (the predicate does not filter any rows) + is_fully_matched: Vec, } impl RowGroupAccessPlanFilter { /// Create a new `RowGroupPlanBuilder` for pruning out the groups to scan /// based on metadata and statistics pub fn new(access_plan: ParquetAccessPlan) -> Self { - Self { access_plan } + let num_row_groups = access_plan.len(); + Self { + access_plan, + is_fully_matched: vec![false; num_row_groups], + } } /// Return true if there are no row groups @@ -70,6 +79,139 @@ impl RowGroupAccessPlanFilter { self.access_plan } + /// Returns the is_fully_matched vector + pub fn is_fully_matched(&self) -> &Vec { + &self.is_fully_matched + } + + /// Prunes the access plan based on the limit and fully contained row groups. + /// + /// The pruning works by leveraging the concept of fully matched row groups. Consider a query like: + /// `WHERE species LIKE 'Alpine%' AND s >= 50 LIMIT N` + /// + /// After initial filtering, row groups can be classified into three states: + /// + /// 1. Not Matching / Pruned + /// 2. Partially Matching (Row Group/Page contains some matches) + /// 3. Fully Matching (Entire range is within predicate) + /// + /// +-----------------------------------------------------------------------+ + /// | NOT MATCHING | + /// | Row group 1 | + /// | +-----------------------------------+-----------------------------+ | + /// | | SPECIES | S | | + /// | +-----------------------------------+-----------------------------+ | + /// | | Snow Vole | 7 | | + /// | | Brown Bear | 133 ✅ | | + /// | | Gray Wolf | 82 ✅ | | + /// | +-----------------------------------+-----------------------------+ | + /// +-----------------------------------------------------------------------+ + /// + /// +---------------------------------------------------------------------------+ + /// | PARTIALLY MATCHING | + /// | | + /// | Row group 2 Row group 4 | + /// | +------------------+--------------+ +------------------+----------+ | + /// | | SPECIES | S | | SPECIES | S | | + /// | +------------------+--------------+ +------------------+----------+ | + /// | | Lynx | 71 ✅ | | Europ. Mole | 4 | | + /// | | Red Fox | 40 | | Polecat | 16 | | + /// | | Alpine Bat ✅ | 6 | | Alpine Ibex ✅ | 97 ✅ | | + /// | +------------------+--------------+ +------------------+----------+ | + /// +---------------------------------------------------------------------------+ + /// + /// +-----------------------------------------------------------------------+ + /// | FULLY MATCHING | + /// | Row group 3 | + /// | +-----------------------------------+-----------------------------+ | + /// | | SPECIES | S | | + /// | +-----------------------------------+-----------------------------+ | + /// | | Alpine Ibex ✅ | 101 ✅ | | + /// | | Alpine Goat ✅ | 76 ✅ | | + /// | | Alpine Sheep ✅ | 83 ✅ | | + /// | +-----------------------------------+-----------------------------+ | + /// +-----------------------------------------------------------------------+ + /// + /// ### Identification of Fully Matching Row Groups + /// + /// DataFusion identifies row groups where ALL rows satisfy the filter by inverting the + /// predicate and checking if statistics prove the inverted version is false for the group. + /// + /// For example, prefix matches like `species LIKE 'Alpine%'` are pruned using ranges: + /// 1. Candidate Range: `species >= 'Alpine' AND species < 'Alpinf'` + /// 2. Inverted Condition (to prove full match): `species < 'Alpine' OR species >= 'Alpinf'` + /// 3. Statistical Evaluation (check if any row *could* satisfy the inverted condition): + /// `min < 'Alpine' OR max >= 'Alpinf'` + /// + /// If this evaluation is **false**, it proves no row can fail the original filter, + /// so the row group is **FULLY MATCHING**. + /// + /// ### Impact of Statistics Truncation + /// + /// The precision of pruning depends on the metadata quality. Truncated statistics + /// may prevent the system from proving a full match. + /// + /// **Example**: `WHERE species LIKE 'Alpine%'` (Target range: `['Alpine', 'Alpinf')`) + /// + /// | Truncation Length | min / max | Inverted Evaluation | Status | + /// |-------------------|---------------------|---------------------------------------------------------------------|------------------------| + /// | **Length 6** | `Alpine` / `Alpine` | `"Alpine" < "Alpine" (F) OR "Alpine" >= "Alpinf" (F)` -> **false** | **FULLY MATCHING** | + /// | **Length 3** | `Alp` / `Alq` | `"Alp" < "Alpine" (T) OR "Alq" >= "Alpinf" (T)` -> **true** | **PARTIALLY MATCHING** | + /// + /// Even though Row Group 3 only contains matching rows, truncation to length 3 makes + /// the statistics `[Alp, Alq]` too broad to prove it (they could include "Alpha"). + /// The system must conservatively scan the group. + /// + /// Without limit pruning: Scan Partition 2 → Partition 3 → Partition 4 (until limit reached) + /// With limit pruning: If Partition 3 contains enough rows to satisfy the limit, + /// skip Partitions 2 and 4 entirely and go directly to Partition 3. + /// + /// This optimization is particularly effective when: + /// - The limit is small relative to the total dataset size + /// - There are row groups that are fully matched by the filter predicates + /// - The fully matched row groups contain sufficient rows to satisfy the limit + /// + /// For more information, see the [paper](https://arxiv.org/pdf/2504.11540)'s "Pruning for LIMIT Queries" part + pub fn prune_by_limit( + &mut self, + limit: usize, + rg_metadata: &[RowGroupMetaData], + metrics: &ParquetFileMetrics, + ) { + let mut fully_matched_row_group_indexes: Vec = Vec::new(); + let mut fully_matched_rows_count: usize = 0; + + // Iterate through the currently accessible row groups and try to + // find a set of matching row groups that can satisfy the limit + for &idx in self.access_plan.row_group_indexes().iter() { + if self.is_fully_matched[idx] { + let row_group_row_count = rg_metadata[idx].num_rows() as usize; + fully_matched_row_group_indexes.push(idx); + fully_matched_rows_count += row_group_row_count; + if fully_matched_rows_count >= limit { + break; + } + } + } + + // If we can satisfy the limit with fully matching row groups, + // rewrite the plan to do so + if fully_matched_rows_count >= limit { + let original_num_accessible_row_groups = + self.access_plan.row_group_indexes().len(); + let new_num_accessible_row_groups = fully_matched_row_group_indexes.len(); + let pruned_count = original_num_accessible_row_groups + .saturating_sub(new_num_accessible_row_groups); + metrics.limit_pruned_row_groups.add_pruned(pruned_count); + + let mut new_access_plan = ParquetAccessPlan::new_none(rg_metadata.len()); + for &idx in &fully_matched_row_group_indexes { + new_access_plan.scan(idx); + } + self.access_plan = new_access_plan; + } + } + /// Prune remaining row groups to only those within the specified range. /// /// Updates this set to mark row groups that should not be scanned @@ -135,15 +277,26 @@ impl RowGroupAccessPlanFilter { // try to prune the row groups in a single call match predicate.prune(&pruning_stats) { Ok(values) => { - // values[i] is false means the predicate could not be true for row group i + let mut fully_contained_candidates_original_idx: Vec = Vec::new(); for (idx, &value) in row_group_indexes.iter().zip(values.iter()) { if !value { self.access_plan.skip(*idx); metrics.row_groups_pruned_statistics.add_pruned(1); } else { metrics.row_groups_pruned_statistics.add_matched(1); + fully_contained_candidates_original_idx.push(*idx); } } + + // Check if any of the matched row groups are fully contained by the predicate + self.identify_fully_matched_row_groups( + &fully_contained_candidates_original_idx, + arrow_schema, + parquet_schema, + groups, + predicate, + metrics, + ); } // stats filter array could not be built, so we can't prune Err(e) => { @@ -153,6 +306,68 @@ impl RowGroupAccessPlanFilter { } } + /// Identifies row groups that are fully matched by the predicate. + /// + /// This optimization checks whether all rows in a row group satisfy the predicate + /// by inverting the predicate and checking if it prunes the row group. If the + /// inverted predicate prunes a row group, it means no rows match the inverted + /// predicate, which implies all rows match the original predicate. + /// + /// Note: This optimization is relatively inexpensive for a limited number of row groups. + fn identify_fully_matched_row_groups( + &mut self, + candidate_row_group_indices: &[usize], + arrow_schema: &Schema, + parquet_schema: &SchemaDescriptor, + groups: &[RowGroupMetaData], + predicate: &PruningPredicate, + metrics: &ParquetFileMetrics, + ) { + if candidate_row_group_indices.is_empty() { + return; + } + + // Use NotExpr to create the inverted predicate + let inverted_expr = Arc::new(NotExpr::new(Arc::clone(predicate.orig_expr()))); + + // Simplify the NOT expression (e.g., NOT(c1 = 0) -> c1 != 0) + // before building the pruning predicate + let simplifier = PhysicalExprSimplifier::new(arrow_schema); + let Ok(inverted_expr) = simplifier.simplify(inverted_expr) else { + return; + }; + + let Ok(inverted_predicate) = + PruningPredicate::try_new(inverted_expr, Arc::clone(predicate.schema())) + else { + return; + }; + + let inverted_pruning_stats = RowGroupPruningStatistics { + parquet_schema, + row_group_metadatas: candidate_row_group_indices + .iter() + .map(|&i| &groups[i]) + .collect::>(), + arrow_schema, + }; + + let Ok(inverted_values) = inverted_predicate.prune(&inverted_pruning_stats) + else { + return; + }; + + for (i, &original_row_group_idx) in candidate_row_group_indices.iter().enumerate() + { + // If the inverted predicate *also* prunes this row group (meaning inverted_values[i] is false), + // it implies that *all* rows in this group satisfy the original predicate. + if !inverted_values[i] { + self.is_fully_matched[original_row_group_idx] = true; + metrics.row_groups_pruned_statistics.add_fully_matched(1); + } + } + } + /// Prune remaining row groups using available bloom filters and the /// [`PruningPredicate`]. /// diff --git a/datafusion/datasource-parquet/src/source.rs b/datafusion/datasource-parquet/src/source.rs index 2e0919b1447de..75d87a4cd16fc 100644 --- a/datafusion/datasource-parquet/src/source.rs +++ b/datafusion/datasource-parquet/src/source.rs @@ -548,6 +548,7 @@ impl FileSource for ParquetSource { .batch_size .expect("Batch size must set before creating ParquetOpener"), limit: base_config.limit, + preserve_order: base_config.preserve_order, predicate: self.predicate.clone(), table_schema: self.table_schema.clone(), metadata_size_hint: self.metadata_size_hint, @@ -756,7 +757,7 @@ impl FileSource for ParquetSource { /// # Returns /// - `Inexact`: Created an optimized source (e.g., reversed scan) that approximates the order /// - `Unsupported`: Cannot optimize for this ordering - fn try_reverse_output( + fn try_pushdown_sort( &self, order: &[PhysicalSortExpr], eq_properties: &EquivalenceProperties, diff --git a/datafusion/datasource/Cargo.toml b/datafusion/datasource/Cargo.toml index df8b70293df00..81a96777e2d0c 100644 --- a/datafusion/datasource/Cargo.toml +++ b/datafusion/datasource/Cargo.toml @@ -36,7 +36,7 @@ default = ["compression"] [dependencies] arrow = { workspace = true } -async-compression = { version = "0.4.37", features = [ +async-compression = { version = "0.4.39", features = [ "bzip2", "gzip", "xz", diff --git a/datafusion/datasource/src/file.rs b/datafusion/datasource/src/file.rs index f5380c27ecc28..a0f82ff7a9b58 100644 --- a/datafusion/datasource/src/file.rs +++ b/datafusion/datasource/src/file.rs @@ -39,12 +39,19 @@ use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use object_store::ObjectStore; -/// Helper function to convert any type implementing FileSource to Arc<dyn FileSource> +/// Helper function to convert any type implementing [`FileSource`] to `Arc` pub fn as_file_source(source: T) -> Arc { Arc::new(source) } -/// file format specific behaviors for elements in [`DataSource`] +/// File format specific behaviors for [`DataSource`] +/// +/// # Schema information +/// There are two important schemas for a [`FileSource`]: +/// 1. [`Self::table_schema`] -- the schema for the overall table +/// (file data plus partition columns) +/// 2. The logical output schema, comprised of [`Self::table_schema`] with +/// [`Self::projection`] applied /// /// See more details on specific implementations: /// * [`ArrowSource`](https://docs.rs/datafusion/latest/datafusion/datasource/physical_plan/struct.ArrowSource.html) @@ -64,24 +71,44 @@ pub trait FileSource: Send + Sync { ) -> Result>; /// Any fn as_any(&self) -> &dyn Any; - /// Returns the table schema for this file source. + + /// Returns the table schema for the overall table (including partition columns, if any) + /// + /// This method returns the unprojected schema: the full schema of the data + /// without [`Self::projection`] applied. /// - /// This always returns the unprojected schema (the full schema of the data). + /// The output schema of this `FileSource` is this TableSchema + /// with [`Self::projection`] applied. + /// + /// Use [`ProjectionExprs::project_schema`] to get the projected schema + /// after applying the projection. fn table_schema(&self) -> &crate::table_schema::TableSchema; + /// Initialize new type with batch size configuration fn with_batch_size(&self, batch_size: usize) -> Arc; - /// Returns the filter expression that will be applied during the file scan. + + /// Returns the filter expression that will be applied *during* the file scan. + /// + /// These expressions are in terms of the unprojected [`Self::table_schema`]. fn filter(&self) -> Option> { None } - /// Return the projection that will be applied to the output stream on top of the table schema. + + /// Return the projection that will be applied to the output stream on top + /// of [`Self::table_schema`]. + /// + /// Note you can use [`ProjectionExprs::project_schema`] on the table + /// schema to get the effective output schema of this source. fn projection(&self) -> Option<&ProjectionExprs> { None } + /// Return execution plan metrics fn metrics(&self) -> &ExecutionPlanMetricsSet; + /// String representation of file source such as "csv", "json", "parquet" fn file_type(&self) -> &str; + /// Format FileType specific information fn fmt_extra(&self, _t: DisplayFormatType, _f: &mut Formatter) -> fmt::Result { Ok(()) @@ -135,6 +162,19 @@ pub trait FileSource: Send + Sync { } /// Try to push down filters into this FileSource. + /// + /// `filters` must be in terms of the unprojected table schema (file schema + /// plus partition columns), before any projection is applied. + /// + /// Any filters that this FileSource chooses to evaluate itself should be + /// returned as `PushedDown::Yes` in the result, along with a FileSource + /// instance that incorporates those filters. Such filters are logically + /// applied "during" the file scan, meaning they may refer to columns not + /// included in the final output projection. + /// + /// Filters that cannot be pushed down should be marked as `PushedDown::No`, + /// and will be evaluated by an execution plan after the file source. + /// /// See [`ExecutionPlan::handle_child_pushdown_result`] for more details. /// /// [`ExecutionPlan::handle_child_pushdown_result`]: datafusion_physical_plan::ExecutionPlan::handle_child_pushdown_result @@ -189,7 +229,29 @@ pub trait FileSource: Send + Sync { /// * `Inexact` - Created a source optimized for ordering (e.g., reversed row groups) but not perfectly sorted /// * `Unsupported` - Cannot optimize for this ordering /// - /// Default implementation returns `Unsupported`. + /// # Deprecation / migration notes + /// - [`Self::try_reverse_output`] was renamed to this method and deprecated since `53.0.0`. + /// Per DataFusion's deprecation guidelines, it will be removed in `59.0.0` or later + /// (6 major versions or 6 months, whichever is longer). + /// - New implementations should override [`Self::try_pushdown_sort`] directly. + /// - For backwards compatibility, the default implementation of + /// [`Self::try_pushdown_sort`] delegates to the deprecated + /// [`Self::try_reverse_output`] until it is removed. After that point, the + /// default implementation will return [`SortOrderPushdownResult::Unsupported`]. + fn try_pushdown_sort( + &self, + order: &[PhysicalSortExpr], + eq_properties: &EquivalenceProperties, + ) -> Result>> { + #[expect(deprecated)] + self.try_reverse_output(order, eq_properties) + } + + /// Deprecated: Renamed to [`Self::try_pushdown_sort`]. + #[deprecated( + since = "53.0.0", + note = "Renamed to try_pushdown_sort. This method was never limited to reversing output. It will be removed in 59.0.0 or later." + )] fn try_reverse_output( &self, _order: &[PhysicalSortExpr], @@ -198,7 +260,7 @@ pub trait FileSource: Send + Sync { Ok(SortOrderPushdownResult::Unsupported) } - /// Try to push down a projection into a this FileSource. + /// Try to push down a projection into this FileSource. /// /// `FileSource` implementations that support projection pushdown should /// override this method and return a new `FileSource` instance with the @@ -232,7 +294,7 @@ pub trait FileSource: Send + Sync { /// `SchemaAdapterFactory` has been removed. Use `PhysicalExprAdapterFactory` instead. /// See `upgrading.md` for more details. #[deprecated( - since = "52.0.0", + since = "53.0.0", note = "SchemaAdapterFactory has been removed. Use PhysicalExprAdapterFactory instead. See upgrading.md for more details." )] #[expect(deprecated)] @@ -250,7 +312,7 @@ pub trait FileSource: Send + Sync { /// `SchemaAdapterFactory` has been removed. Use `PhysicalExprAdapterFactory` instead. /// See `upgrading.md` for more details. #[deprecated( - since = "52.0.0", + since = "53.0.0", note = "SchemaAdapterFactory has been removed. Use PhysicalExprAdapterFactory instead. See upgrading.md for more details." )] #[expect(deprecated)] diff --git a/datafusion/datasource/src/file_scan_config.rs b/datafusion/datasource/src/file_scan_config.rs index 1f7c37315c47a..c3e5cabce7bc2 100644 --- a/datafusion/datasource/src/file_scan_config.rs +++ b/datafusion/datasource/src/file_scan_config.rs @@ -55,10 +55,21 @@ use datafusion_physical_plan::{ use log::{debug, warn}; use std::{any::Any, fmt::Debug, fmt::Formatter, fmt::Result as FmtResult, sync::Arc}; -/// The base configurations for a [`DataSourceExec`], the a physical plan for -/// any given file format. +/// [`FileScanConfig`] represents scanning data from a group of files /// -/// Use [`DataSourceExec::from_data_source`] to create a [`DataSourceExec`] from a ``FileScanConfig`. +/// `FileScanConfig` is used to create a [`DataSourceExec`], the physical plan +/// for scanning files with a particular file format. +/// +/// The [`FileSource`] (e.g. `ParquetSource`, `CsvSource`, etc.) is responsible +/// for creating the actual execution plan to read the files based on a +/// `FileScanConfig`. Fields in a `FileScanConfig` such as Statistics represent +/// information about the files **before** any projection or filtering is +/// applied in the file source. +/// +/// Use [`FileScanConfigBuilder`] to construct a `FileScanConfig`. +/// +/// Use [`DataSourceExec::from_data_source`] to create a [`DataSourceExec`] from +/// a `FileScanConfig`. /// /// # Example /// ``` @@ -152,7 +163,18 @@ pub struct FileScanConfig { /// The maximum number of records to read from this plan. If `None`, /// all records after filtering are returned. pub limit: Option, - /// All equivalent lexicographical orderings that describe the schema. + /// Whether the scan's limit is order sensitive + /// When `true`, files must be read in the exact order specified to produce + /// correct results (e.g., for `ORDER BY ... LIMIT` queries). When `false`, + /// DataFusion may reorder file processing for optimization without affecting correctness. + pub preserve_order: bool, + /// All equivalent lexicographical output orderings of this file scan, in terms of + /// [`FileSource::table_schema`]. See [`FileScanConfigBuilder::with_output_ordering`] for more + /// details. + /// + /// [`Self::eq_properties`] uses this information along with projection + /// and filtering information to compute the effective + /// [`EquivalenceProperties`] pub output_ordering: Vec, /// File compression type pub file_compression_type: FileCompressionType, @@ -164,8 +186,11 @@ pub struct FileScanConfig { /// Expression adapter used to adapt filters and projections that are pushed down into the scan /// from the logical schema to the physical schema of the file. pub expr_adapter_factory: Option>, - /// Unprojected statistics for the table (file schema + partition columns). - /// These are projected on-demand via `projected_stats()`. + /// Statistics for the entire table (file schema + partition columns). + /// See [`FileScanConfigBuilder::with_statistics`] for more details. + /// + /// The effective statistics are computed on-demand via + /// [`ProjectionExprs::project_statistics`]. /// /// Note that this field is pub(crate) because accessing it directly from outside /// would be incorrect if there are filters being applied, thus this should be accessed @@ -240,6 +265,7 @@ pub struct FileScanConfigBuilder { object_store_url: ObjectStoreUrl, file_source: Arc, limit: Option, + preserve_order: bool, constraints: Option, file_groups: Vec, statistics: Option, @@ -269,6 +295,7 @@ impl FileScanConfigBuilder { output_ordering: vec![], file_compression_type: None, limit: None, + preserve_order: false, constraints: None, batch_size: None, expr_adapter_factory: None, @@ -276,22 +303,35 @@ impl FileScanConfigBuilder { } } - /// Set the maximum number of records to read from this plan. If `None`, - /// all records after filtering are returned. + /// Set the maximum number of records to read from this plan. + /// + /// If `None`, all records after filtering are returned. pub fn with_limit(mut self, limit: Option) -> Self { self.limit = limit; self } + /// Set whether the limit should be order-sensitive. + /// + /// When `true`, files must be read in the exact order specified to produce + /// correct results (e.g., for `ORDER BY ... LIMIT` queries). When `false`, + /// DataFusion may reorder file processing for optimization without + /// affecting correctness. + pub fn with_preserve_order(mut self, order_sensitive: bool) -> Self { + self.preserve_order = order_sensitive; + self + } + /// Set the file source for scanning files. /// - /// This method allows you to change the file source implementation (e.g. ParquetSource, CsvSource, etc.) - /// after the builder has been created. + /// This method allows you to change the file source implementation (e.g. + /// ParquetSource, CsvSource, etc.) after the builder has been created. pub fn with_source(mut self, file_source: Arc) -> Self { self.file_source = file_source; self } + /// Return the table schema pub fn table_schema(&self) -> &SchemaRef { self.file_source.table_schema().table_schema() } @@ -316,7 +356,12 @@ impl FileScanConfigBuilder { /// Set the columns on which to project the data using column indices. /// - /// Indexes that are higher than the number of columns of `file_schema` refer to `table_partition_cols`. + /// This method attempts to push down the projection to the underlying file + /// source if supported. If the file source does not support projection + /// pushdown, an error is returned. + /// + /// Indexes that are higher than the number of columns of `file_schema` + /// refer to `table_partition_cols`. pub fn with_projection_indices( mut self, indices: Option>, @@ -355,8 +400,18 @@ impl FileScanConfigBuilder { self } - /// Set the estimated overall statistics of the files, taking `filters` into account. - /// Defaults to [`Statistics::new_unknown`]. + /// Set the statistics of the files, including partition + /// columns. Defaults to [`Statistics::new_unknown`]. + /// + /// These statistics are for the entire table (file schema + partition + /// columns) before any projection or filtering is applied. Projections are + /// applied when statistics are retrieved, and if a filter is present, + /// [`FileScanConfig::statistics`] will mark the statistics as inexact + /// (counts are not adjusted). + /// + /// Projections and filters may be applied by the file source, either by + /// [`Self::with_projection_indices`] or a preexisting + /// [`FileSource::projection`] or [`FileSource::filter`]. pub fn with_statistics(mut self, statistics: Statistics) -> Self { self.statistics = Some(statistics); self @@ -392,6 +447,13 @@ impl FileScanConfigBuilder { } /// Set the output ordering of the files + /// + /// The expressions are in terms of the entire table schema (file schema + + /// partition columns), before any projection or filtering from the file + /// scan is applied. + /// + /// This is used for optimization purposes, e.g. to determine if a file scan + /// can satisfy an `ORDER BY` without an additional sort. pub fn with_output_ordering(mut self, output_ordering: Vec) -> Self { self.output_ordering = output_ordering; self @@ -450,6 +512,7 @@ impl FileScanConfigBuilder { object_store_url, file_source, limit, + preserve_order, constraints, file_groups, statistics, @@ -467,10 +530,14 @@ impl FileScanConfigBuilder { let file_compression_type = file_compression_type.unwrap_or(FileCompressionType::UNCOMPRESSED); + // If there is an output ordering, we should preserve it. + let preserve_order = preserve_order || !output_ordering.is_empty(); + FileScanConfig { object_store_url, file_source, limit, + preserve_order, constraints, file_groups, output_ordering, @@ -493,6 +560,7 @@ impl From for FileScanConfigBuilder { output_ordering: config.output_ordering, file_compression_type: Some(config.file_compression_type), limit: config.limit, + preserve_order: config.preserve_order, constraints: Some(config.constraints), batch_size: config.batch_size, expr_adapter_factory: config.expr_adapter_factory, @@ -661,11 +729,14 @@ impl DataSource for FileScanConfig { Partitioning::UnknownPartitioning(self.file_groups.len()) } + /// Computes the effective equivalence properties of this file scan, taking + /// into account the file schema, any projections or filters applied by the + /// file source, and the output ordering. fn eq_properties(&self) -> EquivalenceProperties { let schema = self.file_source.table_schema().table_schema(); let mut eq_properties = EquivalenceProperties::new_with_orderings( Arc::clone(schema), - self.output_ordering.clone(), + self.validated_output_ordering(), ) .with_constraints(self.constraints.clone()); @@ -771,37 +842,27 @@ impl DataSource for FileScanConfig { config: &ConfigOptions, ) -> Result>> { // Remap filter Column indices to match the table schema (file + partition columns). - // This is necessary because filters may have been created against a different schema - // (e.g., after projection pushdown) and need to be remapped to the table schema - // before being passed to the file source and ultimately serialized. - // For example, the filter being pushed down is `c1_c2 > 5` and it was created - // against the output schema of the this `DataSource` which has projection `c1 + c2 as c1_c2`. - // Thus we need to rewrite the filter back to `c1 + c2 > 5` before passing it to the file source. + // This is necessary because filters refer to the output schema of this `DataSource` + // (e.g., after projection pushdown has been applied) and need to be remapped to the table schema + // before being passed to the file source + // + // For example, consider a filter `c1_c2 > 5` being pushed down. If the + // `DataSource` has a projection `c1 + c2 as c1_c2`, the filter must be rewritten + // to refer to the table schema `c1 + c2 > 5` let table_schema = self.file_source.table_schema().table_schema(); - // If there's a projection with aliases, first map the filters back through - // the projection expressions before remapping to the table schema. let filters_to_remap = if let Some(projection) = self.file_source.projection() { - use datafusion_physical_plan::projection::update_expr; filters .into_iter() - .map(|filter| { - update_expr(&filter, projection.as_ref(), true)?.ok_or_else(|| { - internal_datafusion_err!( - "Failed to map filter expression through projection: {}", - filter - ) - }) - }) + .map(|filter| projection.unproject_expr(&filter)) .collect::>>()? } else { filters }; // Now remap column indices to match the table schema. - let remapped_filters: Result> = filters_to_remap + let remapped_filters = filters_to_remap .into_iter() - .map(|filter| reassign_expr_columns(filter, table_schema.as_ref())) - .collect(); - let remapped_filters = remapped_filters?; + .map(|filter| reassign_expr_columns(filter, table_schema)) + .collect::>>()?; let result = self .file_source @@ -829,20 +890,20 @@ impl DataSource for FileScanConfig { &self, order: &[PhysicalSortExpr], ) -> Result>> { - // Delegate to FileSource to check if reverse scanning can satisfy the request. + // Delegate to FileSource to see if it can optimize for the requested ordering. let pushdown_result = self .file_source - .try_reverse_output(order, &self.eq_properties())?; + .try_pushdown_sort(order, &self.eq_properties())?; match pushdown_result { SortOrderPushdownResult::Exact { inner } => { Ok(SortOrderPushdownResult::Exact { - inner: self.rebuild_with_source(inner, true)?, + inner: self.rebuild_with_source(inner, true, order)?, }) } SortOrderPushdownResult::Inexact { inner } => { Ok(SortOrderPushdownResult::Inexact { - inner: self.rebuild_with_source(inner, false)?, + inner: self.rebuild_with_source(inner, false, order)?, }) } SortOrderPushdownResult::Unsupported => { @@ -850,9 +911,55 @@ impl DataSource for FileScanConfig { } } } + + fn with_preserve_order(&self, preserve_order: bool) -> Option> { + if self.preserve_order == preserve_order { + return Some(Arc::new(self.clone())); + } + + let new_config = FileScanConfig { + preserve_order, + ..self.clone() + }; + Some(Arc::new(new_config)) + } } impl FileScanConfig { + /// Returns only the output orderings that are validated against actual + /// file group statistics. + /// + /// For example, individual files may be ordered by `col1 ASC`, + /// but if we have files with these min/max statistics in a single partition / file group: + /// + /// - file1: min(col1) = 10, max(col1) = 20 + /// - file2: min(col1) = 5, max(col1) = 15 + /// + /// Because reading file1 followed by file2 would produce out-of-order output (there is overlap + /// in the ranges), we cannot retain `col1 ASC` as a valid output ordering. + /// + /// Similarly this would not be a valid order (non-overlapping ranges but not ordered): + /// + /// - file1: min(col1) = 20, max(col1) = 30 + /// - file2: min(col1) = 10, max(col1) = 15 + /// + /// On the other hand if we had: + /// + /// - file1: min(col1) = 5, max(col1) = 15 + /// - file2: min(col1) = 16, max(col1) = 25 + /// + /// Then we know that reading file1 followed by file2 will produce ordered output, + /// so `col1 ASC` would be retained. + /// + /// Note that we are checking for ordering *within* *each* file group / partition, + /// files in different partitions are read independently and do not affect each other's ordering. + /// Merging of the multiple partition streams into a single ordered stream is handled + /// upstream e.g. by `SortPreservingMergeExec`. + fn validated_output_ordering(&self) -> Vec { + let schema = self.file_source.table_schema().table_schema(); + validate_orderings(&self.output_ordering, schema, &self.file_groups, None) + } + /// Get the file schema (schema of the files without partition columns) pub fn file_schema(&self) -> &SchemaRef { self.file_source.table_schema().file_schema() @@ -1123,19 +1230,44 @@ impl FileScanConfig { &self, new_file_source: Arc, is_exact: bool, + order: &[PhysicalSortExpr], ) -> Result> { let mut new_config = self.clone(); - // Reverse file groups (FileScanConfig's responsibility) - new_config.file_groups = new_config - .file_groups - .into_iter() - .map(|group| { - let mut files = group.into_inner(); - files.reverse(); - files.into() - }) - .collect(); + // Reverse file order (within each group) if the caller is requesting a reversal of this + // scan's declared output ordering. + // + // Historically this function always reversed `file_groups` because it was only reached + // via `FileSource::try_reverse_output` (where a reversal was the only supported + // optimization). + // + // Now that `FileSource::try_pushdown_sort` is generic, we must not assume reversal: other + // optimizations may become possible (e.g. already-sorted data, statistics-based file + // reordering). Therefore we only reverse files when it is known to help satisfy the + // requested ordering. + let reverse_file_groups = if self.output_ordering.is_empty() { + false + } else if let Some(requested) = LexOrdering::new(order.iter().cloned()) { + let projected_schema = self.projected_schema()?; + let orderings = project_orderings(&self.output_ordering, &projected_schema); + orderings + .iter() + .any(|ordering| ordering.is_reverse(&requested)) + } else { + false + }; + + if reverse_file_groups { + new_config.file_groups = new_config + .file_groups + .into_iter() + .map(|group| { + let mut files = group.into_inner(); + files.reverse(); + files.into() + }) + .collect(); + } new_config.file_source = new_file_source; @@ -1202,6 +1334,51 @@ fn ordered_column_indices_from_projection( .collect::>>() } +/// Check whether a given ordering is valid for all file groups by verifying +/// that files within each group are sorted according to their min/max statistics. +/// +/// For single-file (or empty) groups, the ordering is trivially valid. +/// For multi-file groups, we check that the min/max statistics for the sort +/// columns are in order and non-overlapping (or touching at boundaries). +/// +/// `projection` maps projected column indices back to table-schema indices +/// when validating after projection; pass `None` when validating at +/// table-schema level. +fn is_ordering_valid_for_file_groups( + file_groups: &[FileGroup], + ordering: &LexOrdering, + schema: &SchemaRef, + projection: Option<&[usize]>, +) -> bool { + file_groups.iter().all(|group| { + if group.len() <= 1 { + return true; // single-file groups are trivially sorted + } + match MinMaxStatistics::new_from_files(ordering, schema, projection, group.iter()) + { + Ok(stats) => stats.is_sorted(), + Err(_) => false, // can't prove sorted → reject + } + }) +} + +/// Filters orderings to retain only those valid for all file groups, +/// verified via min/max statistics. +fn validate_orderings( + orderings: &[LexOrdering], + schema: &SchemaRef, + file_groups: &[FileGroup], + projection: Option<&[usize]>, +) -> Vec { + orderings + .iter() + .filter(|ordering| { + is_ordering_valid_for_file_groups(file_groups, ordering, schema, projection) + }) + .cloned() + .collect() +} + /// The various listing tables does not attempt to read all files /// concurrently, instead they will read files in sequence within a /// partition. This is an important property as it allows plans to @@ -1268,52 +1445,47 @@ fn get_projected_output_ordering( let projected_orderings = project_orderings(&base_config.output_ordering, projected_schema); - let mut all_orderings = vec![]; - for new_ordering in projected_orderings { - // Check if any file groups are not sorted - if base_config.file_groups.iter().any(|group| { - if group.len() <= 1 { - // File groups with <= 1 files are always sorted - return false; - } - - let Some(indices) = base_config - .file_source - .projection() - .as_ref() - .map(|p| ordered_column_indices_from_projection(p)) - else { - // Can't determine if ordered without a simple projection - return true; - }; - - let statistics = match MinMaxStatistics::new_from_files( - &new_ordering, + let indices = base_config + .file_source + .projection() + .as_ref() + .map(|p| ordered_column_indices_from_projection(p)); + + match indices { + Some(Some(indices)) => { + // Simple column projection — validate with statistics + validate_orderings( + &projected_orderings, projected_schema, - indices.as_deref(), - group.iter(), - ) { - Ok(statistics) => statistics, - Err(e) => { - log::trace!("Error fetching statistics for file group: {e}"); - // we can't prove that it's ordered, so we have to reject it - return true; - } - }; - - !statistics.is_sorted() - }) { - debug!( - "Skipping specified output ordering {:?}. \ - Some file groups couldn't be determined to be sorted: {:?}", - base_config.output_ordering[0], base_config.file_groups - ); - continue; + &base_config.file_groups, + Some(indices.as_slice()), + ) + } + None => { + // No projection — validate with statistics (no remapping needed) + validate_orderings( + &projected_orderings, + projected_schema, + &base_config.file_groups, + None, + ) + } + Some(None) => { + // Complex projection (expressions, not simple columns) — can't + // determine column indices for statistics. Still valid if all + // file groups have at most one file. + if base_config.file_groups.iter().all(|g| g.len() <= 1) { + projected_orderings + } else { + debug!( + "Skipping specified output orderings. \ + Some file groups couldn't be determined to be sorted: {:?}", + base_config.file_groups + ); + vec![] + } } - - all_orderings.push(new_ordering); } - all_orderings } /// Convert type to a type suitable for use as a `ListingTable` @@ -1358,6 +1530,62 @@ mod tests { use datafusion_physical_expr::projection::ProjectionExpr; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; + #[derive(Clone)] + struct InexactSortPushdownSource { + metrics: ExecutionPlanMetricsSet, + table_schema: TableSchema, + } + + impl InexactSortPushdownSource { + fn new(table_schema: TableSchema) -> Self { + Self { + metrics: ExecutionPlanMetricsSet::new(), + table_schema, + } + } + } + + impl FileSource for InexactSortPushdownSource { + fn create_file_opener( + &self, + _object_store: Arc, + _base_config: &FileScanConfig, + _partition: usize, + ) -> Result> { + unimplemented!() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn table_schema(&self) -> &TableSchema { + &self.table_schema + } + + fn with_batch_size(&self, _batch_size: usize) -> Arc { + Arc::new(self.clone()) + } + + fn metrics(&self) -> &ExecutionPlanMetricsSet { + &self.metrics + } + + fn file_type(&self) -> &str { + "mock" + } + + fn try_pushdown_sort( + &self, + _order: &[PhysicalSortExpr], + _eq_properties: &EquivalenceProperties, + ) -> Result>> { + Ok(SortOrderPushdownResult::Inexact { + inner: Arc::new(self.clone()) as Arc, + }) + } + } + #[test] fn physical_plan_config_no_projection_tab_cols_as_field() { let file_schema = aggr_test_schema(); @@ -2303,4 +2531,56 @@ mod tests { _ => panic!("Expected Hash partitioning"), } } + + #[test] + fn try_pushdown_sort_reverses_file_groups_only_when_requested_is_reverse() + -> Result<()> { + let file_schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + + let table_schema = TableSchema::new(Arc::clone(&file_schema), vec![]); + let file_source = Arc::new(InexactSortPushdownSource::new(table_schema)); + + let file_groups = vec![FileGroup::new(vec![ + PartitionedFile::new("file1", 1), + PartitionedFile::new("file2", 1), + ])]; + + let sort_expr_asc = PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))); + let config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), file_source) + .with_file_groups(file_groups) + .with_output_ordering(vec![ + LexOrdering::new(vec![sort_expr_asc.clone()]).unwrap(), + ]) + .build(); + + let requested_asc = vec![sort_expr_asc.clone()]; + let result = config.try_pushdown_sort(&requested_asc)?; + let SortOrderPushdownResult::Inexact { inner } = result else { + panic!("Expected Inexact result"); + }; + let pushed_config = inner + .as_any() + .downcast_ref::() + .expect("Expected FileScanConfig"); + let pushed_files = pushed_config.file_groups[0].files(); + assert_eq!(pushed_files[0].object_meta.location.as_ref(), "file1"); + assert_eq!(pushed_files[1].object_meta.location.as_ref(), "file2"); + + let requested_desc = vec![sort_expr_asc.reverse()]; + let result = config.try_pushdown_sort(&requested_desc)?; + let SortOrderPushdownResult::Inexact { inner } = result else { + panic!("Expected Inexact result"); + }; + let pushed_config = inner + .as_any() + .downcast_ref::() + .expect("Expected FileScanConfig"); + let pushed_files = pushed_config.file_groups[0].files(); + assert_eq!(pushed_files[0].object_meta.location.as_ref(), "file2"); + assert_eq!(pushed_files[1].object_meta.location.as_ref(), "file1"); + + Ok(()) + } } diff --git a/datafusion/datasource/src/file_sink_config.rs b/datafusion/datasource/src/file_sink_config.rs index 643831a1199f8..1abce86a3565f 100644 --- a/datafusion/datasource/src/file_sink_config.rs +++ b/datafusion/datasource/src/file_sink_config.rs @@ -32,6 +32,52 @@ use datafusion_expr::dml::InsertOp; use async_trait::async_trait; use object_store::ObjectStore; +/// Determines how `FileSink` output paths are interpreted. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum FileOutputMode { + /// Infer output mode from the output URL (for example, by extension / trailing `/`). + #[default] + Automatic, + /// Write to a single output file at the exact output path. + SingleFile, + /// Write to a directory under the output path with generated filenames. + Directory, +} + +impl FileOutputMode { + /// Resolve this mode into a `single_file_output` boolean for the demuxer. + pub fn single_file_output(self, base_output_path: &ListingTableUrl) -> bool { + match self { + Self::Automatic => { + !base_output_path.is_collection() + && base_output_path.file_extension().is_some() + } + Self::SingleFile => true, + Self::Directory => false, + } + } +} + +impl From> for FileOutputMode { + fn from(value: Option) -> Self { + match value { + None => Self::Automatic, + Some(true) => Self::SingleFile, + Some(false) => Self::Directory, + } + } +} + +impl From for Option { + fn from(value: FileOutputMode) -> Self { + match value { + FileOutputMode::Automatic => None, + FileOutputMode::SingleFile => Some(true), + FileOutputMode::Directory => Some(false), + } + } +} + /// General behaviors for files that do `DataSink` operations #[async_trait] pub trait FileSink: DataSink { @@ -112,6 +158,8 @@ pub struct FileSinkConfig { pub keep_partition_by_columns: bool, /// File extension without a dot(.) pub file_extension: String, + /// Determines how the output path is interpreted. + pub file_output_mode: FileOutputMode, } impl FileSinkConfig { diff --git a/datafusion/datasource/src/mod.rs b/datafusion/datasource/src/mod.rs index 2965be7637899..f80c9cb0b0daa 100644 --- a/datafusion/datasource/src/mod.rs +++ b/datafusion/datasource/src/mod.rs @@ -24,7 +24,6 @@ // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -#![deny(clippy::allow_attributes)] //! A table that uses the `ObjectStore` listing capability //! to get the list of files to process. diff --git a/datafusion/datasource/src/schema_adapter.rs b/datafusion/datasource/src/schema_adapter.rs index 3d0b06954e085..c995fa58d6c89 100644 --- a/datafusion/datasource/src/schema_adapter.rs +++ b/datafusion/datasource/src/schema_adapter.rs @@ -115,10 +115,20 @@ pub trait SchemaMapper: Debug + Send + Sync { /// Deprecated: Default [`SchemaAdapterFactory`] for mapping schemas. /// -/// This struct has been removed. Use [`PhysicalExprAdapterFactory`] instead. +/// This struct has been removed. +/// +/// Use [`PhysicalExprAdapterFactory`] instead to customize scans via +/// [`FileScanConfigBuilder`], i.e. if you had implemented a custom [`SchemaAdapter`] +/// and passed that into [`FileScanConfigBuilder`] / [`ParquetSource`]. +/// Use [`BatchAdapter`] if you want to map a stream of [`RecordBatch`]es +/// between one schema and another, i.e. if you were calling [`SchemaMapper::map_batch`] manually. +/// /// See `upgrading.md` for more details. /// /// [`PhysicalExprAdapterFactory`]: datafusion_physical_expr_adapter::PhysicalExprAdapterFactory +/// [`FileScanConfigBuilder`]: crate::file_scan_config::FileScanConfigBuilder +/// [`ParquetSource`]: https://docs.rs/datafusion-datasource-parquet/latest/datafusion_datasource_parquet/source/struct.ParquetSource.html +/// [`BatchAdapter`]: datafusion_physical_expr_adapter::BatchAdapter #[deprecated( since = "52.0.0", note = "DefaultSchemaAdapterFactory has been removed. Use PhysicalExprAdapterFactory instead. See upgrading.md for more details." @@ -178,10 +188,20 @@ impl SchemaAdapter for DeprecatedSchemaAdapter { /// Deprecated: The SchemaMapping struct held a mapping from the file schema to the table schema. /// -/// This struct has been removed. Use [`PhysicalExprAdapterFactory`] instead. +/// This struct has been removed. +/// +/// Use [`PhysicalExprAdapterFactory`] instead to customize scans via +/// [`FileScanConfigBuilder`], i.e. if you had implemented a custom [`SchemaAdapter`] +/// and passed that into [`FileScanConfigBuilder`] / [`ParquetSource`]. +/// Use [`BatchAdapter`] if you want to map a stream of [`RecordBatch`]es +/// between one schema and another, i.e. if you were calling [`SchemaMapper::map_batch`] manually. +/// /// See `upgrading.md` for more details. /// /// [`PhysicalExprAdapterFactory`]: datafusion_physical_expr_adapter::PhysicalExprAdapterFactory +/// [`FileScanConfigBuilder`]: crate::file_scan_config::FileScanConfigBuilder +/// [`ParquetSource`]: https://docs.rs/datafusion-datasource-parquet/latest/datafusion_datasource_parquet/source/struct.ParquetSource.html +/// [`BatchAdapter`]: datafusion_physical_expr_adapter::BatchAdapter #[deprecated( since = "52.0.0", note = "SchemaMapping has been removed. Use PhysicalExprAdapterFactory instead. See upgrading.md for more details." diff --git a/datafusion/datasource/src/sink.rs b/datafusion/datasource/src/sink.rs index 5460a0ffdc3df..5acc89722b200 100644 --- a/datafusion/datasource/src/sink.rs +++ b/datafusion/datasource/src/sink.rs @@ -94,7 +94,7 @@ pub struct DataSinkExec { impl Debug for DataSinkExec { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "DataSinkExec schema: {:?}", self.count_schema) + write!(f, "DataSinkExec schema: {}", self.count_schema) } } diff --git a/datafusion/datasource/src/source.rs b/datafusion/datasource/src/source.rs index a3892dfac9778..a4e27dac769af 100644 --- a/datafusion/datasource/src/source.rs +++ b/datafusion/datasource/src/source.rs @@ -158,16 +158,6 @@ pub trait DataSource: Send + Sync + Debug { /// across all partitions if `partition` is `None`. fn partition_statistics(&self, partition: Option) -> Result; - /// Returns aggregate statistics across all partitions. - /// - /// # Deprecated - /// Use [`Self::partition_statistics`] instead, which provides more fine-grained - /// control over statistics retrieval (per-partition or aggregate). - #[deprecated(since = "51.0.0", note = "Use partition_statistics instead")] - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - /// Return a copy of this DataSource with a new fetch limit fn with_fetch(&self, _limit: Option) -> Option>; fn fetch(&self) -> Option; @@ -178,7 +168,13 @@ pub trait DataSource: Send + Sync + Debug { &self, _projection: &ProjectionExprs, ) -> Result>>; + /// Try to push down filters into this DataSource. + /// + /// These filters are in terms of the output schema of this DataSource (e.g. + /// [`Self::eq_properties`] and output of any projections pushed into the + /// source), not the original table schema. + /// /// See [`ExecutionPlan::handle_child_pushdown_result`] for more details. /// /// [`ExecutionPlan::handle_child_pushdown_result`]: datafusion_physical_plan::ExecutionPlan::handle_child_pushdown_result @@ -210,6 +206,11 @@ pub trait DataSource: Send + Sync + Debug { ) -> Result>> { Ok(SortOrderPushdownResult::Unsupported) } + + /// Returns a variant of this `DataSource` that is aware of order-sensitivity. + fn with_preserve_order(&self, _preserve_order: bool) -> Option> { + None + } } /// [`ExecutionPlan`] that reads one or more files @@ -393,6 +394,18 @@ impl ExecutionPlan for DataSourceExec { Ok(Arc::new(new_exec) as Arc) }) } + + fn with_preserve_order( + &self, + preserve_order: bool, + ) -> Option> { + self.data_source + .with_preserve_order(preserve_order) + .map(|new_data_source| { + Arc::new(self.clone().with_data_source(new_data_source)) + as Arc + }) + } } impl DataSourceExec { diff --git a/datafusion/datasource/src/statistics.rs b/datafusion/datasource/src/statistics.rs index 2f34ca032e132..b1a56e096c222 100644 --- a/datafusion/datasource/src/statistics.rs +++ b/datafusion/datasource/src/statistics.rs @@ -266,11 +266,12 @@ impl MinMaxStatistics { } /// Check if the min/max statistics are in order and non-overlapping + /// (or touching at boundaries) pub fn is_sorted(&self) -> bool { self.max_by_sort_order .iter() .zip(self.min_by_sort_order.iter().skip(1)) - .all(|(max, next_min)| max < next_min) + .all(|(max, next_min)| max <= next_min) } } diff --git a/datafusion/datasource/src/table_schema.rs b/datafusion/datasource/src/table_schema.rs index a45cdbaaea076..5b7fc4727df05 100644 --- a/datafusion/datasource/src/table_schema.rs +++ b/datafusion/datasource/src/table_schema.rs @@ -20,13 +20,13 @@ use arrow::datatypes::{FieldRef, SchemaBuilder, SchemaRef}; use std::sync::Arc; -/// Helper to hold table schema information for partitioned data sources. +/// The overall schema for potentially partitioned data sources. /// -/// When reading partitioned data (such as Hive-style partitioning), a table's schema +/// When reading partitioned data (such as Hive-style partitioning), a [`TableSchema`] /// consists of two parts: /// 1. **File schema**: The schema of the actual data files on disk -/// 2. **Partition columns**: Columns that are encoded in the directory structure, -/// not stored in the files themselves +/// 2. **Partition columns**: Columns whose values are encoded in the directory structure, +/// but not stored in the files themselves /// /// # Example: Partitioned Table /// diff --git a/datafusion/datasource/src/url.rs b/datafusion/datasource/src/url.rs index 678bd280fc97e..0c274806c09c3 100644 --- a/datafusion/datasource/src/url.rs +++ b/datafusion/datasource/src/url.rs @@ -43,7 +43,7 @@ pub struct ListingTableUrl { prefix: Path, /// An optional glob expression used to filter files glob: Option, - + /// Optional table reference for the table this url belongs to table_ref: Option, } @@ -341,17 +341,19 @@ impl ListingTableUrl { } /// Returns a copy of current [`ListingTableUrl`] with a specified `glob` - pub fn with_glob(self, glob: &str) -> Result { - let glob = - Pattern::new(glob).map_err(|e| DataFusionError::External(Box::new(e)))?; - Self::try_new(self.url, Some(glob)) + pub fn with_glob(mut self, glob: &str) -> Result { + self.glob = + Some(Pattern::new(glob).map_err(|e| DataFusionError::External(Box::new(e)))?); + Ok(self) } + /// Set the table reference for this [`ListingTableUrl`] pub fn with_table_ref(mut self, table_ref: TableReference) -> Self { self.table_ref = Some(table_ref); self } + /// Return the table reference for this [`ListingTableUrl`] pub fn get_table_ref(&self) -> &Option { &self.table_ref } diff --git a/datafusion/datasource/src/write/demux.rs b/datafusion/datasource/src/write/demux.rs index bec5b8b0bff0e..1648624747af2 100644 --- a/datafusion/datasource/src/write/demux.rs +++ b/datafusion/datasource/src/write/demux.rs @@ -35,8 +35,8 @@ use arrow::datatypes::{DataType, Schema}; use datafusion_common::cast::{ as_boolean_array, as_date32_array, as_date64_array, as_float16_array, as_float32_array, as_float64_array, as_int8_array, as_int16_array, as_int32_array, - as_int64_array, as_string_array, as_string_view_array, as_uint8_array, - as_uint16_array, as_uint32_array, as_uint64_array, + as_int64_array, as_large_string_array, as_string_array, as_string_view_array, + as_uint8_array, as_uint16_array, as_uint32_array, as_uint64_array, }; use datafusion_common::{exec_datafusion_err, internal_datafusion_err, not_impl_err}; use datafusion_common_runtime::SpawnedTask; @@ -106,8 +106,9 @@ pub(crate) fn start_demuxer_task( let file_extension = config.file_extension.clone(); let base_output_path = config.table_paths[0].clone(); let task = if config.table_partition_cols.is_empty() { - let single_file_output = !base_output_path.is_collection() - && base_output_path.file_extension().is_some(); + let single_file_output = config + .file_output_mode + .single_file_output(&base_output_path); SpawnedTask::spawn(async move { row_count_demuxer( tx, @@ -397,6 +398,12 @@ fn compute_partition_keys_by_row<'a>( partition_values.push(Cow::from(array.value(i))); } } + DataType::LargeUtf8 => { + let array = as_large_string_array(col_array)?; + for i in 0..rb.num_rows() { + partition_values.push(Cow::from(array.value(i))); + } + } DataType::Utf8View => { let array = as_string_view_array(col_array)?; for i in 0..rb.num_rows() { diff --git a/datafusion/doc/src/lib.rs b/datafusion/doc/src/lib.rs index 836cb9345b51f..591a5a62f3b20 100644 --- a/datafusion/doc/src/lib.rs +++ b/datafusion/doc/src/lib.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -#![deny(clippy::allow_attributes)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] #![doc( html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", diff --git a/datafusion/execution/src/cache/cache_manager.rs b/datafusion/execution/src/cache/cache_manager.rs index 4cc5586440230..bd34c441bdbde 100644 --- a/datafusion/execution/src/cache/cache_manager.rs +++ b/datafusion/execution/src/cache/cache_manager.rs @@ -196,6 +196,7 @@ pub trait ListFilesCache: CacheAccessor { /// Retrieves the information about the entries currently cached. fn list_entries(&self) -> HashMap; + /// Drop all entries for the given table reference. fn drop_table_entries(&self, table_ref: &Option) -> Result<()>; } diff --git a/datafusion/execution/src/cache/list_files_cache.rs b/datafusion/execution/src/cache/list_files_cache.rs index c86a03574e3a3..b1b8e6b500169 100644 --- a/datafusion/execution/src/cache/list_files_cache.rs +++ b/datafusion/execution/src/cache/list_files_cache.rs @@ -139,6 +139,11 @@ pub const DEFAULT_LIST_FILES_CACHE_MEMORY_LIMIT: usize = 1024 * 1024; // 1MiB /// The default cache TTL for the [`DefaultListFilesCache`] pub const DEFAULT_LIST_FILES_CACHE_TTL: Option = None; // Infinite +/// Key for [`DefaultListFilesCache`] +/// +/// Each entry is scoped to its use within a specific table so that the cache +/// can differentiate between identical paths in different tables, and +/// table-level cache invalidation. #[derive(PartialEq, Eq, Hash, Clone, Debug)] pub struct TableScopedPath { pub table: Option, diff --git a/datafusion/execution/src/cache/mod.rs b/datafusion/execution/src/cache/mod.rs index 417cb86cd9e6c..0380e50c0935c 100644 --- a/datafusion/execution/src/cache/mod.rs +++ b/datafusion/execution/src/cache/mod.rs @@ -24,6 +24,7 @@ mod list_files_cache; pub use file_metadata_cache::DefaultFilesMetadataCache; pub use list_files_cache::DefaultListFilesCache; +pub use list_files_cache::ListFilesEntry; pub use list_files_cache::TableScopedPath; /// Base trait for cache implementations with common operations. diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index 30ba7de76a471..854d239236766 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -480,6 +480,12 @@ impl SessionConfig { self.options.execution.enforce_batch_size_in_joins } + /// Toggle SQL ANSI mode for expressions, casting, and error handling + pub fn with_enable_ansi_mode(mut self, enable_ansi_mode: bool) -> Self { + self.options_mut().execution.enable_ansi_mode = enable_ansi_mode; + self + } + /// Convert configuration options to name-value pairs with values /// converted to strings. /// diff --git a/datafusion/execution/src/disk_manager.rs b/datafusion/execution/src/disk_manager.rs index cb87053d8d035..d878fdcf66a4c 100644 --- a/datafusion/execution/src/disk_manager.rs +++ b/datafusion/execution/src/disk_manager.rs @@ -25,7 +25,7 @@ use parking_lot::Mutex; use rand::{Rng, rng}; use std::path::{Path, PathBuf}; use std::sync::Arc; -use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; use tempfile::{Builder, NamedTempFile, TempDir}; use datafusion_common::human_readable_size; @@ -77,6 +77,7 @@ impl DiskManagerBuilder { local_dirs: Mutex::new(Some(vec![])), max_temp_directory_size: self.max_temp_directory_size, used_disk_space: Arc::new(AtomicU64::new(0)), + active_files_count: Arc::new(AtomicUsize::new(0)), }), DiskManagerMode::Directories(conf_dirs) => { let local_dirs = create_local_dirs(&conf_dirs)?; @@ -87,12 +88,14 @@ impl DiskManagerBuilder { local_dirs: Mutex::new(Some(local_dirs)), max_temp_directory_size: self.max_temp_directory_size, used_disk_space: Arc::new(AtomicU64::new(0)), + active_files_count: Arc::new(AtomicUsize::new(0)), }) } DiskManagerMode::Disabled => Ok(DiskManager { local_dirs: Mutex::new(None), max_temp_directory_size: self.max_temp_directory_size, used_disk_space: Arc::new(AtomicU64::new(0)), + active_files_count: Arc::new(AtomicUsize::new(0)), }), } } @@ -169,6 +172,17 @@ pub struct DiskManager { /// Used disk space in the temporary directories. Now only spilled data for /// external executors are counted. used_disk_space: Arc, + /// Number of active temporary files created by this disk manager + active_files_count: Arc, +} + +/// Information about the current disk usage for spilling +#[derive(Debug, Clone, Copy)] +pub struct SpillingProgress { + /// Total bytes currently used on disk for spilling + pub current_bytes: u64, + /// Total number of active spill files + pub active_files_count: usize, } impl DiskManager { @@ -187,6 +201,7 @@ impl DiskManager { local_dirs: Mutex::new(Some(vec![])), max_temp_directory_size: DEFAULT_MAX_TEMP_DIRECTORY_SIZE, used_disk_space: Arc::new(AtomicU64::new(0)), + active_files_count: Arc::new(AtomicUsize::new(0)), })), DiskManagerConfig::NewSpecified(conf_dirs) => { let local_dirs = create_local_dirs(&conf_dirs)?; @@ -197,12 +212,14 @@ impl DiskManager { local_dirs: Mutex::new(Some(local_dirs)), max_temp_directory_size: DEFAULT_MAX_TEMP_DIRECTORY_SIZE, used_disk_space: Arc::new(AtomicU64::new(0)), + active_files_count: Arc::new(AtomicUsize::new(0)), })) } DiskManagerConfig::Disabled => Ok(Arc::new(Self { local_dirs: Mutex::new(None), max_temp_directory_size: DEFAULT_MAX_TEMP_DIRECTORY_SIZE, used_disk_space: Arc::new(AtomicU64::new(0)), + active_files_count: Arc::new(AtomicUsize::new(0)), })), } } @@ -252,6 +269,14 @@ impl DiskManager { self.max_temp_directory_size } + /// Returns the current spilling progress + pub fn spilling_progress(&self) -> SpillingProgress { + SpillingProgress { + current_bytes: self.used_disk_space.load(Ordering::Relaxed), + active_files_count: self.active_files_count.load(Ordering::Relaxed), + } + } + /// Returns the temporary directory paths pub fn temp_dir_paths(&self) -> Vec { self.local_dirs @@ -301,6 +326,7 @@ impl DiskManager { } let dir_index = rng().random_range(0..local_dirs.len()); + self.active_files_count.fetch_add(1, Ordering::Relaxed); Ok(RefCountedTempFile { parent_temp_dir: Arc::clone(&local_dirs[dir_index]), tempfile: Arc::new( @@ -422,6 +448,9 @@ impl Drop for RefCountedTempFile { self.disk_manager .used_disk_space .fetch_sub(current_usage, Ordering::Relaxed); + self.disk_manager + .active_files_count + .fetch_sub(1, Ordering::Relaxed); } } } diff --git a/datafusion/execution/src/lib.rs b/datafusion/execution/src/lib.rs index aced2f46d7224..1a8da9459ae10 100644 --- a/datafusion/execution/src/lib.rs +++ b/datafusion/execution/src/lib.rs @@ -24,7 +24,6 @@ // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -#![deny(clippy::allow_attributes)] //! DataFusion execution configuration and runtime structures diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index fbf9ce41da8fe..30ad658d0d390 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -18,7 +18,7 @@ //! [`MemoryPool`] for memory management during query execution, [`proxy`] for //! help with allocation accounting. -use datafusion_common::{Result, internal_err}; +use datafusion_common::{Result, internal_datafusion_err}; use std::hash::{Hash, Hasher}; use std::{cmp::Ordering, sync::Arc, sync::atomic}; @@ -322,7 +322,7 @@ impl MemoryConsumer { pool: Arc::clone(pool), consumer: self, }), - size: 0, + size: atomic::AtomicUsize::new(0), } } } @@ -351,13 +351,13 @@ impl Drop for SharedRegistration { #[derive(Debug)] pub struct MemoryReservation { registration: Arc, - size: usize, + size: atomic::AtomicUsize, } impl MemoryReservation { /// Returns the size of this reservation in bytes pub fn size(&self) -> usize { - self.size + self.size.load(atomic::Ordering::Relaxed) } /// Returns [MemoryConsumer] for this [MemoryReservation] @@ -367,10 +367,10 @@ impl MemoryReservation { /// Frees all bytes from this reservation back to the underlying /// pool, returning the number of bytes freed. - pub fn free(&mut self) -> usize { - let size = self.size; + pub fn free(&self) -> usize { + let size = self.size.swap(0, atomic::Ordering::Relaxed); if size != 0 { - self.shrink(size) + self.registration.pool.shrink(self, size); } size } @@ -380,60 +380,73 @@ impl MemoryReservation { /// # Panics /// /// Panics if `capacity` exceeds [`Self::size`] - pub fn shrink(&mut self, capacity: usize) { - let new_size = self.size.checked_sub(capacity).unwrap(); + pub fn shrink(&self, capacity: usize) { + self.size + .fetch_update( + atomic::Ordering::Relaxed, + atomic::Ordering::Relaxed, + |prev| prev.checked_sub(capacity), + ) + .expect("capacity exceeds reservation size"); self.registration.pool.shrink(self, capacity); - self.size = new_size } /// Tries to free `capacity` bytes from this reservation - /// if `capacity` does not exceed [`Self::size`] - /// Returns new reservation size - /// or error if shrinking capacity is more than allocated size - pub fn try_shrink(&mut self, capacity: usize) -> Result { - if let Some(new_size) = self.size.checked_sub(capacity) { - self.registration.pool.shrink(self, capacity); - self.size = new_size; - Ok(new_size) - } else { - internal_err!( - "Cannot free the capacity {capacity} out of allocated size {}", - self.size + /// if `capacity` does not exceed [`Self::size`]. + /// Returns new reservation size, + /// or error if shrinking capacity is more than allocated size. + pub fn try_shrink(&self, capacity: usize) -> Result { + let prev = self + .size + .fetch_update( + atomic::Ordering::Relaxed, + atomic::Ordering::Relaxed, + |prev| prev.checked_sub(capacity), ) - } + .map_err(|_| { + let prev = self.size.load(atomic::Ordering::Relaxed); + internal_datafusion_err!( + "Cannot free the capacity {capacity} out of allocated size {prev}" + ) + })?; + + self.registration.pool.shrink(self, capacity); + Ok(prev - capacity) } /// Sets the size of this reservation to `capacity` - pub fn resize(&mut self, capacity: usize) { - match capacity.cmp(&self.size) { - Ordering::Greater => self.grow(capacity - self.size), - Ordering::Less => self.shrink(self.size - capacity), + pub fn resize(&self, capacity: usize) { + let size = self.size.load(atomic::Ordering::Relaxed); + match capacity.cmp(&size) { + Ordering::Greater => self.grow(capacity - size), + Ordering::Less => self.shrink(size - capacity), _ => {} } } /// Try to set the size of this reservation to `capacity` - pub fn try_resize(&mut self, capacity: usize) -> Result<()> { - match capacity.cmp(&self.size) { - Ordering::Greater => self.try_grow(capacity - self.size)?, - Ordering::Less => self.shrink(self.size - capacity), + pub fn try_resize(&self, capacity: usize) -> Result<()> { + let size = self.size.load(atomic::Ordering::Relaxed); + match capacity.cmp(&size) { + Ordering::Greater => self.try_grow(capacity - size)?, + Ordering::Less => self.shrink(size - capacity), _ => {} }; Ok(()) } /// Increase the size of this reservation by `capacity` bytes - pub fn grow(&mut self, capacity: usize) { + pub fn grow(&self, capacity: usize) { self.registration.pool.grow(self, capacity); - self.size += capacity; + self.size.fetch_add(capacity, atomic::Ordering::Relaxed); } /// Try to increase the size of this reservation by `capacity` /// bytes, returning error if there is insufficient capacity left /// in the pool. - pub fn try_grow(&mut self, capacity: usize) -> Result<()> { + pub fn try_grow(&self, capacity: usize) -> Result<()> { self.registration.pool.try_grow(self, capacity)?; - self.size += capacity; + self.size.fetch_add(capacity, atomic::Ordering::Relaxed); Ok(()) } @@ -447,10 +460,16 @@ impl MemoryReservation { /// # Panics /// /// Panics if `capacity` exceeds [`Self::size`] - pub fn split(&mut self, capacity: usize) -> MemoryReservation { - self.size = self.size.checked_sub(capacity).unwrap(); + pub fn split(&self, capacity: usize) -> MemoryReservation { + self.size + .fetch_update( + atomic::Ordering::Relaxed, + atomic::Ordering::Relaxed, + |prev| prev.checked_sub(capacity), + ) + .unwrap(); Self { - size: capacity, + size: atomic::AtomicUsize::new(capacity), registration: Arc::clone(&self.registration), } } @@ -458,7 +477,7 @@ impl MemoryReservation { /// Returns a new empty [`MemoryReservation`] with the same [`MemoryConsumer`] pub fn new_empty(&self) -> Self { Self { - size: 0, + size: atomic::AtomicUsize::new(0), registration: Arc::clone(&self.registration), } } @@ -466,7 +485,7 @@ impl MemoryReservation { /// Splits off all the bytes from this [`MemoryReservation`] into /// a new [`MemoryReservation`] with the same [`MemoryConsumer`] pub fn take(&mut self) -> MemoryReservation { - self.split(self.size) + self.split(self.size.load(atomic::Ordering::Relaxed)) } } @@ -492,7 +511,7 @@ mod tests { #[test] fn test_memory_pool_underflow() { let pool = Arc::new(GreedyMemoryPool::new(50)) as _; - let mut a1 = MemoryConsumer::new("a1").register(&pool); + let a1 = MemoryConsumer::new("a1").register(&pool); assert_eq!(pool.reserved(), 0); a1.grow(100); @@ -507,7 +526,7 @@ mod tests { a1.try_grow(30).unwrap(); assert_eq!(pool.reserved(), 30); - let mut a2 = MemoryConsumer::new("a2").register(&pool); + let a2 = MemoryConsumer::new("a2").register(&pool); a2.try_grow(25).unwrap_err(); assert_eq!(pool.reserved(), 30); @@ -521,7 +540,7 @@ mod tests { #[test] fn test_split() { let pool = Arc::new(GreedyMemoryPool::new(50)) as _; - let mut r1 = MemoryConsumer::new("r1").register(&pool); + let r1 = MemoryConsumer::new("r1").register(&pool); r1.try_grow(20).unwrap(); assert_eq!(r1.size(), 20); @@ -542,10 +561,10 @@ mod tests { #[test] fn test_new_empty() { let pool = Arc::new(GreedyMemoryPool::new(50)) as _; - let mut r1 = MemoryConsumer::new("r1").register(&pool); + let r1 = MemoryConsumer::new("r1").register(&pool); r1.try_grow(20).unwrap(); - let mut r2 = r1.new_empty(); + let r2 = r1.new_empty(); r2.try_grow(5).unwrap(); assert_eq!(r1.size(), 20); @@ -559,7 +578,7 @@ mod tests { let mut r1 = MemoryConsumer::new("r1").register(&pool); r1.try_grow(20).unwrap(); - let mut r2 = r1.take(); + let r2 = r1.take(); r2.try_grow(5).unwrap(); assert_eq!(r1.size(), 0); @@ -572,4 +591,37 @@ mod tests { assert_eq!(r2.size(), 25); assert_eq!(pool.reserved(), 28); } + + #[test] + fn test_try_shrink() { + let pool = Arc::new(GreedyMemoryPool::new(100)) as _; + let r1 = MemoryConsumer::new("r1").register(&pool); + + r1.try_grow(50).unwrap(); + assert_eq!(r1.size(), 50); + assert_eq!(pool.reserved(), 50); + + // Successful shrink returns new size and frees pool memory + let new_size = r1.try_shrink(30).unwrap(); + assert_eq!(new_size, 20); + assert_eq!(r1.size(), 20); + assert_eq!(pool.reserved(), 20); + + // Freed pool memory is now available to other consumers + let r2 = MemoryConsumer::new("r2").register(&pool); + r2.try_grow(80).unwrap(); + assert_eq!(pool.reserved(), 100); + + // Shrinking more than allocated fails without changing state + let err = r1.try_shrink(25); + assert!(err.is_err()); + assert_eq!(r1.size(), 20); + assert_eq!(pool.reserved(), 100); + + // Shrink to exactly zero + let new_size = r1.try_shrink(20).unwrap(); + assert_eq!(new_size, 0); + assert_eq!(r1.size(), 0); + assert_eq!(pool.reserved(), 80); + } } diff --git a/datafusion/execution/src/memory_pool/pool.rs b/datafusion/execution/src/memory_pool/pool.rs index bf74b5f6f4c6b..b10270851cc06 100644 --- a/datafusion/execution/src/memory_pool/pool.rs +++ b/datafusion/execution/src/memory_pool/pool.rs @@ -212,7 +212,7 @@ impl MemoryPool for FairSpillPool { .checked_div(state.num_spill) .unwrap_or(spill_available); - if reservation.size + additional > available { + if reservation.size() + additional > available { return Err(insufficient_capacity_err( reservation, additional, @@ -264,7 +264,7 @@ fn insufficient_capacity_err( "Failed to allocate additional {} for {} with {} already allocated for this reservation - {} remain available for the total pool", human_readable_size(additional), reservation.registration.consumer.name, - human_readable_size(reservation.size), + human_readable_size(reservation.size()), human_readable_size(available) ) } @@ -526,12 +526,12 @@ mod tests { fn test_fair() { let pool = Arc::new(FairSpillPool::new(100)) as _; - let mut r1 = MemoryConsumer::new("unspillable").register(&pool); + let r1 = MemoryConsumer::new("unspillable").register(&pool); // Can grow beyond capacity of pool r1.grow(2000); assert_eq!(pool.reserved(), 2000); - let mut r2 = MemoryConsumer::new("r2") + let r2 = MemoryConsumer::new("r2") .with_can_spill(true) .register(&pool); // Can grow beyond capacity of pool @@ -563,7 +563,7 @@ mod tests { assert_eq!(r2.size(), 10); assert_eq!(pool.reserved(), 30); - let mut r3 = MemoryConsumer::new("r3") + let r3 = MemoryConsumer::new("r3") .with_can_spill(true) .register(&pool); @@ -584,7 +584,7 @@ mod tests { r1.free(); assert_eq!(pool.reserved(), 80); - let mut r4 = MemoryConsumer::new("s4").register(&pool); + let r4 = MemoryConsumer::new("s4").register(&pool); let err = r4.try_grow(30).unwrap_err().strip_backtrace(); assert_snapshot!(err, @"Resources exhausted: Failed to allocate additional 30.0 B for s4 with 0.0 B already allocated for this reservation - 20.0 B remain available for the total pool"); } @@ -601,18 +601,18 @@ mod tests { // Test: use all the different interfaces to change reservation size // set r1=50, using grow and shrink - let mut r1 = MemoryConsumer::new("r1").register(&pool); + let r1 = MemoryConsumer::new("r1").register(&pool); r1.grow(50); r1.grow(20); r1.shrink(20); // set r2=15 using try_grow - let mut r2 = MemoryConsumer::new("r2").register(&pool); + let r2 = MemoryConsumer::new("r2").register(&pool); r2.try_grow(15) .expect("should succeed in memory allotment for r2"); // set r3=20 using try_resize - let mut r3 = MemoryConsumer::new("r3").register(&pool); + let r3 = MemoryConsumer::new("r3").register(&pool); r3.try_resize(25) .expect("should succeed in memory allotment for r3"); r3.try_resize(20) @@ -620,12 +620,12 @@ mod tests { // set r4=10 // this should not be reported in top 3 - let mut r4 = MemoryConsumer::new("r4").register(&pool); + let r4 = MemoryConsumer::new("r4").register(&pool); r4.grow(10); // Test: reports if new reservation causes error // using the previously set sizes for other consumers - let mut r5 = MemoryConsumer::new("r5").register(&pool); + let r5 = MemoryConsumer::new("r5").register(&pool); let res = r5.try_grow(150); assert!(res.is_err()); let error = res.unwrap_err().strip_backtrace(); @@ -650,7 +650,7 @@ mod tests { let same_name = "foo"; // Test: see error message when no consumers recorded yet - let mut r0 = MemoryConsumer::new(same_name).register(&pool); + let r0 = MemoryConsumer::new(same_name).register(&pool); let res = r0.try_grow(150); assert!(res.is_err()); let error = res.unwrap_err().strip_backtrace(); @@ -665,7 +665,7 @@ mod tests { r0.grow(10); // make r0=10, pool available=90 let new_consumer_same_name = MemoryConsumer::new(same_name); - let mut r1 = new_consumer_same_name.register(&pool); + let r1 = new_consumer_same_name.register(&pool); // TODO: the insufficient_capacity_err() message is per reservation, not per consumer. // a followup PR will clarify this message "0 bytes already allocated for this reservation" let res = r1.try_grow(150); @@ -695,7 +695,7 @@ mod tests { // will be recognized as different in the TrackConsumersPool let consumer_with_same_name_but_different_hash = MemoryConsumer::new(same_name).with_can_spill(true); - let mut r2 = consumer_with_same_name_but_different_hash.register(&pool); + let r2 = consumer_with_same_name_but_different_hash.register(&pool); let res = r2.try_grow(150); assert!(res.is_err()); let error = res.unwrap_err().strip_backtrace(); @@ -714,10 +714,10 @@ mod tests { // Baseline: see the 2 memory consumers let setting = make_settings(); let _bound = setting.bind_to_scope(); - let mut r0 = MemoryConsumer::new("r0").register(&pool); + let r0 = MemoryConsumer::new("r0").register(&pool); r0.grow(10); let r1_consumer = MemoryConsumer::new("r1"); - let mut r1 = r1_consumer.register(&pool); + let r1 = r1_consumer.register(&pool); r1.grow(20); let res = r0.try_grow(150); @@ -791,13 +791,13 @@ mod tests { .downcast::>() .unwrap(); // set r1=20 - let mut r1 = MemoryConsumer::new("r1").register(&pool); + let r1 = MemoryConsumer::new("r1").register(&pool); r1.grow(20); // set r2=15 - let mut r2 = MemoryConsumer::new("r2").register(&pool); + let r2 = MemoryConsumer::new("r2").register(&pool); r2.grow(15); // set r3=45 - let mut r3 = MemoryConsumer::new("r3").register(&pool); + let r3 = MemoryConsumer::new("r3").register(&pool); r3.grow(45); let downcasted = upcasted diff --git a/datafusion/execution/src/runtime_env.rs b/datafusion/execution/src/runtime_env.rs index 67398d59f1374..67604c424c766 100644 --- a/datafusion/execution/src/runtime_env.rs +++ b/datafusion/execution/src/runtime_env.rs @@ -19,7 +19,7 @@ //! store, memory manager, disk manager. #[expect(deprecated)] -use crate::disk_manager::DiskManagerConfig; +use crate::disk_manager::{DiskManagerConfig, SpillingProgress}; use crate::{ disk_manager::{DiskManager, DiskManagerBuilder, DiskManagerMode}, memory_pool::{ @@ -199,6 +199,11 @@ impl RuntimeEnv { self.object_store_registry.get_store(url.as_ref()) } + /// Returns the current spilling progress + pub fn spilling_progress(&self) -> SpillingProgress { + self.disk_manager.spilling_progress() + } + /// Register an [`EncryptionFactory`] with an associated identifier that can be later /// used to configure encryption when reading or writing Parquet. /// If an encryption factory with the same identifier was already registered, it is replaced and returned. diff --git a/datafusion/expr-common/src/columnar_value.rs b/datafusion/expr-common/src/columnar_value.rs index 99c21d4abdb6e..1aa42470a1481 100644 --- a/datafusion/expr-common/src/columnar_value.rs +++ b/datafusion/expr-common/src/columnar_value.rs @@ -20,7 +20,7 @@ use arrow::{ array::{Array, ArrayRef, Date32Array, Date64Array, NullArray}, compute::{CastOptions, kernels, max, min}, - datatypes::DataType, + datatypes::{DataType, Field}, util::pretty::pretty_format_columns, }; use datafusion_common::internal_datafusion_err; @@ -274,7 +274,17 @@ impl ColumnarValue { Ok(args) } - /// Cast's this [ColumnarValue] to the specified `DataType` + /// Cast this [ColumnarValue] to the specified `DataType` + /// + /// # Struct Casting Behavior + /// + /// When casting struct types, fields are matched **by name** rather than position: + /// - Source fields are matched to target fields using case-sensitive name comparison + /// - Fields are reordered to match the target schema + /// - Missing target fields are filled with null arrays + /// - Extra source fields are ignored + /// + /// For non-struct types, uses Arrow's standard positional casting. pub fn cast_to( &self, cast_type: &DataType, @@ -283,12 +293,8 @@ impl ColumnarValue { let cast_options = cast_options.cloned().unwrap_or(DEFAULT_CAST_OPTIONS); match self { ColumnarValue::Array(array) => { - ensure_date_array_timestamp_bounds(array, cast_type)?; - Ok(ColumnarValue::Array(kernels::cast::cast_with_options( - array, - cast_type, - &cast_options, - )?)) + let casted = cast_array_by_name(array, cast_type, &cast_options)?; + Ok(ColumnarValue::Array(casted)) } ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar( scalar.cast_to_with_options(cast_type, &cast_options)?, @@ -297,6 +303,37 @@ impl ColumnarValue { } } +fn cast_array_by_name( + array: &ArrayRef, + cast_type: &DataType, + cast_options: &CastOptions<'static>, +) -> Result { + // If types are already equal, no cast needed + if array.data_type() == cast_type { + return Ok(Arc::clone(array)); + } + + match cast_type { + DataType::Struct(_) => { + // Field name is unused; only the struct's inner field names matter + let target_field = Field::new("_", cast_type.clone(), true); + datafusion_common::nested_struct::cast_column( + array, + &target_field, + cast_options, + ) + } + _ => { + ensure_date_array_timestamp_bounds(array, cast_type)?; + Ok(kernels::cast::cast_with_options( + array, + cast_type, + cast_options, + )?) + } + } +} + fn ensure_date_array_timestamp_bounds( array: &ArrayRef, cast_type: &DataType, @@ -378,8 +415,8 @@ impl fmt::Display for ColumnarValue { mod tests { use super::*; use arrow::{ - array::{Date64Array, Int32Array}, - datatypes::TimeUnit, + array::{Date64Array, Int32Array, StructArray}, + datatypes::{Field, Fields, TimeUnit}, }; #[test] @@ -553,6 +590,102 @@ mod tests { ); } + #[test] + fn cast_struct_by_field_name() { + let source_fields = Fields::from(vec![ + Field::new("b", DataType::Int32, true), + Field::new("a", DataType::Int32, true), + ]); + + let target_fields = Fields::from(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ]); + + let struct_array = StructArray::new( + source_fields, + vec![ + Arc::new(Int32Array::from(vec![Some(3)])), + Arc::new(Int32Array::from(vec![Some(4)])), + ], + None, + ); + + let value = ColumnarValue::Array(Arc::new(struct_array)); + let casted = value + .cast_to(&DataType::Struct(target_fields.clone()), None) + .expect("struct cast should succeed"); + + let ColumnarValue::Array(arr) = casted else { + panic!("expected array after cast"); + }; + + let struct_array = arr + .as_any() + .downcast_ref::() + .expect("expected StructArray"); + + let field_a = struct_array + .column_by_name("a") + .expect("expected field a in cast result"); + let field_b = struct_array + .column_by_name("b") + .expect("expected field b in cast result"); + + assert_eq!( + field_a + .as_any() + .downcast_ref::() + .expect("expected Int32 array") + .value(0), + 4 + ); + assert_eq!( + field_b + .as_any() + .downcast_ref::() + .expect("expected Int32 array") + .value(0), + 3 + ); + } + + #[test] + fn cast_struct_missing_field_inserts_nulls() { + let source_fields = Fields::from(vec![Field::new("a", DataType::Int32, true)]); + + let target_fields = Fields::from(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ]); + + let struct_array = StructArray::new( + source_fields, + vec![Arc::new(Int32Array::from(vec![Some(5)]))], + None, + ); + + let value = ColumnarValue::Array(Arc::new(struct_array)); + let casted = value + .cast_to(&DataType::Struct(target_fields.clone()), None) + .expect("struct cast should succeed"); + + let ColumnarValue::Array(arr) = casted else { + panic!("expected array after cast"); + }; + + let struct_array = arr + .as_any() + .downcast_ref::() + .expect("expected StructArray"); + + let field_b = struct_array + .column_by_name("b") + .expect("expected missing field to be added"); + + assert!(field_b.is_null(0)); + } + #[test] fn cast_date64_array_to_timestamp_overflow() { let overflow_value = i64::MAX / 1_000_000 + 1; diff --git a/datafusion/expr-common/src/groups_accumulator.rs b/datafusion/expr-common/src/groups_accumulator.rs index 860e69245a7fd..08c9f01f13c40 100644 --- a/datafusion/expr-common/src/groups_accumulator.rs +++ b/datafusion/expr-common/src/groups_accumulator.rs @@ -89,6 +89,9 @@ impl EmitTo { /// optional and is harder to implement than `Accumulator`, but can be much /// faster for queries with many group values. See the [Aggregating Millions of /// Groups Fast blog] for more background. +/// For more background, please also see the [Aggregating Millions of Groups Fast in Apache Arrow DataFusion 28.0.0 blog] +/// +/// [Aggregating Millions of Groups Fast in Apache Arrow DataFusion 28.0.0 blog]: https://datafusion.apache.org/blog/2023/08/05/datafusion_fast_grouping /// /// [`NullState`] can help keep the state for groups that have not seen any /// values and produce the correct output for those groups. diff --git a/datafusion/expr-common/src/lib.rs b/datafusion/expr-common/src/lib.rs index 2be066beaad24..c9a95fd294503 100644 --- a/datafusion/expr-common/src/lib.rs +++ b/datafusion/expr-common/src/lib.rs @@ -32,7 +32,6 @@ // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -#![deny(clippy::allow_attributes)] pub mod accumulator; pub mod casts; @@ -41,7 +40,10 @@ pub mod dyn_eq; pub mod groups_accumulator; pub mod interval_arithmetic; pub mod operator; +pub mod placement; pub mod signature; pub mod sort_properties; pub mod statistics; pub mod type_coercion; + +pub use placement::ExpressionPlacement; diff --git a/datafusion/expr-common/src/placement.rs b/datafusion/expr-common/src/placement.rs new file mode 100644 index 0000000000000..8212ba618e322 --- /dev/null +++ b/datafusion/expr-common/src/placement.rs @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Expression placement information for optimization decisions. + +/// Describes where an expression should be placed in the query plan for +/// optimal execution. This is used by optimizers to make decisions about +/// expression placement, such as whether to push expressions down through +/// projections. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ExpressionPlacement { + /// A constant literal value. + Literal, + /// A simple column reference. + Column, + /// A cheap expression that can be pushed to leaf nodes in the plan. + /// Examples include `get_field` for struct field access. + /// Pushing these expressions down in the plan can reduce data early + /// at low compute cost. + /// See [`ExpressionPlacement::should_push_to_leaves`] for details. + MoveTowardsLeafNodes, + /// An expensive expression that should stay where it is in the plan. + /// Examples include complex scalar functions or UDFs. + KeepInPlace, +} + +impl ExpressionPlacement { + /// Returns true if the expression can be pushed down to leaf nodes + /// in the query plan. + /// + /// This returns true for: + /// - [`ExpressionPlacement::Column`]: Simple column references can be pushed down. They do no compute and do not increase or + /// decrease the amount of data being processed. + /// A projection that reduces the number of columns can eliminate unnecessary data early, + /// but this method only considers one expression at a time, not a projection as a whole. + /// - [`ExpressionPlacement::MoveTowardsLeafNodes`]: Cheap expressions can be pushed down to leaves to take advantage of + /// early computation and potential optimizations at the data source level. + /// For example `struct_col['field']` is cheap to compute (just an Arc clone of the nested array for `'field'`) + /// and thus can reduce data early in the plan at very low compute cost. + /// It may even be possible to eliminate the expression entirely if the data source can project only the needed field + /// (as e.g. Parquet can). + pub fn should_push_to_leaves(&self) -> bool { + matches!( + self, + ExpressionPlacement::Column | ExpressionPlacement::MoveTowardsLeafNodes + ) + } +} diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 54bb84f03d3d5..4c766b2cc50c9 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -1585,6 +1585,7 @@ mod tests { vec![DataType::UInt16, DataType::UInt16], vec![DataType::UInt32, DataType::UInt32], vec![DataType::UInt64, DataType::UInt64], + vec![DataType::Float16, DataType::Float16], vec![DataType::Float32, DataType::Float32], vec![DataType::Float64, DataType::Float64] ] diff --git a/datafusion/expr-common/src/type_coercion/aggregates.rs b/datafusion/expr-common/src/type_coercion/aggregates.rs index 01d093950d471..ab4d086e4ca5f 100644 --- a/datafusion/expr-common/src/type_coercion/aggregates.rs +++ b/datafusion/expr-common/src/type_coercion/aggregates.rs @@ -42,6 +42,7 @@ pub static NUMERICS: &[DataType] = &[ DataType::UInt16, DataType::UInt32, DataType::UInt64, + DataType::Float16, DataType::Float32, DataType::Float64, ]; diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index c9b39eacefc6a..4daa8a7a7f87d 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -17,6 +17,7 @@ //! Coercion rules for matching argument types for binary operators +use std::collections::HashMap; use std::collections::HashSet; use std::sync::Arc; @@ -350,16 +351,6 @@ impl<'a> BinaryTypeCoercer<'a> { // TODO Move the rest inside of BinaryTypeCoercer -fn is_decimal(data_type: &DataType) -> bool { - matches!( - data_type, - DataType::Decimal32(..) - | DataType::Decimal64(..) - | DataType::Decimal128(..) - | DataType::Decimal256(..) - ) -} - /// Returns true if both operands are Date types (Date32 or Date64) /// Used to detect Date - Date operations which should return Int64 (days difference) fn is_date_minus_date(lhs: &DataType, rhs: &DataType) -> bool { @@ -401,8 +392,8 @@ fn math_decimal_coercion( } // Cross-variant decimal coercion - choose larger variant with appropriate precision/scale (lhs, rhs) - if is_decimal(lhs) - && is_decimal(rhs) + if lhs.is_decimal() + && rhs.is_decimal() && std::mem::discriminant(lhs) != std::mem::discriminant(rhs) => { let coerced_type = get_wider_decimal_type_cross_variant(lhs_type, rhs_type)?; @@ -479,7 +470,9 @@ fn bitwise_coercion(left_type: &DataType, right_type: &DataType) -> Option Option { @@ -1026,8 +1019,8 @@ pub fn decimal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { @@ -1236,30 +1229,123 @@ fn coerce_numeric_type_to_decimal256(numeric_type: &DataType) -> Option Option { use arrow::datatypes::DataType::*; + match (lhs_type, rhs_type) { (Struct(lhs_fields), Struct(rhs_fields)) => { + // Field count must match for coercion if lhs_fields.len() != rhs_fields.len() { return None; } - let coerced_types = std::iter::zip(lhs_fields.iter(), rhs_fields.iter()) - .map(|(lhs, rhs)| comparison_coercion(lhs.data_type(), rhs.data_type())) - .collect::>>()?; - - // preserve the field name and nullability - let orig_fields = std::iter::zip(lhs_fields.iter(), rhs_fields.iter()); + // If the two structs have exactly the same set of field names (possibly in + // different order), prefer name-based coercion. Otherwise fall back to + // positional coercion which preserves backward compatibility. + // + // Name-based coercion is used in: + // 1. Array construction: [s1, s2] where s1 and s2 have reordered fields + // 2. UNION operations: different field orders unified by name + // 3. VALUES clauses: heterogeneous struct rows unified by field name + // 4. JOIN conditions: structs with matching field names + // 5. Window functions: partitions/orders by struct fields + // 6. Aggregate functions: collecting structs with reordered fields + // + // See docs/source/user-guide/sql/struct_coercion.md for detailed examples. + if fields_have_same_names(lhs_fields, rhs_fields) { + return coerce_struct_by_name(lhs_fields, rhs_fields); + } - let fields: Vec = coerced_types - .into_iter() - .zip(orig_fields) - .map(|(datatype, (lhs, rhs))| coerce_fields(datatype, lhs, rhs)) - .collect(); - Some(Struct(fields.into())) + coerce_struct_by_position(lhs_fields, rhs_fields) } _ => None, } } +/// Return true if every left-field name exists in the right fields (and lengths are equal). +/// +/// # Assumptions +/// **This function assumes field names within each struct are unique.** This assumption is safe +/// because field name uniqueness is enforced at multiple levels: +/// - **Arrow level:** `StructType` construction enforces unique field names at the schema level +/// - **DataFusion level:** SQL parser rejects duplicate field names in `CREATE TABLE` and struct type definitions +/// - **Runtime level:** `StructArray::try_new()` validates field uniqueness +/// +/// Therefore, we don't need to handle degenerate cases like: +/// - `struct -> struct` (target has duplicate field names) +/// - `struct -> struct` (source has duplicate field names) +fn fields_have_same_names(lhs_fields: &Fields, rhs_fields: &Fields) -> bool { + // Debug assertions: field names should be unique within each struct + #[cfg(debug_assertions)] + { + let lhs_names: HashSet<_> = lhs_fields.iter().map(|f| f.name()).collect(); + assert_eq!( + lhs_names.len(), + lhs_fields.len(), + "Struct has duplicate field names (should be caught by Arrow schema validation)" + ); + + let rhs_names_check: HashSet<_> = rhs_fields.iter().map(|f| f.name()).collect(); + assert_eq!( + rhs_names_check.len(), + rhs_fields.len(), + "Struct has duplicate field names (should be caught by Arrow schema validation)" + ); + } + + let rhs_names: HashSet<&str> = rhs_fields.iter().map(|f| f.name().as_str()).collect(); + lhs_fields + .iter() + .all(|lf| rhs_names.contains(lf.name().as_str())) +} + +/// Coerce two structs by matching fields by name. Assumes the name-sets match. +fn coerce_struct_by_name(lhs_fields: &Fields, rhs_fields: &Fields) -> Option { + use arrow::datatypes::DataType::*; + + let rhs_by_name: HashMap<&str, &FieldRef> = + rhs_fields.iter().map(|f| (f.name().as_str(), f)).collect(); + + let mut coerced: Vec = Vec::with_capacity(lhs_fields.len()); + + for lhs in lhs_fields.iter() { + let rhs = rhs_by_name.get(lhs.name().as_str()).unwrap(); // safe: caller ensured names match + let coerced_type = comparison_coercion(lhs.data_type(), rhs.data_type())?; + let is_nullable = lhs.is_nullable() || rhs.is_nullable(); + coerced.push(Arc::new(Field::new( + lhs.name().clone(), + coerced_type, + is_nullable, + ))); + } + + Some(Struct(coerced.into())) +} + +/// Coerce two structs positionally (left-to-right). This preserves field names from +/// the left struct and uses the combined nullability. +fn coerce_struct_by_position( + lhs_fields: &Fields, + rhs_fields: &Fields, +) -> Option { + use arrow::datatypes::DataType::*; + + // First coerce individual types; fail early if any pair cannot be coerced. + let coerced_types: Vec = lhs_fields + .iter() + .zip(rhs_fields.iter()) + .map(|(l, r)| comparison_coercion(l.data_type(), r.data_type())) + .collect::>>()?; + + // Build final fields preserving left-side names and combined nullability. + let orig_pairs = lhs_fields.iter().zip(rhs_fields.iter()); + let fields: Vec = coerced_types + .into_iter() + .zip(orig_pairs) + .map(|(datatype, (lhs, rhs))| coerce_fields(datatype, lhs, rhs)) + .collect(); + + Some(Struct(fields.into())) +} + /// returns the result of coercing two fields to a common type fn coerce_fields(common_type: DataType, lhs: &FieldRef, rhs: &FieldRef) -> FieldRef { let is_nullable = lhs.is_nullable() || rhs.is_nullable(); @@ -1709,9 +1795,10 @@ fn binary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option /// Coercion rules for like operations. /// This is a union of string coercion rules, dictionary coercion rules, and REE coercion rules +/// Note: list_coercion is intentionally NOT included here because LIKE is a string pattern +/// matching operation and is not supported for nested types (List, Struct, etc.) pub fn like_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { string_coercion(lhs_type, rhs_type) - .or_else(|| list_coercion(lhs_type, rhs_type)) .or_else(|| binary_to_string_coercion(lhs_type, rhs_type)) .or_else(|| dictionary_comparison_coercion(lhs_type, rhs_type, false)) .or_else(|| ree_comparison_coercion(lhs_type, rhs_type, false)) diff --git a/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs b/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs index bb9d44953b9f9..eb5622fedb8aa 100644 --- a/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs +++ b/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs @@ -228,6 +228,53 @@ fn test_type_coercion_arithmetic() -> Result<()> { Ok(()) } +#[test] +fn test_bitwise_coercion_non_integer_types() -> Result<()> { + let err = BinaryTypeCoercer::new( + &DataType::Float32, + &Operator::BitwiseAnd, + &DataType::Float32, + ) + .get_input_types() + .unwrap_err() + .to_string(); + assert_contains!( + &err, + "Cannot infer common type for bitwise operation Float32 & Float32" + ); + + let err = BinaryTypeCoercer::new( + &DataType::Float32, + &Operator::BitwiseAnd, + &DataType::Float64, + ) + .get_input_types() + .unwrap_err() + .to_string(); + assert_contains!( + &err, + "Cannot infer common type for bitwise operation Float32 & Float64" + ); + + let err = BinaryTypeCoercer::new( + &DataType::Decimal128(10, 2), + &Operator::BitwiseAnd, + &DataType::Decimal128(10, 2), + ) + .get_input_types() + .unwrap_err() + .to_string(); + assert_contains!( + &err, + "Cannot infer common type for bitwise operation Decimal128(10, 2) & Decimal128(10, 2)" + ); + + let dict_int8 = DataType::Dictionary(DataType::Int8.into(), DataType::Int8.into()); + test_coercion_binary_rule!(dict_int8, dict_int8, Operator::BitwiseAnd, dict_int8); + + Ok(()) +} + fn test_math_decimal_coercion_rule( lhs_type: DataType, rhs_type: DataType, diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index c7d825ce1d52f..87e8e029a6ee5 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -38,11 +38,12 @@ use datafusion_common::tree_node::{ use datafusion_common::{ Column, DFSchema, HashMap, Result, ScalarValue, Spans, TableReference, }; +use datafusion_expr_common::placement::ExpressionPlacement; use datafusion_functions_window_common::field::WindowUDFFieldArgs; #[cfg(feature = "sql")] use sqlparser::ast::{ ExceptSelectItem, ExcludeSelectItem, IlikeSelectItem, RenameSelectItem, - ReplaceSelectElement, display_comma_separated, + ReplaceSelectElement, }; // Moved in 51.0.0 to datafusion_common @@ -309,6 +310,7 @@ impl From for NullTreatment { /// assert!(rewritten.transformed); /// // to 42 = 5 AND b = 6 /// assert_eq!(rewritten.data, lit(42).eq(lit(5)).and(col("b").eq(lit(6)))); +/// ``` #[derive(Clone, PartialEq, PartialOrd, Eq, Debug, Hash)] pub enum Expr { /// An expression with a specific name. @@ -372,6 +374,8 @@ pub enum Expr { Exists(Exists), /// IN subquery InSubquery(InSubquery), + /// Set comparison subquery (e.g. `= ANY`, `> ALL`) + SetComparison(SetComparison), /// Scalar subquery ScalarSubquery(Subquery), /// Represents a reference to all available fields in a specific schema, @@ -953,7 +957,7 @@ impl AggregateFunction { pub enum WindowFunctionDefinition { /// A user defined aggregate function AggregateUDF(Arc), - /// A user defined aggregate function + /// A user defined window function WindowUDF(Arc), } @@ -1101,6 +1105,54 @@ impl Exists { } } +/// Whether the set comparison uses `ANY`/`SOME` or `ALL` +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub enum SetQuantifier { + /// `ANY` (or `SOME`) + Any, + /// `ALL` + All, +} + +impl Display for SetQuantifier { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + SetQuantifier::Any => write!(f, "ANY"), + SetQuantifier::All => write!(f, "ALL"), + } + } +} + +/// Set comparison subquery (e.g. `= ANY`, `> ALL`) +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub struct SetComparison { + /// The expression to compare + pub expr: Box, + /// Subquery that will produce a single column of data to compare against + pub subquery: Subquery, + /// Comparison operator (e.g. `=`, `>`, `<`) + pub op: Operator, + /// Quantifier (`ANY`/`ALL`) + pub quantifier: SetQuantifier, +} + +impl SetComparison { + /// Create a new set comparison expression + pub fn new( + expr: Box, + subquery: Subquery, + op: Operator, + quantifier: SetQuantifier, + ) -> Self { + Self { + expr, + subquery, + op, + quantifier, + } + } +} + /// InList expression #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct InList { @@ -1268,7 +1320,6 @@ impl Display for ExceptSelectItem { } } -#[cfg(not(feature = "sql"))] pub fn display_comma_separated(slice: &[T]) -> String where T: Display, @@ -1487,6 +1538,24 @@ impl Expr { } } + /// Returns placement information for this expression. + /// + /// This is used by optimizers to make decisions about expression placement, + /// such as whether to push expressions down through projections. + pub fn placement(&self) -> ExpressionPlacement { + match self { + Expr::Column(_) => ExpressionPlacement::Column, + Expr::Literal(_, _) => ExpressionPlacement::Literal, + Expr::Alias(inner) => inner.expr.placement(), + Expr::ScalarFunction(func) => { + let arg_placements: Vec<_> = + func.args.iter().map(|arg| arg.placement()).collect(); + func.func.placement(&arg_placements) + } + _ => ExpressionPlacement::KeepInPlace, + } + } + /// Return String representation of the variant represented by `self` /// Useful for non-rust based bindings pub fn variant_name(&self) -> &str { @@ -1503,6 +1572,7 @@ impl Expr { Expr::GroupingSet(..) => "GroupingSet", Expr::InList { .. } => "InList", Expr::InSubquery(..) => "InSubquery", + Expr::SetComparison(..) => "SetComparison", Expr::IsNotNull(..) => "IsNotNull", Expr::IsNull(..) => "IsNull", Expr::Like { .. } => "Like", @@ -2058,6 +2128,7 @@ impl Expr { | Expr::GroupingSet(..) | Expr::InList(..) | Expr::InSubquery(..) + | Expr::SetComparison(..) | Expr::IsFalse(..) | Expr::IsNotFalse(..) | Expr::IsNotNull(..) @@ -2651,6 +2722,16 @@ impl HashNode for Expr { subquery.hash(state); negated.hash(state); } + Expr::SetComparison(SetComparison { + expr: _, + subquery, + op, + quantifier, + }) => { + subquery.hash(state); + op.hash(state); + quantifier.hash(state); + } Expr::ScalarSubquery(subquery) => { subquery.hash(state); } @@ -2841,6 +2922,12 @@ impl Display for SchemaDisplay<'_> { write!(f, "NOT IN") } Expr::InSubquery(InSubquery { negated: false, .. }) => write!(f, "IN"), + Expr::SetComparison(SetComparison { + expr, + op, + quantifier, + .. + }) => write!(f, "{} {op} {quantifier}", SchemaDisplay(expr.as_ref())), Expr::IsTrue(expr) => write!(f, "{} IS TRUE", SchemaDisplay(expr)), Expr::IsFalse(expr) => write!(f, "{} IS FALSE", SchemaDisplay(expr)), Expr::IsNotTrue(expr) => { @@ -3316,6 +3403,12 @@ impl Display for Expr { subquery, negated: false, }) => write!(f, "{expr} IN ({subquery:?})"), + Expr::SetComparison(SetComparison { + expr, + subquery, + op, + quantifier, + }) => write!(f, "{expr} {op} {quantifier} ({subquery:?})"), Expr::ScalarSubquery(subquery) => write!(f, "({subquery:?})"), Expr::BinaryExpr(expr) => write!(f, "{expr}"), Expr::ScalarFunction(fun) => { @@ -3799,6 +3892,7 @@ mod test { } use super::*; + use crate::logical_plan::{EmptyRelation, LogicalPlan}; #[test] fn test_display_wildcard() { @@ -3889,6 +3983,28 @@ mod test { ) } + #[test] + fn test_display_set_comparison() { + let subquery = Subquery { + subquery: Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::new(DFSchema::empty()), + })), + outer_ref_columns: vec![], + spans: Spans::new(), + }; + + let expr = Expr::SetComparison(SetComparison::new( + Box::new(Expr::Column(Column::from_name("a"))), + subquery, + Operator::Gt, + SetQuantifier::Any, + )); + + assert_eq!(format!("{expr}"), "a > ANY ()"); + assert_eq!(format!("{}", expr.human_display()), "a > ANY ()"); + } + #[test] fn test_schema_display_alias_with_relation() { assert_eq!( diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index a0faca76e91e4..32a88ab8cf310 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -261,9 +261,16 @@ fn coerce_exprs_for_schema( #[expect(deprecated)] Expr::Wildcard { .. } => Ok(expr), _ => { - // maintain the original name when casting - let name = dst_schema.field(idx).name(); - Ok(expr.cast_to(new_type, src_schema)?.alias(name)) + match expr { + // maintain the original name when casting a column, to avoid the + // tablename being added to it when not explicitly set by the query + // (see: https://github.com/apache/datafusion/issues/18818) + Expr::Column(ref column) => { + let name = column.name().to_owned(); + Ok(expr.cast_to(new_type, src_schema)?.alias(name)) + } + _ => Ok(expr.cast_to(new_type, src_schema)?), + } } } } else { diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 854e907d68b1a..f4e4f014f533c 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -21,7 +21,7 @@ use crate::expr::{ InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, WindowFunctionParams, }; -use crate::type_coercion::functions::fields_with_udf; +use crate::type_coercion::functions::{UDFCoercionExt, fields_with_udf}; use crate::udf::ReturnFieldArgs; use crate::{LogicalPlan, Projection, Subquery, WindowFunctionDefinition, utils}; use arrow::compute::can_cast_types; @@ -152,49 +152,16 @@ impl ExprSchemable for Expr { } } } - Expr::ScalarFunction(_func) => { - let return_type = self.to_field(schema)?.1.data_type().clone(); - Ok(return_type) - } - Expr::WindowFunction(window_function) => Ok(self - .window_function_field(schema, window_function)? - .data_type() - .clone()), - Expr::AggregateFunction(AggregateFunction { - func, - params: AggregateFunctionParams { args, .. }, - }) => { - let fields = args - .iter() - .map(|e| e.to_field(schema).map(|(_, f)| f)) - .collect::>>()?; - let new_fields = fields_with_udf(&fields, func.as_ref()) - .map_err(|err| { - let data_types = fields - .iter() - .map(|f| f.data_type().clone()) - .collect::>(); - plan_datafusion_err!( - "{} {}", - match err { - DataFusionError::Plan(msg) => msg, - err => err.to_string(), - }, - utils::generate_signature_error_msg( - func.name(), - func.signature().clone(), - &data_types - ) - ) - })? - .into_iter() - .collect::>(); - Ok(func.return_field(&new_fields)?.data_type().clone()) + Expr::ScalarFunction(_) + | Expr::WindowFunction(_) + | Expr::AggregateFunction(_) => { + Ok(self.to_field(schema)?.1.data_type().clone()) } Expr::Not(_) | Expr::IsNull(_) | Expr::Exists { .. } | Expr::InSubquery(_) + | Expr::SetComparison(_) | Expr::Between { .. } | Expr::InList { .. } | Expr::IsNotNull(_) @@ -349,18 +316,9 @@ impl ExprSchemable for Expr { } } Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema), - Expr::ScalarFunction(_func) => { - let field = self.to_field(input_schema)?.1; - - let nullable = field.is_nullable(); - Ok(nullable) - } - Expr::AggregateFunction(AggregateFunction { func, .. }) => { - Ok(func.is_nullable()) - } - Expr::WindowFunction(window_function) => Ok(self - .window_function_field(input_schema, window_function)? - .is_nullable()), + Expr::ScalarFunction(_) + | Expr::AggregateFunction(_) + | Expr::WindowFunction(_) => Ok(self.to_field(input_schema)?.1.is_nullable()), Expr::ScalarVariable(field, _) => Ok(field.is_nullable()), Expr::TryCast { .. } | Expr::Unnest(_) | Expr::Placeholder(_) => Ok(true), Expr::IsNull(_) @@ -372,6 +330,7 @@ impl ExprSchemable for Expr { | Expr::IsNotFalse(_) | Expr::IsNotUnknown(_) | Expr::Exists { .. } => Ok(false), + Expr::SetComparison(_) => Ok(true), Expr::InSubquery(InSubquery { expr, .. }) => expr.nullable(input_schema), Expr::ScalarSubquery(subquery) => { Ok(subquery.subquery.schema().field(0).is_nullable()) @@ -532,69 +491,49 @@ impl ExprSchemable for Expr { ))) } Expr::WindowFunction(window_function) => { - self.window_function_field(schema, window_function) - } - Expr::AggregateFunction(aggregate_function) => { - let AggregateFunction { - func, - params: AggregateFunctionParams { args, .. }, + let WindowFunction { + fun, + params: WindowFunctionParams { args, .. }, .. - } = aggregate_function; + } = window_function.as_ref(); let fields = args .iter() .map(|e| e.to_field(schema).map(|(_, f)| f)) .collect::>>()?; - // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` - let new_fields = fields_with_udf(&fields, func.as_ref()) - .map_err(|err| { - let arg_types = fields - .iter() - .map(|f| f.data_type()) - .cloned() - .collect::>(); - plan_datafusion_err!( - "{} {}", - match err { - DataFusionError::Plan(msg) => msg, - err => err.to_string(), - }, - utils::generate_signature_error_msg( - func.name(), - func.signature().clone(), - &arg_types, - ) - ) - })? - .into_iter() - .collect::>(); - + match fun { + WindowFunctionDefinition::AggregateUDF(udaf) => { + let new_fields = + verify_function_arguments(udaf.as_ref(), &fields)?; + let return_field = udaf.return_field(&new_fields)?; + Ok(return_field) + } + WindowFunctionDefinition::WindowUDF(udwf) => { + let new_fields = + verify_function_arguments(udwf.as_ref(), &fields)?; + let return_field = udwf + .field(WindowUDFFieldArgs::new(&new_fields, &schema_name))?; + Ok(return_field) + } + } + } + Expr::AggregateFunction(AggregateFunction { + func, + params: AggregateFunctionParams { args, .. }, + }) => { + let fields = args + .iter() + .map(|e| e.to_field(schema).map(|(_, f)| f)) + .collect::>>()?; + let new_fields = verify_function_arguments(func.as_ref(), &fields)?; func.return_field(&new_fields) } Expr::ScalarFunction(ScalarFunction { func, args }) => { - let (arg_types, fields): (Vec, Vec>) = args + let fields = args .iter() .map(|e| e.to_field(schema).map(|(_, f)| f)) - .collect::>>()? - .into_iter() - .map(|f| (f.data_type().clone(), f)) - .unzip(); - // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` - let new_fields = - fields_with_udf(&fields, func.as_ref()).map_err(|err| { - plan_datafusion_err!( - "{} {}", - match err { - DataFusionError::Plan(msg) => msg, - err => err.to_string(), - }, - utils::generate_signature_error_msg( - func.name(), - func.signature().clone(), - &arg_types, - ) - ) - })?; + .collect::>>()?; + let new_fields = verify_function_arguments(func.as_ref(), &fields)?; let arguments = args .iter() @@ -626,6 +565,7 @@ impl ExprSchemable for Expr { | Expr::TryCast(_) | Expr::InList(_) | Expr::InSubquery(_) + | Expr::SetComparison(_) | Expr::Wildcard { .. } | Expr::GroupingSet(_) | Expr::Placeholder(_) @@ -659,7 +599,16 @@ impl ExprSchemable for Expr { // like all of the binary expressions below. Perhaps Expr should track the // type of the expression? - if can_cast_types(&this_type, cast_to_type) { + // Special handling for struct-to-struct casts with name-based field matching + let can_cast = match (&this_type, cast_to_type) { + (DataType::Struct(_), DataType::Struct(_)) => { + // Always allow struct-to-struct casts; field matching happens at runtime + true + } + _ => can_cast_types(&this_type, cast_to_type), + }; + + if can_cast { match self { Expr::ScalarSubquery(subquery) => { Ok(Expr::ScalarSubquery(cast_subquery(subquery, cast_to_type)?)) @@ -672,6 +621,33 @@ impl ExprSchemable for Expr { } } +/// Verify that function is invoked with correct number and type of arguments as +/// defined in `TypeSignature`. +fn verify_function_arguments( + function: &F, + input_fields: &[FieldRef], +) -> Result> { + fields_with_udf(input_fields, function).map_err(|err| { + let data_types = input_fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); + plan_datafusion_err!( + "{} {}", + match err { + DataFusionError::Plan(msg) => msg, + err => err.to_string(), + }, + utils::generate_signature_error_message( + function.name(), + function.signature(), + &data_types + ) + ) + }) +} + /// Returns the innermost [Expr] that is provably null if `expr` is null. fn unwrap_certainly_null_expr(expr: &Expr) -> &Expr { match expr { @@ -682,90 +658,6 @@ fn unwrap_certainly_null_expr(expr: &Expr) -> &Expr { } } -impl Expr { - /// Common method for window functions that applies type coercion - /// to all arguments of the window function to check if it matches - /// its signature. - /// - /// If successful, this method returns the data type and - /// nullability of the window function's result. - /// - /// Otherwise, returns an error if there's a type mismatch between - /// the window function's signature and the provided arguments. - fn window_function_field( - &self, - schema: &dyn ExprSchema, - window_function: &WindowFunction, - ) -> Result { - let WindowFunction { - fun, - params: WindowFunctionParams { args, .. }, - .. - } = window_function; - - let fields = args - .iter() - .map(|e| e.to_field(schema).map(|(_, f)| f)) - .collect::>>()?; - match fun { - WindowFunctionDefinition::AggregateUDF(udaf) => { - let data_types = fields - .iter() - .map(|f| f.data_type()) - .cloned() - .collect::>(); - let new_fields = fields_with_udf(&fields, udaf.as_ref()) - .map_err(|err| { - plan_datafusion_err!( - "{} {}", - match err { - DataFusionError::Plan(msg) => msg, - err => err.to_string(), - }, - utils::generate_signature_error_msg( - fun.name(), - fun.signature(), - &data_types - ) - ) - })? - .into_iter() - .collect::>(); - - udaf.return_field(&new_fields) - } - WindowFunctionDefinition::WindowUDF(udwf) => { - let data_types = fields - .iter() - .map(|f| f.data_type()) - .cloned() - .collect::>(); - let new_fields = fields_with_udf(&fields, udwf.as_ref()) - .map_err(|err| { - plan_datafusion_err!( - "{} {}", - match err { - DataFusionError::Plan(msg) => msg, - err => err.to_string(), - }, - utils::generate_signature_error_msg( - fun.name(), - fun.signature(), - &data_types - ) - ) - })? - .into_iter() - .collect::>(); - let (_, function_name) = self.qualified_name(); - let field_args = WindowUDFFieldArgs::new(&new_fields, &function_name); - - udwf.field(field_args) - } - } - } -} - /// Cast subquery in InSubquery/ScalarSubquery to a given type. /// /// 1. **Projection plan**: If the subquery is a projection (i.e. a SELECT statement with specific diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 4fb78933d7a5c..cb136229bf88d 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -24,7 +24,6 @@ // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -#![deny(clippy::allow_attributes)] //! [DataFusion](https://github.com/apache/datafusion) //! is an extensible query execution framework that uses @@ -77,6 +76,7 @@ pub mod statistics { pub use datafusion_expr_common::statistics::*; } mod predicate_bounds; +pub mod preimage; pub mod ptr_eq; pub mod test; pub mod tree_node; @@ -95,6 +95,7 @@ pub use datafusion_expr_common::accumulator::Accumulator; pub use datafusion_expr_common::columnar_value::ColumnarValue; pub use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; pub use datafusion_expr_common::operator::Operator; +pub use datafusion_expr_common::placement::ExpressionPlacement; pub use datafusion_expr_common::signature::{ ArrayFunctionArgument, ArrayFunctionSignature, Coercion, Signature, TIMEZONE_WILDCARD, TypeSignature, TypeSignatureClass, Volatility, diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 6f654428e41a1..2e23fef1da768 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1011,6 +1011,25 @@ impl LogicalPlanBuilder { join_keys: (Vec>, Vec>), filter: Option, null_equality: NullEquality, + ) -> Result { + self.join_detailed_with_options( + right, + join_type, + join_keys, + filter, + null_equality, + false, + ) + } + + pub fn join_detailed_with_options( + self, + right: LogicalPlan, + join_type: JoinType, + join_keys: (Vec>, Vec>), + filter: Option, + null_equality: NullEquality, + null_aware: bool, ) -> Result { if join_keys.0.len() != join_keys.1.len() { return plan_err!("left_keys and right_keys were not the same length"); @@ -1128,6 +1147,7 @@ impl LogicalPlanBuilder { join_constraint: JoinConstraint::On, schema: DFSchemaRef::new(join_schema), null_equality, + null_aware, }))) } @@ -1201,6 +1221,7 @@ impl LogicalPlanBuilder { join_type, JoinConstraint::Using, NullEquality::NullEqualsNothing, + false, // null_aware )?; Ok(Self::new(LogicalPlan::Join(join))) @@ -1217,6 +1238,7 @@ impl LogicalPlanBuilder { JoinType::Inner, JoinConstraint::On, NullEquality::NullEqualsNothing, + false, // null_aware )?; Ok(Self::new(LogicalPlan::Join(join))) @@ -1471,6 +1493,7 @@ impl LogicalPlanBuilder { join_type, JoinConstraint::On, NullEquality::NullEqualsNothing, + false, // null_aware )?; Ok(Self::new(LogicalPlan::Join(join))) @@ -2756,12 +2779,12 @@ mod tests { assert_snapshot!(plan, @r" Union - Cross Join: + Cross Join: SubqueryAlias: left Values: (Int32(1)) SubqueryAlias: right Values: (Int32(1)) - Cross Join: + Cross Join: SubqueryAlias: left Values: (Int32(1)) SubqueryAlias: right diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index 480974b055d11..58c7feb616179 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -117,13 +117,7 @@ pub fn display_schema(schema: &Schema) -> impl fmt::Display + '_ { write!(f, ", ")?; } let nullable_str = if field.is_nullable() { ";N" } else { "" }; - write!( - f, - "{}:{:?}{}", - field.name(), - field.data_type(), - nullable_str - )?; + write!(f, "{}:{}{}", field.name(), field.data_type(), nullable_str)?; } write!(f, "]") } diff --git a/datafusion/expr/src/logical_plan/dml.rs b/datafusion/expr/src/logical_plan/dml.rs index 6ac3b309aa0c7..b668cbfe2cc35 100644 --- a/datafusion/expr/src/logical_plan/dml.rs +++ b/datafusion/expr/src/logical_plan/dml.rs @@ -237,6 +237,8 @@ pub enum WriteOp { Update, /// `CREATE TABLE AS SELECT` operation Ctas, + /// `TRUNCATE` operation + Truncate, } impl WriteOp { @@ -247,6 +249,7 @@ impl WriteOp { WriteOp::Delete => "Delete", WriteOp::Update => "Update", WriteOp::Ctas => "Ctas", + WriteOp::Truncate => "Truncate", } } } diff --git a/datafusion/expr/src/logical_plan/invariants.rs b/datafusion/expr/src/logical_plan/invariants.rs index 762491a255cbc..0889afd08fee4 100644 --- a/datafusion/expr/src/logical_plan/invariants.rs +++ b/datafusion/expr/src/logical_plan/invariants.rs @@ -22,7 +22,7 @@ use datafusion_common::{ use crate::{ Aggregate, Expr, Filter, Join, JoinType, LogicalPlan, Window, - expr::{Exists, InSubquery}, + expr::{Exists, InSubquery, SetComparison}, expr_rewriter::strip_outer_reference, utils::{collect_subquery_cols, split_conjunction}, }; @@ -81,6 +81,7 @@ fn assert_valid_extension_nodes(plan: &LogicalPlan, check: InvariantLevel) -> Re match expr { Expr::Exists(Exists { subquery, .. }) | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::SetComparison(SetComparison { subquery, .. }) | Expr::ScalarSubquery(subquery) => { assert_valid_extension_nodes(&subquery.subquery, check)?; } @@ -133,6 +134,7 @@ fn assert_subqueries_are_valid(plan: &LogicalPlan) -> Result<()> { match expr { Expr::Exists(Exists { subquery, .. }) | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::SetComparison(SetComparison { subquery, .. }) | Expr::ScalarSubquery(subquery) => { check_subquery_expr(plan, &subquery.subquery, expr)?; } @@ -206,14 +208,16 @@ pub fn check_subquery_expr( if group_expr.contains(expr) && !aggr_expr.contains(expr) { // TODO revisit this validation logic plan_err!( - "Correlated scalar subquery in the GROUP BY clause must also be in the aggregate expressions" + "Correlated scalar subquery in the GROUP BY clause must \ + also be in the aggregate expressions" ) } else { Ok(()) } } _ => plan_err!( - "Correlated scalar subquery can only be used in Projection, Filter, Aggregate plan nodes" + "Correlated scalar subquery can only be used in Projection, \ + Filter, Aggregate plan nodes" ), }?; } @@ -229,6 +233,20 @@ pub fn check_subquery_expr( ); } } + if let Expr::SetComparison(set_comparison) = expr + && set_comparison.subquery.subquery.schema().fields().len() > 1 + { + return plan_err!( + "Set comparison subquery should only return one column, but found {}: {}", + set_comparison.subquery.subquery.schema().fields().len(), + set_comparison + .subquery + .subquery + .schema() + .field_names() + .join(", ") + ); + } match outer_plan { LogicalPlan::Projection(_) | LogicalPlan::Filter(_) @@ -237,7 +255,7 @@ pub fn check_subquery_expr( | LogicalPlan::Aggregate(_) | LogicalPlan::Join(_) => Ok(()), _ => plan_err!( - "In/Exist subquery can only be used in \ + "In/Exist/SetComparison subquery can only be used in \ Projection, Filter, TableScan, Window functions, Aggregate and Join plan nodes, \ but was used in [{}]", outer_plan.display() diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 4219c24bfc9c9..032a97bdb3efa 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -661,6 +661,7 @@ impl LogicalPlan { on, schema: _, null_equality, + null_aware, }) => { let schema = build_join_schema(left.schema(), right.schema(), &join_type)?; @@ -682,6 +683,7 @@ impl LogicalPlan { filter, schema: DFSchemaRef::new(schema), null_equality, + null_aware, })) } LogicalPlan::Subquery(_) => Ok(self), @@ -901,6 +903,7 @@ impl LogicalPlan { join_constraint, on, null_equality, + null_aware, .. }) => { let (left, right) = self.only_two_inputs(inputs)?; @@ -942,6 +945,7 @@ impl LogicalPlan { filter: filter_expr, schema: DFSchemaRef::new(schema), null_equality: *null_equality, + null_aware: *null_aware, })) } LogicalPlan::Subquery(Subquery { @@ -1969,13 +1973,16 @@ impl LogicalPlan { }; match join_constraint { JoinConstraint::On => { - write!( - f, - "{} Join: {}{}", - join_type, - join_expr.join(", "), - filter_expr - ) + write!(f, "{join_type} Join:",)?; + if !join_expr.is_empty() || !filter_expr.is_empty() { + write!( + f, + " {}{}", + join_expr.join(", "), + filter_expr + )?; + } + Ok(()) } JoinConstraint::Using => { write!( @@ -3781,6 +3788,14 @@ pub struct Join { pub schema: DFSchemaRef, /// Defines the null equality for the join. pub null_equality: NullEquality, + /// Whether this is a null-aware anti join (for NOT IN semantics). + /// + /// Only applies to LeftAnti joins. When true, implements SQL NOT IN semantics where: + /// - If the right side (subquery) contains any NULL in join keys, no rows are output + /// - Left side rows with NULL in join keys are not output + /// + /// This is required for correct NOT IN subquery behavior with three-valued logic. + pub null_aware: bool, } impl Join { @@ -3798,10 +3813,12 @@ impl Join { /// * `join_type` - Type of join (Inner, Left, Right, etc.) /// * `join_constraint` - Join constraint (On, Using) /// * `null_equality` - How to handle nulls in join comparisons + /// * `null_aware` - Whether this is a null-aware anti join (for NOT IN semantics) /// /// # Returns /// /// A new Join operator with the computed schema + #[expect(clippy::too_many_arguments)] pub fn try_new( left: Arc, right: Arc, @@ -3810,6 +3827,7 @@ impl Join { join_type: JoinType, join_constraint: JoinConstraint, null_equality: NullEquality, + null_aware: bool, ) -> Result { let join_schema = build_join_schema(left.schema(), right.schema(), &join_type)?; @@ -3822,6 +3840,7 @@ impl Join { join_constraint, schema: Arc::new(join_schema), null_equality, + null_aware, }) } @@ -3877,6 +3896,7 @@ impl Join { join_constraint: original_join.join_constraint, schema: Arc::new(join_schema), null_equality: original_join.null_equality, + null_aware: original_join.null_aware, }, requalified, )) @@ -5329,6 +5349,7 @@ mod tests { join_constraint: JoinConstraint::On, schema: Arc::new(left_schema.join(&right_schema)?), null_equality: NullEquality::NullEqualsNothing, + null_aware: false, })) } @@ -5440,6 +5461,7 @@ mod tests { join_type, JoinConstraint::On, NullEquality::NullEqualsNothing, + false, )?; match join_type { @@ -5585,6 +5607,7 @@ mod tests { JoinType::Inner, JoinConstraint::Using, NullEquality::NullEqualsNothing, + false, )?; let fields = join.schema.fields(); @@ -5636,6 +5659,7 @@ mod tests { JoinType::Inner, JoinConstraint::On, NullEquality::NullEqualsNothing, + false, )?; let fields = join.schema.fields(); @@ -5685,6 +5709,7 @@ mod tests { JoinType::Inner, JoinConstraint::On, NullEquality::NullEqualsNull, + false, )?; assert_eq!(join.null_equality, NullEquality::NullEqualsNull); @@ -5727,6 +5752,7 @@ mod tests { join_type, JoinConstraint::On, NullEquality::NullEqualsNothing, + false, )?; let fields = join.schema.fields(); @@ -5766,6 +5792,7 @@ mod tests { JoinType::Inner, JoinConstraint::Using, NullEquality::NullEqualsNothing, + false, )?; assert_eq!( diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 62a27b0a025ad..a1285510da569 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -46,7 +46,7 @@ use crate::{ }; use datafusion_common::tree_node::TreeNodeRefContainer; -use crate::expr::{Exists, InSubquery}; +use crate::expr::{Exists, InSubquery, SetComparison}; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeContainer, TreeNodeIterator, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, @@ -133,6 +133,7 @@ impl TreeNode for LogicalPlan { join_constraint, schema, null_equality, + null_aware, }) => (left, right).map_elements(f)?.update_data(|(left, right)| { LogicalPlan::Join(Join { left, @@ -143,6 +144,7 @@ impl TreeNode for LogicalPlan { join_constraint, schema, null_equality, + null_aware, }) }), LogicalPlan::Limit(Limit { skip, fetch, input }) => input @@ -564,6 +566,7 @@ impl LogicalPlan { join_constraint, schema, null_equality, + null_aware, }) => (on, filter).map_elements(f)?.update_data(|(on, filter)| { LogicalPlan::Join(Join { left, @@ -574,6 +577,7 @@ impl LogicalPlan { join_constraint, schema, null_equality, + null_aware, }) }), LogicalPlan::Sort(Sort { expr, input, fetch }) => expr @@ -815,6 +819,7 @@ impl LogicalPlan { expr.apply(|expr| match expr { Expr::Exists(Exists { subquery, .. }) | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::SetComparison(SetComparison { subquery, .. }) | Expr::ScalarSubquery(subquery) => { // use a synthetic plan so the collector sees a // LogicalPlan::Subquery (even though it is @@ -856,6 +861,22 @@ impl LogicalPlan { })), _ => internal_err!("Transformation should return Subquery"), }), + Expr::SetComparison(SetComparison { + expr, + subquery, + op, + quantifier, + }) => f(LogicalPlan::Subquery(subquery))?.map_data(|s| match s { + LogicalPlan::Subquery(subquery) => { + Ok(Expr::SetComparison(SetComparison { + expr, + subquery, + op, + quantifier, + })) + } + _ => internal_err!("Transformation should return Subquery"), + }), Expr::ScalarSubquery(subquery) => f(LogicalPlan::Subquery(subquery))? .map_data(|s| match s { LogicalPlan::Subquery(subquery) => { diff --git a/datafusion/expr/src/partition_evaluator.rs b/datafusion/expr/src/partition_evaluator.rs index 3e7ba5d4f575a..0671f31f6d154 100644 --- a/datafusion/expr/src/partition_evaluator.rs +++ b/datafusion/expr/src/partition_evaluator.rs @@ -86,6 +86,10 @@ use crate::window_state::WindowAggState; /// [`uses_window_frame`]: Self::uses_window_frame /// [`include_rank`]: Self::include_rank /// [`supports_bounded_execution`]: Self::supports_bounded_execution +/// +/// For more background, please also see the [User defined Window Functions in DataFusion blog] +/// +/// [User defined Window Functions in DataFusion blog]: https://datafusion.apache.org/blog/2025/04/19/user-defined-window-functions pub trait PartitionEvaluator: Debug + Send { /// When the window frame has a fixed beginning (e.g UNBOUNDED /// PRECEDING), some functions such as FIRST_VALUE, LAST_VALUE and diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 954f511651ced..837a9eefe289f 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -139,6 +139,10 @@ pub trait ContextProvider { } /// Customize planning of SQL AST expressions to [`Expr`]s +/// +/// For more background, please also see the [Extending SQL in DataFusion: from ->> to TABLESAMPLE blog] +/// +/// [Extending SQL in DataFusion: from ->> to TABLESAMPLE blog]: https://datafusion.apache.org/blog/2026/01/12/extending-sql pub trait ExprPlanner: Debug + Send + Sync { /// Plan the binary operation between two expressions, returns original /// BinaryExpr if not possible @@ -249,13 +253,6 @@ pub trait ExprPlanner: Debug + Send + Sync { ) } - /// Plans `ANY` expression, such as `expr = ANY(array_expr)` - /// - /// Returns origin binary expression if not possible - fn plan_any(&self, expr: RawBinaryExpr) -> Result> { - Ok(PlannerResult::Original(expr)) - } - /// Plans aggregate functions, such as `COUNT()` /// /// Returns original expression arguments if not possible @@ -369,13 +366,16 @@ impl PlannedRelation { #[derive(Debug)] pub enum RelationPlanning { /// The relation was successfully planned by an extension planner - Planned(PlannedRelation), + Planned(Box), /// No extension planner handled the relation, return it for default processing - Original(TableFactor), + Original(Box), } /// Customize planning SQL table factors to [`LogicalPlan`]s. #[cfg(feature = "sql")] +/// For more background, please also see the [Extending SQL in DataFusion: from ->> to TABLESAMPLE blog] +/// +/// [Extending SQL in DataFusion: from ->> to TABLESAMPLE blog]: https://datafusion.apache.org/blog/2026/01/12/extending-sql pub trait RelationPlanner: Debug + Send + Sync { /// Plan a table factor into a [`LogicalPlan`]. /// @@ -427,6 +427,9 @@ pub trait RelationPlannerContext { /// Customize planning SQL types to DataFusion (Arrow) types. #[cfg(feature = "sql")] +/// For more background, please also see the [Extending SQL in DataFusion: from ->> to TABLESAMPLE blog] +/// +/// [Extending SQL in DataFusion: from ->> to TABLESAMPLE blog]: https://datafusion.apache.org/blog/2026/01/12/extending-sql pub trait TypePlanner: Debug + Send + Sync { /// Plan SQL [`sqlparser::ast::DataType`] to DataFusion [`DataType`] /// diff --git a/datafusion/expr/src/preimage.rs b/datafusion/expr/src/preimage.rs new file mode 100644 index 0000000000000..67ca7a91bbf38 --- /dev/null +++ b/datafusion/expr/src/preimage.rs @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr_common::interval_arithmetic::Interval; + +use crate::Expr; + +/// Return from [`crate::ScalarUDFImpl::preimage`] +pub enum PreimageResult { + /// No preimage exists for the specified value + None, + /// The expression always evaluates to the specified constant + /// given that `expr` is within the interval + Range { expr: Expr, interval: Box }, +} diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 742bae5b2320b..226c512a974d8 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -20,8 +20,8 @@ use crate::Expr; use crate::expr::{ AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Cast, - GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest, - WindowFunction, WindowFunctionParams, + GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, SetComparison, + TryCast, Unnest, WindowFunction, WindowFunctionParams, }; use datafusion_common::Result; @@ -58,7 +58,8 @@ impl TreeNode for Expr { | Expr::Negative(expr) | Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) - | Expr::InSubquery(InSubquery { expr, .. }) => expr.apply_elements(f), + | Expr::InSubquery(InSubquery { expr, .. }) + | Expr::SetComparison(SetComparison { expr, .. }) => expr.apply_elements(f), Expr::GroupingSet(GroupingSet::Rollup(exprs)) | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.apply_elements(f), Expr::ScalarFunction(ScalarFunction { args, .. }) => { @@ -128,6 +129,19 @@ impl TreeNode for Expr { | Expr::ScalarSubquery(_) | Expr::ScalarVariable(_, _) | Expr::Literal(_, _) => Transformed::no(self), + Expr::SetComparison(SetComparison { + expr, + subquery, + op, + quantifier, + }) => expr.map_elements(f)?.update_data(|expr| { + Expr::SetComparison(SetComparison { + expr, + subquery, + op, + quantifier, + }) + }), Expr::Unnest(Unnest { expr, .. }) => expr .map_elements(f)? .update_data(|expr| Expr::Unnest(Unnest { expr })), diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index e1f2a19672825..90c137de24cb5 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -94,58 +94,6 @@ impl UDFCoercionExt for WindowUDF { } } -/// Performs type coercion for scalar function arguments. -/// -/// Returns the data types to which each argument must be coerced to -/// match `signature`. -/// -/// For more details on coercion in general, please see the -/// [`type_coercion`](crate::type_coercion) module. -#[deprecated(since = "52.0.0", note = "use fields_with_udf")] -pub fn data_types_with_scalar_udf( - current_types: &[DataType], - func: &ScalarUDF, -) -> Result> { - let current_fields = current_types - .iter() - .map(|dt| Arc::new(Field::new("f", dt.clone(), true))) - .collect::>(); - Ok(fields_with_udf(¤t_fields, func)? - .iter() - .map(|f| f.data_type().clone()) - .collect()) -} - -/// Performs type coercion for aggregate function arguments. -/// -/// Returns the fields to which each argument must be coerced to -/// match `signature`. -/// -/// For more details on coercion in general, please see the -/// [`type_coercion`](crate::type_coercion) module. -#[deprecated(since = "52.0.0", note = "use fields_with_udf")] -pub fn fields_with_aggregate_udf( - current_fields: &[FieldRef], - func: &AggregateUDF, -) -> Result> { - fields_with_udf(current_fields, func) -} - -/// Performs type coercion for window function arguments. -/// -/// Returns the data types to which each argument must be coerced to -/// match `signature`. -/// -/// For more details on coercion in general, please see the -/// [`type_coercion`](crate::type_coercion) module. -#[deprecated(since = "52.0.0", note = "use fields_with_udf")] -pub fn fields_with_window_udf( - current_fields: &[FieldRef], - func: &WindowUDF, -) -> Result> { - fields_with_udf(current_fields, func) -} - /// Performs type coercion for UDF arguments. /// /// Returns the data types to which each argument must be coerced to @@ -200,6 +148,58 @@ pub fn fields_with_udf( .collect()) } +/// Performs type coercion for scalar function arguments. +/// +/// Returns the data types to which each argument must be coerced to +/// match `signature`. +/// +/// For more details on coercion in general, please see the +/// [`type_coercion`](crate::type_coercion) module. +#[deprecated(since = "52.0.0", note = "use fields_with_udf")] +pub fn data_types_with_scalar_udf( + current_types: &[DataType], + func: &ScalarUDF, +) -> Result> { + let current_fields = current_types + .iter() + .map(|dt| Arc::new(Field::new("f", dt.clone(), true))) + .collect::>(); + Ok(fields_with_udf(¤t_fields, func)? + .iter() + .map(|f| f.data_type().clone()) + .collect()) +} + +/// Performs type coercion for aggregate function arguments. +/// +/// Returns the fields to which each argument must be coerced to +/// match `signature`. +/// +/// For more details on coercion in general, please see the +/// [`type_coercion`](crate::type_coercion) module. +#[deprecated(since = "52.0.0", note = "use fields_with_udf")] +pub fn fields_with_aggregate_udf( + current_fields: &[FieldRef], + func: &AggregateUDF, +) -> Result> { + fields_with_udf(current_fields, func) +} + +/// Performs type coercion for window function arguments. +/// +/// Returns the data types to which each argument must be coerced to +/// match `signature`. +/// +/// For more details on coercion in general, please see the +/// [`type_coercion`](crate::type_coercion) module. +#[deprecated(since = "52.0.0", note = "use fields_with_udf")] +pub fn fields_with_window_udf( + current_fields: &[FieldRef], + func: &WindowUDF, +) -> Result> { + fields_with_udf(current_fields, func) +} + /// Performs type coercion for function arguments. /// /// Returns the data types to which each argument must be coerced to @@ -487,7 +487,7 @@ fn get_valid_types( let valid_types = match signature { TypeSignature::Variadic(valid_types) => valid_types .iter() - .map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect()) + .map(|valid_type| vec![valid_type.clone(); current_types.len()]) .collect(), TypeSignature::String(number) => { function_length_check(function_name, current_types.len(), *number)?; @@ -635,8 +635,13 @@ fn get_valid_types( default_casted_type.default_cast_for(current_type)?; new_types.push(casted_type); } else { - return internal_err!( - "Expect {} but received NativeType::{}, DataType: {}", + let hint = if matches!(current_native_type, NativeType::Binary) { + "\n\nHint: Binary types are not automatically coerced to String. Use CAST(column AS VARCHAR) to convert Binary data to String." + } else { + "" + }; + return plan_err!( + "Function '{function_name}' requires {}, but received {} (DataType: {}).{hint}", param.desired_type(), current_native_type, current_type @@ -655,7 +660,7 @@ fn get_valid_types( valid_types .iter() - .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) + .map(|valid_type| vec![valid_type.clone(); *number]) .collect() } TypeSignature::UserDefined => { @@ -722,7 +727,7 @@ fn get_valid_types( current_types.len() ); } - vec![(0..*number).map(|i| current_types[i].clone()).collect()] + vec![current_types.to_vec()] } TypeSignature::OneOf(types) => types .iter() @@ -800,6 +805,7 @@ fn maybe_data_types_without_coercion( /// (losslessly converted) into a value of `type_to` /// /// See the module level documentation for more detail on coercion. +#[deprecated(since = "53.0.0", note = "Unused internal function")] pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool { if type_into == type_from { return true; @@ -846,10 +852,13 @@ fn coerced_from<'a>( (UInt16, Null | UInt8 | UInt16) => Some(type_into.clone()), (UInt32, Null | UInt8 | UInt16 | UInt32) => Some(type_into.clone()), (UInt64, Null | UInt8 | UInt16 | UInt32 | UInt64) => Some(type_into.clone()), + (Float16, Null | Int8 | Int16 | UInt8 | UInt16 | Float16) => { + Some(type_into.clone()) + } ( Float32, Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 - | Float32, + | Float16 | Float32, ) => Some(type_into.clone()), ( Float64, @@ -862,6 +871,7 @@ fn coerced_from<'a>( | UInt16 | UInt32 | UInt64 + | Float16 | Float32 | Float64 | Decimal32(_, _) @@ -928,18 +938,21 @@ mod tests { use super::*; use arrow::datatypes::Field; - use datafusion_common::{assert_contains, types::logical_binary}; + use datafusion_common::{ + assert_contains, + types::{logical_binary, logical_int64}, + }; use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; #[test] fn test_string_conversion() { let cases = vec![ - (DataType::Utf8View, DataType::Utf8, true), - (DataType::Utf8View, DataType::LargeUtf8, true), + (DataType::Utf8View, DataType::Utf8), + (DataType::Utf8View, DataType::LargeUtf8), ]; for case in cases { - assert_eq!(can_coerce_from(&case.0, &case.1), case.2); + assert_eq!(coerced_from(&case.0, &case.1), Some(case.0)); } } @@ -1063,7 +1076,7 @@ mod tests { .unwrap_err(); assert_contains!( got.to_string(), - "Function 'test' expects NativeType::Numeric but received NativeType::Timestamp(Second, None)" + "Function 'test' expects NativeType::Numeric but received NativeType::Timestamp(s)" ); Ok(()) @@ -1118,22 +1131,22 @@ mod tests { Ok(()) } - #[test] - fn test_fixed_list_wildcard_coerce() -> Result<()> { - struct MockUdf(Signature); + struct MockUdf(Signature); - impl UDFCoercionExt for MockUdf { - fn name(&self) -> &str { - "test" - } - fn signature(&self) -> &Signature { - &self.0 - } - fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { - unimplemented!() - } + impl UDFCoercionExt for MockUdf { + fn name(&self) -> &str { + "test" + } + fn signature(&self) -> &Signature { + &self.0 } + fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { + unimplemented!() + } + } + #[test] + fn test_fixed_list_wildcard_coerce() -> Result<()> { let inner = Arc::new(Field::new_list_field(DataType::Int32, false)); // able to coerce for any size let current_fields = vec![Arc::new(Field::new( @@ -1340,6 +1353,140 @@ mod tests { Ok(()) } + #[test] + fn test_coercible_nulls() -> Result<()> { + fn null_input(coercion: Coercion) -> Result> { + fields_with_udf( + &[Field::new("field", DataType::Null, true).into()], + &MockUdf(Signature::coercible(vec![coercion], Volatility::Immutable)), + ) + .map(|v| v.into_iter().map(|f| f.data_type().clone()).collect()) + } + + // Casts Null to Int64 if we use TypeSignatureClass::Native + let output = null_input(Coercion::new_exact(TypeSignatureClass::Native( + logical_int64(), + )))?; + assert_eq!(vec![DataType::Int64], output); + + let output = null_input(Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![], + NativeType::Int64, + ))?; + assert_eq!(vec![DataType::Int64], output); + + // Null gets passed through if we use TypeSignatureClass apart from Native + let output = null_input(Coercion::new_exact(TypeSignatureClass::Integer))?; + assert_eq!(vec![DataType::Null], output); + + let output = null_input(Coercion::new_implicit( + TypeSignatureClass::Integer, + vec![], + NativeType::Int64, + ))?; + assert_eq!(vec![DataType::Null], output); + + Ok(()) + } + + #[test] + fn test_coercible_dictionary() -> Result<()> { + let dictionary = + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int64)); + fn dictionary_input(coercion: Coercion) -> Result> { + fields_with_udf( + &[Field::new( + "field", + DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Int64), + ), + true, + ) + .into()], + &MockUdf(Signature::coercible(vec![coercion], Volatility::Immutable)), + ) + .map(|v| v.into_iter().map(|f| f.data_type().clone()).collect()) + } + + // Casts Dictionary to Int64 if we use TypeSignatureClass::Native + let output = dictionary_input(Coercion::new_exact(TypeSignatureClass::Native( + logical_int64(), + )))?; + assert_eq!(vec![DataType::Int64], output); + + let output = dictionary_input(Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![], + NativeType::Int64, + ))?; + assert_eq!(vec![DataType::Int64], output); + + // Dictionary gets passed through if we use TypeSignatureClass apart from Native + let output = dictionary_input(Coercion::new_exact(TypeSignatureClass::Integer))?; + assert_eq!(vec![dictionary.clone()], output); + + let output = dictionary_input(Coercion::new_implicit( + TypeSignatureClass::Integer, + vec![], + NativeType::Int64, + ))?; + assert_eq!(vec![dictionary.clone()], output); + + Ok(()) + } + + #[test] + fn test_coercible_run_end_encoded() -> Result<()> { + let run_end_encoded = DataType::RunEndEncoded( + Field::new("run_ends", DataType::Int16, false).into(), + Field::new("values", DataType::Int64, true).into(), + ); + fn run_end_encoded_input(coercion: Coercion) -> Result> { + fields_with_udf( + &[Field::new( + "field", + DataType::RunEndEncoded( + Field::new("run_ends", DataType::Int16, false).into(), + Field::new("values", DataType::Int64, true).into(), + ), + true, + ) + .into()], + &MockUdf(Signature::coercible(vec![coercion], Volatility::Immutable)), + ) + .map(|v| v.into_iter().map(|f| f.data_type().clone()).collect()) + } + + // Casts REE to Int64 if we use TypeSignatureClass::Native + let output = run_end_encoded_input(Coercion::new_exact( + TypeSignatureClass::Native(logical_int64()), + ))?; + assert_eq!(vec![DataType::Int64], output); + + let output = run_end_encoded_input(Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![], + NativeType::Int64, + ))?; + assert_eq!(vec![DataType::Int64], output); + + // REE gets passed through if we use TypeSignatureClass apart from Native + let output = + run_end_encoded_input(Coercion::new_exact(TypeSignatureClass::Integer))?; + assert_eq!(vec![run_end_encoded.clone()], output); + + let output = run_end_encoded_input(Coercion::new_implicit( + TypeSignatureClass::Integer, + vec![], + NativeType::Int64, + ))?; + assert_eq!(vec![run_end_encoded.clone()], output); + + Ok(()) + } + #[test] fn test_get_valid_types_coercible_binary() -> Result<()> { let signature = Signature::coercible( diff --git a/datafusion/expr/src/type_coercion/mod.rs b/datafusion/expr/src/type_coercion/mod.rs index bd1acd3f3a2e2..c92d434e34abe 100644 --- a/datafusion/expr/src/type_coercion/mod.rs +++ b/datafusion/expr/src/type_coercion/mod.rs @@ -58,11 +58,6 @@ pub fn is_signed_numeric(dt: &DataType) -> bool { ) } -/// Determine whether the given data type `dt` is `Null`. -pub fn is_null(dt: &DataType) -> bool { - *dt == DataType::Null -} - /// Determine whether the given data type `dt` is a `Timestamp`. pub fn is_timestamp(dt: &DataType) -> bool { matches!(dt, DataType::Timestamp(_, _)) @@ -80,22 +75,3 @@ pub fn is_datetime(dt: &DataType) -> bool { DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _) ) } - -/// Determine whether the given data type `dt` is a `Utf8` or `Utf8View` or `LargeUtf8`. -pub fn is_utf8_or_utf8view_or_large_utf8(dt: &DataType) -> bool { - matches!( - dt, - DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 - ) -} - -/// Determine whether the given data type `dt` is a `Decimal`. -pub fn is_decimal(dt: &DataType) -> bool { - matches!( - dt, - DataType::Decimal32(_, _) - | DataType::Decimal64(_, _) - | DataType::Decimal128(_, _) - | DataType::Decimal256(_, _) - ) -} diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 0654370ac7ebf..405fb256803b6 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -19,6 +19,7 @@ use crate::async_udf::AsyncScalarUDF; use crate::expr::schema_name_from_exprs_comma_separated_without_space; +use crate::preimage::PreimageResult; use crate::simplify::{ExprSimplifyResult, SimplifyContext}; use crate::sort_properties::{ExprProperties, SortProperties}; use crate::udf_eq::UdfEq; @@ -30,6 +31,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::{ExprSchema, Result, ScalarValue, not_impl_err}; use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; use datafusion_expr_common::interval_arithmetic::Interval; +use datafusion_expr_common::placement::ExpressionPlacement; use std::any::Any; use std::cmp::Ordering; use std::fmt::Debug; @@ -232,6 +234,18 @@ impl ScalarUDF { self.inner.is_nullable(args, schema) } + /// Return a preimage + /// + /// See [`ScalarUDFImpl::preimage`] for more details. + pub fn preimage( + &self, + args: &[Expr], + lit_expr: &Expr, + info: &SimplifyContext, + ) -> Result { + self.inner.preimage(args, lit_expr, info) + } + /// Invoke the function on `args`, returning the appropriate result. /// /// See [`ScalarUDFImpl::invoke_with_args`] for details. @@ -348,6 +362,13 @@ impl ScalarUDF { pub fn as_async(&self) -> Option<&AsyncScalarUDF> { self.inner().as_any().downcast_ref::() } + + /// Returns placement information for this function. + /// + /// See [`ScalarUDFImpl::placement`] for more details. + pub fn placement(&self, args: &[ExpressionPlacement]) -> ExpressionPlacement { + self.inner.placement(args) + } } impl From for ScalarUDF @@ -696,6 +717,111 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync { Ok(ExprSimplifyResult::Original(args)) } + /// Returns a single contiguous preimage for this function and the specified + /// scalar expression, if any. + /// + /// Currently only applies to `=, !=, >, >=, <, <=, is distinct from, is not distinct from` predicates + /// # Return Value + /// + /// Implementations should return a half-open interval: inclusive lower + /// bound and exclusive upper bound. This is slightly different from normal + /// [`Interval`] semantics where the upper bound is closed (inclusive). + /// Typically this means the upper endpoint must be adjusted to the next + /// value not included in the preimage. See the Half-Open Intervals section + /// below for more details. + /// + /// # Background + /// + /// Inspired by the [ClickHouse Paper], a "preimage rewrite" transforms a + /// predicate containing a function call into a predicate containing an + /// equivalent set of input literal (constant) values. The resulting + /// predicate can often be further optimized by other rewrites (see + /// Examples). + /// + /// From the paper: + /// + /// > some functions can compute the preimage of a given function result. + /// > This is used to replace comparisons of constants with function calls + /// > on the key columns by comparing the key column value with the preimage. + /// > For example, `toYear(k) = 2024` can be replaced by + /// > `k >= 2024-01-01 && k < 2025-01-01` + /// + /// For example, given an expression like + /// ```sql + /// date_part('YEAR', k) = 2024 + /// ``` + /// + /// The interval `[2024-01-01, 2025-12-31`]` contains all possible input + /// values (preimage values) for which the function `date_part(YEAR, k)` + /// produces the output value `2024` (image value). Returning the interval + /// (note upper bound adjusted up) `[2024-01-01, 2025-01-01]` the expression + /// can be rewritten to + /// + /// ```sql + /// k >= '2024-01-01' AND k < '2025-01-01' + /// ``` + /// + /// which is a simpler and a more canonical form, making it easier for other + /// optimizer passes to recognize and apply further transformations. + /// + /// # Examples + /// + /// Case 1: + /// + /// Original: + /// ```sql + /// date_part('YEAR', k) = 2024 AND k >= '2024-06-01' + /// ``` + /// + /// After preimage rewrite: + /// ```sql + /// k >= '2024-01-01' AND k < '2025-01-01' AND k >= '2024-06-01' + /// ``` + /// + /// Since this form is much simpler, the optimizer can combine and simplify + /// sub-expressions further into: + /// ```sql + /// k >= '2024-06-01' AND k < '2025-01-01' + /// ``` + /// + /// Case 2: + /// + /// For min/max pruning, simpler predicates such as: + /// ```sql + /// k >= '2024-01-01' AND k < '2025-01-01' + /// ``` + /// are much easier for the pruner to reason about. See [PruningPredicate] + /// for the backgrounds of predicate pruning. + /// + /// The trade-off with the preimage rewrite is that evaluating the rewritten + /// form might be slightly more expensive than evaluating the original + /// expression. In practice, this cost is usually outweighed by the more + /// aggressive optimization opportunities it enables. + /// + /// # Half-Open Intervals + /// + /// The preimage API uses half-open intervals, which makes the rewrite + /// easier to implement by avoiding calculations to adjust the upper bound. + /// For example, if a function returns its input unchanged and the desired + /// output is the single value `5`, a closed interval could be represented + /// as `[5, 5]`, but then the rewrite would require adjusting the upper + /// bound to `6` to create a proper range predicate. With a half-open + /// interval, the same range is represented as `[5, 6)`, which already + /// forms a valid predicate. + /// + /// [PruningPredicate]: https://docs.rs/datafusion/latest/datafusion/physical_optimizer/pruning/struct.PruningPredicate.html + /// [ClickHouse Paper]: https://www.vldb.org/pvldb/vol17/p3731-schulze.pdf + /// [image]: https://en.wikipedia.org/wiki/Image_(mathematics)#Image_of_an_element + /// [preimage]: https://en.wikipedia.org/wiki/Image_(mathematics)#Inverse_image + fn preimage( + &self, + _args: &[Expr], + _lit_expr: &Expr, + _info: &SimplifyContext, + ) -> Result { + Ok(PreimageResult::None) + } + /// Returns true if some of this `exprs` subexpressions may not be evaluated /// and thus any side effects (like divide by zero) may not be encountered. /// @@ -846,6 +972,20 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync { fn documentation(&self) -> Option<&Documentation> { None } + + /// Returns placement information for this function. + /// + /// This is used by optimizers to make decisions about expression placement, + /// such as whether to push expressions down through projections. + /// + /// The default implementation returns [`ExpressionPlacement::KeepInPlace`], + /// meaning the expression should be kept where it is in the plan. + /// + /// Override this method to indicate that the function can be pushed down + /// closer to the data source. + fn placement(&self, _args: &[ExpressionPlacement]) -> ExpressionPlacement { + ExpressionPlacement::KeepInPlace + } } /// ScalarUDF that adds an alias to the underlying function. It is better to @@ -926,6 +1066,15 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { self.inner.simplify(args, info) } + fn preimage( + &self, + args: &[Expr], + lit_expr: &Expr, + info: &SimplifyContext, + ) -> Result { + self.inner.preimage(args, lit_expr, info) + } + fn conditional_arguments<'a>( &self, args: &'a [Expr], @@ -964,6 +1113,10 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { fn documentation(&self) -> Option<&Documentation> { self.inner.documentation() } + + fn placement(&self, args: &[ExpressionPlacement]) -> ExpressionPlacement { + self.inner.placement(args) + } } #[cfg(test)] diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index de4ebf5fa96e9..b19299981cef3 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -312,6 +312,7 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::InList { .. } | Expr::Exists { .. } | Expr::InSubquery(_) + | Expr::SetComparison(_) | Expr::ScalarSubquery(_) | Expr::Wildcard { .. } | Expr::Placeholder(_) @@ -937,6 +938,7 @@ pub fn find_valid_equijoin_key_pair( /// round(Float32) /// ``` #[expect(clippy::needless_pass_by_value)] +#[deprecated(since = "53.0.0", note = "Internal function")] pub fn generate_signature_error_msg( func_name: &str, func_signature: Signature, @@ -958,6 +960,26 @@ pub fn generate_signature_error_msg( ) } +/// Creates a detailed error message for a function with wrong signature. +/// +/// For example, a query like `select round(3.14, 1.1);` would yield: +/// ```text +/// Error during planning: No function matches 'round(Float64, Float64)'. You might need to add explicit type casts. +/// Candidate functions: +/// round(Float64, Int64) +/// round(Float32, Int64) +/// round(Float64) +/// round(Float32) +/// ``` +pub(crate) fn generate_signature_error_message( + func_name: &str, + func_signature: &Signature, + input_expr_types: &[DataType], +) -> String { + #[expect(deprecated)] + generate_signature_error_msg(func_name, func_signature.clone(), input_expr_types) +} + /// Splits a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` /// /// See [`split_conjunction_owned`] for more details and an example. @@ -1734,7 +1756,8 @@ mod tests { .expect("valid parameter names"); // Generate error message with only 1 argument provided - let error_msg = generate_signature_error_msg("substr", sig, &[DataType::Utf8]); + let error_msg = + generate_signature_error_message("substr", &sig, &[DataType::Utf8]); assert!( error_msg.contains("str: Utf8, start_pos: Int64"), @@ -1753,7 +1776,8 @@ mod tests { Volatility::Immutable, ); - let error_msg = generate_signature_error_msg("my_func", sig, &[DataType::Int32]); + let error_msg = + generate_signature_error_message("my_func", &sig, &[DataType::Int32]); assert!( error_msg.contains("Any, Any"), diff --git a/datafusion/ffi/src/execution_plan.rs b/datafusion/ffi/src/execution_plan.rs index c879b022067c3..94e1d03d0832c 100644 --- a/datafusion/ffi/src/execution_plan.rs +++ b/datafusion/ffi/src/execution_plan.rs @@ -367,10 +367,6 @@ pub(crate) mod tests { ) -> Result { unimplemented!() } - - fn statistics(&self) -> Result { - unimplemented!() - } } #[test] diff --git a/datafusion/ffi/src/lib.rs b/datafusion/ffi/src/lib.rs index bf0cf9b122c1c..2ca9b8f6f495a 100644 --- a/datafusion/ffi/src/lib.rs +++ b/datafusion/ffi/src/lib.rs @@ -24,7 +24,6 @@ // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -#![deny(clippy::allow_attributes)] pub mod arrow_wrappers; pub mod catalog_provider; diff --git a/datafusion/functions-aggregate-common/benches/accumulate.rs b/datafusion/functions-aggregate-common/benches/accumulate.rs index f1e4fe23cbb15..aceec57df9666 100644 --- a/datafusion/functions-aggregate-common/benches/accumulate.rs +++ b/datafusion/functions-aggregate-common/benches/accumulate.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use std::sync::Arc; use arrow::array::{ArrayRef, BooleanArray, Int64Array}; diff --git a/datafusion/functions-aggregate-common/src/lib.rs b/datafusion/functions-aggregate-common/src/lib.rs index 61b880095047c..574d160d4214a 100644 --- a/datafusion/functions-aggregate-common/src/lib.rs +++ b/datafusion/functions-aggregate-common/src/lib.rs @@ -31,8 +31,6 @@ // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -// https://github.com/apache/datafusion/issues/18881 -#![deny(clippy::allow_attributes)] pub mod accumulator; pub mod aggregate; diff --git a/datafusion/functions-aggregate-common/src/tdigest.rs b/datafusion/functions-aggregate-common/src/tdigest.rs index 225c61b71939e..a7450f0eb52e9 100644 --- a/datafusion/functions-aggregate-common/src/tdigest.rs +++ b/datafusion/functions-aggregate-common/src/tdigest.rs @@ -49,17 +49,6 @@ macro_rules! cast_scalar_f64 { }; } -// Cast a non-null [`ScalarValue::UInt64`] to an [`u64`], or -// panic. -macro_rules! cast_scalar_u64 { - ($value:expr ) => { - match &$value { - ScalarValue::UInt64(Some(v)) => *v, - v => panic!("invalid type {}", v), - } - }; -} - /// Centroid implementation to the cluster mentioned in the paper. #[derive(Debug, PartialEq, Clone)] pub struct Centroid { @@ -110,7 +99,7 @@ pub struct TDigest { centroids: Vec, max_size: usize, sum: f64, - count: u64, + count: f64, max: f64, min: f64, } @@ -120,8 +109,8 @@ impl TDigest { TDigest { centroids: Vec::new(), max_size, - sum: 0_f64, - count: 0, + sum: 0.0, + count: 0.0, max: f64::NAN, min: f64::NAN, } @@ -133,14 +122,14 @@ impl TDigest { centroids: vec![centroid.clone()], max_size, sum: centroid.mean * centroid.weight, - count: 1, + count: centroid.weight, max: centroid.mean, min: centroid.mean, } } #[inline] - pub fn count(&self) -> u64 { + pub fn count(&self) -> f64 { self.count } @@ -170,8 +159,8 @@ impl Default for TDigest { TDigest { centroids: Vec::new(), max_size: 100, - sum: 0_f64, - count: 0, + sum: 0.0, + count: 0.0, max: f64::NAN, min: f64::NAN, } @@ -216,12 +205,12 @@ impl TDigest { } let mut result = TDigest::new(self.max_size()); - result.count = self.count() + sorted_values.len() as u64; + result.count = self.count() + sorted_values.len() as f64; let maybe_min = *sorted_values.first().unwrap(); let maybe_max = *sorted_values.last().unwrap(); - if self.count() > 0 { + if self.count() > 0.0 { result.min = self.min.min(maybe_min); result.max = self.max.max(maybe_max); } else { @@ -233,7 +222,7 @@ impl TDigest { let mut k_limit: u64 = 1; let mut q_limit_times_count = - Self::k_to_q(k_limit, self.max_size) * result.count() as f64; + Self::k_to_q(k_limit, self.max_size) * result.count(); k_limit += 1; let mut iter_centroids = self.centroids.iter().peekable(); @@ -281,7 +270,7 @@ impl TDigest { compressed.push(curr.clone()); q_limit_times_count = - Self::k_to_q(k_limit, self.max_size) * result.count() as f64; + Self::k_to_q(k_limit, self.max_size) * result.count(); k_limit += 1; curr = next; } @@ -353,7 +342,7 @@ impl TDigest { let mut centroids: Vec = Vec::with_capacity(n_centroids); let mut starts: Vec = Vec::with_capacity(digests.len()); - let mut count = 0; + let mut count = 0.0; let mut min = f64::INFINITY; let mut max = f64::NEG_INFINITY; @@ -362,7 +351,7 @@ impl TDigest { starts.push(start); let curr_count = digest.count(); - if curr_count > 0 { + if curr_count > 0.0 { min = min.min(digest.min); max = max.max(digest.max); count += curr_count; @@ -373,6 +362,11 @@ impl TDigest { } } + // If no centroids were added (all digests had zero count), return default + if centroids.is_empty() { + return TDigest::default(); + } + let mut digests_per_block: usize = 1; while digests_per_block < starts.len() { for i in (0..starts.len()).step_by(digests_per_block * 2) { @@ -397,7 +391,7 @@ impl TDigest { let mut compressed: Vec = Vec::with_capacity(max_size); let mut k_limit = 1; - let mut q_limit_times_count = Self::k_to_q(k_limit, max_size) * count as f64; + let mut q_limit_times_count = Self::k_to_q(k_limit, max_size) * count; let mut iter_centroids = centroids.iter_mut(); let mut curr = iter_centroids.next().unwrap(); @@ -416,7 +410,7 @@ impl TDigest { sums_to_merge = 0_f64; weights_to_merge = 0_f64; compressed.push(curr.clone()); - q_limit_times_count = Self::k_to_q(k_limit, max_size) * count as f64; + q_limit_times_count = Self::k_to_q(k_limit, max_size) * count; k_limit += 1; curr = centroid; } @@ -440,7 +434,7 @@ impl TDigest { return 0.0; } - let rank = q * self.count as f64; + let rank = q * self.count; let mut pos: usize; let mut t; @@ -450,7 +444,7 @@ impl TDigest { } pos = 0; - t = self.count as f64; + t = self.count; for (k, centroid) in self.centroids.iter().enumerate().rev() { t -= centroid.weight(); @@ -563,7 +557,7 @@ impl TDigest { vec![ ScalarValue::UInt64(Some(self.max_size as u64)), ScalarValue::Float64(Some(self.sum)), - ScalarValue::UInt64(Some(self.count)), + ScalarValue::Float64(Some(self.count)), ScalarValue::Float64(Some(self.max)), ScalarValue::Float64(Some(self.min)), ScalarValue::List(arr), @@ -611,7 +605,7 @@ impl TDigest { Self { max_size, sum: cast_scalar_f64!(state[1]), - count: cast_scalar_u64!(&state[2]), + count: cast_scalar_f64!(state[2]), max, min, centroids, diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index 8f8697fef0a1f..39337e44bb051 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -53,6 +53,7 @@ datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } half = { workspace = true } log = { workspace = true } +num-traits = { workspace = true } paste = { workspace = true } [dev-dependencies] diff --git a/datafusion/functions-aggregate/benches/array_agg.rs b/datafusion/functions-aggregate/benches/array_agg.rs index d7f687386333f..793c2aac96293 100644 --- a/datafusion/functions-aggregate/benches/array_agg.rs +++ b/datafusion/functions-aggregate/benches/array_agg.rs @@ -43,7 +43,7 @@ fn merge_batch_bench(c: &mut Criterion, name: &str, values: ArrayRef) { let list_item_data_type = values.as_list::().values().data_type().clone(); c.bench_function(name, |b| { b.iter(|| { - #[allow(clippy::unit_arg)] + #[expect(clippy::unit_arg)] black_box( ArrayAggAccumulator::try_new(&list_item_data_type, false) .unwrap() diff --git a/datafusion/functions-aggregate/benches/count.rs b/datafusion/functions-aggregate/benches/count.rs index 711bbe5a3c4df..48f71858c1204 100644 --- a/datafusion/functions-aggregate/benches/count.rs +++ b/datafusion/functions-aggregate/benches/count.rs @@ -130,7 +130,7 @@ fn count_benchmark(c: &mut Criterion) { let mut accumulator = prepare_accumulator(); c.bench_function("count low cardinality dict 20% nulls, no filter", |b| { b.iter(|| { - #[allow(clippy::unit_arg)] + #[expect(clippy::unit_arg)] black_box( accumulator .update_batch(std::slice::from_ref(&values)) diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs index 739e333b54617..2205b009ecb27 100644 --- a/datafusion/functions-aggregate/src/approx_median.rs +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -110,7 +110,7 @@ impl AggregateUDFImpl for ApproxMedian { Ok(vec![ Field::new(format_state_name(args.name, "max_size"), UInt64, false), Field::new(format_state_name(args.name, "sum"), Float64, false), - Field::new(format_state_name(args.name, "count"), UInt64, false), + Field::new(format_state_name(args.name, "count"), Float64, false), Field::new(format_state_name(args.name, "max"), Float64, false), Field::new(format_state_name(args.name, "min"), Float64, false), Field::new_list( diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index b1e649ec029ff..392a044d01394 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -259,7 +259,7 @@ impl AggregateUDFImpl for ApproxPercentileCont { ), Field::new( format_state_name(args.name, "count"), - DataType::UInt64, + DataType::Float64, false, ), Field::new( @@ -436,7 +436,7 @@ impl Accumulator for ApproxPercentileAccumulator { } fn evaluate(&mut self) -> Result { - if self.digest.count() == 0 { + if self.digest.count() == 0.0 { return ScalarValue::try_from(self.return_type.clone()); } let q = self.digest.estimate_quantile(self.percentile); @@ -513,8 +513,8 @@ mod tests { ApproxPercentileAccumulator::new_with_max_size(0.5, DataType::Float64, 100); accumulator.merge_digests(&[t1]); - assert_eq!(accumulator.digest.count(), 50_000); + assert_eq!(accumulator.digest.count(), 50_000.0); accumulator.merge_digests(&[t2]); - assert_eq!(accumulator.digest.count(), 100_000); + assert_eq!(accumulator.digest.count(), 100_000.0); } } diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index ff7762e816ad6..6fd90130e6741 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::fmt::{Debug, Formatter}; +use std::fmt::Debug; use std::hash::Hash; use std::mem::size_of_val; use std::sync::Arc; @@ -111,20 +111,12 @@ An alternative syntax is also supported: description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory." ) )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct ApproxPercentileContWithWeight { signature: Signature, approx_percentile_cont: ApproxPercentileCont, } -impl Debug for ApproxPercentileContWithWeight { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ApproxPercentileContWithWeight") - .field("signature", &self.signature) - .finish() - } -} - impl Default for ApproxPercentileContWithWeight { fn default() -> Self { Self::new() diff --git a/datafusion/functions-aggregate/src/bool_and_or.rs b/datafusion/functions-aggregate/src/bool_and_or.rs index a107024e2fb4f..77b99cd1ae993 100644 --- a/datafusion/functions-aggregate/src/bool_and_or.rs +++ b/datafusion/functions-aggregate/src/bool_and_or.rs @@ -114,11 +114,7 @@ pub struct BoolAnd { impl BoolAnd { fn new() -> Self { Self { - signature: Signature::uniform( - 1, - vec![DataType::Boolean], - Volatility::Immutable, - ), + signature: Signature::exact(vec![DataType::Boolean], Volatility::Immutable), } } } @@ -251,11 +247,7 @@ pub struct BoolOr { impl BoolOr { fn new() -> Self { Self { - signature: Signature::uniform( - 1, - vec![DataType::Boolean], - Volatility::Immutable, - ), + signature: Signature::exact(vec![DataType::Boolean], Volatility::Immutable), } } } diff --git a/datafusion/functions-aggregate/src/correlation.rs b/datafusion/functions-aggregate/src/correlation.rs index 119f861a57608..6c76c6e940099 100644 --- a/datafusion/functions-aggregate/src/correlation.rs +++ b/datafusion/functions-aggregate/src/correlation.rs @@ -367,7 +367,7 @@ fn accumulate_correlation_states( /// where: /// n = number of observations /// sum_x = sum of x values -/// sum_y = sum of y values +/// sum_y = sum of y values /// sum_xy = sum of (x * y) /// sum_xx = sum of x^2 values /// sum_yy = sum of y^2 values diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 10cc2ad33f563..376cf39745903 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -147,20 +147,11 @@ pub fn count_all_window() -> Expr { ```"#, standard_argument(name = "expression",) )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct Count { signature: Signature, } -impl Debug for Count { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("Count") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for Count { fn default() -> Self { Self::new() diff --git a/datafusion/functions-aggregate/src/covariance.rs b/datafusion/functions-aggregate/src/covariance.rs index e86d742db3d45..8252cf1b19c4e 100644 --- a/datafusion/functions-aggregate/src/covariance.rs +++ b/datafusion/functions-aggregate/src/covariance.rs @@ -17,19 +17,13 @@ //! [`CovarianceSample`]: covariance sample aggregations. -use arrow::datatypes::FieldRef; -use arrow::{ - array::{ArrayRef, Float64Array, UInt64Array}, - compute::kernels::cast, - datatypes::{DataType, Field}, -}; -use datafusion_common::{ - Result, ScalarValue, downcast_value, plan_err, unwrap_or_internal_err, -}; +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::cast::{as_float64_array, as_uint64_array}; +use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, function::{AccumulatorArgs, StateFieldsArgs}, - type_coercion::aggregates::NUMERICS, utils::format_state_name, }; use datafusion_functions_aggregate_common::stats::StatsType; @@ -69,21 +63,12 @@ make_udaf_expr_and_func!( standard_argument(name = "expression1", prefix = "First"), standard_argument(name = "expression2", prefix = "Second") )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct CovarianceSample { signature: Signature, aliases: Vec, } -impl Debug for CovarianceSample { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("CovarianceSample") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for CovarianceSample { fn default() -> Self { Self::new() @@ -94,7 +79,10 @@ impl CovarianceSample { pub fn new() -> Self { Self { aliases: vec![String::from("covar")], - signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), + signature: Signature::exact( + vec![DataType::Float64, DataType::Float64], + Volatility::Immutable, + ), } } } @@ -112,11 +100,7 @@ impl AggregateUDFImpl for CovarianceSample { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { - return plan_err!("Covariance requires numeric input types"); - } - + fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(DataType::Float64) } @@ -165,20 +149,11 @@ impl AggregateUDFImpl for CovarianceSample { standard_argument(name = "expression1", prefix = "First"), standard_argument(name = "expression2", prefix = "Second") )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct CovariancePopulation { signature: Signature, } -impl Debug for CovariancePopulation { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("CovariancePopulation") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for CovariancePopulation { fn default() -> Self { Self::new() @@ -188,7 +163,10 @@ impl Default for CovariancePopulation { impl CovariancePopulation { pub fn new() -> Self { Self { - signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), + signature: Signature::exact( + vec![DataType::Float64, DataType::Float64], + Volatility::Immutable, + ), } } } @@ -206,11 +184,7 @@ impl AggregateUDFImpl for CovariancePopulation { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { - return plan_err!("Covariance requires numeric input types"); - } - + fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(DataType::Float64) } @@ -304,30 +278,15 @@ impl Accumulator for CovarianceAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values1 = &cast(&values[0], &DataType::Float64)?; - let values2 = &cast(&values[1], &DataType::Float64)?; - - let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten(); - let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten(); + let values1 = as_float64_array(&values[0])?; + let values2 = as_float64_array(&values[1])?; - for i in 0..values1.len() { - let value1 = if values1.is_valid(i) { - arr1.next() - } else { - None - }; - let value2 = if values2.is_valid(i) { - arr2.next() - } else { - None + for (value1, value2) in values1.iter().zip(values2) { + let (value1, value2) = match (value1, value2) { + (Some(a), Some(b)) => (a, b), + _ => continue, }; - if value1.is_none() || value2.is_none() { - continue; - } - - let value1 = unwrap_or_internal_err!(value1); - let value2 = unwrap_or_internal_err!(value2); let new_count = self.count + 1; let delta1 = value1 - self.mean1; let new_mean1 = delta1 / new_count as f64 + self.mean1; @@ -345,29 +304,14 @@ impl Accumulator for CovarianceAccumulator { } fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values1 = &cast(&values[0], &DataType::Float64)?; - let values2 = &cast(&values[1], &DataType::Float64)?; - let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten(); - let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten(); - - for i in 0..values1.len() { - let value1 = if values1.is_valid(i) { - arr1.next() - } else { - None - }; - let value2 = if values2.is_valid(i) { - arr2.next() - } else { - None - }; - - if value1.is_none() || value2.is_none() { - continue; - } + let values1 = as_float64_array(&values[0])?; + let values2 = as_float64_array(&values[1])?; - let value1 = unwrap_or_internal_err!(value1); - let value2 = unwrap_or_internal_err!(value2); + for (value1, value2) in values1.iter().zip(values2) { + let (value1, value2) = match (value1, value2) { + (Some(a), Some(b)) => (a, b), + _ => continue, + }; let new_count = self.count - 1; let delta1 = self.mean1 - value1; @@ -386,10 +330,10 @@ impl Accumulator for CovarianceAccumulator { } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let counts = downcast_value!(states[0], UInt64Array); - let means1 = downcast_value!(states[1], Float64Array); - let means2 = downcast_value!(states[2], Float64Array); - let cs = downcast_value!(states[3], Float64Array); + let counts = as_uint64_array(&states[0])?; + let means1 = as_float64_array(&states[1])?; + let means2 = as_float64_array(&states[2])?; + let cs = as_float64_array(&states[3])?; for i in 0..counts.len() { let c = counts.value(i); diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 5f3490f535a46..b339479b35e9d 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -90,22 +90,12 @@ pub fn last_value(expression: Expr, order_by: Vec) -> Expr { ```"#, standard_argument(name = "expression",) )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct FirstValue { signature: Signature, is_input_pre_ordered: bool, } -impl Debug for FirstValue { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("FirstValue") - .field("name", &self.name()) - .field("signature", &self.signature) - .field("accumulator", &"") - .finish() - } -} - impl Default for FirstValue { fn default() -> Self { Self::new() @@ -1040,22 +1030,12 @@ impl Accumulator for FirstValueAccumulator { ```"#, standard_argument(name = "expression",) )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct LastValue { signature: Signature, is_input_pre_ordered: bool, } -impl Debug for LastValue { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("LastValue") - .field("name", &self.name()) - .field("signature", &self.signature) - .field("accumulator", &"") - .finish() - } -} - impl Default for LastValue { fn default() -> Self { Self::new() diff --git a/datafusion/functions-aggregate/src/grouping.rs b/datafusion/functions-aggregate/src/grouping.rs index 43218b1147d39..c7af2df4b10fc 100644 --- a/datafusion/functions-aggregate/src/grouping.rs +++ b/datafusion/functions-aggregate/src/grouping.rs @@ -18,7 +18,6 @@ //! Defines physical expressions that can evaluated at runtime during query execution use std::any::Any; -use std::fmt; use arrow::datatypes::Field; use arrow::datatypes::{DataType, FieldRef}; @@ -60,20 +59,11 @@ make_udaf_expr_and_func!( description = "Expression to evaluate whether data is aggregated across the specified column. Can be a constant, column, or function." ) )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct Grouping { signature: Signature, } -impl fmt::Debug for Grouping { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("Grouping") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for Grouping { fn default() -> Self { Self::new() diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index f364b785ddaed..1b9996220d882 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -24,8 +24,6 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] -// https://github.com/apache/datafusion/issues/18881 -#![deny(clippy::allow_attributes)] //! Aggregate Function packages for [DataFusion]. //! diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index f137ae0801f09..db769918d1353 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -85,20 +85,11 @@ make_udaf_expr_and_func!( /// If using the distinct variation, the memory usage will be similarly high if the /// cardinality is high as it stores all distinct values in memory before computing the /// result, but if cardinality is low then memory usage will also be lower. -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct Median { signature: Signature, } -impl Debug for Median { - fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { - f.debug_struct("Median") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for Median { fn default() -> Self { Self::new() @@ -566,10 +557,8 @@ impl Accumulator for DistinctMedianAccumulator { } fn evaluate(&mut self) -> Result { - let mut d = std::mem::take(&mut self.distinct_values.values) - .into_iter() - .map(|v| v.0) - .collect::>(); + let mut d: Vec = + self.distinct_values.values.iter().map(|v| v.0).collect(); let median = calculate_median::(&mut d); ScalarValue::new_primitive::(median, &self.data_type) } diff --git a/datafusion/functions-aggregate/src/percentile_cont.rs b/datafusion/functions-aggregate/src/percentile_cont.rs index 37f4ffd9d1707..1aa150b56350b 100644 --- a/datafusion/functions-aggregate/src/percentile_cont.rs +++ b/datafusion/functions-aggregate/src/percentile_cont.rs @@ -26,11 +26,11 @@ use arrow::array::{ use arrow::buffer::{OffsetBuffer, ScalarBuffer}; use arrow::{ array::{Array, ArrayRef, AsArray}, - datatypes::{ - ArrowNativeType, DataType, Field, FieldRef, Float16Type, Float32Type, Float64Type, - }, + datatypes::{DataType, Field, FieldRef, Float16Type, Float32Type, Float64Type}, }; +use num_traits::AsPrimitive; + use arrow::array::ArrowNativeTypeOp; use datafusion_common::internal_err; use datafusion_common::types::{NativeType, logical_float64}; @@ -68,7 +68,10 @@ use crate::utils::validate_percentile_expr; /// The interpolation formula: `lower + (upper - lower) * fraction` /// is computed as: `lower + ((upper - lower) * (fraction * PRECISION)) / PRECISION` /// to avoid floating-point operations on integer types while maintaining precision. -const INTERPOLATION_PRECISION: usize = 1_000_000; +/// +/// The interpolation arithmetic is performed in f64 and then cast back to the +/// native type to avoid overflowing Float16 intermediates. +const INTERPOLATION_PRECISION: f64 = 1_000_000.0; create_func!(PercentileCont, percentile_cont_udaf); @@ -389,7 +392,12 @@ impl PercentileContAccumulator { } } -impl Accumulator for PercentileContAccumulator { +impl Accumulator for PercentileContAccumulator +where + T: ArrowNumericType + Debug, + T::Native: Copy + AsPrimitive, + f64: AsPrimitive, +{ fn state(&mut self) -> Result> { // Convert `all_values` to `ListArray` and return a single List ScalarValue @@ -493,8 +501,11 @@ impl PercentileContGroupsAccumulator { } } -impl GroupsAccumulator - for PercentileContGroupsAccumulator +impl GroupsAccumulator for PercentileContGroupsAccumulator +where + T: ArrowNumericType + Send, + T::Native: Copy + AsPrimitive, + f64: AsPrimitive, { fn update_batch( &mut self, @@ -673,7 +684,12 @@ impl DistinctPercentileContAccumulator { } } -impl Accumulator for DistinctPercentileContAccumulator { +impl Accumulator for DistinctPercentileContAccumulator +where + T: ArrowNumericType + Debug, + T::Native: Copy + AsPrimitive, + f64: AsPrimitive, +{ fn state(&mut self) -> Result> { self.distinct_values.state() } @@ -728,7 +744,11 @@ impl Accumulator for DistinctPercentileContAccumula fn calculate_percentile( values: &mut [T::Native], percentile: f64, -) -> Option { +) -> Option +where + T::Native: Copy + AsPrimitive, + f64: AsPrimitive, +{ let cmp = |x: &T::Native, y: &T::Native| x.compare(*y); let len = values.len(); @@ -772,22 +792,47 @@ fn calculate_percentile( let (_, upper_value, _) = values.select_nth_unstable_by(upper_index, cmp); let upper_value = *upper_value; - // Linear interpolation using wrapping arithmetic - // We use wrapping operations here (matching the approach in median.rs) because: - // 1. Both values come from the input data, so diff is bounded by the value range - // 2. fraction is between 0 and 1, and INTERPOLATION_PRECISION is small enough - // to prevent overflow when combined with typical numeric ranges - // 3. The result is guaranteed to be between lower_value and upper_value - // 4. For floating-point types, wrapping ops behave the same as standard ops + // Linear interpolation. + // We compute a quantized interpolation weight using `INTERPOLATION_PRECISION` because: + // 1. Both values come from the input data, so (upper - lower) is bounded by the value range + // 2. fraction is between 0 and 1; quantizing it provides stable, predictable results + // 3. The result is guaranteed to be between lower_value and upper_value (modulo cast rounding) + // 4. Arithmetic is performed in f64 and cast back to avoid overflowing Float16 intermediates let fraction = index - (lower_index as f64); - let diff = upper_value.sub_wrapping(lower_value); - let interpolated = lower_value.add_wrapping( - diff.mul_wrapping(T::Native::usize_as( - (fraction * INTERPOLATION_PRECISION as f64) as usize, - )) - .div_wrapping(T::Native::usize_as(INTERPOLATION_PRECISION)), - ); - Some(interpolated) + let scaled = (fraction * INTERPOLATION_PRECISION) as usize; + let weight = scaled as f64 / INTERPOLATION_PRECISION; + + let lower_f: f64 = lower_value.as_(); + let upper_f: f64 = upper_value.as_(); + let interpolated_f = lower_f + (upper_f - lower_f) * weight; + Some(interpolated_f.as_()) } } } + +#[cfg(test)] +mod tests { + use super::calculate_percentile; + use half::f16; + + #[test] + fn f16_interpolation_does_not_overflow_to_nan() { + // Regression test for https://github.com/apache/datafusion/issues/18945 + // Interpolating between 0 and the max finite f16 value previously overflowed + // intermediate f16 computations and produced NaN. + let mut values = vec![f16::from_f32(0.0), f16::from_f32(65504.0)]; + let result = + calculate_percentile::(&mut values, 0.5) + .expect("non-empty input"); + let result_f = result.to_f32(); + assert!( + !result_f.is_nan(), + "expected non-NaN result, got {result_f}" + ); + // 0.5 percentile should be close to midpoint + assert!( + (result_f - 32752.0).abs() < 1.0, + "unexpected result {result_f}" + ); + } +} diff --git a/datafusion/functions-aggregate/src/regr.rs b/datafusion/functions-aggregate/src/regr.rs index bbc5567dab9d6..066fa3c5f32e7 100644 --- a/datafusion/functions-aggregate/src/regr.rs +++ b/datafusion/functions-aggregate/src/regr.rs @@ -17,20 +17,12 @@ //! Defines physical expressions that can evaluated at runtime during query execution -use arrow::array::Float64Array; use arrow::datatypes::FieldRef; -use arrow::{ - array::{ArrayRef, UInt64Array}, - compute::cast, - datatypes::DataType, - datatypes::Field, -}; -use datafusion_common::{ - HashMap, Result, ScalarValue, downcast_value, plan_err, unwrap_or_internal_err, -}; +use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; +use datafusion_common::cast::{as_float64_array, as_uint64_array}; +use datafusion_common::{HashMap, Result, ScalarValue}; use datafusion_doc::aggregate_doc_sections::DOC_SECTION_STATISTICAL; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, @@ -58,26 +50,20 @@ make_regr_udaf_expr_and_func!(regr_sxx, regr_sxx_udaf, RegrType::SXX); make_regr_udaf_expr_and_func!(regr_syy, regr_syy_udaf, RegrType::SYY); make_regr_udaf_expr_and_func!(regr_sxy, regr_sxy_udaf, RegrType::SXY); -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct Regr { signature: Signature, regr_type: RegrType, func_name: &'static str, } -impl Debug for Regr { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("regr") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Regr { pub fn new(regr_type: RegrType, func_name: &'static str) -> Self { Self { - signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), + signature: Signature::exact( + vec![DataType::Float64, DataType::Float64], + Volatility::Immutable, + ), regr_type, func_name, } @@ -468,11 +454,7 @@ impl AggregateUDFImpl for Regr { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { - return plan_err!("Covariance requires numeric input types"); - } - + fn return_type(&self, _arg_types: &[DataType]) -> Result { if matches!(self.regr_type, RegrType::Count) { Ok(DataType::UInt64) } else { @@ -606,32 +588,18 @@ impl Accumulator for RegrAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { // regr_slope(Y, X) calculates k in y = k*x + b - let values_y = &cast(&values[0], &DataType::Float64)?; - let values_x = &cast(&values[1], &DataType::Float64)?; - - let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten(); - let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten(); + let values_y = as_float64_array(&values[0])?; + let values_x = as_float64_array(&values[1])?; - for i in 0..values_y.len() { + for (value_y, value_x) in values_y.iter().zip(values_x) { // skip either x or y is NULL - let value_y = if values_y.is_valid(i) { - arr_y.next() - } else { - None - }; - let value_x = if values_x.is_valid(i) { - arr_x.next() - } else { - None + let (value_y, value_x) = match (value_y, value_x) { + (Some(y), Some(x)) => (y, x), + // skip either x or y is NULL + _ => continue, }; - if value_y.is_none() || value_x.is_none() { - continue; - } // Update states for regr_slope(y,x) [using cov_pop(x,y)/var_pop(x)] - let value_y = unwrap_or_internal_err!(value_y); - let value_x = unwrap_or_internal_err!(value_x); - self.count += 1; let delta_x = value_x - self.mean_x; let delta_y = value_y - self.mean_y; @@ -652,32 +620,18 @@ impl Accumulator for RegrAccumulator { } fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values_y = &cast(&values[0], &DataType::Float64)?; - let values_x = &cast(&values[1], &DataType::Float64)?; - - let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten(); - let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten(); + let values_y = as_float64_array(&values[0])?; + let values_x = as_float64_array(&values[1])?; - for i in 0..values_y.len() { + for (value_y, value_x) in values_y.iter().zip(values_x) { // skip either x or y is NULL - let value_y = if values_y.is_valid(i) { - arr_y.next() - } else { - None + let (value_y, value_x) = match (value_y, value_x) { + (Some(y), Some(x)) => (y, x), + // skip either x or y is NULL + _ => continue, }; - let value_x = if values_x.is_valid(i) { - arr_x.next() - } else { - None - }; - if value_y.is_none() || value_x.is_none() { - continue; - } // Update states for regr_slope(y,x) [using cov_pop(x,y)/var_pop(x)] - let value_y = unwrap_or_internal_err!(value_y); - let value_x = unwrap_or_internal_err!(value_x); - if self.count > 1 { self.count -= 1; let delta_x = value_x - self.mean_x; @@ -703,12 +657,12 @@ impl Accumulator for RegrAccumulator { } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let count_arr = downcast_value!(states[0], UInt64Array); - let mean_x_arr = downcast_value!(states[1], Float64Array); - let mean_y_arr = downcast_value!(states[2], Float64Array); - let m2_x_arr = downcast_value!(states[3], Float64Array); - let m2_y_arr = downcast_value!(states[4], Float64Array); - let algo_const_arr = downcast_value!(states[5], Float64Array); + let count_arr = as_uint64_array(&states[0])?; + let mean_x_arr = as_float64_array(&states[1])?; + let mean_y_arr = as_float64_array(&states[2])?; + let m2_x_arr = as_float64_array(&states[3])?; + let m2_y_arr = as_float64_array(&states[4])?; + let algo_const_arr = as_float64_array(&states[5])?; for i in 0..count_arr.len() { let count_b = count_arr.value(i); diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index 13eb5e1660b52..6f77e7df92547 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -18,7 +18,7 @@ //! Defines physical expressions that can evaluated at runtime during query execution use std::any::Any; -use std::fmt::{Debug, Formatter}; +use std::fmt::Debug; use std::hash::Hash; use std::mem::align_of_val; use std::sync::Arc; @@ -26,8 +26,8 @@ use std::sync::Arc; use arrow::array::Float64Array; use arrow::datatypes::FieldRef; use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; +use datafusion_common::ScalarValue; use datafusion_common::{Result, internal_err, not_impl_err}; -use datafusion_common::{ScalarValue, plan_err}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ @@ -62,21 +62,12 @@ make_udaf_expr_and_func!( standard_argument(name = "expression",) )] /// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct Stddev { signature: Signature, alias: Vec, } -impl Debug for Stddev { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Stddev") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for Stddev { fn default() -> Self { Self::new() @@ -87,7 +78,7 @@ impl Stddev { /// Create a new STDDEV aggregate function pub fn new() -> Self { Self { - signature: Signature::numeric(1, Volatility::Immutable), + signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), alias: vec!["stddev_samp".to_string()], } } @@ -180,20 +171,11 @@ make_udaf_expr_and_func!( standard_argument(name = "expression",) )] /// STDDEV_POP population aggregate expression -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct StddevPop { signature: Signature, } -impl Debug for StddevPop { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("StddevPop") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for StddevPop { fn default() -> Self { Self::new() @@ -204,7 +186,7 @@ impl StddevPop { /// Create a new STDDEV_POP aggregate function pub fn new() -> Self { Self { - signature: Signature::numeric(1, Volatility::Immutable), + signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), } } } @@ -249,11 +231,7 @@ impl AggregateUDFImpl for StddevPop { Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?)) } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { - return plan_err!("StddevPop requires numeric input types"); - } - + fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(DataType::Float64) } @@ -318,13 +296,8 @@ impl Accumulator for StddevAccumulator { fn evaluate(&mut self) -> Result { let variance = self.variance.evaluate()?; match variance { - ScalarValue::Float64(e) => { - if e.is_none() { - Ok(ScalarValue::Float64(None)) - } else { - Ok(ScalarValue::Float64(e.map(|f| f.sqrt()))) - } - } + ScalarValue::Float64(None) => Ok(ScalarValue::Float64(None)), + ScalarValue::Float64(Some(f)) => Ok(ScalarValue::Float64(Some(f.sqrt()))), _ => internal_err!("Variance should be f64"), } } diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 9e35bf0a2bea7..fb089ba4f9cea 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -22,10 +22,10 @@ use arrow::datatypes::{FieldRef, Float64Type}; use arrow::{ array::{Array, ArrayRef, BooleanArray, Float64Array, UInt64Array}, buffer::NullBuffer, - compute::kernels::cast, datatypes::{DataType, Field}, }; -use datafusion_common::{Result, ScalarValue, downcast_value, plan_err}; +use datafusion_common::cast::{as_float64_array, as_uint64_array}; +use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature, Volatility, @@ -62,21 +62,12 @@ make_udaf_expr_and_func!( syntax_example = "var(expression)", standard_argument(name = "expression", prefix = "Numeric") )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct VarianceSample { signature: Signature, aliases: Vec, } -impl Debug for VarianceSample { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("VarianceSample") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for VarianceSample { fn default() -> Self { Self::new() @@ -87,7 +78,7 @@ impl VarianceSample { pub fn new() -> Self { Self { aliases: vec![String::from("var_sample"), String::from("var_samp")], - signature: Signature::numeric(1, Volatility::Immutable), + signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), } } } @@ -171,21 +162,12 @@ impl AggregateUDFImpl for VarianceSample { syntax_example = "var_pop(expression)", standard_argument(name = "expression", prefix = "Numeric") )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct VariancePopulation { signature: Signature, aliases: Vec, } -impl Debug for VariancePopulation { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("VariancePopulation") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for VariancePopulation { fn default() -> Self { Self::new() @@ -196,7 +178,7 @@ impl VariancePopulation { pub fn new() -> Self { Self { aliases: vec![String::from("var_population")], - signature: Signature::numeric(1, Volatility::Immutable), + signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), } } } @@ -214,11 +196,7 @@ impl AggregateUDFImpl for VariancePopulation { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { - return plan_err!("Variance requires numeric input types"); - } - + fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(DataType::Float64) } @@ -278,6 +256,7 @@ impl AggregateUDFImpl for VariancePopulation { StatsType::Population, ))) } + fn documentation(&self) -> Option<&Documentation> { self.doc() } @@ -365,10 +344,8 @@ impl Accumulator for VarianceAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &cast(&values[0], &DataType::Float64)?; - let arr = downcast_value!(values, Float64Array).iter().flatten(); - - for value in arr { + let arr = as_float64_array(&values[0])?; + for value in arr.iter().flatten() { (self.count, self.mean, self.m2) = update(self.count, self.mean, self.m2, value) } @@ -377,10 +354,8 @@ impl Accumulator for VarianceAccumulator { } fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &cast(&values[0], &DataType::Float64)?; - let arr = downcast_value!(values, Float64Array).iter().flatten(); - - for value in arr { + let arr = as_float64_array(&values[0])?; + for value in arr.iter().flatten() { let new_count = self.count - 1; let delta1 = self.mean - value; let new_mean = delta1 / new_count as f64 + self.mean; @@ -396,9 +371,9 @@ impl Accumulator for VarianceAccumulator { } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let counts = downcast_value!(states[0], UInt64Array); - let means = downcast_value!(states[1], Float64Array); - let m2s = downcast_value!(states[2], Float64Array); + let counts = as_uint64_array(&states[0])?; + let means = as_float64_array(&states[1])?; + let m2s = as_float64_array(&states[2])?; for i in 0..counts.len() { let c = counts.value(i); @@ -533,8 +508,7 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 1, "single argument to update_batch"); - let values = &cast(&values[0], &DataType::Float64)?; - let values = downcast_value!(values, Float64Array); + let values = as_float64_array(&values[0])?; self.resize(total_num_groups); accumulate(group_indices, values, opt_filter, |group_index, value| { @@ -561,9 +535,9 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { ) -> Result<()> { assert_eq!(values.len(), 3, "two arguments to merge_batch"); // first batch is counts, second is partial means, third is partial m2s - let partial_counts = downcast_value!(values[0], UInt64Array); - let partial_means = downcast_value!(values[1], Float64Array); - let partial_m2s = downcast_value!(values[2], Float64Array); + let partial_counts = as_uint64_array(&values[0])?; + let partial_means = as_float64_array(&values[1])?; + let partial_m2s = as_float64_array(&values[2])?; self.resize(total_num_groups); Self::merge( @@ -633,9 +607,7 @@ impl DistinctVarianceAccumulator { impl Accumulator for DistinctVarianceAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let cast_values = cast(&values[0], &DataType::Float64)?; - self.distinct_values - .update_batch(vec![cast_values].as_ref()) + self.distinct_values.update_batch(values) } fn evaluate(&mut self) -> Result { diff --git a/datafusion/functions-nested/Cargo.toml b/datafusion/functions-nested/Cargo.toml index 6b0241a10a544..e5e601f30ae84 100644 --- a/datafusion/functions-nested/Cargo.toml +++ b/datafusion/functions-nested/Cargo.toml @@ -84,3 +84,15 @@ name = "array_slice" [[bench]] harness = false name = "map" + +[[bench]] +harness = false +name = "array_remove" + +[[bench]] +harness = false +name = "array_repeat" + +[[bench]] +harness = false +name = "array_set_ops" diff --git a/datafusion/functions-nested/benches/array_expression.rs b/datafusion/functions-nested/benches/array_expression.rs index 8d72ffa3c1cd5..ad9f565f4d643 100644 --- a/datafusion/functions-nested/benches/array_expression.rs +++ b/datafusion/functions-nested/benches/array_expression.rs @@ -15,11 +15,7 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -extern crate arrow; - -use crate::criterion::Criterion; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_expr::lit; use datafusion_functions_nested::expr_fn::{array_replace_all, make_array}; use std::hint::black_box; diff --git a/datafusion/functions-nested/benches/array_has.rs b/datafusion/functions-nested/benches/array_has.rs index a44a80c6ae63e..d96f26d410dd0 100644 --- a/datafusion/functions-nested/benches/array_has.rs +++ b/datafusion/functions-nested/benches/array_has.rs @@ -15,10 +15,9 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; - -use criterion::{BenchmarkId, Criterion}; +use criterion::{ + criterion_group, criterion_main, {BenchmarkId, Criterion}, +}; use datafusion_expr::lit; use datafusion_functions_nested::expr_fn::{ array_has, array_has_all, array_has_any, make_array, diff --git a/datafusion/functions-nested/benches/array_remove.rs b/datafusion/functions-nested/benches/array_remove.rs new file mode 100644 index 0000000000000..a494d322392a8 --- /dev/null +++ b/datafusion/functions-nested/benches/array_remove.rs @@ -0,0 +1,572 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayRef, BinaryArray, BooleanArray, Decimal128Array, FixedSizeBinaryArray, + Float64Array, Int64Array, ListArray, StringArray, +}; +use arrow::buffer::OffsetBuffer; +use arrow::datatypes::{DataType, Field}; +use criterion::{ + criterion_group, criterion_main, {BenchmarkId, Criterion}, +}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_functions_nested::remove::ArrayRemove; +use rand::Rng; +use rand::SeedableRng; +use rand::rngs::StdRng; +use std::hint::black_box; +use std::sync::Arc; + +const NUM_ROWS: usize = 10000; +const ARRAY_SIZES: &[usize] = &[10, 100, 500]; +const SEED: u64 = 42; +const NULL_DENSITY: f64 = 0.1; + +fn criterion_benchmark(c: &mut Criterion) { + // Test array_remove with different data types and array sizes + // TODO: Add performance tests for nested datatypes + bench_array_remove_int64(c); + bench_array_remove_f64(c); + bench_array_remove_strings(c); + bench_array_remove_binary(c); + bench_array_remove_boolean(c); + bench_array_remove_decimal64(c); + bench_array_remove_fixed_size_binary(c); +} + +fn bench_array_remove_int64(c: &mut Criterion) { + let mut group = c.benchmark_group("array_remove_int64"); + + for &array_size in ARRAY_SIZES { + let list_array = create_int64_list_array(NUM_ROWS, array_size, NULL_DENSITY); + let element_to_remove = ScalarValue::Int64(Some(1)); + let args = create_args(list_array.clone(), element_to_remove.clone()); + + group.bench_with_input( + BenchmarkId::new("remove", array_size), + &array_size, + |b, _| { + let udf = ArrayRemove::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new("arr", list_array.data_type().clone(), false) + .into(), + Field::new("el", DataType::Int64, false).into(), + ], + number_rows: NUM_ROWS, + return_field: Field::new( + "result", + list_array.data_type().clone(), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + + group.finish(); +} + +fn bench_array_remove_f64(c: &mut Criterion) { + let mut group = c.benchmark_group("array_remove_f64"); + + for &array_size in ARRAY_SIZES { + let list_array = create_f64_list_array(NUM_ROWS, array_size, NULL_DENSITY); + let element_to_remove = ScalarValue::Float64(Some(1.0)); + let args = create_args(list_array.clone(), element_to_remove.clone()); + + group.bench_with_input( + BenchmarkId::new("remove", array_size), + &array_size, + |b, _| { + let udf = ArrayRemove::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new("arr", list_array.data_type().clone(), false) + .into(), + Field::new("el", DataType::Float64, false).into(), + ], + number_rows: NUM_ROWS, + return_field: Field::new( + "result", + list_array.data_type().clone(), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + + group.finish(); +} + +fn bench_array_remove_strings(c: &mut Criterion) { + let mut group = c.benchmark_group("array_remove_strings"); + + for &array_size in ARRAY_SIZES { + let list_array = create_string_list_array(NUM_ROWS, array_size, NULL_DENSITY); + let element_to_remove = ScalarValue::Utf8(Some("value_1".to_string())); + let args = create_args(list_array.clone(), element_to_remove.clone()); + + group.bench_with_input( + BenchmarkId::new("remove", array_size), + &array_size, + |b, _| { + let udf = ArrayRemove::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new("arr", list_array.data_type().clone(), false) + .into(), + Field::new("el", DataType::Utf8, false).into(), + ], + number_rows: NUM_ROWS, + return_field: Field::new( + "result", + list_array.data_type().clone(), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + + group.finish(); +} + +fn bench_array_remove_binary(c: &mut Criterion) { + let mut group = c.benchmark_group("array_remove_binary"); + + for &array_size in ARRAY_SIZES { + let list_array = create_binary_list_array(NUM_ROWS, array_size, NULL_DENSITY); + let element_to_remove = ScalarValue::Binary(Some(b"value_1".to_vec())); + let args = create_args(list_array.clone(), element_to_remove.clone()); + + group.bench_with_input( + BenchmarkId::new("remove", array_size), + &array_size, + |b, _| { + let udf = ArrayRemove::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new("arr", list_array.data_type().clone(), false) + .into(), + Field::new("el", DataType::Binary, false).into(), + ], + number_rows: NUM_ROWS, + return_field: Field::new( + "result", + list_array.data_type().clone(), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + + group.finish(); +} + +fn bench_array_remove_boolean(c: &mut Criterion) { + let mut group = c.benchmark_group("array_remove_boolean"); + + for &array_size in ARRAY_SIZES { + let list_array = create_boolean_list_array(NUM_ROWS, array_size, NULL_DENSITY); + let element_to_remove = ScalarValue::Boolean(Some(true)); + let args = create_args(list_array.clone(), element_to_remove.clone()); + + group.bench_with_input( + BenchmarkId::new("remove", array_size), + &array_size, + |b, _| { + let udf = ArrayRemove::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new("arr", list_array.data_type().clone(), false) + .into(), + Field::new("el", DataType::Boolean, false).into(), + ], + number_rows: NUM_ROWS, + return_field: Field::new( + "result", + list_array.data_type().clone(), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + + group.finish(); +} + +fn bench_array_remove_decimal64(c: &mut Criterion) { + let mut group = c.benchmark_group("array_remove_decimal64"); + + for &array_size in ARRAY_SIZES { + let list_array = create_decimal64_list_array(NUM_ROWS, array_size, NULL_DENSITY); + let element_to_remove = ScalarValue::Decimal128(Some(100_i128), 10, 2); + let args = create_args(list_array.clone(), element_to_remove.clone()); + + group.bench_with_input( + BenchmarkId::new("remove", array_size), + &array_size, + |b, _| { + let udf = ArrayRemove::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new("arr", list_array.data_type().clone(), false) + .into(), + Field::new("el", DataType::Decimal128(10, 2), false) + .into(), + ], + number_rows: NUM_ROWS, + return_field: Field::new( + "result", + list_array.data_type().clone(), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + + group.finish(); +} + +fn bench_array_remove_fixed_size_binary(c: &mut Criterion) { + let mut group = c.benchmark_group("array_remove_fixed_size_binary"); + + for &array_size in ARRAY_SIZES { + let list_array = + create_fixed_size_binary_list_array(NUM_ROWS, array_size, NULL_DENSITY); + let element_to_remove = ScalarValue::FixedSizeBinary(16, Some(vec![1u8; 16])); + let args = create_args(list_array.clone(), element_to_remove.clone()); + + group.bench_with_input( + BenchmarkId::new("remove", array_size), + &array_size, + |b, _| { + let udf = ArrayRemove::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new("arr", list_array.data_type().clone(), false) + .into(), + Field::new("el", DataType::FixedSizeBinary(16), false) + .into(), + ], + number_rows: NUM_ROWS, + return_field: Field::new( + "result", + list_array.data_type().clone(), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + + group.finish(); +} + +fn create_args(list_array: ArrayRef, element: ScalarValue) -> Vec { + vec![ + ColumnarValue::Array(list_array), + ColumnarValue::Scalar(element), + ] +} + +fn create_int64_list_array( + num_rows: usize, + array_size: usize, + null_density: f64, +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let values = (0..num_rows * array_size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(rng.random_range(0..array_size as i64)) + } + }) + .collect::(); + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Int64, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} + +fn create_f64_list_array( + num_rows: usize, + array_size: usize, + null_density: f64, +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let values = (0..num_rows * array_size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(rng.random_range(0.0..array_size as f64)) + } + }) + .collect::(); + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Float64, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} + +fn create_string_list_array( + num_rows: usize, + array_size: usize, + null_density: f64, +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let values = (0..num_rows * array_size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + let idx = rng.random_range(0..array_size); + Some(format!("value_{idx}")) + } + }) + .collect::(); + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Utf8, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} + +fn create_binary_list_array( + num_rows: usize, + array_size: usize, + null_density: f64, +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let values = (0..num_rows * array_size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + let idx = rng.random_range(0..array_size); + Some(format!("value_{idx}").into_bytes()) + } + }) + .collect::(); + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Binary, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} + +fn create_boolean_list_array( + num_rows: usize, + array_size: usize, + null_density: f64, +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let values = (0..num_rows * array_size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(rng.random::()) + } + }) + .collect::(); + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Boolean, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} + +fn create_decimal64_list_array( + num_rows: usize, + array_size: usize, + null_density: f64, +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let values = (0..num_rows * array_size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(rng.random_range(0..array_size) as i128 * 100) + } + }) + .collect::() + .with_precision_and_scale(10, 2) + .unwrap(); + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Decimal128(10, 2), true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} + +fn create_fixed_size_binary_list_array( + num_rows: usize, + array_size: usize, + null_density: f64, +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let mut buffer = Vec::with_capacity(num_rows * array_size * 16); + let mut null_buffer = Vec::with_capacity(num_rows * array_size); + for _ in 0..num_rows * array_size { + if rng.random::() < null_density { + null_buffer.push(false); + buffer.extend_from_slice(&[0u8; 16]); + } else { + null_buffer.push(true); + let mut bytes = [0u8; 16]; + rng.fill(&mut bytes); + buffer.extend_from_slice(&bytes); + } + } + let nulls = arrow::buffer::NullBuffer::from_iter(null_buffer.iter().copied()); + let values = FixedSizeBinaryArray::new(16, buffer.into(), Some(nulls)); + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::FixedSizeBinary(16), true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-nested/benches/array_repeat.rs b/datafusion/functions-nested/benches/array_repeat.rs new file mode 100644 index 0000000000000..0ce8db00ceb8f --- /dev/null +++ b/datafusion/functions-nested/benches/array_repeat.rs @@ -0,0 +1,476 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, BooleanArray, Float64Array, Int64Array, ListArray}; +use arrow::buffer::OffsetBuffer; +use arrow::datatypes::{DataType, Field}; +use criterion::{ + criterion_group, criterion_main, {BenchmarkId, Criterion}, +}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_functions_nested::repeat::ArrayRepeat; +use rand::Rng; +use rand::SeedableRng; +use rand::rngs::StdRng; +use std::hint::black_box; +use std::sync::Arc; + +const NUM_ROWS: &[usize] = &[100, 1000, 10000]; +const REPEAT_COUNTS: &[u64] = &[5, 50]; +const SEED: u64 = 42; +const NULL_DENSITY: f64 = 0.1; + +fn criterion_benchmark(c: &mut Criterion) { + // Test array_repeat with different element types + bench_array_repeat_int64(c); + bench_array_repeat_string(c); + bench_array_repeat_float64(c); + bench_array_repeat_boolean(c); + + // Test array_repeat with list element (nested arrays) + bench_array_repeat_nested_int64_list(c); + bench_array_repeat_nested_string_list(c); +} + +fn bench_array_repeat_int64(c: &mut Criterion) { + let mut group = c.benchmark_group("array_repeat_int64"); + + for &num_rows in NUM_ROWS { + let element_array = create_int64_array(num_rows, NULL_DENSITY); + + for &repeat_count in REPEAT_COUNTS { + let args = vec![ + ColumnarValue::Array(element_array.clone()), + ColumnarValue::Scalar(ScalarValue::from(repeat_count)), + ]; + + group.bench_with_input( + BenchmarkId::new(format!("repeat_{repeat_count}_count"), num_rows), + &num_rows, + |b, _| { + let udf = ArrayRepeat::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new("element", DataType::Int64, false).into(), + Field::new("count", DataType::UInt64, false).into(), + ], + number_rows: num_rows, + return_field: Field::new( + "result", + DataType::List(Arc::new(Field::new_list_field( + DataType::Int64, + true, + ))), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + } + + group.finish(); +} + +fn bench_array_repeat_string(c: &mut Criterion) { + let mut group = c.benchmark_group("array_repeat_string"); + + for &num_rows in NUM_ROWS { + let element_array = create_string_array(num_rows, NULL_DENSITY); + + for &repeat_count in REPEAT_COUNTS { + let args = vec![ + ColumnarValue::Array(element_array.clone()), + ColumnarValue::Scalar(ScalarValue::from(repeat_count)), + ]; + + group.bench_with_input( + BenchmarkId::new(format!("repeat_{repeat_count}_count"), num_rows), + &num_rows, + |b, _| { + let udf = ArrayRepeat::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new("element", DataType::Utf8, false).into(), + Field::new("count", DataType::UInt64, false).into(), + ], + number_rows: num_rows, + return_field: Field::new( + "result", + DataType::List(Arc::new(Field::new_list_field( + DataType::Utf8, + true, + ))), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + } + + group.finish(); +} + +fn bench_array_repeat_nested_int64_list(c: &mut Criterion) { + let mut group = c.benchmark_group("array_repeat_nested_int64"); + + for &num_rows in NUM_ROWS { + let list_array = create_int64_list_array(num_rows, 5, NULL_DENSITY); + + for &repeat_count in REPEAT_COUNTS { + let args = vec![ + ColumnarValue::Array(list_array.clone()), + ColumnarValue::Scalar(ScalarValue::from(repeat_count)), + ]; + + group.bench_with_input( + BenchmarkId::new(format!("repeat_{repeat_count}_count"), num_rows), + &num_rows, + |b, _| { + let udf = ArrayRepeat::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new( + "element", + list_array.data_type().clone(), + false, + ) + .into(), + Field::new("count", DataType::UInt64, false).into(), + ], + number_rows: num_rows, + return_field: Field::new( + "result", + DataType::List(Arc::new(Field::new_list_field( + list_array.data_type().clone(), + true, + ))), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + } + + group.finish(); +} + +fn bench_array_repeat_float64(c: &mut Criterion) { + let mut group = c.benchmark_group("array_repeat_float64"); + + for &num_rows in NUM_ROWS { + let element_array = create_float64_array(num_rows, NULL_DENSITY); + + for &repeat_count in REPEAT_COUNTS { + let args = vec![ + ColumnarValue::Array(element_array.clone()), + ColumnarValue::Scalar(ScalarValue::from(repeat_count)), + ]; + + group.bench_with_input( + BenchmarkId::new(format!("repeat_{repeat_count}_count"), num_rows), + &num_rows, + |b, _| { + let udf = ArrayRepeat::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new("element", DataType::Float64, false) + .into(), + Field::new("count", DataType::UInt64, false).into(), + ], + number_rows: num_rows, + return_field: Field::new( + "result", + DataType::List(Arc::new(Field::new_list_field( + DataType::Float64, + true, + ))), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + } + + group.finish(); +} + +fn bench_array_repeat_boolean(c: &mut Criterion) { + let mut group = c.benchmark_group("array_repeat_boolean"); + + for &num_rows in NUM_ROWS { + let element_array = create_boolean_array(num_rows, NULL_DENSITY); + + for &repeat_count in REPEAT_COUNTS { + let args = vec![ + ColumnarValue::Array(element_array.clone()), + ColumnarValue::Scalar(ScalarValue::from(repeat_count)), + ]; + + group.bench_with_input( + BenchmarkId::new(format!("repeat_{repeat_count}_count"), num_rows), + &num_rows, + |b, _| { + let udf = ArrayRepeat::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new("element", DataType::Boolean, false) + .into(), + Field::new("count", DataType::UInt64, false).into(), + ], + number_rows: num_rows, + return_field: Field::new( + "result", + DataType::List(Arc::new(Field::new_list_field( + DataType::Boolean, + true, + ))), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + } + + group.finish(); +} + +fn bench_array_repeat_nested_string_list(c: &mut Criterion) { + let mut group = c.benchmark_group("array_repeat_nested_string"); + + for &num_rows in NUM_ROWS { + let list_array = create_string_list_array(num_rows, 5, NULL_DENSITY); + + for &repeat_count in REPEAT_COUNTS { + let args = vec![ + ColumnarValue::Array(list_array.clone()), + ColumnarValue::Scalar(ScalarValue::from(repeat_count)), + ]; + + group.bench_with_input( + BenchmarkId::new(format!("repeat_{repeat_count}_count"), num_rows), + &num_rows, + |b, _| { + let udf = ArrayRepeat::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new( + "element", + list_array.data_type().clone(), + false, + ) + .into(), + Field::new("count", DataType::UInt64, false).into(), + ], + number_rows: num_rows, + return_field: Field::new( + "result", + DataType::List(Arc::new(Field::new_list_field( + list_array.data_type().clone(), + true, + ))), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + } + + group.finish(); +} + +fn create_int64_array(num_rows: usize, null_density: f64) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let values = (0..num_rows) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(rng.random_range(0..1000)) + } + }) + .collect::(); + + Arc::new(values) +} + +fn create_string_array(num_rows: usize, null_density: f64) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + use arrow::array::StringArray; + + let values = (0..num_rows) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(format!("value_{}", rng.random_range(0..100))) + } + }) + .collect::(); + + Arc::new(values) +} + +fn create_int64_list_array( + num_rows: usize, + array_size: usize, + null_density: f64, +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let values = (0..num_rows * array_size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(rng.random_range(0..1000)) + } + }) + .collect::(); + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Int64, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} + +fn create_float64_array(num_rows: usize, null_density: f64) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let values = (0..num_rows) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(rng.random_range(0.0..1000.0)) + } + }) + .collect::(); + + Arc::new(values) +} + +fn create_boolean_array(num_rows: usize, null_density: f64) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let values = (0..num_rows) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(rng.random()) + } + }) + .collect::(); + + Arc::new(values) +} + +fn create_string_list_array( + num_rows: usize, + array_size: usize, + null_density: f64, +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + use arrow::array::StringArray; + + let values = (0..num_rows * array_size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(format!("value_{}", rng.random_range(0..100))) + } + }) + .collect::(); + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Utf8, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-nested/benches/array_reverse.rs b/datafusion/functions-nested/benches/array_reverse.rs index 92a65128fe6ba..0c37296188315 100644 --- a/datafusion/functions-nested/benches/array_reverse.rs +++ b/datafusion/functions-nested/benches/array_reverse.rs @@ -15,18 +15,14 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -extern crate arrow; - use std::{hint::black_box, sync::Arc}; -use crate::criterion::Criterion; use arrow::{ array::{ArrayRef, FixedSizeListArray, Int32Array, ListArray, ListViewArray}, buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}, datatypes::{DataType, Field}, }; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_functions_nested::reverse::array_reverse_inner; fn array_reverse(array: &ArrayRef) -> ArrayRef { diff --git a/datafusion/functions-nested/benches/array_set_ops.rs b/datafusion/functions-nested/benches/array_set_ops.rs new file mode 100644 index 0000000000000..e3146921d7fe1 --- /dev/null +++ b/datafusion/functions-nested/benches/array_set_ops.rs @@ -0,0 +1,259 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, Int64Array, ListArray}; +use arrow::buffer::OffsetBuffer; +use arrow::datatypes::{DataType, Field}; +use criterion::{ + criterion_group, criterion_main, {BenchmarkId, Criterion}, +}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_functions_nested::set_ops::{ArrayDistinct, ArrayIntersect, ArrayUnion}; +use rand::SeedableRng; +use rand::prelude::SliceRandom; +use rand::rngs::StdRng; +use std::collections::HashSet; +use std::hint::black_box; +use std::sync::Arc; + +const NUM_ROWS: usize = 1000; +const ARRAY_SIZES: &[usize] = &[10, 50, 100]; +const SEED: u64 = 42; + +fn criterion_benchmark(c: &mut Criterion) { + bench_array_union(c); + bench_array_intersect(c); + bench_array_distinct(c); +} + +fn invoke_udf(udf: &impl ScalarUDFImpl, array1: &ArrayRef, array2: &ArrayRef) { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(array1.clone()), + ColumnarValue::Array(array2.clone()), + ], + arg_fields: vec![ + Field::new("arr1", array1.data_type().clone(), false).into(), + Field::new("arr2", array2.data_type().clone(), false).into(), + ], + number_rows: NUM_ROWS, + return_field: Field::new("result", array1.data_type().clone(), false).into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ); +} + +fn bench_array_union(c: &mut Criterion) { + let mut group = c.benchmark_group("array_union"); + let udf = ArrayUnion::new(); + + for (overlap_label, overlap_ratio) in &[("high_overlap", 0.8), ("low_overlap", 0.2)] { + for &array_size in ARRAY_SIZES { + let (array1, array2) = + create_arrays_with_overlap(NUM_ROWS, array_size, *overlap_ratio); + group.bench_with_input( + BenchmarkId::new(*overlap_label, array_size), + &array_size, + |b, _| b.iter(|| invoke_udf(&udf, &array1, &array2)), + ); + } + } + + group.finish(); +} + +fn bench_array_intersect(c: &mut Criterion) { + let mut group = c.benchmark_group("array_intersect"); + let udf = ArrayIntersect::new(); + + for (overlap_label, overlap_ratio) in &[("high_overlap", 0.8), ("low_overlap", 0.2)] { + for &array_size in ARRAY_SIZES { + let (array1, array2) = + create_arrays_with_overlap(NUM_ROWS, array_size, *overlap_ratio); + group.bench_with_input( + BenchmarkId::new(*overlap_label, array_size), + &array_size, + |b, _| b.iter(|| invoke_udf(&udf, &array1, &array2)), + ); + } + } + + group.finish(); +} + +fn bench_array_distinct(c: &mut Criterion) { + let mut group = c.benchmark_group("array_distinct"); + let udf = ArrayDistinct::new(); + + for (duplicate_label, duplicate_ratio) in + &[("high_duplicate", 0.8), ("low_duplicate", 0.2)] + { + for &array_size in ARRAY_SIZES { + let array = + create_array_with_duplicates(NUM_ROWS, array_size, *duplicate_ratio); + group.bench_with_input( + BenchmarkId::new(*duplicate_label, array_size), + &array_size, + |b, _| { + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Array(array.clone())], + arg_fields: vec![ + Field::new("arr", array.data_type().clone(), false) + .into(), + ], + number_rows: NUM_ROWS, + return_field: Field::new( + "result", + array.data_type().clone(), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + } + + group.finish(); +} + +fn create_arrays_with_overlap( + num_rows: usize, + array_size: usize, + overlap_ratio: f64, +) -> (ArrayRef, ArrayRef) { + assert!((0.0..=1.0).contains(&overlap_ratio)); + let overlap_count = ((array_size as f64) * overlap_ratio).round() as usize; + + let mut rng = StdRng::seed_from_u64(SEED); + + let mut values1 = Vec::with_capacity(num_rows * array_size); + let mut values2 = Vec::with_capacity(num_rows * array_size); + + for row in 0..num_rows { + let base = (row as i64) * (array_size as i64) * 2; + + for i in 0..array_size { + values1.push(base + i as i64); + } + + let mut positions: Vec = (0..array_size).collect(); + positions.shuffle(&mut rng); + + let overlap_positions: HashSet<_> = + positions[..overlap_count].iter().copied().collect(); + + for i in 0..array_size { + if overlap_positions.contains(&i) { + values2.push(base + i as i64); + } else { + values2.push(base + array_size as i64 + i as i64); + } + } + } + + let values1 = Int64Array::from(values1); + let values2 = Int64Array::from(values2); + + let field = Arc::new(Field::new("item", DataType::Int64, true)); + + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + let array1 = Arc::new( + ListArray::try_new( + field.clone(), + OffsetBuffer::new(offsets.clone().into()), + Arc::new(values1), + None, + ) + .unwrap(), + ); + + let array2 = Arc::new( + ListArray::try_new( + field, + OffsetBuffer::new(offsets.into()), + Arc::new(values2), + None, + ) + .unwrap(), + ); + + (array1, array2) +} + +fn create_array_with_duplicates( + num_rows: usize, + array_size: usize, + duplicate_ratio: f64, +) -> ArrayRef { + assert!((0.0..=1.0).contains(&duplicate_ratio)); + let unique_count = ((array_size as f64) * (1.0 - duplicate_ratio)).round() as usize; + let duplicate_count = array_size - unique_count; + + let mut rng = StdRng::seed_from_u64(SEED); + let mut values = Vec::with_capacity(num_rows * array_size); + + for row in 0..num_rows { + let base = (row as i64) * (array_size as i64) * 2; + + // Add unique values first + for i in 0..unique_count { + values.push(base + i as i64); + } + + // Fill the rest with duplicates randomly picked from the unique values + let mut unique_indices: Vec = + (0..unique_count).map(|i| base + i as i64).collect(); + unique_indices.shuffle(&mut rng); + + for i in 0..duplicate_count { + values.push(unique_indices[i % unique_count]); + } + } + + let values = Int64Array::from(values); + let field = Arc::new(Field::new("item", DataType::Int64, true)); + + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + field, + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-nested/benches/array_slice.rs b/datafusion/functions-nested/benches/array_slice.rs index 858e438996190..b95fe47575e53 100644 --- a/datafusion/functions-nested/benches/array_slice.rs +++ b/datafusion/functions-nested/benches/array_slice.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::{ Int64Array, ListArray, ListViewArray, NullBufferBuilder, PrimitiveArray, }; diff --git a/datafusion/functions-nested/benches/map.rs b/datafusion/functions-nested/benches/map.rs index 75b4045a193d5..e50c4659b17cd 100644 --- a/datafusion/functions-nested/benches/map.rs +++ b/datafusion/functions-nested/benches/map.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::{Int32Array, ListArray, StringArray}; use arrow::buffer::{OffsetBuffer, ScalarBuffer}; use arrow::datatypes::{DataType, Field}; diff --git a/datafusion/functions-nested/src/array_has.rs b/datafusion/functions-nested/src/array_has.rs index 97671d4a95f23..abc0e7406b2c9 100644 --- a/datafusion/functions-nested/src/array_has.rs +++ b/datafusion/functions-nested/src/array_has.rs @@ -262,7 +262,7 @@ impl<'a> TryFrom<&'a dyn Array> for ArrayWrapper<'a> { DataType::FixedSizeList(_, _) => Ok(ArrayWrapper::FixedSizeList( as_fixed_size_list_array(value)?, )), - _ => exec_err!("array_has does not support type '{:?}'.", value.data_type()), + _ => exec_err!("array_has does not support type '{}'.", value.data_type()), } } } diff --git a/datafusion/functions-nested/src/except.rs b/datafusion/functions-nested/src/except.rs index a8ac997ce33ec..19a4e9573e35b 100644 --- a/datafusion/functions-nested/src/except.rs +++ b/datafusion/functions-nested/src/except.rs @@ -15,11 +15,12 @@ // specific language governing permissions and limitations // under the License. -//! [`ScalarUDFImpl`] definitions for array_except function. +//! [`ScalarUDFImpl`] definition for array_except function. use crate::utils::{check_datatypes, make_scalar_function}; +use arrow::array::new_null_array; use arrow::array::{Array, ArrayRef, GenericListArray, OffsetSizeTrait, cast::AsArray}; -use arrow::buffer::OffsetBuffer; +use arrow::buffer::{NullBuffer, OffsetBuffer}; use arrow::datatypes::{DataType, FieldRef}; use arrow::row::{RowConverter, SortField}; use datafusion_common::utils::{ListCoercion, take_function_args}; @@ -28,6 +29,7 @@ use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; +use itertools::Itertools; use std::any::Any; use std::sync::Arc; @@ -104,8 +106,11 @@ impl ScalarUDFImpl for ArrayExcept { } fn return_type(&self, arg_types: &[DataType]) -> Result { - match (&arg_types[0].clone(), &arg_types[1].clone()) { - (DataType::Null, _) | (_, DataType::Null) => Ok(arg_types[0].clone()), + match (&arg_types[0], &arg_types[1]) { + (DataType::Null, DataType::Null) => { + Ok(DataType::new_list(DataType::Null, true)) + } + (DataType::Null, dt) | (dt, DataType::Null) => Ok(dt.clone()), (dt, _) => Ok(dt.clone()), } } @@ -129,8 +134,16 @@ impl ScalarUDFImpl for ArrayExcept { fn array_except_inner(args: &[ArrayRef]) -> Result { let [array1, array2] = take_function_args("array_except", args)?; + let len = array1.len(); match (array1.data_type(), array2.data_type()) { - (DataType::Null, _) | (_, DataType::Null) => Ok(array1.to_owned()), + (DataType::Null, DataType::Null) => Ok(new_null_array( + &DataType::new_list(DataType::Null, true), + len, + )), + (DataType::Null, dt @ DataType::List(_)) + | (DataType::Null, dt @ DataType::LargeList(_)) + | (dt @ DataType::List(_), DataType::Null) + | (dt @ DataType::LargeList(_), DataType::Null) => Ok(new_null_array(dt, len)), (DataType::List(field), DataType::List(_)) => { check_datatypes("array_except", &[array1, array2])?; let list1 = array1.as_list::(); @@ -169,15 +182,27 @@ fn general_except( let mut rows = Vec::with_capacity(l_values.num_rows()); let mut dedup = HashSet::new(); - for (l_w, r_w) in l.offsets().windows(2).zip(r.offsets().windows(2)) { - let l_slice = l_w[0].as_usize()..l_w[1].as_usize(); - let r_slice = r_w[0].as_usize()..r_w[1].as_usize(); - for i in r_slice { - let right_row = r_values.row(i); + let nulls = NullBuffer::union(l.nulls(), r.nulls()); + + let l_offsets_iter = l.offsets().iter().tuple_windows(); + let r_offsets_iter = r.offsets().iter().tuple_windows(); + for (list_index, ((l_start, l_end), (r_start, r_end))) in + l_offsets_iter.zip(r_offsets_iter).enumerate() + { + if nulls + .as_ref() + .is_some_and(|nulls| nulls.is_null(list_index)) + { + offsets.push(OffsetSize::usize_as(rows.len())); + continue; + } + + for element_index in r_start.as_usize()..r_end.as_usize() { + let right_row = r_values.row(element_index); dedup.insert(right_row); } - for i in l_slice { - let left_row = l_values.row(i); + for element_index in l_start.as_usize()..l_end.as_usize() { + let left_row = l_values.row(element_index); if dedup.insert(left_row) { rows.push(left_row); } @@ -192,7 +217,7 @@ fn general_except( field.to_owned(), OffsetBuffer::new(offsets.into()), values.to_owned(), - l.nulls().cloned(), + nulls, )) } else { internal_err!("array_except failed to convert rows") diff --git a/datafusion/functions-nested/src/flatten.rs b/datafusion/functions-nested/src/flatten.rs index 33b3e102ae0bc..8c21348507d26 100644 --- a/datafusion/functions-nested/src/flatten.rs +++ b/datafusion/functions-nested/src/flatten.rs @@ -208,7 +208,7 @@ fn flatten_inner(args: &[ArrayRef]) -> Result { } Null => Ok(Arc::clone(array)), _ => { - exec_err!("flatten does not support type '{:?}'", array.data_type()) + exec_err!("flatten does not support type '{}'", array.data_type()) } } } diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index ed9e1af4eaa8f..9ac6911236e40 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -24,8 +24,6 @@ // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -// https://github.com/apache/datafusion/issues/18881 -#![deny(clippy::allow_attributes)] //! Nested type Functions for [DataFusion]. //! diff --git a/datafusion/functions-nested/src/make_array.rs b/datafusion/functions-nested/src/make_array.rs index 410a545853acf..bc899126fb643 100644 --- a/datafusion/functions-nested/src/make_array.rs +++ b/datafusion/functions-nested/src/make_array.rs @@ -31,7 +31,6 @@ use arrow::datatypes::DataType; use arrow::datatypes::{DataType::Null, Field}; use datafusion_common::utils::SingleRowListArrayBuilder; use datafusion_common::{Result, plan_err}; -use datafusion_expr::TypeSignature; use datafusion_expr::binary::{ try_type_union_resolution_with_struct, type_union_resolution, }; @@ -80,10 +79,7 @@ impl Default for MakeArray { impl MakeArray { pub fn new() -> Self { Self { - signature: Signature::one_of( - vec![TypeSignature::Nullary, TypeSignature::UserDefined], - Volatility::Immutable, - ), + signature: Signature::user_defined(Volatility::Immutable), aliases: vec![String::from("make_list")], } } @@ -125,7 +121,11 @@ impl ScalarUDFImpl for MakeArray { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - coerce_types_inner(arg_types, self.name()) + if arg_types.is_empty() { + Ok(vec![]) + } else { + coerce_types_inner(arg_types, self.name()) + } } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions-nested/src/map.rs b/datafusion/functions-nested/src/map.rs index a96bbc0589e3c..7df131cf5e27e 100644 --- a/datafusion/functions-nested/src/map.rs +++ b/datafusion/functions-nested/src/map.rs @@ -119,7 +119,7 @@ fn get_first_array_ref(columnar_value: &ColumnarValue) -> Result { ScalarValue::List(array) => Ok(array.value(0)), ScalarValue::LargeList(array) => Ok(array.value(0)), ScalarValue::FixedSizeList(array) => Ok(array.value(0)), - _ => exec_err!("Expected array, got {:?}", value), + _ => exec_err!("Expected array, got {}", value), }, ColumnarValue::Array(array) => Ok(array.to_owned()), } diff --git a/datafusion/functions-nested/src/planner.rs b/datafusion/functions-nested/src/planner.rs index afb18a44f48ab..e96fdb7d4baca 100644 --- a/datafusion/functions-nested/src/planner.rs +++ b/datafusion/functions-nested/src/planner.rs @@ -37,7 +37,7 @@ use std::sync::Arc; use crate::map::map_udf; use crate::{ - array_has::{array_has_all, array_has_udf}, + array_has::array_has_all, expr_fn::{array_append, array_concat, array_prepend}, extract::{array_element, array_slice}, make_array::make_array, @@ -120,20 +120,6 @@ impl ExprPlanner for NestedFunctionPlanner { ScalarFunction::new_udf(map_udf(), vec![keys, values]), ))) } - - fn plan_any(&self, expr: RawBinaryExpr) -> Result> { - if expr.op == BinaryOperator::Eq { - Ok(PlannerResult::Planned(Expr::ScalarFunction( - ScalarFunction::new_udf( - array_has_udf(), - // left and right are reversed here so `needle=any(haystack)` -> `array_has(haystack, needle)` - vec![expr.right, expr.left], - ), - ))) - } else { - plan_err!("Unsupported AnyOp: '{}', only '=' is supported", expr.op) - } - } } #[derive(Debug)] diff --git a/datafusion/functions-nested/src/position.rs b/datafusion/functions-nested/src/position.rs index d085fa29cc7e1..fc3a295963ce2 100644 --- a/datafusion/functions-nested/src/position.rs +++ b/datafusion/functions-nested/src/position.rs @@ -164,7 +164,6 @@ fn general_position_dispatch(args: &[ArrayRef]) -> Result>() diff --git a/datafusion/functions-nested/src/remove.rs b/datafusion/functions-nested/src/remove.rs index 41c06cb9c4cbf..9e957c93e1c66 100644 --- a/datafusion/functions-nested/src/remove.rs +++ b/datafusion/functions-nested/src/remove.rs @@ -20,8 +20,8 @@ use crate::utils; use crate::utils::make_scalar_function; use arrow::array::{ - Array, ArrayRef, BooleanArray, GenericListArray, OffsetSizeTrait, cast::AsArray, - new_empty_array, + Array, ArrayRef, Capacities, GenericListArray, MutableArrayData, NullBufferBuilder, + OffsetSizeTrait, cast::AsArray, make_array, }; use arrow::buffer::OffsetBuffer; use arrow::datatypes::{DataType, FieldRef}; @@ -377,73 +377,84 @@ fn general_remove( ); } }; - let data_type = list_field.data_type(); - let mut new_values = vec![]; + let original_data = list_array.values().to_data(); // Build up the offsets for the final output array let mut offsets = Vec::::with_capacity(arr_n.len() + 1); offsets.push(OffsetSize::zero()); - // n is the number of elements to remove in this row - for (row_index, (list_array_row, n)) in - list_array.iter().zip(arr_n.iter()).enumerate() - { - match list_array_row { - Some(list_array_row) => { - let eq_array = utils::compare_element_to_list( - &list_array_row, - element_array, - row_index, - false, - )?; - - // We need to keep at most first n elements as `false`, which represent the elements to remove. - let eq_array = if eq_array.false_count() < *n as usize { - eq_array - } else { - let mut count = 0; - eq_array - .iter() - .map(|e| { - // Keep first n `false` elements, and reverse other elements to `true`. - if let Some(false) = e { - if count < *n { - count += 1; - e - } else { - Some(true) - } - } else { - e - } - }) - .collect::() - }; - - let filtered_array = arrow::compute::filter(&list_array_row, &eq_array)?; - offsets.push( - offsets[row_index] + OffsetSize::usize_as(filtered_array.len()), - ); - new_values.push(filtered_array); - } - None => { - // Null element results in a null row (no new offsets) - offsets.push(offsets[row_index]); + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data], + false, + Capacities::Array(original_data.len()), + ); + let mut valid = NullBufferBuilder::new(list_array.len()); + + for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { + if list_array.is_null(row_index) { + offsets.push(offsets[row_index]); + valid.append_null(); + continue; + } + + let start = offset_window[0].to_usize().unwrap(); + let end = offset_window[1].to_usize().unwrap(); + // n is the number of elements to remove in this row + let n = arr_n[row_index]; + + // compare each element in the list, `false` means the element matches and should be removed + let eq_array = utils::compare_element_to_list( + &list_array.value(row_index), + element_array, + row_index, + false, + )?; + + let num_to_remove = eq_array.false_count(); + + // Fast path: no elements to remove, copy entire row + if num_to_remove == 0 { + mutable.extend(0, start, end); + offsets.push(offsets[row_index] + OffsetSize::usize_as(end - start)); + valid.append_non_null(); + continue; + } + + // Remove at most `n` matching elements + let max_removals = n.min(num_to_remove as i64); + let mut removed = 0i64; + let mut copied = 0usize; + // marks the beginning of a range of elements pending to be copied. + let mut pending_batch_to_retain: Option = None; + for (i, keep) in eq_array.iter().enumerate() { + if keep == Some(false) && removed < max_removals { + // Flush pending batch before skipping this element + if let Some(bs) = pending_batch_to_retain { + mutable.extend(0, start + bs, start + i); + copied += i - bs; + pending_batch_to_retain = None; + } + removed += 1; + } else if pending_batch_to_retain.is_none() { + pending_batch_to_retain = Some(i); } } - } - let values = if new_values.is_empty() { - new_empty_array(data_type) - } else { - let new_values = new_values.iter().map(|x| x.as_ref()).collect::>(); - arrow::compute::concat(&new_values)? - }; + // Flush remaining batch + if let Some(bs) = pending_batch_to_retain { + mutable.extend(0, start + bs, start + eq_array.len()); + copied += eq_array.len() - bs; + } + + offsets.push(offsets[row_index] + OffsetSize::usize_as(copied)); + valid.append_non_null(); + } + let new_values = make_array(mutable.freeze()); Ok(Arc::new(GenericListArray::::try_new( Arc::clone(list_field), OffsetBuffer::new(offsets.into()), - values, - list_array.nulls().cloned(), + new_values, + valid.finish(), )?)) } diff --git a/datafusion/functions-nested/src/repeat.rs b/datafusion/functions-nested/src/repeat.rs index a121b5f03162e..5e78a4d0f601c 100644 --- a/datafusion/functions-nested/src/repeat.rs +++ b/datafusion/functions-nested/src/repeat.rs @@ -19,22 +19,23 @@ use crate::utils::make_scalar_function; use arrow::array::{ - Array, ArrayRef, Capacities, GenericListArray, ListArray, MutableArrayData, - OffsetSizeTrait, UInt64Array, new_null_array, + Array, ArrayRef, BooleanBufferBuilder, GenericListArray, Int64Array, OffsetSizeTrait, + UInt64Array, }; -use arrow::buffer::OffsetBuffer; +use arrow::buffer::{NullBuffer, OffsetBuffer}; use arrow::compute; -use arrow::compute::cast; use arrow::datatypes::DataType; use arrow::datatypes::{ DataType::{LargeList, List}, Field, }; -use datafusion_common::cast::{as_large_list_array, as_list_array, as_uint64_array}; -use datafusion_common::{Result, exec_err, utils::take_function_args}; +use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array}; +use datafusion_common::types::{NativeType, logical_int64}; +use datafusion_common::{DataFusionError, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; +use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; @@ -89,7 +90,17 @@ impl Default for ArrayRepeat { impl ArrayRepeat { pub fn new() -> Self { Self { - signature: Signature::user_defined(Volatility::Immutable), + signature: Signature::coercible( + vec![ + Coercion::new_exact(TypeSignatureClass::Any), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![TypeSignatureClass::Integer], + NativeType::Int64, + ), + ], + Volatility::Immutable, + ), aliases: vec![String::from("list_repeat")], } } @@ -109,10 +120,17 @@ impl ScalarUDFImpl for ArrayRepeat { } fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(List(Arc::new(Field::new_list_field( - arg_types[0].clone(), - true, - )))) + let element_type = &arg_types[0]; + match element_type { + LargeList(_) => Ok(LargeList(Arc::new(Field::new_list_field( + element_type.clone(), + true, + )))), + _ => Ok(List(Arc::new(Field::new_list_field( + element_type.clone(), + true, + )))), + } } fn invoke_with_args( @@ -126,23 +144,6 @@ impl ScalarUDFImpl for ArrayRepeat { &self.aliases } - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - let [first_type, second_type] = take_function_args(self.name(), arg_types)?; - - // Coerce the second argument to Int64/UInt64 if it's a numeric type - let second = match second_type { - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { - DataType::Int64 - } - DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { - DataType::UInt64 - } - _ => return exec_err!("count must be an integer type"), - }; - - Ok(vec![first_type.clone(), second]) - } - fn documentation(&self) -> Option<&Documentation> { self.doc() } @@ -150,15 +151,7 @@ impl ScalarUDFImpl for ArrayRepeat { fn array_repeat_inner(args: &[ArrayRef]) -> Result { let element = &args[0]; - let count_array = &args[1]; - - let count_array = match count_array.data_type() { - DataType::Int64 => &cast(count_array, &DataType::UInt64)?, - DataType::UInt64 => count_array, - _ => return exec_err!("count must be an integer type"), - }; - - let count_array = as_uint64_array(count_array)?; + let count_array = as_int64_array(&args[1])?; match element.data_type() { List(_) => { @@ -187,45 +180,46 @@ fn array_repeat_inner(args: &[ArrayRef]) -> Result { /// ``` fn general_repeat( array: &ArrayRef, - count_array: &UInt64Array, + count_array: &Int64Array, ) -> Result { - let data_type = array.data_type(); - let mut new_values = vec![]; - - let count_vec = count_array - .values() - .to_vec() - .iter() - .map(|x| *x as usize) - .collect::>(); - - for (row_index, &count) in count_vec.iter().enumerate() { - let repeated_array = if array.is_null(row_index) { - new_null_array(data_type, count) - } else { - let original_data = array.to_data(); - let capacity = Capacities::Array(count); - let mut mutable = - MutableArrayData::with_capacities(vec![&original_data], false, capacity); - - for _ in 0..count { - mutable.extend(0, row_index, row_index + 1); - } - - let data = mutable.freeze(); - arrow::array::make_array(data) - }; - new_values.push(repeated_array); + let total_repeated_values: usize = (0..count_array.len()) + .map(|i| get_count_with_validity(count_array, i)) + .sum(); + + let mut take_indices = Vec::with_capacity(total_repeated_values); + let mut offsets = Vec::with_capacity(count_array.len() + 1); + offsets.push(O::zero()); + let mut running_offset = 0usize; + + for idx in 0..count_array.len() { + let count = get_count_with_validity(count_array, idx); + running_offset = running_offset.checked_add(count).ok_or_else(|| { + DataFusionError::Execution( + "array_repeat: running_offset overflowed usize".to_string(), + ) + })?; + let offset = O::from_usize(running_offset).ok_or_else(|| { + DataFusionError::Execution(format!( + "array_repeat: offset {running_offset} exceeds the maximum value for offset type" + )) + })?; + offsets.push(offset); + take_indices.extend(std::iter::repeat_n(idx as u64, count)); } - let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); - let values = compute::concat(&new_values)?; + // Build the flattened values + let repeated_values = compute::take( + array.as_ref(), + &UInt64Array::from_iter_values(take_indices), + None, + )?; + // Construct final ListArray Ok(Arc::new(GenericListArray::::try_new( - Arc::new(Field::new_list_field(data_type.to_owned(), true)), - OffsetBuffer::from_lengths(count_vec), - values, - None, + Arc::new(Field::new_list_field(array.data_type().to_owned(), true)), + OffsetBuffer::new(offsets.into()), + repeated_values, + count_array.nulls().cloned(), )?)) } @@ -241,58 +235,95 @@ fn general_repeat( /// ``` fn general_list_repeat( list_array: &GenericListArray, - count_array: &UInt64Array, + count_array: &Int64Array, ) -> Result { - let data_type = list_array.data_type(); - let value_type = list_array.value_type(); - let mut new_values = vec![]; - - let count_vec = count_array - .values() - .to_vec() - .iter() - .map(|x| *x as usize) - .collect::>(); - - for (list_array_row, &count) in list_array.iter().zip(count_vec.iter()) { - let list_arr = match list_array_row { - Some(list_array_row) => { - let original_data = list_array_row.to_data(); - let capacity = Capacities::Array(original_data.len() * count); - let mut mutable = MutableArrayData::with_capacities( - vec![&original_data], - false, - capacity, - ); - - for _ in 0..count { - mutable.extend(0, 0, original_data.len()); - } - - let data = mutable.freeze(); - let repeated_array = arrow::array::make_array(data); - - let list_arr = GenericListArray::::try_new( - Arc::new(Field::new_list_field(value_type.clone(), true)), - OffsetBuffer::::from_lengths(vec![original_data.len(); count]), - repeated_array, - None, - )?; - Arc::new(list_arr) as ArrayRef + let list_offsets = list_array.value_offsets(); + + // calculate capacities for pre-allocation + let mut outer_total = 0usize; + let mut inner_total = 0usize; + for i in 0..count_array.len() { + let count = get_count_with_validity(count_array, i); + if count > 0 { + outer_total += count; + if list_array.is_valid(i) { + let len = list_offsets[i + 1].to_usize().unwrap() + - list_offsets[i].to_usize().unwrap(); + inner_total += len * count; } - None => new_null_array(data_type, count), - }; - new_values.push(list_arr); + } } - let lengths = new_values.iter().map(|a| a.len()).collect::>(); - let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); - let values = compute::concat(&new_values)?; + // Build inner structures + let mut inner_offsets = Vec::with_capacity(outer_total + 1); + let mut take_indices = Vec::with_capacity(inner_total); + let mut inner_nulls = BooleanBufferBuilder::new(outer_total); + let mut inner_running = 0usize; + inner_offsets.push(O::zero()); + + for row_idx in 0..count_array.len() { + let count = get_count_with_validity(count_array, row_idx); + let list_is_valid = list_array.is_valid(row_idx); + let start = list_offsets[row_idx].to_usize().unwrap(); + let end = list_offsets[row_idx + 1].to_usize().unwrap(); + let row_len = end - start; + + for _ in 0..count { + inner_running = inner_running.checked_add(row_len).ok_or_else(|| { + DataFusionError::Execution( + "array_repeat: inner offset overflowed usize".to_string(), + ) + })?; + let offset = O::from_usize(inner_running).ok_or_else(|| { + DataFusionError::Execution(format!( + "array_repeat: offset {inner_running} exceeds the maximum value for offset type" + )) + })?; + inner_offsets.push(offset); + inner_nulls.append(list_is_valid); + if list_is_valid { + take_indices.extend(start as u64..end as u64); + } + } + } - Ok(Arc::new(ListArray::try_new( - Arc::new(Field::new_list_field(data_type.to_owned(), true)), - OffsetBuffer::::from_lengths(lengths), - values, + // Build inner ListArray + let inner_values = compute::take( + list_array.values().as_ref(), + &UInt64Array::from_iter_values(take_indices), None, + )?; + let inner_list = GenericListArray::::try_new( + Arc::new(Field::new_list_field(list_array.value_type().clone(), true)), + OffsetBuffer::new(inner_offsets.into()), + inner_values, + Some(NullBuffer::new(inner_nulls.finish())), + )?; + + // Build outer ListArray + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new_list_field( + list_array.data_type().to_owned(), + true, + )), + OffsetBuffer::::from_lengths( + count_array + .iter() + .map(|c| c.map(|v| if v > 0 { v as usize } else { 0 }).unwrap_or(0)), + ), + Arc::new(inner_list), + count_array.nulls().cloned(), )?)) } + +/// Helper function to get count from count_array at given index +/// Return 0 for null values or non-positive count. +#[inline] +fn get_count_with_validity(count_array: &Int64Array, idx: usize) -> usize { + if count_array.is_null(idx) { + 0 + } else { + let c = count_array.value(idx); + if c > 0 { c as usize } else { 0 } + } +} diff --git a/datafusion/functions-nested/src/set_ops.rs b/datafusion/functions-nested/src/set_ops.rs index 69a220e125c04..2348b3c530c53 100644 --- a/datafusion/functions-nested/src/set_ops.rs +++ b/datafusion/functions-nested/src/set_ops.rs @@ -19,11 +19,9 @@ use crate::utils::make_scalar_function; use arrow::array::{ - Array, ArrayRef, GenericListArray, LargeListArray, ListArray, OffsetSizeTrait, - new_null_array, + Array, ArrayRef, GenericListArray, OffsetSizeTrait, new_empty_array, new_null_array, }; use arrow::buffer::{NullBuffer, OffsetBuffer}; -use arrow::compute; use arrow::datatypes::DataType::{LargeList, List, Null}; use arrow::datatypes::{DataType, Field, FieldRef}; use arrow::row::{RowConverter, SortField}; @@ -36,7 +34,6 @@ use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; -use itertools::Itertools; use std::any::Any; use std::collections::HashSet; use std::fmt::{Display, Formatter}; @@ -69,7 +66,7 @@ make_udf_expr_and_func!( #[user_doc( doc_section(label = "Array Functions"), - description = "Returns an array of elements that are present in both arrays (all elements from both arrays) with out duplicates.", + description = "Returns an array of elements that are present in both arrays (all elements from both arrays) without duplicates.", syntax_example = "array_union(array1, array2)", sql_example = r#"```sql > select array_union([1, 2, 3, 4], [5, 6, 3, 4]); @@ -136,8 +133,7 @@ impl ScalarUDFImpl for ArrayUnion { let [array1, array2] = take_function_args(self.name(), arg_types)?; match (array1, array2) { (Null, Null) => Ok(DataType::new_list(Null, true)), - (Null, dt) => Ok(dt.clone()), - (dt, Null) => Ok(dt.clone()), + (Null, dt) | (dt, Null) => Ok(dt.clone()), (dt, _) => Ok(dt.clone()), } } @@ -186,11 +182,17 @@ impl ScalarUDFImpl for ArrayUnion { ) )] #[derive(Debug, PartialEq, Eq, Hash)] -pub(super) struct ArrayIntersect { +pub struct ArrayIntersect { signature: Signature, aliases: Vec, } +impl Default for ArrayIntersect { + fn default() -> Self { + Self::new() + } +} + impl ArrayIntersect { pub fn new() -> Self { Self { @@ -221,8 +223,7 @@ impl ScalarUDFImpl for ArrayIntersect { let [array1, array2] = take_function_args(self.name(), arg_types)?; match (array1, array2) { (Null, Null) => Ok(DataType::new_list(Null, true)), - (Null, dt) => Ok(dt.clone()), - (dt, Null) => Ok(dt.clone()), + (Null, dt) | (dt, Null) => Ok(dt.clone()), (dt, _) => Ok(dt.clone()), } } @@ -261,7 +262,7 @@ impl ScalarUDFImpl for ArrayIntersect { ) )] #[derive(Debug, PartialEq, Eq, Hash)] -pub(super) struct ArrayDistinct { +pub struct ArrayDistinct { signature: Signature, aliases: Vec, } @@ -275,6 +276,12 @@ impl ArrayDistinct { } } +impl Default for ArrayDistinct { + fn default() -> Self { + Self::new() + } +} + impl ScalarUDFImpl for ArrayDistinct { fn as_any(&self) -> &dyn Any { self @@ -361,76 +368,118 @@ fn generic_set_lists( "{set_op:?} is not implemented for '{l:?}' and '{r:?}'" ); - let mut offsets = vec![OffsetSize::usize_as(0)]; - let mut new_arrays = vec![]; - let mut new_null_buf = vec![]; + // Convert all values to rows in batch for performance. let converter = RowConverter::new(vec![SortField::new(l.value_type())])?; - for (first_arr, second_arr) in l.iter().zip(r.iter()) { - let mut ele_should_be_null = false; + let rows_l = converter.convert_columns(&[Arc::clone(l.values())])?; + let rows_r = converter.convert_columns(&[Arc::clone(r.values())])?; - let l_values = if let Some(first_arr) = first_arr { - converter.convert_columns(&[first_arr])? - } else { - ele_should_be_null = true; - converter.empty_rows(0, 0) - }; + match set_op { + SetOp::Union => generic_set_loop::( + l, r, &rows_l, &rows_r, field, &converter, + ), + SetOp::Intersect => generic_set_loop::( + l, r, &rows_l, &rows_r, field, &converter, + ), + } +} - let r_values = if let Some(second_arr) = second_arr { - converter.convert_columns(&[second_arr])? - } else { - ele_should_be_null = true; - converter.empty_rows(0, 0) - }; - - let l_iter = l_values.iter().sorted().dedup(); - let values_set: HashSet<_> = l_iter.clone().collect(); - let mut rows = if set_op == SetOp::Union { - l_iter.collect() - } else { - vec![] - }; - - for r_val in r_values.iter().sorted().dedup() { - match set_op { - SetOp::Union => { - if !values_set.contains(&r_val) { - rows.push(r_val); - } +/// Inner loop for set operations, parameterized by const generic to +/// avoid branching inside the hot loop. +fn generic_set_loop( + l: &GenericListArray, + r: &GenericListArray, + rows_l: &arrow::row::Rows, + rows_r: &arrow::row::Rows, + field: Arc, + converter: &RowConverter, +) -> Result { + let l_offsets = l.value_offsets(); + let r_offsets = r.value_offsets(); + + let mut result_offsets = Vec::with_capacity(l.len() + 1); + result_offsets.push(OffsetSize::usize_as(0)); + let initial_capacity = if IS_UNION { + // Union can include all elements from both sides + rows_l.num_rows() + } else { + // Intersect result is bounded by the smaller side + rows_l.num_rows().min(rows_r.num_rows()) + }; + + let mut final_rows = Vec::with_capacity(initial_capacity); + + // Reuse hash sets across iterations + let mut seen = HashSet::new(); + let mut lookup_set = HashSet::new(); + for i in 0..l.len() { + let last_offset = *result_offsets.last().unwrap(); + + if l.is_null(i) || r.is_null(i) { + result_offsets.push(last_offset); + continue; + } + + let l_start = l_offsets[i].as_usize(); + let l_end = l_offsets[i + 1].as_usize(); + let r_start = r_offsets[i].as_usize(); + let r_end = r_offsets[i + 1].as_usize(); + + seen.clear(); + + if IS_UNION { + for idx in l_start..l_end { + let row = rows_l.row(idx); + if seen.insert(row) { + final_rows.push(row); } - SetOp::Intersect => { - if values_set.contains(&r_val) { - rows.push(r_val); - } + } + for idx in r_start..r_end { + let row = rows_r.row(idx); + if seen.insert(row) { + final_rows.push(row); } } - } - - let last_offset = match offsets.last() { - Some(offset) => *offset, - None => return internal_err!("offsets should not be empty"), - }; - - offsets.push(last_offset + OffsetSize::usize_as(rows.len())); - let arrays = converter.convert_rows(rows)?; - let array = match arrays.first() { - Some(array) => Arc::clone(array), - None => { - return internal_err!("{set_op}: failed to get array from rows"); + } else { + let l_len = l_end - l_start; + let r_len = r_end - r_start; + + // Select shorter side for lookup, longer side for probing + let (lookup_rows, lookup_range, probe_rows, probe_range) = if l_len < r_len { + (rows_l, l_start..l_end, rows_r, r_start..r_end) + } else { + (rows_r, r_start..r_end, rows_l, l_start..l_end) + }; + lookup_set.clear(); + lookup_set.reserve(lookup_range.len()); + + // Build lookup table + for idx in lookup_range { + lookup_set.insert(lookup_rows.row(idx)); } - }; - new_null_buf.push(!ele_should_be_null); - new_arrays.push(array); + // Probe and emit distinct intersected rows + for idx in probe_range { + let row = probe_rows.row(idx); + if lookup_set.contains(&row) && seen.insert(row) { + final_rows.push(row); + } + } + } + result_offsets.push(last_offset + OffsetSize::usize_as(seen.len())); } - let offsets = OffsetBuffer::new(offsets.into()); - let new_arrays_ref: Vec<_> = new_arrays.iter().map(|v| v.as_ref()).collect(); - let values = compute::concat(&new_arrays_ref)?; + let final_values = if final_rows.is_empty() { + new_empty_array(&l.value_type()) + } else { + let arrays = converter.convert_rows(final_rows)?; + Arc::clone(&arrays[0]) + }; + let arr = GenericListArray::::try_new( field, - offsets, - values, - Some(NullBuffer::new(new_null_buf.into())), + OffsetBuffer::new(result_offsets.into()), + final_values, + NullBuffer::union(l.nulls(), r.nulls()), )?; Ok(Arc::new(arr)) } @@ -440,59 +489,13 @@ fn general_set_op( array2: &ArrayRef, set_op: SetOp, ) -> Result { - fn empty_array(data_type: &DataType, len: usize, large: bool) -> Result { - let field = Arc::new(Field::new_list_field(data_type.clone(), true)); - let values = new_null_array(data_type, len); - if large { - Ok(Arc::new(LargeListArray::try_new( - field, - OffsetBuffer::new_zeroed(len), - values, - None, - )?)) - } else { - Ok(Arc::new(ListArray::try_new( - field, - OffsetBuffer::new_zeroed(len), - values, - None, - )?)) - } - } - + let len = array1.len(); match (array1.data_type(), array2.data_type()) { - (Null, Null) => Ok(Arc::new(ListArray::new_null( - Arc::new(Field::new_list_field(Null, true)), - array1.len(), - ))), - (Null, List(field)) => { - if set_op == SetOp::Intersect { - return empty_array(field.data_type(), array1.len(), false); - } - let array = as_list_array(&array2)?; - general_array_distinct::(array, field) - } - (List(field), Null) => { - if set_op == SetOp::Intersect { - return empty_array(field.data_type(), array1.len(), false); - } - let array = as_list_array(&array1)?; - general_array_distinct::(array, field) - } - (Null, LargeList(field)) => { - if set_op == SetOp::Intersect { - return empty_array(field.data_type(), array1.len(), true); - } - let array = as_large_list_array(&array2)?; - general_array_distinct::(array, field) - } - (LargeList(field), Null) => { - if set_op == SetOp::Intersect { - return empty_array(field.data_type(), array1.len(), true); - } - let array = as_large_list_array(&array1)?; - general_array_distinct::(array, field) - } + (Null, Null) => Ok(new_null_array(&DataType::new_list(Null, true), len)), + (Null, dt @ List(_)) + | (Null, dt @ LargeList(_)) + | (dt @ List(_), Null) + | (dt @ LargeList(_), Null) => Ok(new_null_array(dt, len)), (List(field), List(_)) => { let array1 = as_list_array(&array1)?; let array2 = as_list_array(&array2)?; @@ -528,42 +531,52 @@ fn general_array_distinct( if array.is_empty() { return Ok(Arc::new(array.clone()) as ArrayRef); } + let value_offsets = array.value_offsets(); let dt = array.value_type(); - let mut offsets = Vec::with_capacity(array.len()); + let mut offsets = Vec::with_capacity(array.len() + 1); offsets.push(OffsetSize::usize_as(0)); - let mut new_arrays = Vec::with_capacity(array.len()); - let converter = RowConverter::new(vec![SortField::new(dt)])?; - // distinct for each list in ListArray - for arr in array.iter() { - let last_offset: OffsetSize = offsets.last().copied().unwrap(); - let Some(arr) = arr else { - // Add same offset for null + + // Convert all values to row format in a single batch for performance + let converter = RowConverter::new(vec![SortField::new(dt.clone())])?; + let rows = converter.convert_columns(&[Arc::clone(array.values())])?; + let mut final_rows = Vec::with_capacity(rows.num_rows()); + let mut seen = HashSet::new(); + for i in 0..array.len() { + let last_offset = *offsets.last().unwrap(); + + // Null list entries produce no output; just carry forward the offset. + if array.is_null(i) { offsets.push(last_offset); continue; - }; - let values = converter.convert_columns(&[arr])?; - // sort elements in list and remove duplicates - let rows = values.iter().sorted().dedup().collect::>(); - offsets.push(last_offset + OffsetSize::usize_as(rows.len())); - let arrays = converter.convert_rows(rows)?; - let array = match arrays.first() { - Some(array) => Arc::clone(array), - None => { - return internal_err!("array_distinct: failed to get array from rows"); + } + + let start = value_offsets[i].as_usize(); + let end = value_offsets[i + 1].as_usize(); + seen.clear(); + seen.reserve(end - start); + + // Walk the sub-array and keep only the first occurrence of each value. + for idx in start..end { + let row = rows.row(idx); + if seen.insert(row) { + final_rows.push(row); } - }; - new_arrays.push(array); - } - if new_arrays.is_empty() { - return Ok(Arc::new(array.clone()) as ArrayRef); + } + offsets.push(last_offset + OffsetSize::usize_as(seen.len())); } - let offsets = OffsetBuffer::new(offsets.into()); - let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::>(); - let values = compute::concat(&new_arrays_ref)?; + + // Convert all collected distinct rows back + let final_values = if final_rows.is_empty() { + new_empty_array(&dt) + } else { + let arrays = converter.convert_rows(final_rows)?; + Arc::clone(&arrays[0]) + }; + Ok(Arc::new(GenericListArray::::try_new( Arc::clone(field), - offsets, - values, + OffsetBuffer::new(offsets.into()), + final_values, // Keep the list nulls array.nulls().cloned(), )?)) diff --git a/datafusion/functions-nested/src/sort.rs b/datafusion/functions-nested/src/sort.rs index ba2da0f760eee..cbe101f111b26 100644 --- a/datafusion/functions-nested/src/sort.rs +++ b/datafusion/functions-nested/src/sort.rs @@ -18,16 +18,14 @@ //! [`ScalarUDFImpl`] definitions for array_sort function. use crate::utils::make_scalar_function; -use arrow::array::{ - Array, ArrayRef, GenericListArray, NullBufferBuilder, OffsetSizeTrait, new_null_array, -}; +use arrow::array::{Array, ArrayRef, GenericListArray, OffsetSizeTrait, new_null_array}; use arrow::buffer::OffsetBuffer; use arrow::compute::SortColumn; use arrow::datatypes::{DataType, FieldRef}; use arrow::{compute, compute::SortOptions}; use datafusion_common::cast::{as_large_list_array, as_list_array, as_string_array}; use datafusion_common::utils::ListCoercion; -use datafusion_common::{Result, exec_err, plan_err}; +use datafusion_common::{Result, exec_err}; use datafusion_expr::{ ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, Volatility, @@ -134,18 +132,7 @@ impl ScalarUDFImpl for ArraySort { } fn return_type(&self, arg_types: &[DataType]) -> Result { - match &arg_types[0] { - DataType::Null => Ok(DataType::Null), - DataType::List(field) => { - Ok(DataType::new_list(field.data_type().clone(), true)) - } - DataType::LargeList(field) => { - Ok(DataType::new_large_list(field.data_type().clone(), true)) - } - arg_type => { - plan_err!("{} does not support type {arg_type}", self.name()) - } - } + Ok(arg_types[0].clone()) } fn invoke_with_args( @@ -206,11 +193,11 @@ fn array_sort_inner(args: &[ArrayRef]) -> Result { } DataType::List(field) => { let array = as_list_array(&args[0])?; - array_sort_generic(array, field, sort_options) + array_sort_generic(array, Arc::clone(field), sort_options) } DataType::LargeList(field) => { let array = as_large_list_array(&args[0])?; - array_sort_generic(array, field, sort_options) + array_sort_generic(array, Arc::clone(field), sort_options) } // Signature should prevent this arm ever occurring _ => exec_err!("array_sort expects list for first argument"), @@ -219,18 +206,16 @@ fn array_sort_inner(args: &[ArrayRef]) -> Result { fn array_sort_generic( list_array: &GenericListArray, - field: &FieldRef, + field: FieldRef, sort_options: Option, ) -> Result { let row_count = list_array.len(); let mut array_lengths = vec![]; let mut arrays = vec![]; - let mut valid = NullBufferBuilder::new(row_count); for i in 0..row_count { if list_array.is_null(i) { array_lengths.push(0); - valid.append_null(); } else { let arr_ref = list_array.value(i); @@ -253,25 +238,22 @@ fn array_sort_generic( }; array_lengths.push(sorted_array.len()); arrays.push(sorted_array); - valid.append_non_null(); } } - let buffer = valid.finish(); - let elements = arrays .iter() .map(|a| a.as_ref()) .collect::>(); let list_arr = if elements.is_empty() { - GenericListArray::::new_null(Arc::clone(field), row_count) + GenericListArray::::new_null(field, row_count) } else { GenericListArray::::new( - Arc::clone(field), + field, OffsetBuffer::from_lengths(array_lengths), Arc::new(compute::concat(elements.as_slice())?), - buffer, + list_array.nulls().cloned(), ) }; Ok(Arc::new(list_arr)) diff --git a/datafusion/functions-nested/src/utils.rs b/datafusion/functions-nested/src/utils.rs index d2a69c010e8e7..9f46917a87eb3 100644 --- a/datafusion/functions-nested/src/utils.rs +++ b/datafusion/functions-nested/src/utils.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Field, Fields}; use arrow::array::{ - Array, ArrayRef, BooleanArray, GenericListArray, OffsetSizeTrait, Scalar, UInt32Array, + Array, ArrayRef, BooleanArray, GenericListArray, OffsetSizeTrait, Scalar, }; use arrow::buffer::OffsetBuffer; use datafusion_common::cast::{ @@ -161,8 +161,7 @@ pub(crate) fn compare_element_to_list( ); } - let indices = UInt32Array::from(vec![row_index as u32]); - let element_array_row = arrow::compute::take(element_array, &indices, None)?; + let element_array_row = element_array.slice(row_index, 1); // Compute all positions in list_row_array (that is itself an // array) that are equal to `from_array_row` @@ -260,7 +259,7 @@ pub(crate) fn get_map_entry_field(data_type: &DataType) -> Result<&Fields> { match field_data_type { DataType::Struct(fields) => Ok(fields), _ => { - internal_err!("Expected a Struct type, got {:?}", field_data_type) + internal_err!("Expected a Struct type, got {}", field_data_type) } } } diff --git a/datafusion/functions-table/src/generate_series.rs b/datafusion/functions-table/src/generate_series.rs index b806798bcecc0..342269fbc2996 100644 --- a/datafusion/functions-table/src/generate_series.rs +++ b/datafusion/functions-table/src/generate_series.rs @@ -433,30 +433,11 @@ fn reach_end_int64(val: i64, end: i64, step: i64, include_end: bool) -> bool { } } -fn validate_interval_step( - step: IntervalMonthDayNano, - start: i64, - end: i64, -) -> Result<()> { +fn validate_interval_step(step: IntervalMonthDayNano) -> Result<()> { if step.months == 0 && step.days == 0 && step.nanoseconds == 0 { return plan_err!("Step interval cannot be zero"); } - let step_is_positive = step.months > 0 || step.days > 0 || step.nanoseconds > 0; - let step_is_negative = step.months < 0 || step.days < 0 || step.nanoseconds < 0; - - if start > end && step_is_positive { - return plan_err!( - "Start is bigger than end, but increment is positive: Cannot generate infinite series" - ); - } - - if start < end && step_is_negative { - return plan_err!( - "Start is smaller than end, but increment is negative: Cannot generate infinite series" - ); - } - Ok(()) } @@ -567,18 +548,6 @@ impl GenerateSeriesFuncImpl { } }; - if start > end && step > 0 { - return plan_err!( - "Start is bigger than end, but increment is positive: Cannot generate infinite series" - ); - } - - if start < end && step < 0 { - return plan_err!( - "Start is smaller than end, but increment is negative: Cannot generate infinite series" - ); - } - if step == 0 { return plan_err!("Step cannot be zero"); } @@ -656,7 +625,7 @@ impl GenerateSeriesFuncImpl { }; // Validate step interval - validate_interval_step(step, start, end)?; + validate_interval_step(step)?; Ok(Arc::new(GenerateSeriesTable { schema, @@ -749,7 +718,7 @@ impl GenerateSeriesFuncImpl { let end_ts = end_date as i64 * NANOS_PER_DAY; // Validate step interval - validate_interval_step(step_interval, start_ts, end_ts)?; + validate_interval_step(step_interval)?; Ok(Arc::new(GenerateSeriesTable { schema, diff --git a/datafusion/functions-table/src/lib.rs b/datafusion/functions-table/src/lib.rs index 1783c15b14b58..cd9ade041acbf 100644 --- a/datafusion/functions-table/src/lib.rs +++ b/datafusion/functions-table/src/lib.rs @@ -24,8 +24,6 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] -// https://github.com/apache/datafusion/issues/18881 -#![deny(clippy::allow_attributes)] pub mod generate_series; diff --git a/datafusion/functions-window-common/src/lib.rs b/datafusion/functions-window-common/src/lib.rs index 210e54d672893..301f2c34a6c95 100644 --- a/datafusion/functions-window-common/src/lib.rs +++ b/datafusion/functions-window-common/src/lib.rs @@ -24,8 +24,6 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] -// https://github.com/apache/datafusion/issues/18881 -#![deny(clippy::allow_attributes)] //! Common user-defined window functionality for [DataFusion] //! diff --git a/datafusion/functions-window/src/lib.rs b/datafusion/functions-window/src/lib.rs index 300313387388a..6edfb92744f5b 100644 --- a/datafusion/functions-window/src/lib.rs +++ b/datafusion/functions-window/src/lib.rs @@ -25,7 +25,6 @@ // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] // https://github.com/apache/datafusion/issues/18881 -#![deny(clippy::allow_attributes)] //! Window Function packages for [DataFusion]. //! diff --git a/datafusion/functions-window/src/nth_value.rs b/datafusion/functions-window/src/nth_value.rs index c62f0a9ae4e89..c8980d9f1dc67 100644 --- a/datafusion/functions-window/src/nth_value.rs +++ b/datafusion/functions-window/src/nth_value.rs @@ -97,7 +97,7 @@ impl NthValue { Self { signature: Signature::one_of( vec![ - TypeSignature::Any(0), + TypeSignature::Nullary, TypeSignature::Any(1), TypeSignature::Any(2), ], diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 4ecd7a597814b..1940f1378b635 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -82,12 +82,13 @@ hex = { workspace = true, optional = true } itertools = { workspace = true } log = { workspace = true } md-5 = { version = "^0.10.0", optional = true } +memchr = { workspace = true } num-traits = { workspace = true } rand = { workspace = true } regex = { workspace = true, optional = true } -sha2 = { version = "^0.10.9", optional = true } +sha2 = { workspace = true, optional = true } unicode-segmentation = { version = "^1.7.1", optional = true } -uuid = { version = "1.19", features = ["v4"], optional = true } +uuid = { workspace = true, features = ["v4"], optional = true } [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } @@ -132,6 +133,11 @@ harness = false name = "gcd" required-features = ["math_expressions"] +[[bench]] +harness = false +name = "nanvl" +required-features = ["math_expressions"] + [[bench]] harness = false name = "uuid" @@ -186,6 +192,11 @@ harness = false name = "signum" required-features = ["math_expressions"] +[[bench]] +harness = false +name = "atan2" +required-features = ["math_expressions"] + [[bench]] harness = false name = "substr_index" @@ -308,10 +319,20 @@ required-features = ["string_expressions"] [[bench]] harness = false -name = "left" +name = "left_right" required-features = ["unicode_expressions"] [[bench]] harness = false name = "factorial" required-features = ["math_expressions"] + +[[bench]] +harness = false +name = "floor_ceil" +required-features = ["math_expressions"] + +[[bench]] +harness = false +name = "round" +required-features = ["math_expressions"] diff --git a/datafusion/functions/benches/ascii.rs b/datafusion/functions/benches/ascii.rs index 66d81261bfe85..a2424ed352afc 100644 --- a/datafusion/functions/benches/ascii.rs +++ b/datafusion/functions/benches/ascii.rs @@ -15,19 +15,47 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; mod helper; use arrow::datatypes::{DataType, Field}; use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; -use datafusion_expr::ScalarFunctionArgs; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use helper::gen_string_array; use std::hint::black_box; use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { let ascii = datafusion_functions::string::ascii(); + let config_options = Arc::new(ConfigOptions::default()); + + // Scalar benchmarks (outside loop) + c.bench_function("ascii/scalar_utf8", |b| { + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( + "hello".to_string(), + )))], + arg_fields: vec![Field::new("a", DataType::Utf8, false).into()], + number_rows: 1, + return_field: Field::new("f", DataType::Int32, true).into(), + config_options: Arc::clone(&config_options), + }; + b.iter(|| black_box(ascii.invoke_with_args(args.clone()).unwrap())) + }); + + c.bench_function("ascii/scalar_utf8view", |b| { + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + "hello".to_string(), + )))], + arg_fields: vec![Field::new("a", DataType::Utf8View, false).into()], + number_rows: 1, + return_field: Field::new("f", DataType::Int32, true).into(), + config_options: Arc::clone(&config_options), + }; + b.iter(|| black_box(ascii.invoke_with_args(args.clone()).unwrap())) + }); // All benches are single batch run with 8192 rows const N_ROWS: usize = 8192; diff --git a/datafusion/functions/benches/atan2.rs b/datafusion/functions/benches/atan2.rs new file mode 100644 index 0000000000000..f1c9756a0cc08 --- /dev/null +++ b/datafusion/functions/benches/atan2.rs @@ -0,0 +1,146 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::datatypes::{DataType, Field, Float32Type, Float64Type}; +use arrow::util::bench_util::create_primitive_array; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::math::atan2; +use std::hint::black_box; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let atan2_fn = atan2(); + let config_options = Arc::new(ConfigOptions::default()); + + for size in [1024, 4096, 8192] { + let y_f32 = Arc::new(create_primitive_array::(size, 0.2)); + let x_f32 = Arc::new(create_primitive_array::(size, 0.2)); + let f32_args = vec![ColumnarValue::Array(y_f32), ColumnarValue::Array(x_f32)]; + let f32_arg_fields = f32_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let return_field_f32 = Field::new("f", DataType::Float32, true).into(); + + c.bench_function(&format!("atan2 f32 array: {size}"), |b| { + b.iter(|| { + black_box( + atan2_fn + .invoke_with_args(ScalarFunctionArgs { + args: f32_args.clone(), + arg_fields: f32_arg_fields.clone(), + number_rows: size, + return_field: Arc::clone(&return_field_f32), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + let y_f64 = Arc::new(create_primitive_array::(size, 0.2)); + let x_f64 = Arc::new(create_primitive_array::(size, 0.2)); + let f64_args = vec![ColumnarValue::Array(y_f64), ColumnarValue::Array(x_f64)]; + let f64_arg_fields = f64_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let return_field_f64 = Field::new("f", DataType::Float64, true).into(); + + c.bench_function(&format!("atan2 f64 array: {size}"), |b| { + b.iter(|| { + black_box( + atan2_fn + .invoke_with_args(ScalarFunctionArgs { + args: f64_args.clone(), + arg_fields: f64_arg_fields.clone(), + number_rows: size, + return_field: Arc::clone(&return_field_f64), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + } + + let scalar_f32_args = vec![ + ColumnarValue::Scalar(ScalarValue::Float32(Some(1.0))), + ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), + ]; + let scalar_f32_arg_fields = vec![ + Field::new("a", DataType::Float32, false).into(), + Field::new("b", DataType::Float32, false).into(), + ]; + let return_field_f32 = Field::new("f", DataType::Float32, false).into(); + + c.bench_function("atan2 f32 scalar", |b| { + b.iter(|| { + black_box( + atan2_fn + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f32_args.clone(), + arg_fields: scalar_f32_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&return_field_f32), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + let scalar_f64_args = vec![ + ColumnarValue::Scalar(ScalarValue::Float64(Some(1.0))), + ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), + ]; + let scalar_f64_arg_fields = vec![ + Field::new("a", DataType::Float64, false).into(), + Field::new("b", DataType::Float64, false).into(), + ]; + let return_field_f64 = Field::new("f", DataType::Float64, false).into(); + + c.bench_function("atan2 f64 scalar", |b| { + b.iter(|| { + black_box( + atan2_fn + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f64_args.clone(), + arg_fields: scalar_f64_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&return_field_f64), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/character_length.rs b/datafusion/functions/benches/character_length.rs index 35a0cf886b7f0..4927627ec2f05 100644 --- a/datafusion/functions/benches/character_length.rs +++ b/datafusion/functions/benches/character_length.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::{DataType, Field}; use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::config::ConfigOptions; diff --git a/datafusion/functions/benches/chr.rs b/datafusion/functions/benches/chr.rs index 9a6342ca40bb6..a702dc161ae06 100644 --- a/datafusion/functions/benches/chr.rs +++ b/datafusion/functions/benches/chr.rs @@ -15,10 +15,9 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::{array::PrimitiveArray, datatypes::Int64Type}; use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string::chr; use rand::{Rng, SeedableRng}; @@ -35,11 +34,32 @@ pub fn seedable_rng() -> StdRng { } fn criterion_benchmark(c: &mut Criterion) { - let cot_fn = chr(); + let chr_fn = chr(); + let config_options = Arc::new(ConfigOptions::default()); + + // Scalar benchmarks + c.bench_function("chr/scalar", |b| { + let args = vec![ColumnarValue::Scalar(ScalarValue::Int64(Some(65)))]; + let arg_fields = vec![Field::new("arg_0", DataType::Int64, true).into()]; + b.iter(|| { + black_box( + chr_fn + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: 1, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + let size = 1024; let input: PrimitiveArray = { let null_density = 0.2; - let mut rng = StdRng::seed_from_u64(42); + let mut rng = seedable_rng(); (0..size) .map(|_| { if rng.random::() < null_density { @@ -57,12 +77,11 @@ fn criterion_benchmark(c: &mut Criterion) { .enumerate() .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) .collect::>(); - let config_options = Arc::new(ConfigOptions::default()); - c.bench_function("chr", |b| { + c.bench_function("chr/array", |b| { b.iter(|| { black_box( - cot_fn + chr_fn .invoke_with_args(ScalarFunctionArgs { args: args.clone(), arg_fields: arg_fields.clone(), diff --git a/datafusion/functions/benches/contains.rs b/datafusion/functions/benches/contains.rs index 052eff38869dc..6c39f45e14fa6 100644 --- a/datafusion/functions/benches/contains.rs +++ b/datafusion/functions/benches/contains.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::{StringArray, StringViewArray}; use arrow::datatypes::{DataType, Field}; use criterion::{Criterion, criterion_group, criterion_main}; diff --git a/datafusion/functions/benches/cot.rs b/datafusion/functions/benches/cot.rs index c47198d4a6208..16c3fba2175fe 100644 --- a/datafusion/functions/benches/cot.rs +++ b/datafusion/functions/benches/cot.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::{ datatypes::{Float32Type, Float64Type}, util::bench_util::create_primitive_array, @@ -27,11 +25,15 @@ use datafusion_functions::math::cot; use std::hint::black_box; use arrow::datatypes::{DataType, Field}; +use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { let cot_fn = cot(); + let config_options = Arc::new(ConfigOptions::default()); + + // Array benchmarks - run for different sizes for size in [1024, 4096, 8192] { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let f32_args = vec![ColumnarValue::Array(f32_array)]; @@ -42,7 +44,6 @@ fn criterion_benchmark(c: &mut Criterion) { Field::new(format!("arg_{idx}"), arg.data_type(), true).into() }) .collect::>(); - let config_options = Arc::new(ConfigOptions::default()); c.bench_function(&format!("cot f32 array: {size}"), |b| { b.iter(|| { @@ -59,6 +60,7 @@ fn criterion_benchmark(c: &mut Criterion) { ) }) }); + let f64_array = Arc::new(create_primitive_array::(size, 0.2)); let f64_args = vec![ColumnarValue::Array(f64_array)]; let arg_fields = f64_args @@ -86,6 +88,47 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); } + + // Scalar benchmarks - run only once since size doesn't affect scalar performance + let scalar_f32_args = vec![ColumnarValue::Scalar(ScalarValue::Float32(Some(1.0)))]; + let scalar_f32_arg_fields = vec![Field::new("a", DataType::Float32, false).into()]; + let return_field_f32 = Field::new("f", DataType::Float32, false).into(); + + c.bench_function("cot f32 scalar", |b| { + b.iter(|| { + black_box( + cot_fn + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f32_args.clone(), + arg_fields: scalar_f32_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&return_field_f32), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + let scalar_f64_args = vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(1.0)))]; + let scalar_f64_arg_fields = vec![Field::new("a", DataType::Float64, false).into()]; + let return_field_f64 = Field::new("f", DataType::Float64, false).into(); + + c.bench_function("cot f64 scalar", |b| { + b.iter(|| { + black_box( + cot_fn + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f64_args.clone(), + arg_fields: scalar_f64_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&return_field_f64), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/functions/benches/crypto.rs b/datafusion/functions/benches/crypto.rs index bf30cc9a0c445..9a86efbff9ed8 100644 --- a/datafusion/functions/benches/crypto.rs +++ b/datafusion/functions/benches/crypto.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; use criterion::{Criterion, criterion_group, criterion_main}; diff --git a/datafusion/functions/benches/date_bin.rs b/datafusion/functions/benches/date_bin.rs index eb4e960d8312b..28dee96987261 100644 --- a/datafusion/functions/benches/date_bin.rs +++ b/datafusion/functions/benches/date_bin.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use std::hint::black_box; use std::sync::Arc; diff --git a/datafusion/functions/benches/date_trunc.rs b/datafusion/functions/benches/date_trunc.rs index f5c8ceb5fe9d5..0668a1cc5085c 100644 --- a/datafusion/functions/benches/date_trunc.rs +++ b/datafusion/functions/benches/date_trunc.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use std::hint::black_box; use std::sync::Arc; @@ -25,7 +23,7 @@ use arrow::datatypes::Field; use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; -use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_expr::{ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs}; use datafusion_functions::datetime::date_trunc; use rand::Rng; use rand::rngs::ThreadRng; @@ -57,10 +55,13 @@ fn criterion_benchmark(c: &mut Criterion) { }) .collect::>(); - let return_type = udf - .return_type(&args.iter().map(|arg| arg.data_type()).collect::>()) + let scalar_arguments = vec![None; arg_fields.len()]; + let return_field = udf + .return_field_from_args(ReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &scalar_arguments, + }) .unwrap(); - let return_field = Arc::new(Field::new("f", return_type, true)); let config_options = Arc::new(ConfigOptions::default()); b.iter(|| { diff --git a/datafusion/functions/benches/encoding.rs b/datafusion/functions/benches/encoding.rs index 8a7c2b7b664b7..0b8f0c5c51a58 100644 --- a/datafusion/functions/benches/encoding.rs +++ b/datafusion/functions/benches/encoding.rs @@ -15,11 +15,9 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::Array; use arrow::datatypes::{DataType, Field}; -use arrow::util::bench_util::create_string_array_with_len; +use arrow::util::bench_util::create_binary_array; use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; @@ -32,20 +30,22 @@ fn criterion_benchmark(c: &mut Criterion) { let config_options = Arc::new(ConfigOptions::default()); for size in [1024, 4096, 8192] { - let str_array = Arc::new(create_string_array_with_len::(size, 0.2, 32)); + let bin_array = Arc::new(create_binary_array::(size, 0.2)); c.bench_function(&format!("base64_decode/{size}"), |b| { let method = ColumnarValue::Scalar("base64".into()); let encoded = encoding::encode() .invoke_with_args(ScalarFunctionArgs { - args: vec![ColumnarValue::Array(str_array.clone()), method.clone()], + args: vec![ColumnarValue::Array(bin_array.clone()), method.clone()], arg_fields: vec![ - Field::new("a", str_array.data_type().to_owned(), true).into(), + Field::new("a", bin_array.data_type().to_owned(), true).into(), Field::new("b", method.data_type().to_owned(), true).into(), ], number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), }) + .unwrap() + .cast_to(&DataType::Binary, None) .unwrap(); let arg_fields = vec![ @@ -61,7 +61,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: arg_fields.clone(), number_rows: size, - return_field: Field::new("f", DataType::Utf8, true).into(), + return_field: Field::new("f", DataType::Binary, true).into(), config_options: Arc::clone(&config_options), }) .unwrap(), @@ -72,24 +72,26 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function(&format!("hex_decode/{size}"), |b| { let method = ColumnarValue::Scalar("hex".into()); let arg_fields = vec![ - Field::new("a", str_array.data_type().to_owned(), true).into(), + Field::new("a", bin_array.data_type().to_owned(), true).into(), Field::new("b", method.data_type().to_owned(), true).into(), ]; let encoded = encoding::encode() .invoke_with_args(ScalarFunctionArgs { - args: vec![ColumnarValue::Array(str_array.clone()), method.clone()], + args: vec![ColumnarValue::Array(bin_array.clone()), method.clone()], arg_fields, number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), }) + .unwrap() + .cast_to(&DataType::Binary, None) .unwrap(); let arg_fields = vec![ Field::new("a", encoded.data_type().to_owned(), true).into(), Field::new("b", method.data_type().to_owned(), true).into(), ]; - let return_field = Field::new("f", DataType::Utf8, true).into(); + let return_field = Field::new("f", DataType::Binary, true).into(); let args = vec![encoded, method]; b.iter(|| { diff --git a/datafusion/functions/benches/ends_with.rs b/datafusion/functions/benches/ends_with.rs index 926fd9ff72a5a..474e8a1555cf2 100644 --- a/datafusion/functions/benches/ends_with.rs +++ b/datafusion/functions/benches/ends_with.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::{StringArray, StringViewArray}; use arrow::datatypes::{DataType, Field}; use criterion::{Criterion, criterion_group, criterion_main}; diff --git a/datafusion/functions/benches/factorial.rs b/datafusion/functions/benches/factorial.rs index 5c5ff991d7453..c441b50c288c3 100644 --- a/datafusion/functions/benches/factorial.rs +++ b/datafusion/functions/benches/factorial.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::Int64Array; use arrow::datatypes::{DataType, Field}; use criterion::{Criterion, criterion_group, criterion_main}; diff --git a/datafusion/functions/benches/find_in_set.rs b/datafusion/functions/benches/find_in_set.rs index e207c1fa48ab3..9ee20ecd14fdf 100644 --- a/datafusion/functions/benches/find_in_set.rs +++ b/datafusion/functions/benches/find_in_set.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::{StringArray, StringViewArray}; use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ diff --git a/datafusion/functions/benches/floor_ceil.rs b/datafusion/functions/benches/floor_ceil.rs new file mode 100644 index 0000000000000..dc095e0152c4d --- /dev/null +++ b/datafusion/functions/benches/floor_ceil.rs @@ -0,0 +1,133 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, Float64Type}; +use arrow::util::bench_util::create_primitive_array; +use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::math::{ceil, floor}; +use std::hint::black_box; +use std::sync::Arc; +use std::time::Duration; + +fn criterion_benchmark(c: &mut Criterion) { + let floor_fn = floor(); + let ceil_fn = ceil(); + let config_options = Arc::new(ConfigOptions::default()); + + for size in [1024, 4096, 8192] { + let mut group = c.benchmark_group(format!("floor_ceil size={size}")); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + // Float64 array benchmark + let f64_array = Arc::new(create_primitive_array::(size, 0.1)); + let batch_len = f64_array.len(); + let f64_args = vec![ColumnarValue::Array(f64_array)]; + + group.bench_function("floor_f64_array", |b| { + b.iter(|| { + let args_cloned = f64_args.clone(); + black_box( + floor_fn + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![ + Field::new("a", DataType::Float64, true).into(), + ], + number_rows: batch_len, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + group.bench_function("ceil_f64_array", |b| { + b.iter(|| { + let args_cloned = f64_args.clone(); + black_box( + ceil_fn + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![ + Field::new("a", DataType::Float64, true).into(), + ], + number_rows: batch_len, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + // Scalar benchmark (the optimization we added) + let scalar_args = vec![ColumnarValue::Scalar(ScalarValue::Float64(Some( + std::f64::consts::PI, + )))]; + + group.bench_function("floor_f64_scalar", |b| { + b.iter(|| { + let args_cloned = scalar_args.clone(); + black_box( + floor_fn + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![ + Field::new("a", DataType::Float64, false).into(), + ], + number_rows: 1, + return_field: Field::new("f", DataType::Float64, false) + .into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + group.bench_function("ceil_f64_scalar", |b| { + b.iter(|| { + let args_cloned = scalar_args.clone(); + black_box( + ceil_fn + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![ + Field::new("a", DataType::Float64, false).into(), + ], + number_rows: 1, + return_field: Field::new("f", DataType::Float64, false) + .into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + group.finish(); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/gcd.rs b/datafusion/functions/benches/gcd.rs index 9705af8a2fcd2..3c72a46e6643d 100644 --- a/datafusion/functions/benches/gcd.rs +++ b/datafusion/functions/benches/gcd.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::Field; use arrow::{ array::{ArrayRef, Int64Array}, diff --git a/datafusion/functions/benches/initcap.rs b/datafusion/functions/benches/initcap.rs index ba055d58f5664..e68e41baa2e1d 100644 --- a/datafusion/functions/benches/initcap.rs +++ b/datafusion/functions/benches/initcap.rs @@ -15,19 +15,19 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::OffsetSizeTrait; use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; -use criterion::{Criterion, criterion_group, criterion_main}; +use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::unicode; use std::hint::black_box; use std::sync::Arc; +use std::time::Duration; fn create_args( size: usize, @@ -49,60 +49,87 @@ fn create_args( fn criterion_benchmark(c: &mut Criterion) { let initcap = unicode::initcap(); - for size in [1024, 4096] { - let args = create_args::(size, 8, true); - let arg_fields = args - .iter() - .enumerate() - .map(|(idx, arg)| { - Field::new(format!("arg_{idx}"), arg.data_type(), true).into() - }) - .collect::>(); - let config_options = Arc::new(ConfigOptions::default()); - - c.bench_function( - format!("initcap string view shorter than 12 [size={size}]").as_str(), - |b| { - b.iter(|| { - black_box(initcap.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: arg_fields.clone(), - number_rows: size, - return_field: Field::new("f", DataType::Utf8View, true).into(), - config_options: Arc::clone(&config_options), - })) - }) - }, - ); - - let args = create_args::(size, 16, true); - c.bench_function( - format!("initcap string view longer than 12 [size={size}]").as_str(), - |b| { - b.iter(|| { - black_box(initcap.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: arg_fields.clone(), - number_rows: size, - return_field: Field::new("f", DataType::Utf8View, true).into(), - config_options: Arc::clone(&config_options), - })) - }) - }, - ); - - let args = create_args::(size, 16, false); - c.bench_function(format!("initcap string [size={size}]").as_str(), |b| { + let config_options = Arc::new(ConfigOptions::default()); + + // Grouped benchmarks for array sizes - to compare with scalar performance + for size in [1024, 4096, 8192] { + let mut group = c.benchmark_group(format!("initcap size={size}")); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + // Array benchmark - Utf8 + let array_args = create_args::(size, 16, false); + let array_arg_fields = vec![Field::new("arg_0", DataType::Utf8, true).into()]; + let batch_len = size; + + group.bench_function("array_utf8", |b| { b.iter(|| { black_box(initcap.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: arg_fields.clone(), - number_rows: size, + args: array_args.clone(), + arg_fields: array_arg_fields.clone(), + number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), })) }) }); + + // Array benchmark - Utf8View + let array_view_args = create_args::(size, 16, true); + let array_view_arg_fields = + vec![Field::new("arg_0", DataType::Utf8View, true).into()]; + + group.bench_function("array_utf8view", |b| { + b.iter(|| { + black_box(initcap.invoke_with_args(ScalarFunctionArgs { + args: array_view_args.clone(), + arg_fields: array_view_arg_fields.clone(), + number_rows: batch_len, + return_field: Field::new("f", DataType::Utf8View, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Scalar benchmark - Utf8 (the optimization we added) + let scalar_args = vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( + "hello world test string".to_string(), + )))]; + let scalar_arg_fields = vec![Field::new("arg_0", DataType::Utf8, false).into()]; + + group.bench_function("scalar_utf8", |b| { + b.iter(|| { + black_box(initcap.invoke_with_args(ScalarFunctionArgs { + args: scalar_args.clone(), + arg_fields: scalar_arg_fields.clone(), + number_rows: 1, + return_field: Field::new("f", DataType::Utf8, false).into(), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Scalar benchmark - Utf8View + let scalar_view_args = vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + "hello world test string".to_string(), + )))]; + let scalar_view_arg_fields = + vec![Field::new("arg_0", DataType::Utf8View, false).into()]; + + group.bench_function("scalar_utf8view", |b| { + b.iter(|| { + black_box(initcap.invoke_with_args(ScalarFunctionArgs { + args: scalar_view_args.clone(), + arg_fields: scalar_view_arg_fields.clone(), + number_rows: 1, + return_field: Field::new("f", DataType::Utf8View, false).into(), + config_options: Arc::clone(&config_options), + })) + }) + }); + + group.finish(); } } diff --git a/datafusion/functions/benches/isnan.rs b/datafusion/functions/benches/isnan.rs index d4e41e882fe20..e353b9d27a0a1 100644 --- a/datafusion/functions/benches/isnan.rs +++ b/datafusion/functions/benches/isnan.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::{DataType, Field}; use arrow::{ datatypes::{Float32Type, Float64Type}, diff --git a/datafusion/functions/benches/iszero.rs b/datafusion/functions/benches/iszero.rs index 53e38745afa92..c6d0aed4c615c 100644 --- a/datafusion/functions/benches/iszero.rs +++ b/datafusion/functions/benches/iszero.rs @@ -15,14 +15,13 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::{DataType, Field}; use arrow::{ datatypes::{Float32Type, Float64Type}, util::bench_util::create_primitive_array, }; use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::math::iszero; @@ -31,6 +30,8 @@ use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { let iszero = iszero(); + let config_options = Arc::new(ConfigOptions::default()); + for size in [1024, 4096, 8192] { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = f32_array.len(); @@ -43,7 +44,6 @@ fn criterion_benchmark(c: &mut Criterion) { }) .collect::>(); let return_field = Arc::new(Field::new("f", DataType::Boolean, true)); - let config_options = Arc::new(ConfigOptions::default()); c.bench_function(&format!("iszero f32 array: {size}"), |b| { b.iter(|| { @@ -60,6 +60,7 @@ fn criterion_benchmark(c: &mut Criterion) { ) }) }); + let f64_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = f64_array.len(); let f64_args = vec![ColumnarValue::Array(f64_array)]; @@ -88,6 +89,46 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); } + + // Scalar benchmarks - run once since size doesn't affect scalar performance + let scalar_f32_args = vec![ColumnarValue::Scalar(ScalarValue::Float32(Some(1.0)))]; + let scalar_f32_arg_fields = vec![Field::new("a", DataType::Float32, false).into()]; + let return_field_scalar = Arc::new(Field::new("f", DataType::Boolean, false)); + + c.bench_function("iszero f32 scalar", |b| { + b.iter(|| { + black_box( + iszero + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f32_args.clone(), + arg_fields: scalar_f32_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&return_field_scalar), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + let scalar_f64_args = vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(1.0)))]; + let scalar_f64_arg_fields = vec![Field::new("a", DataType::Float64, false).into()]; + + c.bench_function("iszero f64 scalar", |b| { + b.iter(|| { + black_box( + iszero + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f64_args.clone(), + arg_fields: scalar_f64_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&return_field_scalar), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/functions/benches/left.rs b/datafusion/functions/benches/left.rs deleted file mode 100644 index 3ea628fe2987c..0000000000000 --- a/datafusion/functions/benches/left.rs +++ /dev/null @@ -1,111 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -extern crate criterion; - -use std::hint::black_box; -use std::sync::Arc; - -use arrow::array::{ArrayRef, Int64Array}; -use arrow::datatypes::{DataType, Field}; -use arrow::util::bench_util::create_string_array_with_len; -use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; -use datafusion_common::config::ConfigOptions; -use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; -use datafusion_functions::unicode::left; - -fn create_args(size: usize, str_len: usize, use_negative: bool) -> Vec { - let string_array = Arc::new(create_string_array_with_len::(size, 0.1, str_len)); - - // For negative n, we want to trigger the double-iteration code path - let n_values: Vec = if use_negative { - (0..size).map(|i| -((i % 10 + 1) as i64)).collect() - } else { - (0..size).map(|i| (i % 10 + 1) as i64).collect() - }; - let n_array = Arc::new(Int64Array::from(n_values)); - - vec![ - ColumnarValue::Array(string_array), - ColumnarValue::Array(Arc::clone(&n_array) as ArrayRef), - ] -} - -fn criterion_benchmark(c: &mut Criterion) { - for size in [1024, 4096] { - let mut group = c.benchmark_group(format!("left size={size}")); - - // Benchmark with positive n (no optimization needed) - let args = create_args(size, 32, false); - group.bench_function(BenchmarkId::new("positive n", size), |b| { - let arg_fields = args - .iter() - .enumerate() - .map(|(idx, arg)| { - Field::new(format!("arg_{idx}"), arg.data_type(), true).into() - }) - .collect::>(); - let config_options = Arc::new(ConfigOptions::default()); - - b.iter(|| { - black_box( - left() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: arg_fields.clone(), - number_rows: size, - return_field: Field::new("f", DataType::Utf8, true).into(), - config_options: Arc::clone(&config_options), - }) - .expect("left should work"), - ) - }) - }); - - // Benchmark with negative n (triggers optimization) - let args = create_args(size, 32, true); - group.bench_function(BenchmarkId::new("negative n", size), |b| { - let arg_fields = args - .iter() - .enumerate() - .map(|(idx, arg)| { - Field::new(format!("arg_{idx}"), arg.data_type(), true).into() - }) - .collect::>(); - let config_options = Arc::new(ConfigOptions::default()); - - b.iter(|| { - black_box( - left() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: arg_fields.clone(), - number_rows: size, - return_field: Field::new("f", DataType::Utf8, true).into(), - config_options: Arc::clone(&config_options), - }) - .expect("left should work"), - ) - }) - }); - - group.finish(); - } -} - -criterion_group!(benches, criterion_benchmark); -criterion_main!(benches); diff --git a/datafusion/functions/benches/left_right.rs b/datafusion/functions/benches/left_right.rs new file mode 100644 index 0000000000000..59f8d8a75f74c --- /dev/null +++ b/datafusion/functions/benches/left_right.rs @@ -0,0 +1,128 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::hint::black_box; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int64Array}; +use arrow::datatypes::{DataType, Field}; +use arrow::util::bench_util::{ + create_string_array_with_len, create_string_view_array_with_len, +}; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::unicode::{left, right}; + +fn create_args( + size: usize, + str_len: usize, + use_negative: bool, + is_string_view: bool, +) -> Vec { + let string_arg = if is_string_view { + ColumnarValue::Array(Arc::new(create_string_view_array_with_len( + size, 0.1, str_len, true, + ))) + } else { + ColumnarValue::Array(Arc::new(create_string_array_with_len::( + size, 0.1, str_len, + ))) + }; + + // For negative n, we want to trigger the double-iteration code path + let n_values: Vec = if use_negative { + (0..size).map(|i| -((i % 10 + 1) as i64)).collect() + } else { + (0..size).map(|i| (i % 10 + 1) as i64).collect() + }; + let n_array = Arc::new(Int64Array::from(n_values)); + + vec![ + string_arg, + ColumnarValue::Array(Arc::clone(&n_array) as ArrayRef), + ] +} + +fn criterion_benchmark(c: &mut Criterion) { + let left_function = left(); + let right_function = right(); + + for function in [left_function, right_function] { + for is_string_view in [false, true] { + for is_negative in [false, true] { + for size in [1024, 4096] { + let function_name = function.name(); + let mut group = + c.benchmark_group(format!("{function_name} size={size}")); + + let bench_name = format!( + "{} {} n", + if is_string_view { + "string_view_array" + } else { + "string_array" + }, + if is_negative { "negative" } else { "positive" }, + ); + let return_type = if is_string_view { + DataType::Utf8View + } else { + DataType::Utf8 + }; + + let args = create_args(size, 32, is_negative, is_string_view); + group.bench_function(BenchmarkId::new(bench_name, size), |b| { + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true) + .into() + }) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + + b.iter(|| { + black_box( + function + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new( + "f", + return_type.clone(), + true, + ) + .into(), + config_options: Arc::clone(&config_options), + }) + .expect("should work"), + ) + }) + }); + + group.finish(); + } + } + } + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/levenshtein.rs b/datafusion/functions/benches/levenshtein.rs index 19f81b6cafcb3..08733b245ffb4 100644 --- a/datafusion/functions/benches/levenshtein.rs +++ b/datafusion/functions/benches/levenshtein.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::OffsetSizeTrait; use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; diff --git a/datafusion/functions/benches/lower.rs b/datafusion/functions/benches/lower.rs index 333dca390054b..6dbc8dcb7d148 100644 --- a/datafusion/functions/benches/lower.rs +++ b/datafusion/functions/benches/lower.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::{ArrayRef, StringArray, StringViewBuilder}; use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ diff --git a/datafusion/functions/benches/make_date.rs b/datafusion/functions/benches/make_date.rs index 8b1b32edfc9c5..42b5b1019538d 100644 --- a/datafusion/functions/benches/make_date.rs +++ b/datafusion/functions/benches/make_date.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use std::hint::black_box; use std::sync::Arc; diff --git a/datafusion/functions/benches/nanvl.rs b/datafusion/functions/benches/nanvl.rs new file mode 100644 index 0000000000000..206eebd81eb81 --- /dev/null +++ b/datafusion/functions/benches/nanvl.rs @@ -0,0 +1,114 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::array::{ArrayRef, Float32Array, Float64Array}; +use arrow::datatypes::{DataType, Field}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::math::nanvl; +use std::hint::black_box; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let nanvl_fn = nanvl(); + let config_options = Arc::new(ConfigOptions::default()); + + // Scalar benchmarks + c.bench_function("nanvl/scalar_f64", |b| { + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Float64(Some(f64::NAN))), + ColumnarValue::Scalar(ScalarValue::Float64(Some(1.0))), + ], + arg_fields: vec![ + Field::new("a", DataType::Float64, true).into(), + Field::new("b", DataType::Float64, true).into(), + ], + number_rows: 1, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::clone(&config_options), + }; + + b.iter(|| black_box(nanvl_fn.invoke_with_args(args.clone()).unwrap())) + }); + + c.bench_function("nanvl/scalar_f32", |b| { + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Float32(Some(f32::NAN))), + ColumnarValue::Scalar(ScalarValue::Float32(Some(1.0))), + ], + arg_fields: vec![ + Field::new("a", DataType::Float32, true).into(), + Field::new("b", DataType::Float32, true).into(), + ], + number_rows: 1, + return_field: Field::new("f", DataType::Float32, true).into(), + config_options: Arc::clone(&config_options), + }; + + b.iter(|| black_box(nanvl_fn.invoke_with_args(args.clone()).unwrap())) + }); + + // Array benchmarks + for size in [1024, 4096, 8192] { + let a64: ArrayRef = Arc::new(Float64Array::from(vec![f64::NAN; size])); + let b64: ArrayRef = Arc::new(Float64Array::from(vec![1.0; size])); + c.bench_function(&format!("nanvl/array_f64/{size}"), |bench| { + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::clone(&a64)), + ColumnarValue::Array(Arc::clone(&b64)), + ], + arg_fields: vec![ + Field::new("a", DataType::Float64, true).into(), + Field::new("b", DataType::Float64, true).into(), + ], + number_rows: size, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::clone(&config_options), + }; + bench.iter(|| black_box(nanvl_fn.invoke_with_args(args.clone()).unwrap())) + }); + + let a32: ArrayRef = Arc::new(Float32Array::from(vec![f32::NAN; size])); + let b32: ArrayRef = Arc::new(Float32Array::from(vec![1.0; size])); + c.bench_function(&format!("nanvl/array_f32/{size}"), |bench| { + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::clone(&a32)), + ColumnarValue::Array(Arc::clone(&b32)), + ], + arg_fields: vec![ + Field::new("a", DataType::Float32, true).into(), + Field::new("b", DataType::Float32, true).into(), + ], + number_rows: size, + return_field: Field::new("f", DataType::Float32, true).into(), + config_options: Arc::clone(&config_options), + }; + bench.iter(|| black_box(nanvl_fn.invoke_with_args(args.clone()).unwrap())) + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/nullif.rs b/datafusion/functions/benches/nullif.rs index f937d19421e89..f9f063c52d0d4 100644 --- a/datafusion/functions/benches/nullif.rs +++ b/datafusion/functions/benches/nullif.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; use criterion::{Criterion, criterion_group, criterion_main}; diff --git a/datafusion/functions/benches/pad.rs b/datafusion/functions/benches/pad.rs index f6b2ed7636bf8..99f177c035597 100644 --- a/datafusion/functions/benches/pad.rs +++ b/datafusion/functions/benches/pad.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::{ArrowPrimitiveType, OffsetSizeTrait, PrimitiveArray}; use arrow::datatypes::{DataType, Field, Int64Type}; use arrow::util::bench_util::{ diff --git a/datafusion/functions/benches/random.rs b/datafusion/functions/benches/random.rs index 3d8631140c05f..71ded120eb515 100644 --- a/datafusion/functions/benches/random.rs +++ b/datafusion/functions/benches/random.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::{DataType, Field}; use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::config::ConfigOptions; diff --git a/datafusion/functions/benches/regexp_count.rs b/datafusion/functions/benches/regexp_count.rs index eae7ef00f16bd..bce76c05585b9 100644 --- a/datafusion/functions/benches/regexp_count.rs +++ b/datafusion/functions/benches/regexp_count.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::Int64Array; use arrow::array::OffsetSizeTrait; use arrow::datatypes::{DataType, Field}; diff --git a/datafusion/functions/benches/regx.rs b/datafusion/functions/benches/regx.rs index 32378ccd126e5..c5014655a860a 100644 --- a/datafusion/functions/benches/regx.rs +++ b/datafusion/functions/benches/regx.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::builder::StringBuilder; use arrow::array::{ArrayRef, AsArray, Int64Array, StringArray, StringViewArray}; use arrow::compute::cast; diff --git a/datafusion/functions/benches/repeat.rs b/datafusion/functions/benches/repeat.rs index 304739b42f5fc..354812c0d2ea2 100644 --- a/datafusion/functions/benches/repeat.rs +++ b/datafusion/functions/benches/repeat.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::{ArrayRef, Int64Array, OffsetSizeTrait}; use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ @@ -24,6 +22,7 @@ use arrow::util::bench_util::{ }; use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; use datafusion_common::DataFusionError; +use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string; @@ -80,6 +79,44 @@ fn invoke_repeat_with_args( } fn criterion_benchmark(c: &mut Criterion) { + let repeat_fn = string::repeat(); + let config_options = Arc::new(ConfigOptions::default()); + + // Scalar benchmarks (outside loop) + c.bench_function("repeat/scalar_utf8", |b| { + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("hello".to_string()))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ], + arg_fields: vec![ + Field::new("a", DataType::Utf8, false).into(), + Field::new("b", DataType::Int64, false).into(), + ], + number_rows: 1, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }; + b.iter(|| black_box(repeat_fn.invoke_with_args(args.clone()).unwrap())) + }); + + c.bench_function("repeat/scalar_utf8view", |b| { + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("hello".to_string()))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ], + arg_fields: vec![ + Field::new("a", DataType::Utf8View, false).into(), + Field::new("b", DataType::Int64, false).into(), + ], + number_rows: 1, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }; + b.iter(|| black_box(repeat_fn.invoke_with_args(args.clone()).unwrap())) + }); + for size in [1024, 4096] { // REPEAT 3 TIMES let repeat_times = 3; diff --git a/datafusion/functions/benches/replace.rs b/datafusion/functions/benches/replace.rs index deadbfeb99a84..55fbd6ae57af2 100644 --- a/datafusion/functions/benches/replace.rs +++ b/datafusion/functions/benches/replace.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::OffsetSizeTrait; use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ diff --git a/datafusion/functions/benches/reverse.rs b/datafusion/functions/benches/reverse.rs index 73f5be5b45df0..f2e2898bbfe43 100644 --- a/datafusion/functions/benches/reverse.rs +++ b/datafusion/functions/benches/reverse.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; mod helper; use arrow::datatypes::{DataType, Field}; diff --git a/datafusion/functions/benches/round.rs b/datafusion/functions/benches/round.rs new file mode 100644 index 0000000000000..7010aa3507dbc --- /dev/null +++ b/datafusion/functions/benches/round.rs @@ -0,0 +1,152 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, Float32Type, Float64Type}; +use arrow::util::bench_util::create_primitive_array; +use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::math::round; +use std::hint::black_box; +use std::sync::Arc; +use std::time::Duration; + +fn criterion_benchmark(c: &mut Criterion) { + let round_fn = round(); + let config_options = Arc::new(ConfigOptions::default()); + + for size in [1024, 4096, 8192] { + let mut group = c.benchmark_group(format!("round size={size}")); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + // Float64 array benchmark + let f64_array = Arc::new(create_primitive_array::(size, 0.1)); + let batch_len = f64_array.len(); + let f64_args = vec![ + ColumnarValue::Array(f64_array), + ColumnarValue::Scalar(ScalarValue::Int32(Some(2))), + ]; + + group.bench_function("round_f64_array", |b| { + b.iter(|| { + let args_cloned = f64_args.clone(); + black_box( + round_fn + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![ + Field::new("a", DataType::Float64, true).into(), + Field::new("b", DataType::Int32, false).into(), + ], + number_rows: batch_len, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + // Float32 array benchmark + let f32_array = Arc::new(create_primitive_array::(size, 0.1)); + let f32_args = vec![ + ColumnarValue::Array(f32_array), + ColumnarValue::Scalar(ScalarValue::Int32(Some(2))), + ]; + + group.bench_function("round_f32_array", |b| { + b.iter(|| { + let args_cloned = f32_args.clone(); + black_box( + round_fn + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![ + Field::new("a", DataType::Float32, true).into(), + Field::new("b", DataType::Int32, false).into(), + ], + number_rows: batch_len, + return_field: Field::new("f", DataType::Float32, true).into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + // Scalar benchmark (the optimization we added) + let scalar_f64_args = vec![ + ColumnarValue::Scalar(ScalarValue::Float64(Some(std::f64::consts::PI))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(2))), + ]; + + group.bench_function("round_f64_scalar", |b| { + b.iter(|| { + let args_cloned = scalar_f64_args.clone(); + black_box( + round_fn + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![ + Field::new("a", DataType::Float64, false).into(), + Field::new("b", DataType::Int32, false).into(), + ], + number_rows: 1, + return_field: Field::new("f", DataType::Float64, false) + .into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + let scalar_f32_args = vec![ + ColumnarValue::Scalar(ScalarValue::Float32(Some(std::f32::consts::PI))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(2))), + ]; + + group.bench_function("round_f32_scalar", |b| { + b.iter(|| { + let args_cloned = scalar_f32_args.clone(); + black_box( + round_fn + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![ + Field::new("a", DataType::Float32, false).into(), + Field::new("b", DataType::Int32, false).into(), + ], + number_rows: 1, + return_field: Field::new("f", DataType::Float32, false) + .into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + group.finish(); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/signum.rs b/datafusion/functions/benches/signum.rs index 08a197a60eb75..e98d1b2c22ea2 100644 --- a/datafusion/functions/benches/signum.rs +++ b/datafusion/functions/benches/signum.rs @@ -15,14 +15,13 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::DataType; use arrow::{ datatypes::{Field, Float32Type, Float64Type}, util::bench_util::create_primitive_array, }; use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::math::signum; @@ -88,6 +87,51 @@ fn criterion_benchmark(c: &mut Criterion) { ) }) }); + + // Scalar benchmarks (the optimization we added) + let scalar_f32_args = + vec![ColumnarValue::Scalar(ScalarValue::Float32(Some(-42.5)))]; + let scalar_f32_arg_fields = + vec![Field::new("a", DataType::Float32, false).into()]; + let return_field_f32 = Field::new("f", DataType::Float32, false).into(); + + c.bench_function(&format!("signum f32 scalar: {size}"), |b| { + b.iter(|| { + black_box( + signum + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f32_args.clone(), + arg_fields: scalar_f32_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&return_field_f32), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + let scalar_f64_args = + vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(-42.5)))]; + let scalar_f64_arg_fields = + vec![Field::new("a", DataType::Float64, false).into()]; + let return_field_f64 = Field::new("f", DataType::Float64, false).into(); + + c.bench_function(&format!("signum f64 scalar: {size}"), |b| { + b.iter(|| { + black_box( + signum + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f64_args.clone(), + arg_fields: scalar_f64_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&return_field_f64), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); } } diff --git a/datafusion/functions/benches/split_part.rs b/datafusion/functions/benches/split_part.rs index e23610338d15c..7ef84a058920e 100644 --- a/datafusion/functions/benches/split_part.rs +++ b/datafusion/functions/benches/split_part.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::{ArrayRef, Int64Array, StringArray, StringViewArray}; use arrow::datatypes::{DataType, Field}; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; diff --git a/datafusion/functions/benches/starts_with.rs b/datafusion/functions/benches/starts_with.rs index 9ee39b694539c..17483f0da7a07 100644 --- a/datafusion/functions/benches/starts_with.rs +++ b/datafusion/functions/benches/starts_with.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::{StringArray, StringViewArray}; use arrow::datatypes::{DataType, Field}; use criterion::{Criterion, criterion_group, criterion_main}; diff --git a/datafusion/functions/benches/strpos.rs b/datafusion/functions/benches/strpos.rs index 9babf1d05c059..94ce919c3d801 100644 --- a/datafusion/functions/benches/strpos.rs +++ b/datafusion/functions/benches/strpos.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::{StringArray, StringViewArray}; use arrow::datatypes::{DataType, Field}; use criterion::{Criterion, criterion_group, criterion_main}; @@ -29,9 +27,12 @@ use std::hint::black_box; use std::str::Chars; use std::sync::Arc; -/// gen_arr(4096, 128, 0.1, 0.1, true) will generate a StringViewArray with -/// 4096 rows, each row containing a string with 128 random characters. -/// around 10% of the rows are null, around 10% of the rows are non-ASCII. +/// Returns a `Vec` with two elements: a haystack array and a +/// needle array. Each haystack is a random string of `str_len_chars` +/// characters. Each needle is a random contiguous substring of its +/// corresponding haystack (i.e., the needle is always present in the haystack). +/// Around `null_density` fraction of rows are null and `utf8_density` fraction +/// contain non-ASCII characters; the remaining rows are ASCII-only. fn gen_string_array( n_rows: usize, str_len_chars: usize, diff --git a/datafusion/functions/benches/substr.rs b/datafusion/functions/benches/substr.rs index a6989c1bca456..37a1e178f5612 100644 --- a/datafusion/functions/benches/substr.rs +++ b/datafusion/functions/benches/substr.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::{ArrayRef, Int64Array, OffsetSizeTrait}; use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ diff --git a/datafusion/functions/benches/substr_index.rs b/datafusion/functions/benches/substr_index.rs index 28ce6e444eb5c..663e7928bfd95 100644 --- a/datafusion/functions/benches/substr_index.rs +++ b/datafusion/functions/benches/substr_index.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use std::hint::black_box; use std::sync::Arc; diff --git a/datafusion/functions/benches/to_char.rs b/datafusion/functions/benches/to_char.rs index ac5b5dc7e03a3..65f4999d23489 100644 --- a/datafusion/functions/benches/to_char.rs +++ b/datafusion/functions/benches/to_char.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use std::hint::black_box; use std::sync::Arc; diff --git a/datafusion/functions/benches/to_hex.rs b/datafusion/functions/benches/to_hex.rs index 1c6757a291b24..33f8d9c49e8eb 100644 --- a/datafusion/functions/benches/to_hex.rs +++ b/datafusion/functions/benches/to_hex.rs @@ -15,12 +15,11 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::Int64Array; use arrow::datatypes::{DataType, Field, Int32Type, Int64Type}; use arrow::util::bench_util::create_primitive_array; use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string; @@ -32,6 +31,42 @@ fn criterion_benchmark(c: &mut Criterion) { let hex = string::to_hex(); let config_options = Arc::new(ConfigOptions::default()); + c.bench_function("to_hex/scalar_i32", |b| { + let args = vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(2147483647)))]; + let arg_fields = vec![Field::new("a", DataType::Int32, true).into()]; + b.iter(|| { + black_box( + hex.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: 1, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + c.bench_function("to_hex/scalar_i64", |b| { + let args = vec![ColumnarValue::Scalar(ScalarValue::Int64(Some( + 9223372036854775807, + )))]; + let arg_fields = vec![Field::new("a", DataType::Int64, true).into()]; + b.iter(|| { + black_box( + hex.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: 1, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + for size in [1024, 4096, 8192] { let mut group = c.benchmark_group(format!("to_hex size={size}")); group.sampling_mode(SamplingMode::Flat); diff --git a/datafusion/functions/benches/to_timestamp.rs b/datafusion/functions/benches/to_timestamp.rs index ed865fa6e8d50..90ea145d5d2c0 100644 --- a/datafusion/functions/benches/to_timestamp.rs +++ b/datafusion/functions/benches/to_timestamp.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use std::hint::black_box; use std::sync::Arc; diff --git a/datafusion/functions/benches/translate.rs b/datafusion/functions/benches/translate.rs index 601bdec7cd364..d0568ba0f5355 100644 --- a/datafusion/functions/benches/translate.rs +++ b/datafusion/functions/benches/translate.rs @@ -15,23 +15,23 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::OffsetSizeTrait; use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; -use datafusion_common::DataFusionError; use datafusion_common::config::ConfigOptions; +use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::unicode; use std::hint::black_box; use std::sync::Arc; use std::time::Duration; -fn create_args(size: usize, str_len: usize) -> Vec { +fn create_args_array_from_to( + size: usize, + str_len: usize, +) -> Vec { let string_array = Arc::new(create_string_array_with_len::(size, 0.1, str_len)); - // Create simple from/to strings for translation let from_array = Arc::new(create_string_array_with_len::(size, 0.1, 3)); let to_array = Arc::new(create_string_array_with_len::(size, 0.1, 2)); @@ -42,6 +42,19 @@ fn create_args(size: usize, str_len: usize) -> Vec( + size: usize, + str_len: usize, +) -> Vec { + let string_array = Arc::new(create_string_array_with_len::(size, 0.1, str_len)); + + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Scalar(ScalarValue::from("aeiou")), + ColumnarValue::Scalar(ScalarValue::from("AEIOU")), + ] +} + fn invoke_translate_with_args( args: Vec, number_rows: usize, @@ -69,17 +82,22 @@ fn criterion_benchmark(c: &mut Criterion) { group.sample_size(10); group.measurement_time(Duration::from_secs(10)); - for str_len in [8, 32] { - let args = create_args::(size, str_len); - group.bench_function( - format!("translate_string [size={size}, str_len={str_len}]"), - |b| { - b.iter(|| { - let args_cloned = args.clone(); - black_box(invoke_translate_with_args(args_cloned, size)) - }) - }, - ); + for str_len in [8, 32, 128, 1024] { + let args = create_args_array_from_to::(size, str_len); + group.bench_function(format!("array_from_to [str_len={str_len}]"), |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(invoke_translate_with_args(args_cloned, size)) + }) + }); + + let args = create_args_scalar_from_to::(size, str_len); + group.bench_function(format!("scalar_from_to [str_len={str_len}]"), |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(invoke_translate_with_args(args_cloned, size)) + }) + }); } group.finish(); diff --git a/datafusion/functions/benches/trim.rs b/datafusion/functions/benches/trim.rs index 29bbc3f7dcb48..23a53eefb217b 100644 --- a/datafusion/functions/benches/trim.rs +++ b/datafusion/functions/benches/trim.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::{ArrayRef, LargeStringArray, StringArray, StringViewArray}; use arrow::datatypes::{DataType, Field}; use criterion::{ @@ -143,7 +141,7 @@ fn create_args( ] } -#[allow(clippy::too_many_arguments)] +#[expect(clippy::too_many_arguments)] fn run_with_string_type( group: &mut BenchmarkGroup<'_, M>, trim_func: &ScalarUDF, @@ -189,7 +187,7 @@ fn run_with_string_type( ); } -#[allow(clippy::too_many_arguments)] +#[expect(clippy::too_many_arguments)] fn run_trim_benchmark( c: &mut Criterion, group_name: &str, diff --git a/datafusion/functions/benches/trunc.rs b/datafusion/functions/benches/trunc.rs index d0a6e2be75e0b..ffbedcb142c71 100644 --- a/datafusion/functions/benches/trunc.rs +++ b/datafusion/functions/benches/trunc.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::{ datatypes::{Field, Float32Type, Float64Type}, util::bench_util::create_primitive_array, @@ -32,12 +30,13 @@ use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { let trunc = trunc(); + let config_options = Arc::new(ConfigOptions::default()); + for size in [1024, 4096, 8192] { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let f32_args = vec![ColumnarValue::Array(f32_array)]; let arg_fields = vec![Field::new("a", DataType::Float32, false).into()]; let return_field = Field::new("f", DataType::Float32, true).into(); - let config_options = Arc::new(ConfigOptions::default()); c.bench_function(&format!("trunc f32 array: {size}"), |b| { b.iter(|| { @@ -74,6 +73,51 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); } + + // Scalar benchmarks - to measure optimized performance + let scalar_f64_args = vec![ColumnarValue::Scalar( + datafusion_common::ScalarValue::Float64(Some(std::f64::consts::PI)), + )]; + let scalar_arg_fields = vec![Field::new("a", DataType::Float64, false).into()]; + let scalar_return_field = Field::new("f", DataType::Float64, false).into(); + + c.bench_function("trunc f64 scalar", |b| { + b.iter(|| { + black_box( + trunc + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f64_args.clone(), + arg_fields: scalar_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&scalar_return_field), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + let scalar_f32_args = vec![ColumnarValue::Scalar( + datafusion_common::ScalarValue::Float32(Some(std::f32::consts::PI)), + )]; + let scalar_f32_arg_fields = vec![Field::new("a", DataType::Float32, false).into()]; + let scalar_f32_return_field = Field::new("f", DataType::Float32, false).into(); + + c.bench_function("trunc f32 scalar", |b| { + b.iter(|| { + black_box( + trunc + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f32_args.clone(), + arg_fields: scalar_f32_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&scalar_f32_return_field), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/functions/benches/upper.rs b/datafusion/functions/benches/upper.rs index 51ce1da0fa1f9..3f6fa36b18c13 100644 --- a/datafusion/functions/benches/upper.rs +++ b/datafusion/functions/benches/upper.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; use criterion::{Criterion, criterion_group, criterion_main}; diff --git a/datafusion/functions/benches/uuid.rs b/datafusion/functions/benches/uuid.rs index df9b2bed4be2b..629fb950dd9ff 100644 --- a/datafusion/functions/benches/uuid.rs +++ b/datafusion/functions/benches/uuid.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::{DataType, Field}; use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::config::ConfigOptions; diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 47a903639dde5..8d1ffb7c4c042 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -33,8 +33,8 @@ use datafusion_common::{ use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, - ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, Expr, ExpressionPlacement, ReturnFieldArgs, + ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; @@ -499,6 +499,32 @@ impl ScalarUDFImpl for GetFieldFunc { fn documentation(&self) -> Option<&Documentation> { self.doc() } + + fn placement(&self, args: &[ExpressionPlacement]) -> ExpressionPlacement { + // get_field can be pushed to leaves if: + // 1. The base (first arg) is a column or already placeable at leaves + // 2. All field keys (remaining args) are literals + if args.is_empty() { + return ExpressionPlacement::KeepInPlace; + } + + let base_placement = args[0]; + let base_is_pushable = matches!( + base_placement, + ExpressionPlacement::Column | ExpressionPlacement::MoveTowardsLeafNodes + ); + + let all_keys_are_literals = args + .iter() + .skip(1) + .all(|p| matches!(p, ExpressionPlacement::Literal)); + + if base_is_pushable && all_keys_are_literals { + ExpressionPlacement::MoveTowardsLeafNodes + } else { + ExpressionPlacement::KeepInPlace + } + } } #[cfg(test)] @@ -542,4 +568,92 @@ mod tests { Ok(()) } + + #[test] + fn test_placement_literal_key() { + let func = GetFieldFunc::new(); + + // get_field(col, 'literal') -> leaf-pushable (static field access) + let args = vec![ExpressionPlacement::Column, ExpressionPlacement::Literal]; + assert_eq!( + func.placement(&args), + ExpressionPlacement::MoveTowardsLeafNodes + ); + + // get_field(col, 'a', 'b') -> leaf-pushable (nested static field access) + let args = vec![ + ExpressionPlacement::Column, + ExpressionPlacement::Literal, + ExpressionPlacement::Literal, + ]; + assert_eq!( + func.placement(&args), + ExpressionPlacement::MoveTowardsLeafNodes + ); + + // get_field(get_field(col, 'a'), 'b') represented as MoveTowardsLeafNodes for base + let args = vec![ + ExpressionPlacement::MoveTowardsLeafNodes, + ExpressionPlacement::Literal, + ]; + assert_eq!( + func.placement(&args), + ExpressionPlacement::MoveTowardsLeafNodes + ); + } + + #[test] + fn test_placement_column_key() { + let func = GetFieldFunc::new(); + + // get_field(col, other_col) -> NOT leaf-pushable (dynamic per-row lookup) + let args = vec![ExpressionPlacement::Column, ExpressionPlacement::Column]; + assert_eq!(func.placement(&args), ExpressionPlacement::KeepInPlace); + + // get_field(col, 'a', other_col) -> NOT leaf-pushable (dynamic nested lookup) + let args = vec![ + ExpressionPlacement::Column, + ExpressionPlacement::Literal, + ExpressionPlacement::Column, + ]; + assert_eq!(func.placement(&args), ExpressionPlacement::KeepInPlace); + } + + #[test] + fn test_placement_root() { + let func = GetFieldFunc::new(); + + // get_field(root_expr, 'literal') -> NOT leaf-pushable + let args = vec![ + ExpressionPlacement::KeepInPlace, + ExpressionPlacement::Literal, + ]; + assert_eq!(func.placement(&args), ExpressionPlacement::KeepInPlace); + + // get_field(col, root_expr) -> NOT leaf-pushable + let args = vec![ + ExpressionPlacement::Column, + ExpressionPlacement::KeepInPlace, + ]; + assert_eq!(func.placement(&args), ExpressionPlacement::KeepInPlace); + } + + #[test] + fn test_placement_edge_cases() { + let func = GetFieldFunc::new(); + + // Empty args -> NOT leaf-pushable + assert_eq!(func.placement(&[]), ExpressionPlacement::KeepInPlace); + + // Just base, no key -> MoveTowardsLeafNodes (not a valid call but should handle gracefully) + let args = vec![ExpressionPlacement::Column]; + assert_eq!( + func.placement(&args), + ExpressionPlacement::MoveTowardsLeafNodes + ); + + // Literal base with literal key -> NOT leaf-pushable (would be constant-folded) + let args = vec![ExpressionPlacement::Literal, ExpressionPlacement::Literal]; + assert_eq!(func.placement(&args), ExpressionPlacement::KeepInPlace); + } } diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs index 56d4f23cc4e2e..8d915fb2e2c07 100644 --- a/datafusion/functions/src/core/union_extract.rs +++ b/datafusion/functions/src/core/union_extract.rs @@ -189,13 +189,14 @@ mod tests { fn test_scalar_value() -> Result<()> { let fun = UnionExtractFun::new(); - let fields = UnionFields::new( + let fields = UnionFields::try_new( vec![1, 3], vec![ Field::new("str", DataType::Utf8, false), Field::new("int", DataType::Int32, false), ], - ); + ) + .unwrap(); let args = vec![ ColumnarValue::Scalar(ScalarValue::Union( diff --git a/datafusion/functions/src/core/union_tag.rs b/datafusion/functions/src/core/union_tag.rs index 809679dea6465..fac5c82691adc 100644 --- a/datafusion/functions/src/core/union_tag.rs +++ b/datafusion/functions/src/core/union_tag.rs @@ -143,7 +143,7 @@ impl ScalarUDFImpl for UnionTagFunc { args.return_field.data_type(), )?)), }, - v => exec_err!("union_tag only support unions, got {:?}", v.data_type()), + v => exec_err!("union_tag only support unions, got {}", v.data_type()), } } diff --git a/datafusion/functions/src/crypto/basic.rs b/datafusion/functions/src/crypto/basic.rs index bda16684c8b6d..abb86b8246fc9 100644 --- a/datafusion/functions/src/crypto/basic.rs +++ b/datafusion/functions/src/crypto/basic.rs @@ -17,19 +17,13 @@ //! "crypto" DataFusion functions -use arrow::array::{ - Array, ArrayRef, AsArray, BinaryArray, BinaryArrayType, StringViewArray, -}; +use arrow::array::{Array, ArrayRef, AsArray, BinaryArray, BinaryArrayType}; use arrow::datatypes::DataType; use blake2::{Blake2b512, Blake2s256, Digest}; use blake3::Hasher as Blake3; -use datafusion_common::cast::as_binary_array; use arrow::compute::StringArrayType; -use datafusion_common::{ - DataFusionError, Result, ScalarValue, exec_err, internal_err, plan_err, - utils::take_function_args, -}; +use datafusion_common::{DataFusionError, Result, ScalarValue, exec_err, plan_err}; use datafusion_expr::ColumnarValue; use md5::Md5; use sha2::{Sha224, Sha256, Sha384, Sha512}; @@ -37,53 +31,8 @@ use std::fmt; use std::str::FromStr; use std::sync::Arc; -macro_rules! define_digest_function { - ($NAME: ident, $METHOD: ident, $DOC: expr) => { - #[doc = $DOC] - pub fn $NAME(args: &[ColumnarValue]) -> Result { - let [data] = take_function_args(&DigestAlgorithm::$METHOD.to_string(), args)?; - digest_process(data, DigestAlgorithm::$METHOD) - } - }; -} -define_digest_function!( - sha224, - Sha224, - "computes sha224 hash digest of the given input" -); -define_digest_function!( - sha256, - Sha256, - "computes sha256 hash digest of the given input" -); -define_digest_function!( - sha384, - Sha384, - "computes sha384 hash digest of the given input" -); -define_digest_function!( - sha512, - Sha512, - "computes sha512 hash digest of the given input" -); -define_digest_function!( - blake2b, - Blake2b, - "computes blake2b hash digest of the given input" -); -define_digest_function!( - blake2s, - Blake2s, - "computes blake2s hash digest of the given input" -); -define_digest_function!( - blake3, - Blake3, - "computes blake3 hash digest of the given input" -); - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] -pub enum DigestAlgorithm { +pub(crate) enum DigestAlgorithm { Md5, Sha224, Sha256, @@ -135,44 +84,6 @@ impl fmt::Display for DigestAlgorithm { } } -/// computes md5 hash digest of the given input -pub fn md5(args: &[ColumnarValue]) -> Result { - let [data] = take_function_args("md5", args)?; - let value = digest_process(data, DigestAlgorithm::Md5)?; - - // md5 requires special handling because of its unique utf8view return type - Ok(match value { - ColumnarValue::Array(array) => { - let binary_array = as_binary_array(&array)?; - let string_array: StringViewArray = binary_array - .iter() - .map(|opt| opt.map(hex_encode::<_>)) - .collect(); - ColumnarValue::Array(Arc::new(string_array)) - } - ColumnarValue::Scalar(ScalarValue::Binary(opt)) => { - ColumnarValue::Scalar(ScalarValue::Utf8View(opt.map(hex_encode::<_>))) - } - _ => return internal_err!("Impossibly got invalid results from digest"), - }) -} - -/// Hex encoding lookup table for fast byte-to-hex conversion -const HEX_CHARS_LOWER: &[u8; 16] = b"0123456789abcdef"; - -/// Fast hex encoding using a lookup table instead of format strings. -/// This is significantly faster than using `write!("{:02x}")` for each byte. -#[inline] -fn hex_encode>(data: T) -> String { - let bytes = data.as_ref(); - let mut s = String::with_capacity(bytes.len() * 2); - for &b in bytes { - s.push(HEX_CHARS_LOWER[(b >> 4) as usize] as char); - s.push(HEX_CHARS_LOWER[(b & 0x0f) as usize] as char); - } - s -} - macro_rules! digest_to_array { ($METHOD:ident, $INPUT:expr) => {{ let binary_array: BinaryArray = $INPUT @@ -269,7 +180,7 @@ impl DigestAlgorithm { } } -pub fn digest_process( +pub(crate) fn digest_process( value: &ColumnarValue, digest_algorithm: DigestAlgorithm, ) -> Result { diff --git a/datafusion/functions/src/crypto/md5.rs b/datafusion/functions/src/crypto/md5.rs index 728e0d4a33099..355e3e287ad22 100644 --- a/datafusion/functions/src/crypto/md5.rs +++ b/datafusion/functions/src/crypto/md5.rs @@ -15,11 +15,13 @@ // specific language governing permissions and limitations // under the License. -use crate::crypto::basic::md5; -use arrow::datatypes::DataType; +use arrow::{array::StringViewArray, datatypes::DataType}; use datafusion_common::{ - Result, + Result, ScalarValue, + cast::as_binary_array, + internal_err, types::{logical_binary, logical_string}, + utils::take_function_args, }; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, @@ -27,7 +29,9 @@ use datafusion_expr::{ }; use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; use datafusion_macros::user_doc; -use std::any::Any; +use std::{any::Any, sync::Arc}; + +use crate::crypto::basic::{DigestAlgorithm, digest_process}; #[user_doc( doc_section(label = "Hashing Functions"), @@ -97,3 +101,38 @@ impl ScalarUDFImpl for Md5Func { self.doc() } } + +/// Hex encoding lookup table for fast byte-to-hex conversion +const HEX_CHARS_LOWER: &[u8; 16] = b"0123456789abcdef"; + +/// Fast hex encoding using a lookup table instead of format strings. +/// This is significantly faster than using `write!("{:02x}")` for each byte. +#[inline] +fn hex_encode(data: impl AsRef<[u8]>) -> String { + let bytes = data.as_ref(); + let mut s = String::with_capacity(bytes.len() * 2); + for &b in bytes { + s.push(HEX_CHARS_LOWER[(b >> 4) as usize] as char); + s.push(HEX_CHARS_LOWER[(b & 0x0f) as usize] as char); + } + s +} + +fn md5(args: &[ColumnarValue]) -> Result { + let [data] = take_function_args("md5", args)?; + let value = digest_process(data, DigestAlgorithm::Md5)?; + + // md5 requires special handling because of its unique utf8view return type + Ok(match value { + ColumnarValue::Array(array) => { + let binary_array = as_binary_array(&array)?; + let string_array: StringViewArray = + binary_array.iter().map(|opt| opt.map(hex_encode)).collect(); + ColumnarValue::Array(Arc::new(string_array)) + } + ColumnarValue::Scalar(ScalarValue::Binary(opt)) => { + ColumnarValue::Scalar(ScalarValue::Utf8View(opt.map(hex_encode))) + } + _ => return internal_err!("Impossibly got invalid results from digest"), + }) +} diff --git a/datafusion/functions/src/datetime/common.rs b/datafusion/functions/src/datetime/common.rs index 2db64beafa9b7..74c3d32a1deed 100644 --- a/datafusion/functions/src/datetime/common.rs +++ b/datafusion/functions/src/datetime/common.rs @@ -24,21 +24,125 @@ use arrow::array::{ }; use arrow::compute::DecimalCast; use arrow::compute::kernels::cast_utils::string_to_datetime; -use arrow::datatypes::{DataType, TimeUnit}; +use arrow::datatypes::{ArrowTimestampType, DataType, TimeUnit}; use arrow_buffer::ArrowNativeType; use chrono::LocalResult::Single; use chrono::format::{Parsed, StrftimeItems, parse}; -use chrono::{DateTime, TimeZone, Utc}; +use chrono::{DateTime, MappedLocalTime, Offset, TimeDelta, TimeZone, Utc}; use datafusion_common::cast::as_generic_string_array; use datafusion_common::{ DataFusionError, Result, ScalarValue, exec_datafusion_err, exec_err, internal_datafusion_err, unwrap_or_internal_err, }; use datafusion_expr::ColumnarValue; +use std::ops::Add; /// Error message if nanosecond conversion request beyond supported interval const ERR_NANOSECONDS_NOT_SUPPORTED: &str = "The dates that can be represented as nanoseconds have to be between 1677-09-21T00:12:44.0 and 2262-04-11T23:47:16.854775804"; +/// This function converts a timestamp with a timezone to a timestamp without a timezone. +/// The display value of the adjusted timestamp remain the same, but the underlying timestamp +/// representation is adjusted according to the relative timezone offset to UTC. +/// +/// This function uses chrono to handle daylight saving time changes. +/// +/// For example, +/// +/// ```text +/// '2019-03-31T01:00:00Z'::timestamp at time zone 'Europe/Brussels' +/// ``` +/// +/// is displayed as follows in datafusion-cli: +/// +/// ```text +/// 2019-03-31T01:00:00+01:00 +/// ``` +/// +/// and is represented in DataFusion as: +/// +/// ```text +/// TimestampNanosecond(Some(1_553_990_400_000_000_000), Some("Europe/Brussels")) +/// ``` +/// +/// To strip off the timezone while keeping the display value the same, we need to +/// adjust the underlying timestamp with the timezone offset value using `adjust_to_local_time()` +/// +/// ```text +/// adjust_to_local_time(1_553_990_400_000_000_000, "Europe/Brussels") --> 1_553_994_000_000_000_000 +/// ``` +/// +/// The difference between `1_553_990_400_000_000_000` and `1_553_994_000_000_000_000` is +/// `3600_000_000_000` ns, which corresponds to 1 hour. This matches with the timezone +/// offset for "Europe/Brussels" for this date. +/// +/// Note that the offset varies with daylight savings time (DST), which makes this tricky! For +/// example, timezone "Europe/Brussels" has a 2-hour offset during DST and a 1-hour offset +/// when DST ends. +/// +/// Consequently, DataFusion can represent the timestamp in local time (with no offset or +/// timezone information) as +/// +/// ```text +/// TimestampNanosecond(Some(1_553_994_000_000_000_000), None) +/// ``` +/// +/// which is displayed as follows in datafusion-cli: +/// +/// ```text +/// 2019-03-31T01:00:00 +/// ``` +/// +/// See `test_adjust_to_local_time()` for example +pub fn adjust_to_local_time(ts: i64, tz: Tz) -> Result { + fn convert_timestamp(ts: i64, converter: F) -> Result> + where + F: Fn(i64) -> MappedLocalTime>, + { + match converter(ts) { + MappedLocalTime::Ambiguous(earliest, latest) => exec_err!( + "Ambiguous timestamp. Do you mean {:?} or {:?}", + earliest, + latest + ), + MappedLocalTime::None => exec_err!( + "The local time does not exist because there is a gap in the local time." + ), + Single(date_time) => Ok(date_time), + } + } + + let date_time = match T::UNIT { + TimeUnit::Nanosecond => Utc.timestamp_nanos(ts), + TimeUnit::Microsecond => convert_timestamp(ts, |ts| Utc.timestamp_micros(ts))?, + TimeUnit::Millisecond => { + convert_timestamp(ts, |ts| Utc.timestamp_millis_opt(ts))? + } + TimeUnit::Second => convert_timestamp(ts, |ts| Utc.timestamp_opt(ts, 0))?, + }; + + let offset_seconds: i64 = tz + .offset_from_utc_datetime(&date_time.naive_utc()) + .fix() + .local_minus_utc() as i64; + + let adjusted_date_time = date_time.add( + TimeDelta::try_seconds(offset_seconds) + .ok_or_else(|| internal_datafusion_err!("Offset seconds should be less than i64::MAX / 1_000 or greater than -i64::MAX / 1_000"))?, + ); + + // convert back to i64 + match T::UNIT { + TimeUnit::Nanosecond => adjusted_date_time.timestamp_nanos_opt().ok_or_else(|| { + internal_datafusion_err!( + "Failed to convert DateTime to timestamp in nanosecond. This error may occur if the date is out of range. The supported date ranges are between 1677-09-21T00:12:43.145224192 and 2262-04-11T23:47:16.854775807" + ) + }), + TimeUnit::Microsecond => Ok(adjusted_date_time.timestamp_micros()), + TimeUnit::Millisecond => Ok(adjusted_date_time.timestamp_millis()), + TimeUnit::Second => Ok(adjusted_date_time.timestamp()), + } +} + static UTC: LazyLock = LazyLock::new(|| "UTC".parse().expect("UTC is always valid")); /// Converts a string representation of a date‑time into a timestamp expressed in @@ -452,8 +556,7 @@ where ) } other => exec_err!( - "Unsupported data type {other:?} for function substr,\ - expected Utf8View, Utf8 or LargeUtf8." + "Unsupported data type {other:?} for function substr, expected Utf8View, Utf8 or LargeUtf8." ), }, other => exec_err!( @@ -487,12 +590,14 @@ where DataType::Utf8View => Ok(a.as_string_view().value(pos)), DataType::LargeUtf8 => Ok(a.as_string::().value(pos)), DataType::Utf8 => Ok(a.as_string::().value(pos)), - other => exec_err!("Unexpected type encountered '{other}'"), + other => exec_err!("Unexpected type encountered '{}'", other), }, ColumnarValue::Scalar(s) => match s.try_as_str() { Some(Some(v)) => Ok(v), Some(None) => continue, // null string - None => exec_err!("Unexpected scalar type encountered '{s}'"), + None => { + exec_err!("Unexpected scalar type encountered '{}'", s) + } }, }?; @@ -540,6 +645,6 @@ fn scalar_value(dt: &DataType, r: Option) -> Result { TimeUnit::Microsecond => Ok(ScalarValue::TimestampMicrosecond(r, tz.clone())), TimeUnit::Nanosecond => Ok(ScalarValue::TimestampNanosecond(r, tz.clone())), }, - t => Err(internal_datafusion_err!("Unsupported data type: {t:?}")), + t => Err(internal_datafusion_err!("Unsupported data type: {:?}", t)), } } diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index 375200d07280b..c08406de7afea 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -19,20 +19,29 @@ use std::any::Any; use std::str::FromStr; use std::sync::Arc; -use arrow::array::{Array, ArrayRef, Float64Array, Int32Array}; +use arrow::array::timezone::Tz; +use arrow::array::{Array, ArrayRef, Float64Array, Int32Array, PrimitiveBuilder}; use arrow::compute::kernels::cast_utils::IntervalUnit; use arrow::compute::{DatePart, binary, date_part}; use arrow::datatypes::DataType::{ Date32, Date64, Duration, Interval, Time32, Time64, Timestamp, }; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; -use arrow::datatypes::{DataType, Field, FieldRef, TimeUnit}; +use arrow::datatypes::{ + ArrowTimestampType, DataType, Date32Type, Date64Type, Field, FieldRef, + IntervalUnit as ArrowIntervalUnit, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, +}; +use chrono::{Datelike, NaiveDate, TimeZone, Utc}; +use datafusion_common::cast::as_primitive_array; use datafusion_common::types::{NativeType, logical_date}; +use super::adjust_to_local_time; use datafusion_common::{ Result, ScalarValue, cast::{ - as_date32_array, as_date64_array, as_int32_array, as_time32_millisecond_array, + as_date32_array, as_date64_array, as_int32_array, as_interval_dt_array, + as_interval_mdn_array, as_interval_ym_array, as_time32_millisecond_array, as_time32_second_array, as_time64_microsecond_array, as_time64_nanosecond_array, as_timestamp_microsecond_array, as_timestamp_millisecond_array, as_timestamp_nanosecond_array, as_timestamp_second_array, @@ -41,9 +50,12 @@ use datafusion_common::{ types::logical_string, utils::take_function_args, }; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::preimage::PreimageResult; +use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::{ - ColumnarValue, Documentation, ReturnFieldArgs, ScalarUDFImpl, Signature, - TypeSignature, Volatility, + ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarUDFImpl, Signature, + TypeSignature, Volatility, interval_arithmetic, }; use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; use datafusion_macros::user_doc; @@ -56,8 +68,9 @@ use datafusion_macros::user_doc; argument( name = "part", description = r#"Part of the date to return. The following date parts are supported: - + - year + - isoyear (ISO 8601 week-numbering year) - quarter (emits value in inclusive range [1, 4] based on which quartile of the year the date is in) - month - week (week of the year) @@ -70,7 +83,7 @@ use datafusion_macros::user_doc; - nanosecond - dow (day of the week where Sunday is 0) - doy (day of the year) - - epoch (seconds since Unix epoch) + - epoch (seconds since Unix epoch for timestamps/dates, total seconds for intervals) - isodow (day of the week where Monday is 0) "# ), @@ -122,9 +135,9 @@ impl DatePartFunc { Coercion::new_exact(TypeSignatureClass::Duration), ]), ], - Volatility::Immutable, + Volatility::Immutable ), - aliases: vec![String::from("datepart")], + aliases: vec![String::from("datepart"), String::from("extract")], } } } @@ -148,6 +161,7 @@ impl ScalarUDFImpl for DatePartFunc { fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { let [field, _] = take_function_args(self.name(), args.scalar_arguments)?; + let nullable = args.arg_fields[1].is_nullable(); field .and_then(|sv| { @@ -156,9 +170,9 @@ impl ScalarUDFImpl for DatePartFunc { .filter(|s| !s.is_empty()) .map(|part| { if is_epoch(part) { - Field::new(self.name(), DataType::Float64, true) + Field::new(self.name(), DataType::Float64, nullable) } else { - Field::new(self.name(), DataType::Int32, true) + Field::new(self.name(), DataType::Int32, nullable) } }) }) @@ -173,6 +187,7 @@ impl ScalarUDFImpl for DatePartFunc { &self, args: datafusion_expr::ScalarFunctionArgs, ) -> Result { + let config = &args.config_options; let args = args.args; let [part, array] = take_function_args(self.name(), args)?; @@ -193,7 +208,10 @@ impl ScalarUDFImpl for DatePartFunc { ColumnarValue::Scalar(scalar) => scalar.to_array()?, }; + let part_trim = part_normalization(&part); + let is_epoch = is_epoch(part_trim); + let array = adjust_array_for_timezone(config, array, is_epoch)?; // using IntervalUnit here means we hand off all the work of supporting plurals (like "seconds") // and synonyms ( like "ms,msec,msecond,millisecond") to Arrow @@ -209,12 +227,12 @@ impl ScalarUDFImpl for DatePartFunc { IntervalUnit::Millisecond => seconds_as_i32(array.as_ref(), Millisecond)?, IntervalUnit::Microsecond => seconds_as_i32(array.as_ref(), Microsecond)?, IntervalUnit::Nanosecond => seconds_as_i32(array.as_ref(), Nanosecond)?, - // century and decade are not supported by `DatePart`, although they are supported in postgres _ => return exec_err!("Date part '{part}' not supported"), } } else { // special cases that can be extracted (in postgres) but are not interval units match part_trim.to_lowercase().as_str() { + "isoyear" => date_part(array.as_ref(), DatePart::YearISO)?, "qtr" | "quarter" => date_part(array.as_ref(), DatePart::Quarter)?, "doy" => date_part(array.as_ref(), DatePart::DayOfYear)?, "dow" => date_part(array.as_ref(), DatePart::DayOfWeekSunday0)?, @@ -231,6 +249,71 @@ impl ScalarUDFImpl for DatePartFunc { }) } + // Only casting the year is supported since pruning other IntervalUnit is not possible + // date_part(col, YEAR) = 2024 => col >= '2024-01-01' and col < '2025-01-01' + // But for anything less than YEAR simplifying is not possible without specifying the bigger interval + // date_part(col, MONTH) = 1 => col = '2023-01-01' or col = '2024-01-01' or ... or col = '3000-01-01' + fn preimage( + &self, + args: &[Expr], + lit_expr: &Expr, + info: &SimplifyContext, + ) -> Result { + let [part, col_expr] = take_function_args(self.name(), args)?; + + // Get the interval unit from the part argument + let interval_unit = part + .as_literal() + .and_then(|sv| sv.try_as_str().flatten()) + .map(part_normalization) + .and_then(|s| IntervalUnit::from_str(s).ok()); + + // only support extracting year + match interval_unit { + Some(IntervalUnit::Year) => (), + _ => return Ok(PreimageResult::None), + } + + // Check if the argument is a literal (e.g. date_part(YEAR, col) = 2024) + let Some(argument_literal) = lit_expr.as_literal() else { + return Ok(PreimageResult::None); + }; + + // Extract i32 year from Scalar value + let year = match argument_literal { + ScalarValue::Int32(Some(y)) => *y, + _ => return Ok(PreimageResult::None), + }; + + // Can only extract year from Date32/64 and Timestamp column + let target_type = match info.get_data_type(col_expr)? { + Date32 | Date64 | Timestamp(_, _) => &info.get_data_type(col_expr)?, + _ => return Ok(PreimageResult::None), + }; + + // Compute the Interval bounds + let Some(start_time) = NaiveDate::from_ymd_opt(year, 1, 1) else { + return Ok(PreimageResult::None); + }; + let Some(end_time) = start_time.with_year(year + 1) else { + return Ok(PreimageResult::None); + }; + + // Convert to ScalarValues + let (Some(lower), Some(upper)) = ( + date_to_scalar(start_time, target_type), + date_to_scalar(end_time, target_type), + ) else { + return Ok(PreimageResult::None); + }; + let interval = Box::new(interval_arithmetic::Interval::try_new(lower, upper)?); + + Ok(PreimageResult::Range { + expr: col_expr.clone(), + interval, + }) + } + fn aliases(&self) -> &[String] { &self.aliases } @@ -240,11 +323,160 @@ impl ScalarUDFImpl for DatePartFunc { } } +fn adjust_timestamp_array( + array: &ArrayRef, + tz: Tz, +) -> Result { + let mut builder = PrimitiveBuilder::::new(); + let primitive_array = as_primitive_array::(array)?; + for ts_opt in primitive_array.iter() { + match ts_opt { + None => builder.append_null(), + Some(ts) => { + let adjusted_ts = adjust_to_local_time::(ts, tz)?; + builder.append_value(adjusted_ts); + } + } + } + Ok(Arc::new(builder.finish())) +} + fn is_epoch(part: &str) -> bool { - let part = part_normalization(part); matches!(part.to_lowercase().as_str(), "epoch") } + +/// Adjusts the timezone of a given array containing timestamp data. +/// +/// # Arguments +/// +/// * `config` - A reference to an Arc (atomic reference-counted pointer) wrapping the configuration options. +/// * `array` - An Arc wrapping a dynamically typed array that needs to be adjusted for timezone. +/// * `is_epoch` - A boolean flag indicating whether the timestamps are represented as epoch values. +/// +/// # Behavior +/// +/// This function first checks if the provided array contains timezone-aware timestamps or not. +/// If so, it extracts these timestamps in their own timezone. If `is_epoch` is false and the array +/// does not contain timezone-aware timestamps, but it contains naive timestamps (timestamps without +/// associated timezone information), this function will interpret them in the session timezone if +/// available in the configuration options. +/// +/// If none of these conditions are met, the original array is returned. +fn adjust_array_for_timezone(config: &Arc, array: Arc, is_epoch: bool) -> Result> { + let (is_timezone_aware, tz_str_opt) = match array.data_type() { + Timestamp(_, Some(tz_str)) => (true, Some(Arc::clone(tz_str))), + _ => (false, None), + }; + + // Epoch is timezone-independent - it always returns seconds since 1970-01-01 UTC + let array = if is_epoch { + array + } else if is_timezone_aware { + // For timezone-aware timestamps, extract in their own timezone + match tz_str_opt.as_ref() { + Some(tz_str) => { + let tz = interpret_session_timezone(tz_str)?; + match array.data_type() { + Timestamp(time_unit, _) => match time_unit { + Nanosecond => adjust_timestamp_array::< + TimestampNanosecondType, + >(&array, tz)?, + Microsecond => adjust_timestamp_array::< + TimestampMicrosecondType, + >(&array, tz)?, + Millisecond => adjust_timestamp_array::< + TimestampMillisecondType, + >(&array, tz)?, + Second => { + adjust_timestamp_array::(&array, tz)? + } + }, + _ => array, + } + } + None => array, + } + } else if let Timestamp(time_unit, None) = array.data_type() { + // For naive timestamps, interpret in session timezone if available + match config.execution.time_zone.as_ref() { + Some(tz_str) => { + let tz = interpret_session_timezone(tz_str)?; + + match time_unit { + Nanosecond => { + adjust_timestamp_array::(&array, tz)? + } + Microsecond => { + adjust_timestamp_array::( + &array, tz, + )? + } + Millisecond => { + adjust_timestamp_array::( + &array, tz, + )? + } + Second => { + adjust_timestamp_array::(&array, tz)? + } + } + } + None => array, + } + } else { + array + }; + + Ok(array) +} + +fn date_to_scalar(date: NaiveDate, target_type: &DataType) -> Option { + Some(match target_type { + Date32 => ScalarValue::Date32(Some(Date32Type::from_naive_date(date))), + Date64 => ScalarValue::Date64(Some(Date64Type::from_naive_date(date))), + + Timestamp(unit, tz_opt) => { + let naive_midnight = date.and_hms_opt(0, 0, 0)?; + + let utc_dt = if let Some(tz_str) = tz_opt { + let tz: Tz = tz_str.parse().ok()?; + + let local = tz.from_local_datetime(&naive_midnight); + + let local_dt = match local { + chrono::offset::LocalResult::Single(dt) => dt, + chrono::offset::LocalResult::Ambiguous(dt1, _dt2) => dt1, + chrono::offset::LocalResult::None => local.earliest()?, + }; + + local_dt.with_timezone(&Utc) + } else { + Utc.from_utc_datetime(&naive_midnight) + }; + + match unit { + Second => { + ScalarValue::TimestampSecond(Some(utc_dt.timestamp()), tz_opt.clone()) + } + Millisecond => ScalarValue::TimestampMillisecond( + Some(utc_dt.timestamp_millis()), + tz_opt.clone(), + ), + Microsecond => ScalarValue::TimestampMicrosecond( + Some(utc_dt.timestamp_micros()), + tz_opt.clone(), + ), + Nanosecond => ScalarValue::TimestampNanosecond( + Some(utc_dt.timestamp_nanos_opt()?), + tz_opt.clone(), + ), + } + } + _ => return None, + }) +} + // Try to remove quote if exist, if the quote is invalid, return original string and let the downstream function handle the error fn part_normalization(part: &str) -> &str { part.strip_prefix(|c| c == '\'' || c == '\"') @@ -252,9 +484,13 @@ fn part_normalization(part: &str) -> &str { .unwrap_or(part) } -/// Invoke [`date_part`] on an `array` (e.g. Timestamp) and convert the -/// result to a total number of seconds, milliseconds, microseconds or -/// nanoseconds +fn interpret_session_timezone(tz_str: &str) -> Result { + match tz_str.parse::() { + Ok(tz) => Ok(tz), + Err(err) => exec_err!("Invalid timezone '{tz_str}': {err}"), + } +} + fn seconds_as_i32(array: &dyn Array, unit: TimeUnit) -> Result { // Nanosecond is neither supported in Postgres nor DuckDB, to avoid dealing // with overflow and precision issue we don't support nanosecond @@ -277,7 +513,6 @@ fn seconds_as_i32(array: &dyn Array, unit: TimeUnit) -> Result { }; let secs = date_part(array, DatePart::Second)?; - // This assumes array is primitive and not a dictionary let secs = as_int32_array(secs.as_ref())?; let subsecs = date_part(array, DatePart::Nanosecond)?; let subsecs = as_int32_array(subsecs.as_ref())?; @@ -305,11 +540,8 @@ fn seconds_as_i32(array: &dyn Array, unit: TimeUnit) -> Result { } } -/// Invoke [`date_part`] on an `array` (e.g. Timestamp) and convert the -/// result to a total number of seconds, milliseconds, microseconds or -/// nanoseconds -/// -/// Given epoch return f64, this is a duplicated function to optimize for f64 type +// Converts seconds to f64 with the specified time unit. +// Used for Interval and Duration types that need floating-point precision. fn seconds(array: &dyn Array, unit: TimeUnit) -> Result { let sf = match unit { Second => 1_f64, @@ -318,7 +550,6 @@ fn seconds(array: &dyn Array, unit: TimeUnit) -> Result { Nanosecond => 1_000_000_000_f64, }; let secs = date_part(array, DatePart::Second)?; - // This assumes array is primitive and not a dictionary let secs = as_int32_array(secs.as_ref())?; let subsecs = date_part(array, DatePart::Nanosecond)?; let subsecs = as_int32_array(subsecs.as_ref())?; @@ -349,6 +580,11 @@ fn seconds(array: &dyn Array, unit: TimeUnit) -> Result { fn epoch(array: &dyn Array) -> Result { const SECONDS_IN_A_DAY: f64 = 86400_f64; + // Note: Month-to-second conversion uses 30 days as an approximation. + // This matches PostgreSQL's behavior for interval epoch extraction, + // but does not represent exact calendar months (which vary 28-31 days). + // See: https://doxygen.postgresql.org/datatype_2timestamp_8h.html + const DAYS_PER_MONTH: f64 = 30_f64; let f: Float64Array = match array.data_type() { Timestamp(Second, _) => as_timestamp_second_array(array)?.unary(|x| x as f64), @@ -373,7 +609,19 @@ fn epoch(array: &dyn Array) -> Result { Time64(Nanosecond) => { as_time64_nanosecond_array(array)?.unary(|x| x as f64 / 1_000_000_000_f64) } - Interval(_) | Duration(_) => return seconds(array, Second), + Interval(ArrowIntervalUnit::YearMonth) => as_interval_ym_array(array)? + .unary(|x| x as f64 * DAYS_PER_MONTH * SECONDS_IN_A_DAY), + Interval(ArrowIntervalUnit::DayTime) => as_interval_dt_array(array)?.unary(|x| { + x.days as f64 * SECONDS_IN_A_DAY + x.milliseconds as f64 / 1_000_f64 + }), + Interval(ArrowIntervalUnit::MonthDayNano) => { + as_interval_mdn_array(array)?.unary(|x| { + x.months as f64 * DAYS_PER_MONTH * SECONDS_IN_A_DAY + + x.days as f64 * SECONDS_IN_A_DAY + + x.nanoseconds as f64 / 1_000_000_000_f64 + }) + } + Duration(_) => return seconds(array, Second), d => return exec_err!("Cannot convert {d:?} to epoch"), }; Ok(Arc::new(f)) diff --git a/datafusion/functions/src/datetime/date_trunc.rs b/datafusion/functions/src/datetime/date_trunc.rs index 8c8a4a1c1b771..8497e583ba4bc 100644 --- a/datafusion/functions/src/datetime/date_trunc.rs +++ b/datafusion/functions/src/datetime/date_trunc.rs @@ -34,14 +34,16 @@ use arrow::array::types::{ use arrow::array::{Array, ArrayRef, PrimitiveArray}; use arrow::datatypes::DataType::{self, Time32, Time64, Timestamp}; use arrow::datatypes::TimeUnit::{self, Microsecond, Millisecond, Nanosecond, Second}; +use arrow::datatypes::{Field, FieldRef}; use datafusion_common::cast::as_primitive_array; use datafusion_common::types::{NativeType, logical_date, logical_string}; use datafusion_common::{ - DataFusionError, Result, ScalarValue, exec_datafusion_err, exec_err, + DataFusionError, Result, ScalarValue, exec_datafusion_err, exec_err, internal_err, }; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, Volatility, + ColumnarValue, Documentation, ReturnFieldArgs, ScalarUDFImpl, Signature, + TypeSignature, Volatility, }; use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; use datafusion_macros::user_doc; @@ -221,12 +223,22 @@ impl ScalarUDFImpl for DateTruncFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types[1].is_null() { - Ok(Timestamp(Nanosecond, None)) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be called instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let field = &args.arg_fields[1]; + let return_type = if field.data_type().is_null() { + Timestamp(Nanosecond, None) } else { - Ok(arg_types[1].clone()) - } + field.data_type().clone() + }; + Ok(Arc::new(Field::new( + self.name(), + return_type, + field.is_nullable(), + ))) } fn invoke_with_args( diff --git a/datafusion/functions/src/datetime/mod.rs b/datafusion/functions/src/datetime/mod.rs index 39b9453295df6..4f3e45d761c34 100644 --- a/datafusion/functions/src/datetime/mod.rs +++ b/datafusion/functions/src/datetime/mod.rs @@ -22,6 +22,7 @@ use std::sync::Arc; use datafusion_expr::ScalarUDF; pub mod common; +pub use common::adjust_to_local_time; pub mod current_date; pub mod current_time; pub mod date_bin; diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index 8d0c47cfe664c..2c6f8235457c3 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -153,7 +153,7 @@ impl ScalarUDFImpl for ToCharFunc { ColumnarValue::Array(_) => to_char_array(&args), _ => { exec_err!( - "Format for `to_char` must be non-null Utf8, received {:?}", + "Format for `to_char` must be non-null Utf8, received {}", format.data_type() ) } @@ -814,7 +814,7 @@ mod tests { let result = ToCharFunc::new().invoke_with_args(args); assert_eq!( result.err().unwrap().strip_backtrace(), - "Execution error: Format for `to_char` must be non-null Utf8, received Timestamp(Nanosecond, None)" + "Execution error: Format for `to_char` must be non-null Utf8, received Timestamp(ns)" ); } } diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index 86c949711d011..178714c78bd37 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -16,7 +16,6 @@ // under the License. use std::any::Any; -use std::ops::Add; use std::sync::Arc; use arrow::array::timezone::Tz; @@ -27,13 +26,10 @@ use arrow::datatypes::{ ArrowTimestampType, DataType, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; -use chrono::{DateTime, MappedLocalTime, Offset, TimeDelta, TimeZone, Utc}; +use crate::datetime::adjust_to_local_time; use datafusion_common::cast::as_primitive_array; -use datafusion_common::{ - Result, ScalarValue, exec_err, internal_datafusion_err, internal_err, - utils::take_function_args, -}; +use datafusion_common::{Result, ScalarValue, internal_err, utils::take_function_args}; use datafusion_expr::{ Coercion, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, @@ -324,60 +320,12 @@ fn to_local_time(time_value: &ColumnarValue) -> Result { /// ``` /// /// See `test_adjust_to_local_time()` for example -fn adjust_to_local_time(ts: i64, tz: Tz) -> Result { - fn convert_timestamp(ts: i64, converter: F) -> Result> - where - F: Fn(i64) -> MappedLocalTime>, - { - match converter(ts) { - MappedLocalTime::Ambiguous(earliest, latest) => exec_err!( - "Ambiguous timestamp. Do you mean {:?} or {:?}", - earliest, - latest - ), - MappedLocalTime::None => exec_err!( - "The local time does not exist because there is a gap in the local time." - ), - MappedLocalTime::Single(date_time) => Ok(date_time), - } - } - - let date_time = match T::UNIT { - Nanosecond => Utc.timestamp_nanos(ts), - Microsecond => convert_timestamp(ts, |ts| Utc.timestamp_micros(ts))?, - Millisecond => convert_timestamp(ts, |ts| Utc.timestamp_millis_opt(ts))?, - Second => convert_timestamp(ts, |ts| Utc.timestamp_opt(ts, 0))?, - }; - - let offset_seconds: i64 = tz - .offset_from_utc_datetime(&date_time.naive_utc()) - .fix() - .local_minus_utc() as i64; - - let adjusted_date_time = date_time.add( - // This should not fail under normal circumstances as the - // maximum possible offset is 26 hours (93,600 seconds) - TimeDelta::try_seconds(offset_seconds) - .ok_or_else(|| internal_datafusion_err!("Offset seconds should be less than i64::MAX / 1_000 or greater than -i64::MAX / 1_000"))?, - ); - - // convert the naive datetime back to i64 - match T::UNIT { - Nanosecond => adjusted_date_time.timestamp_nanos_opt().ok_or_else(|| - internal_datafusion_err!( - "Failed to convert DateTime to timestamp in nanosecond. This error may occur if the date is out of range. The supported date ranges are between 1677-09-21T00:12:43.145224192 and 2262-04-11T23:47:16.854775807" - ) - ), - Microsecond => Ok(adjusted_date_time.timestamp_micros()), - Millisecond => Ok(adjusted_date_time.timestamp_millis()), - Second => Ok(adjusted_date_time.timestamp()), - } -} - #[cfg(test)] mod tests { use std::sync::Arc; + use super::ToLocalTimeFunc; + use crate::datetime::adjust_to_local_time; use arrow::array::{Array, TimestampNanosecondArray, types::TimestampNanosecondType}; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; use arrow::datatypes::{DataType, Field, TimeUnit}; @@ -386,8 +334,6 @@ mod tests { use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; - use super::{ToLocalTimeFunc, adjust_to_local_time}; - #[test] fn test_adjust_to_local_time() { let timestamp_str = "2020-03-31T13:40:00"; diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index 1c5d3dbd88bcd..6d40133bd29bf 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -42,7 +42,8 @@ use datafusion_macros::user_doc; description = r#" Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000`) in the session time zone. Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') -if no [Chrono formats] are provided. Strings that parse without a time zone are treated as if they are in the +if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. +Strings that parse without a time zone are treated as if they are in the session time zone, or UTC if no session time zone is set. Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). @@ -96,7 +97,8 @@ pub struct ToTimestampFunc { description = r#" Converts a value to a timestamp (`YYYY-MM-DDT00:00:00`) in the session time zone. Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') -if no [Chrono formats] are provided. Strings that parse without a time zone are treated as if they are in the +if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. +Strings that parse without a time zone are treated as if they are in the session time zone, or UTC if no session time zone is set. Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). @@ -145,7 +147,8 @@ pub struct ToTimestampSecondsFunc { description = r#" Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000`) in the session time zone. Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') -if no [Chrono formats] are provided. Strings that parse without a time zone are treated as if they are in the +if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. +Strings that parse without a time zone are treated as if they are in the session time zone, or UTC if no session time zone is set. Integers, unsigned integers, and doubles are interpreted as milliseconds since the unix epoch (`1970-01-01T00:00:00Z`). @@ -194,7 +197,8 @@ pub struct ToTimestampMillisFunc { description = r#" Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000`) in the session time zone. Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') -if no [Chrono formats] are provided. Strings that parse without a time zone are treated as if they are in the +if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. +Strings that parse without a time zone are treated as if they are in the session time zone, or UTC if no session time zone is set. Integers, unsigned integers, and doubles are interpreted as microseconds since the unix epoch (`1970-01-01T00:00:00Z`). @@ -243,7 +247,8 @@ pub struct ToTimestampMicrosFunc { description = r#" Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000000`) in the session time zone. Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') -if no [Chrono formats] are provided. Strings that parse without a time zone are treated as if they are in the +if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. +Strings that parse without a time zone are treated as if they are in the session time zone. Integers, unsigned integers, and doubles are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`). The session time zone can be set using the statement `SET TIMEZONE = 'desired time zone'`. @@ -425,27 +430,56 @@ impl ScalarUDFImpl for ToTimestampFunc { .cast_to(&Timestamp(Second, None), None)? .cast_to(&Timestamp(Nanosecond, tz), None), Null | Timestamp(_, _) => args[0].cast_to(&Timestamp(Nanosecond, tz), None), - Float16 => { - let arr = args[0].to_array(1)?; - let f16_arr = downcast_arg!(&arr, Float16Array); - let result: TimestampNanosecondArray = - f16_arr.unary(|x| (x.to_f64() * 1_000_000_000.0) as i64); - Ok(ColumnarValue::Array(Arc::new(result.with_timezone_opt(tz)))) - } - Float32 => { - let arr = args[0].to_array(1)?; - let f32_arr = downcast_arg!(&arr, Float32Array); - let result: TimestampNanosecondArray = - f32_arr.unary(|x| (x as f64 * 1_000_000_000.0) as i64); - Ok(ColumnarValue::Array(Arc::new(result.with_timezone_opt(tz)))) - } - Float64 => { - let arr = args[0].to_array(1)?; - let f64_arr = downcast_arg!(&arr, Float64Array); - let result: TimestampNanosecondArray = - f64_arr.unary(|x| (x * 1_000_000_000.0) as i64); - Ok(ColumnarValue::Array(Arc::new(result.with_timezone_opt(tz)))) - } + Float16 => match &args[0] { + ColumnarValue::Scalar(ScalarValue::Float16(value)) => { + let timestamp_nanos = + value.map(|v| (v.to_f64() * 1_000_000_000.0) as i64); + Ok(ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + timestamp_nanos, + tz, + ))) + } + ColumnarValue::Array(arr) => { + let f16_arr = downcast_arg!(arr, Float16Array); + let result: TimestampNanosecondArray = + f16_arr.unary(|x| (x.to_f64() * 1_000_000_000.0) as i64); + Ok(ColumnarValue::Array(Arc::new(result.with_timezone_opt(tz)))) + } + _ => exec_err!("Invalid Float16 value for to_timestamp"), + }, + Float32 => match &args[0] { + ColumnarValue::Scalar(ScalarValue::Float32(value)) => { + let timestamp_nanos = + value.map(|v| (v as f64 * 1_000_000_000.0) as i64); + Ok(ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + timestamp_nanos, + tz, + ))) + } + ColumnarValue::Array(arr) => { + let f32_arr = downcast_arg!(arr, Float32Array); + let result: TimestampNanosecondArray = + f32_arr.unary(|x| (x as f64 * 1_000_000_000.0) as i64); + Ok(ColumnarValue::Array(Arc::new(result.with_timezone_opt(tz)))) + } + _ => exec_err!("Invalid Float32 value for to_timestamp"), + }, + Float64 => match &args[0] { + ColumnarValue::Scalar(ScalarValue::Float64(value)) => { + let timestamp_nanos = value.map(|v| (v * 1_000_000_000.0) as i64); + Ok(ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + timestamp_nanos, + tz, + ))) + } + ColumnarValue::Array(arr) => { + let f64_arr = downcast_arg!(arr, Float64Array); + let result: TimestampNanosecondArray = + f64_arr.unary(|x| (x * 1_000_000_000.0) as i64); + Ok(ColumnarValue::Array(Arc::new(result.with_timezone_opt(tz)))) + } + _ => exec_err!("Invalid Float64 value for to_timestamp"), + }, Decimal32(_, _) | Decimal64(_, _) | Decimal256(_, _) => { let arg = args[0].cast_to(&Decimal128(38, 9), None)?; decimal128_to_timestamp_nanos(&arg, tz) diff --git a/datafusion/functions/src/datetime/to_unixtime.rs b/datafusion/functions/src/datetime/to_unixtime.rs index 5ebcce0a7cfc2..2dd377282725a 100644 --- a/datafusion/functions/src/datetime/to_unixtime.rs +++ b/datafusion/functions/src/datetime/to_unixtime.rs @@ -27,7 +27,12 @@ use std::any::Any; #[user_doc( doc_section(label = "Time and Date Functions"), - description = "Converts a value to seconds since the unix epoch (`1970-01-01T00:00:00`). Supports strings, dates, timestamps, integer, unsigned integer, and float types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. Integers, unsigned integers, and floats are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00`).", + description = r#" +Converts a value to seconds since the unix epoch (`1970-01-01T00:00:00`). +Supports strings, dates, timestamps, integer, unsigned integer, and float types as input. +Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') +if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. +Integers, unsigned integers, and floats are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00`)."#, syntax_example = "to_unixtime(expression[, ..., format_n])", sql_example = r#" ```sql diff --git a/datafusion/functions/src/encoding/inner.rs b/datafusion/functions/src/encoding/inner.rs index 7b72c264e5557..4ad67b78178f2 100644 --- a/datafusion/functions/src/encoding/inner.rs +++ b/datafusion/functions/src/encoding/inner.rs @@ -19,8 +19,8 @@ use arrow::{ array::{ - Array, ArrayRef, AsArray, BinaryArrayType, FixedSizeBinaryArray, - GenericBinaryArray, GenericStringArray, OffsetSizeTrait, + Array, ArrayRef, AsArray, BinaryArrayType, GenericBinaryArray, + GenericStringArray, OffsetSizeTrait, }, datatypes::DataType, }; @@ -52,6 +52,12 @@ const BASE64_ENGINE: GeneralPurpose = GeneralPurpose::new( .with_decode_padding_mode(DecodePaddingMode::Indifferent), ); +// Generate padding characters when encoding +const BASE64_ENGINE_PADDED: GeneralPurpose = GeneralPurpose::new( + &base64::alphabet::STANDARD, + GeneralPurposeConfig::new().with_encode_padding(true), +); + #[user_doc( doc_section(label = "Binary String Functions"), description = "Encode binary data into a textual representation.", @@ -62,7 +68,7 @@ const BASE64_ENGINE: GeneralPurpose = GeneralPurpose::new( ), argument( name = "format", - description = "Supported formats are: `base64`, `hex`" + description = "Supported formats are: `base64`, `base64pad`, `hex`" ), related_udf(name = "decode") )] @@ -239,7 +245,7 @@ fn encode_array(array: &ArrayRef, encoding: Encoding) -> Result { encoding.encode_array::<_, i64>(&array.as_binary::()) } DataType::FixedSizeBinary(_) => { - encoding.encode_fsb_array(array.as_fixed_size_binary()) + encoding.encode_array::<_, i32>(&array.as_fixed_size_binary()) } dt => { internal_err!("Unexpected data type for encode: {dt}") @@ -307,7 +313,7 @@ fn decode_array(array: &ArrayRef, encoding: Encoding) -> Result { let array = array.as_fixed_size_binary(); // TODO: could we be more conservative by accounting for nulls? let estimate = array.len().saturating_mul(*size as usize); - encoding.decode_fsb_array(array, estimate) + encoding.decode_array::<_, i32>(&array, estimate) } dt => { internal_err!("Unexpected data type for decode: {dt}") @@ -319,12 +325,18 @@ fn decode_array(array: &ArrayRef, encoding: Encoding) -> Result { #[derive(Debug, Copy, Clone)] enum Encoding { Base64, + Base64Padded, Hex, } impl fmt::Display for Encoding { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", format!("{self:?}").to_lowercase()) + let name = match self { + Self::Base64 => "base64", + Self::Base64Padded => "base64pad", + Self::Hex => "hex", + }; + write!(f, "{name}") } } @@ -345,9 +357,10 @@ impl TryFrom<&ColumnarValue> for Encoding { }; match encoding { "base64" => Ok(Self::Base64), + "base64pad" => Ok(Self::Base64Padded), "hex" => Ok(Self::Hex), _ => { - let options = [Self::Base64, Self::Hex] + let options = [Self::Base64, Self::Base64Padded, Self::Hex] .iter() .map(|i| i.to_string()) .collect::>() @@ -364,15 +377,18 @@ impl Encoding { fn encode_bytes(self, value: &[u8]) -> String { match self { Self::Base64 => BASE64_ENGINE.encode(value), + Self::Base64Padded => BASE64_ENGINE_PADDED.encode(value), Self::Hex => hex::encode(value), } } fn decode_bytes(self, value: &[u8]) -> Result> { match self { - Self::Base64 => BASE64_ENGINE.decode(value).map_err(|e| { - exec_datafusion_err!("Failed to decode value using base64: {e}") - }), + Self::Base64 | Self::Base64Padded => { + BASE64_ENGINE.decode(value).map_err(|e| { + exec_datafusion_err!("Failed to decode value using {self}: {e}") + }) + } Self::Hex => hex::decode(value).map_err(|e| { exec_datafusion_err!("Failed to decode value using hex: {e}") }), @@ -396,26 +412,15 @@ impl Encoding { .collect(); Ok(Arc::new(array)) } - Self::Hex => { - let array: GenericStringArray = - array.iter().map(|x| x.map(hex::encode)).collect(); - Ok(Arc::new(array)) - } - } - } - - // TODO: refactor this away once https://github.com/apache/arrow-rs/pull/8993 lands - fn encode_fsb_array(self, array: &FixedSizeBinaryArray) -> Result { - match self { - Self::Base64 => { - let array: GenericStringArray = array + Self::Base64Padded => { + let array: GenericStringArray = array .iter() - .map(|x| x.map(|x| BASE64_ENGINE.encode(x))) + .map(|x| x.map(|x| BASE64_ENGINE_PADDED.encode(x))) .collect(); Ok(Arc::new(array)) } Self::Hex => { - let array: GenericStringArray = + let array: GenericStringArray = array.iter().map(|x| x.map(hex::encode)).collect(); Ok(Arc::new(array)) } @@ -448,7 +453,7 @@ impl Encoding { } match self { - Self::Base64 => { + Self::Base64 | Self::Base64Padded => { let upper_bound = base64::decoded_len_estimate(approx_data_size); delegated_decode::<_, _, OutputOffset>(base64_decode, value, upper_bound) } @@ -461,73 +466,6 @@ impl Encoding { } } } - - // TODO: refactor this away once https://github.com/apache/arrow-rs/pull/8993 lands - fn decode_fsb_array( - self, - value: &FixedSizeBinaryArray, - approx_data_size: usize, - ) -> Result { - fn hex_decode(input: &[u8], buf: &mut [u8]) -> Result { - // only write input / 2 bytes to buf - let out_len = input.len() / 2; - let buf = &mut buf[..out_len]; - hex::decode_to_slice(input, buf) - .map_err(|e| exec_datafusion_err!("Failed to decode from hex: {e}"))?; - Ok(out_len) - } - - fn base64_decode(input: &[u8], buf: &mut [u8]) -> Result { - BASE64_ENGINE - .decode_slice(input, buf) - .map_err(|e| exec_datafusion_err!("Failed to decode from base64: {e}")) - } - - fn delegated_decode( - decode: DecodeFunction, - input: &FixedSizeBinaryArray, - conservative_upper_bound_size: usize, - ) -> Result - where - DecodeFunction: Fn(&[u8], &mut [u8]) -> Result, - { - let mut values = vec![0; conservative_upper_bound_size]; - let mut offsets = OffsetBufferBuilder::new(input.len()); - let mut total_bytes_decoded = 0; - for v in input.iter() { - if let Some(v) = v { - let cursor = &mut values[total_bytes_decoded..]; - let decoded = decode(v, cursor)?; - total_bytes_decoded += decoded; - offsets.push_length(decoded); - } else { - offsets.push_length(0); - } - } - // We reserved an upper bound size for the values buffer, but we only use the actual size - values.truncate(total_bytes_decoded); - let binary_array = GenericBinaryArray::::try_new( - offsets.finish(), - Buffer::from_vec(values), - input.nulls().cloned(), - )?; - Ok(Arc::new(binary_array)) - } - - match self { - Self::Base64 => { - let upper_bound = base64::decoded_len_estimate(approx_data_size); - delegated_decode(base64_decode, value, upper_bound) - } - Self::Hex => { - // Calculate the upper bound for decoded byte size - // For hex encoding, each pair of hex characters (2 bytes) represents 1 byte when decoded - // So the upper bound is half the length of the input values. - let upper_bound = approx_data_size / 2; - delegated_decode(hex_decode, value, upper_bound) - } - } - } } fn delegated_decode<'a, DecodeFunction, InputBinaryArray, OutputOffset>( diff --git a/datafusion/functions/src/lib.rs b/datafusion/functions/src/lib.rs index f88304a6a5f8d..b9ce113efa627 100644 --- a/datafusion/functions/src/lib.rs +++ b/datafusion/functions/src/lib.rs @@ -24,8 +24,6 @@ // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -// https://github.com/apache/datafusion/issues/18881 -#![deny(clippy::allow_attributes)] //! Function packages for [DataFusion]. //! diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index 4adc331fef669..380877b593643 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -332,7 +332,8 @@ macro_rules! make_math_binary_udf { use arrow::array::{ArrayRef, AsArray}; use arrow::datatypes::{DataType, Float32Type, Float64Type}; - use datafusion_common::{Result, exec_err}; + use datafusion_common::utils::take_function_args; + use datafusion_common::{Result, ScalarValue, internal_err}; use datafusion_expr::TypeSignature; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ @@ -393,37 +394,76 @@ macro_rules! make_math_binary_udf { &self, args: ScalarFunctionArgs, ) -> Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => { - let y = args[0].as_primitive::(); - let x = args[1].as_primitive::(); - let result = arrow::compute::binary::<_, _, _, Float64Type>( - y, - x, - |y, x| f64::$BINARY_FUNC(y, x), - )?; - Arc::new(result) as _ - } - DataType::Float32 => { - let y = args[0].as_primitive::(); - let x = args[1].as_primitive::(); - let result = arrow::compute::binary::<_, _, _, Float32Type>( - y, - x, - |y, x| f32::$BINARY_FUNC(y, x), - )?; - Arc::new(result) as _ - } - other => { - return exec_err!( - "Unsupported data type {other:?} for function {}", - self.name() - ); + let ScalarFunctionArgs { + args, return_field, .. + } = args; + let return_type = return_field.data_type(); + let [y, x] = take_function_args(self.name(), args)?; + + match (y, x) { + ( + ColumnarValue::Scalar(y_scalar), + ColumnarValue::Scalar(x_scalar), + ) => match (&y_scalar, &x_scalar) { + (y, x) if y.is_null() || x.is_null() => { + ColumnarValue::Scalar(ScalarValue::Null) + .cast_to(return_type, None) + } + ( + ScalarValue::Float64(Some(yv)), + ScalarValue::Float64(Some(xv)), + ) => Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some( + f64::$BINARY_FUNC(*yv, *xv), + )))), + ( + ScalarValue::Float32(Some(yv)), + ScalarValue::Float32(Some(xv)), + ) => Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some( + f32::$BINARY_FUNC(*yv, *xv), + )))), + _ => internal_err!( + "Unexpected scalar types for function {}: {:?}, {:?}", + self.name(), + y_scalar.data_type(), + x_scalar.data_type() + ), + }, + (y, x) => { + let args = ColumnarValue::values_to_arrays(&[y, x])?; + let arr: ArrayRef = match args[0].data_type() { + DataType::Float64 => { + let y = args[0].as_primitive::(); + let x = args[1].as_primitive::(); + let result = + arrow::compute::binary::<_, _, _, Float64Type>( + y, + x, + |y, x| f64::$BINARY_FUNC(y, x), + )?; + Arc::new(result) as _ + } + DataType::Float32 => { + let y = args[0].as_primitive::(); + let x = args[1].as_primitive::(); + let result = + arrow::compute::binary::<_, _, _, Float32Type>( + y, + x, + |y, x| f32::$BINARY_FUNC(y, x), + )?; + Arc::new(result) as _ + } + other => { + return internal_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ); + } + }; + + Ok(ColumnarValue::Array(arr)) } - }; - - Ok(ColumnarValue::Array(arr)) + } } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/math/abs.rs b/datafusion/functions/src/math/abs.rs index 081668f7669f6..1b5aaf7745a84 100644 --- a/datafusion/functions/src/math/abs.rs +++ b/datafusion/functions/src/math/abs.rs @@ -50,6 +50,7 @@ macro_rules! make_abs_function { }}; } +#[macro_export] macro_rules! make_try_abs_function { ($ARRAY_TYPE:ident) => {{ |input: &ArrayRef| { @@ -62,7 +63,8 @@ macro_rules! make_try_abs_function { x )) }) - })?; + }) + .and_then(|v| Ok(v.with_data_type(input.data_type().clone())))?; // maintain decimal's precision and scale Ok(Arc::new(res) as ArrayRef) } }}; diff --git a/datafusion/functions/src/math/ceil.rs b/datafusion/functions/src/math/ceil.rs index 501741002f968..5961b3cb27fed 100644 --- a/datafusion/functions/src/math/ceil.rs +++ b/datafusion/functions/src/math/ceil.rs @@ -95,8 +95,35 @@ impl ScalarUDFImpl for CeilFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let value = &args[0]; + let arg = &args.args[0]; + + // Scalar fast path for float types - avoid array conversion overhead entirely + if let ColumnarValue::Scalar(scalar) = arg { + match scalar { + ScalarValue::Float64(v) => { + return Ok(ColumnarValue::Scalar(ScalarValue::Float64( + v.map(f64::ceil), + ))); + } + ScalarValue::Float32(v) => { + return Ok(ColumnarValue::Scalar(ScalarValue::Float32( + v.map(f32::ceil), + ))); + } + ScalarValue::Null => { + return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))); + } + // For decimals: convert to array of size 1, process, then extract scalar + // This ensures we don't expand the array while reusing overflow validation + _ => {} + } + } + + // Track if input was a scalar to convert back at the end + let is_scalar = matches!(arg, ColumnarValue::Scalar(_)); + + // Array path (also handles decimal scalars converted to size-1 arrays) + let value = arg.to_array(args.number_rows)?; let result: ArrayRef = match value.data_type() { DataType::Float64 => Arc::new( @@ -114,7 +141,7 @@ impl ScalarUDFImpl for CeilFunc { } DataType::Decimal32(precision, scale) => { apply_decimal_op::( - value, + &value, *precision, *scale, self.name(), @@ -123,7 +150,7 @@ impl ScalarUDFImpl for CeilFunc { } DataType::Decimal64(precision, scale) => { apply_decimal_op::( - value, + &value, *precision, *scale, self.name(), @@ -132,7 +159,7 @@ impl ScalarUDFImpl for CeilFunc { } DataType::Decimal128(precision, scale) => { apply_decimal_op::( - value, + &value, *precision, *scale, self.name(), @@ -141,7 +168,7 @@ impl ScalarUDFImpl for CeilFunc { } DataType::Decimal256(precision, scale) => { apply_decimal_op::( - value, + &value, *precision, *scale, self.name(), @@ -156,7 +183,12 @@ impl ScalarUDFImpl for CeilFunc { } }; - Ok(ColumnarValue::Array(result)) + // If input was a scalar, convert result back to scalar + if is_scalar { + ScalarValue::try_from_array(&result, 0).map(ColumnarValue::Scalar) + } else { + Ok(ColumnarValue::Array(result)) + } } fn output_ordering(&self, input: &[ExprProperties]) -> Result { diff --git a/datafusion/functions/src/math/cot.rs b/datafusion/functions/src/math/cot.rs index a0d7b02b68e5a..1f67ef713833f 100644 --- a/datafusion/functions/src/math/cot.rs +++ b/datafusion/functions/src/math/cot.rs @@ -18,12 +18,12 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, AsArray}; +use arrow::array::AsArray; use arrow::datatypes::DataType::{Float32, Float64}; use arrow::datatypes::{DataType, Float32Type, Float64Type}; -use crate::utils::make_scalar_function; -use datafusion_common::{Result, exec_err}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, internal_err}; use datafusion_expr::{ColumnarValue, Documentation, ScalarFunctionArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -96,24 +96,47 @@ impl ScalarUDFImpl for CotFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(cot, vec![])(&args.args) - } -} + let return_field = args.return_field; + let [arg] = take_function_args(self.name(), args.args)?; + + match arg { + ColumnarValue::Scalar(scalar) => { + if scalar.is_null() { + return ColumnarValue::Scalar(ScalarValue::Null) + .cast_to(return_field.data_type(), None); + } -///cot SQL function -fn cot(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - Float64 => Ok(Arc::new( - args[0] - .as_primitive::() - .unary::<_, Float64Type>(|x: f64| compute_cot64(x)), - ) as ArrayRef), - Float32 => Ok(Arc::new( - args[0] - .as_primitive::() - .unary::<_, Float32Type>(|x: f32| compute_cot32(x)), - ) as ArrayRef), - other => exec_err!("Unsupported data type {other:?} for function cot"), + match scalar { + ScalarValue::Float64(Some(v)) => Ok(ColumnarValue::Scalar( + ScalarValue::Float64(Some(compute_cot64(v))), + )), + ScalarValue::Float32(Some(v)) => Ok(ColumnarValue::Scalar( + ScalarValue::Float32(Some(compute_cot32(v))), + )), + _ => { + internal_err!( + "Unexpected scalar type for cot: {:?}", + scalar.data_type() + ) + } + } + } + ColumnarValue::Array(array) => match array.data_type() { + Float64 => Ok(ColumnarValue::Array(Arc::new( + array + .as_primitive::() + .unary::<_, Float64Type>(compute_cot64), + ))), + Float32 => Ok(ColumnarValue::Array(Arc::new( + array + .as_primitive::() + .unary::<_, Float32Type>(compute_cot32), + ))), + other => { + internal_err!("Unexpected data type {other:?} for function cot") + } + }, + } } } @@ -129,54 +152,212 @@ fn compute_cot64(x: f64) -> f64 { #[cfg(test)] mod test { - use crate::math::cot::cot; + use std::sync::Arc; + use arrow::array::{ArrayRef, Float32Array, Float64Array}; + use arrow::datatypes::{DataType, Field}; + use datafusion_common::ScalarValue; use datafusion_common::cast::{as_float32_array, as_float64_array}; - use std::sync::Arc; + use datafusion_common::config::ConfigOptions; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; + + use crate::math::cot::CotFunc; #[test] fn test_cot_f32() { - let args: Vec = - vec![Arc::new(Float32Array::from(vec![12.1, 30.0, 90.0, -30.0]))]; - let result = cot(&args).expect("failed to initialize function cot"); - let floats = - as_float32_array(&result).expect("failed to initialize function cot"); - - let expected = Float32Array::from(vec![ - -1.986_460_4, - -0.156_119_96, - -0.501_202_8, - 0.156_119_96, - ]); - - let eps = 1e-6; - assert_eq!(floats.len(), 4); - assert!((floats.value(0) - expected.value(0)).abs() < eps); - assert!((floats.value(1) - expected.value(1)).abs() < eps); - assert!((floats.value(2) - expected.value(2)).abs() < eps); - assert!((floats.value(3) - expected.value(3)).abs() < eps); + let array = Arc::new(Float32Array::from(vec![12.1, 30.0, 90.0, -30.0])); + let arg_fields = vec![Field::new("a", DataType::Float32, false).into()]; + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)], + arg_fields, + number_rows: array.len(), + return_field: Field::new("f", DataType::Float32, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = CotFunc::new() + .invoke_with_args(args) + .expect("failed to initialize function cot"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float32_array(&arr) + .expect("failed to convert result to a Float32Array"); + + let expected = Float32Array::from(vec![ + -1.986_460_4, + -0.156_119_96, + -0.501_202_8, + 0.156_119_96, + ]); + + let eps = 1e-6; + assert_eq!(floats.len(), 4); + assert!((floats.value(0) - expected.value(0)).abs() < eps); + assert!((floats.value(1) - expected.value(1)).abs() < eps); + assert!((floats.value(2) - expected.value(2)).abs() < eps); + assert!((floats.value(3) - expected.value(3)).abs() < eps); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } } #[test] fn test_cot_f64() { - let args: Vec = - vec![Arc::new(Float64Array::from(vec![12.1, 30.0, 90.0, -30.0]))]; - let result = cot(&args).expect("failed to initialize function cot"); - let floats = - as_float64_array(&result).expect("failed to initialize function cot"); - - let expected = Float64Array::from(vec![ - -1.986_458_685_881_4, - -0.156_119_952_161_6, - -0.501_202_783_380_1, - 0.156_119_952_161_6, - ]); - - let eps = 1e-12; - assert_eq!(floats.len(), 4); - assert!((floats.value(0) - expected.value(0)).abs() < eps); - assert!((floats.value(1) - expected.value(1)).abs() < eps); - assert!((floats.value(2) - expected.value(2)).abs() < eps); - assert!((floats.value(3) - expected.value(3)).abs() < eps); + let array = Arc::new(Float64Array::from(vec![12.1, 30.0, 90.0, -30.0])); + let arg_fields = vec![Field::new("a", DataType::Float64, false).into()]; + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)], + arg_fields, + number_rows: array.len(), + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = CotFunc::new() + .invoke_with_args(args) + .expect("failed to initialize function cot"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + + let expected = Float64Array::from(vec![ + -1.986_458_685_881_4, + -0.156_119_952_161_6, + -0.501_202_783_380_1, + 0.156_119_952_161_6, + ]); + + let eps = 1e-12; + assert_eq!(floats.len(), 4); + assert!((floats.value(0) - expected.value(0)).abs() < eps); + assert!((floats.value(1) - expected.value(1)).abs() < eps); + assert!((floats.value(2) - expected.value(2)).abs() < eps); + assert!((floats.value(3) - expected.value(3)).abs() < eps); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_cot_scalar_f64() { + let arg_fields = vec![Field::new("a", DataType::Float64, false).into()]; + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(1.0)))], + arg_fields, + number_rows: 1, + return_field: Field::new("f", DataType::Float64, false).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = CotFunc::new() + .invoke_with_args(args) + .expect("cot scalar should succeed"); + + match result { + ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) => { + // cot(1.0) = 1/tan(1.0) ≈ 0.6420926159343306 + let expected = 1.0_f64 / 1.0_f64.tan(); + assert!((v - expected).abs() < 1e-12); + } + _ => panic!("Expected Float64 scalar"), + } + } + + #[test] + fn test_cot_scalar_f32() { + let arg_fields = vec![Field::new("a", DataType::Float32, false).into()]; + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(ScalarValue::Float32(Some(1.0)))], + arg_fields, + number_rows: 1, + return_field: Field::new("f", DataType::Float32, false).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = CotFunc::new() + .invoke_with_args(args) + .expect("cot scalar should succeed"); + + match result { + ColumnarValue::Scalar(ScalarValue::Float32(Some(v))) => { + let expected = 1.0_f32 / 1.0_f32.tan(); + assert!((v - expected).abs() < 1e-6); + } + _ => panic!("Expected Float32 scalar"), + } + } + + #[test] + fn test_cot_scalar_null() { + let arg_fields = vec![Field::new("a", DataType::Float64, true).into()]; + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(ScalarValue::Float64(None))], + arg_fields, + number_rows: 1, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = CotFunc::new() + .invoke_with_args(args) + .expect("cot null should succeed"); + + match result { + ColumnarValue::Scalar(scalar) => { + assert!(scalar.is_null()); + } + _ => panic!("Expected scalar result"), + } + } + + #[test] + fn test_cot_scalar_zero() { + let arg_fields = vec![Field::new("a", DataType::Float64, false).into()]; + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(0.0)))], + arg_fields, + number_rows: 1, + return_field: Field::new("f", DataType::Float64, false).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = CotFunc::new() + .invoke_with_args(args) + .expect("cot zero should succeed"); + + match result { + ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) => { + // cot(0) = 1/tan(0) = infinity + assert!(v.is_infinite()); + } + _ => panic!("Expected Float64 scalar"), + } + } + + #[test] + fn test_cot_scalar_pi() { + let arg_fields = vec![Field::new("a", DataType::Float64, false).into()]; + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(ScalarValue::Float64(Some( + std::f64::consts::PI, + )))], + arg_fields, + number_rows: 1, + return_field: Field::new("f", DataType::Float64, false).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = CotFunc::new() + .invoke_with_args(args) + .expect("cot pi should succeed"); + + match result { + ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) => { + // cot(PI) = 1/tan(PI) - very large negative number due to floating point + let expected = 1.0_f64 / std::f64::consts::PI.tan(); + assert!((v - expected).abs() < 1e-6); + } + _ => panic!("Expected Float64 scalar"), + } } } diff --git a/datafusion/functions/src/math/factorial.rs b/datafusion/functions/src/math/factorial.rs index ffe12466dc173..c1dd802140c04 100644 --- a/datafusion/functions/src/math/factorial.rs +++ b/datafusion/functions/src/math/factorial.rs @@ -22,8 +22,9 @@ use std::sync::Arc; use arrow::datatypes::DataType::Int64; use arrow::datatypes::{DataType, Int64Type}; -use crate::utils::make_scalar_function; -use datafusion_common::{Result, exec_err}; +use datafusion_common::{ + Result, ScalarValue, exec_err, internal_err, utils::take_function_args, +}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -81,7 +82,39 @@ impl ScalarUDFImpl for FactorialFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(factorial, vec![])(&args.args) + let [arg] = take_function_args(self.name(), args.args)?; + + match arg { + ColumnarValue::Scalar(scalar) => { + if scalar.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))); + } + + match scalar { + ScalarValue::Int64(Some(v)) => { + let result = compute_factorial(v)?; + Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(result)))) + } + _ => { + internal_err!( + "Unexpected data type {:?} for function factorial", + scalar.data_type() + ) + } + } + } + ColumnarValue::Array(array) => match array.data_type() { + Int64 => { + let result: Int64Array = array + .as_primitive::() + .try_unary(compute_factorial)?; + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } + other => { + internal_err!("Unexpected data type {other:?} for function factorial") + } + }, + } } fn documentation(&self) -> Option<&Documentation> { @@ -113,53 +146,12 @@ const FACTORIALS: [i64; 21] = [ 2432902008176640000, ]; // if return type changes, this constant needs to be updated accordingly -/// Factorial SQL function -fn factorial(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - Int64 => { - let result: Int64Array = - args[0].as_primitive::().try_unary(|a| { - if a < 0 { - Ok(1) - } else if a < FACTORIALS.len() as i64 { - Ok(FACTORIALS[a as usize]) - } else { - exec_err!("Overflow happened on FACTORIAL({a})") - } - })?; - Ok(Arc::new(result) as ArrayRef) - } - other => exec_err!("Unsupported data type {other:?} for function factorial."), - } -} - -#[cfg(test)] -mod test { - use super::*; - use datafusion_common::cast::as_int64_array; - - #[test] - fn test_factorial_i64() { - let args: Vec = vec![ - Arc::new(Int64Array::from(vec![0, 1, 2, 4, 20, -1])), // input - ]; - - let result = factorial(&args).expect("failed to initialize function factorial"); - let ints = - as_int64_array(&result).expect("failed to initialize function factorial"); - - let expected = Int64Array::from(vec![1, 1, 2, 24, 2432902008176640000, 1]); - - assert_eq!(ints, &expected); - } - - #[test] - fn test_overflow() { - let args: Vec = vec![ - Arc::new(Int64Array::from(vec![21])), // input - ]; - - let result = factorial(&args); - assert!(result.is_err()); +fn compute_factorial(n: i64) -> Result { + if n < 0 { + Ok(1) + } else if n < FACTORIALS.len() as i64 { + Ok(FACTORIALS[n as usize]) + } else { + exec_err!("Overflow happened on FACTORIAL({n})") } } diff --git a/datafusion/functions/src/math/floor.rs b/datafusion/functions/src/math/floor.rs index 221e58e1e7a7f..d4f25716ff7ee 100644 --- a/datafusion/functions/src/math/floor.rs +++ b/datafusion/functions/src/math/floor.rs @@ -19,18 +19,22 @@ use std::any::Any; use std::sync::Arc; use arrow::array::{ArrayRef, AsArray}; +use arrow::compute::{DecimalCast, rescale_decimal}; use arrow::datatypes::{ - DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float32Type, - Float64Type, + ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type, + Decimal256Type, DecimalType, Float32Type, Float64Type, }; use datafusion_common::{Result, ScalarValue, exec_err}; use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::preimage::PreimageResult; +use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ - Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, - TypeSignature, TypeSignatureClass, Volatility, + Coercion, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDFImpl, + Signature, TypeSignature, TypeSignatureClass, Volatility, }; use datafusion_macros::user_doc; +use num_traits::{CheckedAdd, Float, One}; use super::decimal::{apply_decimal_op, floor_decimal_value}; @@ -74,6 +78,42 @@ impl FloorFunc { } } +// ============ Macro for preimage bounds ============ +/// Generates the code to call the appropriate bounds function and wrap results. +macro_rules! preimage_bounds { + // Float types: call float_preimage_bounds and wrap in ScalarValue + (float: $variant:ident, $value:expr) => { + float_preimage_bounds($value).map(|(lo, hi)| { + ( + ScalarValue::$variant(Some(lo)), + ScalarValue::$variant(Some(hi)), + ) + }) + }; + + // Integer types: call int_preimage_bounds and wrap in ScalarValue + (int: $variant:ident, $value:expr) => { + int_preimage_bounds($value).map(|(lo, hi)| { + ( + ScalarValue::$variant(Some(lo)), + ScalarValue::$variant(Some(hi)), + ) + }) + }; + + // Decimal types: call decimal_preimage_bounds with precision/scale and wrap in ScalarValue + (decimal: $variant:ident, $decimal_type:ty, $value:expr, $precision:expr, $scale:expr) => { + decimal_preimage_bounds::<$decimal_type>($value, $precision, $scale).map( + |(lo, hi)| { + ( + ScalarValue::$variant(Some(lo), $precision, $scale), + ScalarValue::$variant(Some(hi), $precision, $scale), + ) + }, + ) + }; +} + impl ScalarUDFImpl for FloorFunc { fn as_any(&self) -> &dyn Any { self @@ -95,8 +135,35 @@ impl ScalarUDFImpl for FloorFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let value = &args[0]; + let arg = &args.args[0]; + + // Scalar fast path for float types - avoid array conversion overhead entirely + if let ColumnarValue::Scalar(scalar) = arg { + match scalar { + ScalarValue::Float64(v) => { + return Ok(ColumnarValue::Scalar(ScalarValue::Float64( + v.map(f64::floor), + ))); + } + ScalarValue::Float32(v) => { + return Ok(ColumnarValue::Scalar(ScalarValue::Float32( + v.map(f32::floor), + ))); + } + ScalarValue::Null => { + return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))); + } + // For decimals: convert to array of size 1, process, then extract scalar + // This ensures we don't expand the array while reusing overflow validation + _ => {} + } + } + + // Track if input was a scalar to convert back at the end + let is_scalar = matches!(arg, ColumnarValue::Scalar(_)); + + // Array path (also handles decimal scalars converted to size-1 arrays) + let value = arg.to_array(args.number_rows)?; let result: ArrayRef = match value.data_type() { DataType::Float64 => Arc::new( @@ -114,7 +181,7 @@ impl ScalarUDFImpl for FloorFunc { } DataType::Decimal32(precision, scale) => { apply_decimal_op::( - value, + &value, *precision, *scale, self.name(), @@ -123,7 +190,7 @@ impl ScalarUDFImpl for FloorFunc { } DataType::Decimal64(precision, scale) => { apply_decimal_op::( - value, + &value, *precision, *scale, self.name(), @@ -132,7 +199,7 @@ impl ScalarUDFImpl for FloorFunc { } DataType::Decimal128(precision, scale) => { apply_decimal_op::( - value, + &value, *precision, *scale, self.name(), @@ -141,7 +208,7 @@ impl ScalarUDFImpl for FloorFunc { } DataType::Decimal256(precision, scale) => { apply_decimal_op::( - value, + &value, *precision, *scale, self.name(), @@ -156,7 +223,12 @@ impl ScalarUDFImpl for FloorFunc { } }; - Ok(ColumnarValue::Array(result)) + // If input was a scalar, convert result back to scalar + if is_scalar { + ScalarValue::try_from_array(&result, 0).map(ColumnarValue::Scalar) + } else { + Ok(ColumnarValue::Array(result)) + } } fn output_ordering(&self, input: &[ExprProperties]) -> Result { @@ -168,7 +240,450 @@ impl ScalarUDFImpl for FloorFunc { Interval::make_unbounded(&data_type) } + /// Compute the preimage for floor function. + /// + /// For `floor(x) = N`, the preimage is `x >= N AND x < N + 1` + /// because floor(x) = N for all x in [N, N+1). + /// + /// This enables predicate pushdown optimizations, transforming: + /// `floor(col) = 100` into `col >= 100 AND col < 101` + fn preimage( + &self, + args: &[Expr], + lit_expr: &Expr, + _info: &SimplifyContext, + ) -> Result { + // floor takes exactly one argument and we do not expect to reach here with multiple arguments. + debug_assert!(args.len() == 1, "floor() takes exactly one argument"); + + let arg = args[0].clone(); + + // Extract the literal value being compared to + let Expr::Literal(lit_value, _) = lit_expr else { + return Ok(PreimageResult::None); + }; + + // Compute lower bound (N) and upper bound (N + 1) using helper functions + let Some((lower, upper)) = (match lit_value { + // Floating-point types + ScalarValue::Float64(Some(n)) => preimage_bounds!(float: Float64, *n), + ScalarValue::Float32(Some(n)) => preimage_bounds!(float: Float32, *n), + + // Integer types (not reachable from SQL/SLT: floor() only accepts Float64/Float32/Decimal, + // so the RHS literal is always coerced to one of those before preimage runs; kept for + // programmatic use and unit tests) + ScalarValue::Int8(Some(n)) => preimage_bounds!(int: Int8, *n), + ScalarValue::Int16(Some(n)) => preimage_bounds!(int: Int16, *n), + ScalarValue::Int32(Some(n)) => preimage_bounds!(int: Int32, *n), + ScalarValue::Int64(Some(n)) => preimage_bounds!(int: Int64, *n), + + // Decimal types + // DECIMAL(precision, scale) where precision ≤ 38 -> Decimal128(precision, scale) + // DECIMAL(precision, scale) where precision > 38 -> Decimal256(precision, scale) + // Decimal32 and Decimal64 are unreachable from SQL/SLT. + ScalarValue::Decimal32(Some(n), precision, scale) => { + preimage_bounds!(decimal: Decimal32, Decimal32Type, *n, *precision, *scale) + } + ScalarValue::Decimal64(Some(n), precision, scale) => { + preimage_bounds!(decimal: Decimal64, Decimal64Type, *n, *precision, *scale) + } + ScalarValue::Decimal128(Some(n), precision, scale) => { + preimage_bounds!(decimal: Decimal128, Decimal128Type, *n, *precision, *scale) + } + ScalarValue::Decimal256(Some(n), precision, scale) => { + preimage_bounds!(decimal: Decimal256, Decimal256Type, *n, *precision, *scale) + } + + // Unsupported types + _ => None, + }) else { + return Ok(PreimageResult::None); + }; + + Ok(PreimageResult::Range { + expr: arg, + interval: Box::new(Interval::try_new(lower, upper)?), + }) + } + fn documentation(&self) -> Option<&Documentation> { self.doc() } } + +// ============ Helper functions for preimage bounds ============ + +/// Compute preimage bounds for floor function on floating-point types. +/// For floor(x) = n, the preimage is [n, n+1). +/// Returns None if: +/// - The value is non-finite (infinity, NaN) +/// - The value is not an integer (floor always returns integers, so floor(x) = 1.3 has no solution) +/// - Adding 1 would lose precision at extreme values +fn float_preimage_bounds(n: F) -> Option<(F, F)> { + let one = F::one(); + // Check for non-finite values (infinity, NaN) + if !n.is_finite() { + return None; + } + // floor always returns an integer, so if n has a fractional part, there's no solution + if n.fract() != F::zero() { + return None; + } + // Check for precision loss at extreme values + if n + one <= n { + return None; + } + Some((n, n + one)) +} + +/// Compute preimage bounds for floor function on integer types. +/// For floor(x) = n, the preimage is [n, n+1). +/// Returns None if adding 1 would overflow. +fn int_preimage_bounds(n: I) -> Option<(I, I)> { + let upper = n.checked_add(&I::one())?; + Some((n, upper)) +} + +/// Compute preimage bounds for floor function on decimal types. +/// For floor(x) = n, the preimage is [n, n+1). +/// Returns None if: +/// - The value has a fractional part (floor always returns integers) +/// - Adding 1 would overflow +fn decimal_preimage_bounds( + value: D::Native, + precision: u8, + scale: i8, +) -> Option<(D::Native, D::Native)> +where + D::Native: DecimalCast + ArrowNativeTypeOp + std::ops::Rem, +{ + // Use rescale_decimal to compute "1" at target scale (avoids manual pow) + // Convert integer 1 (scale=0) to the target scale + let one_scaled: D::Native = rescale_decimal::( + D::Native::ONE, // value = 1 + 1, // input_precision = 1 + 0, // input_scale = 0 (integer) + precision, // output_precision + scale, // output_scale + )?; + + // floor always returns an integer, so if value has a fractional part, there's no solution + // Check: value % one_scaled != 0 means fractional part exists + if scale > 0 && value % one_scaled != D::Native::ZERO { + return None; + } + + // Compute upper bound using checked addition + // Before preimage stage, the internal i128/i256(value) is validated based on the precision and scale. + // MAX_DECIMAL128_FOR_EACH_PRECISION and MAX_DECIMAL256_FOR_EACH_PRECISION are used to validate the internal i128/i256. + // Any invalid i128/i256 will not reach here. + // Therefore, the add_checked will always succeed if tested via SQL/SLT path. + let upper = value.add_checked(one_scaled).ok()?; + + Some((value, upper)) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_buffer::i256; + use datafusion_expr::col; + + /// Helper to test valid preimage cases that should return a Range + fn assert_preimage_range( + input: ScalarValue, + expected_lower: ScalarValue, + expected_upper: ScalarValue, + ) { + let floor_func = FloorFunc::new(); + let args = vec![col("x")]; + let lit_expr = Expr::Literal(input.clone(), None); + let info = SimplifyContext::default(); + + let result = floor_func.preimage(&args, &lit_expr, &info).unwrap(); + + match result { + PreimageResult::Range { expr, interval } => { + assert_eq!(expr, col("x")); + assert_eq!(interval.lower().clone(), expected_lower); + assert_eq!(interval.upper().clone(), expected_upper); + } + PreimageResult::None => { + panic!("Expected Range, got None for input {input:?}") + } + } + } + + /// Helper to test cases that should return None + fn assert_preimage_none(input: ScalarValue) { + let floor_func = FloorFunc::new(); + let args = vec![col("x")]; + let lit_expr = Expr::Literal(input.clone(), None); + let info = SimplifyContext::default(); + + let result = floor_func.preimage(&args, &lit_expr, &info).unwrap(); + assert!( + matches!(result, PreimageResult::None), + "Expected None for input {input:?}" + ); + } + + #[test] + fn test_floor_preimage_valid_cases() { + // Float64 + assert_preimage_range( + ScalarValue::Float64(Some(100.0)), + ScalarValue::Float64(Some(100.0)), + ScalarValue::Float64(Some(101.0)), + ); + // Float32 + assert_preimage_range( + ScalarValue::Float32(Some(50.0)), + ScalarValue::Float32(Some(50.0)), + ScalarValue::Float32(Some(51.0)), + ); + // Int64 + assert_preimage_range( + ScalarValue::Int64(Some(42)), + ScalarValue::Int64(Some(42)), + ScalarValue::Int64(Some(43)), + ); + // Int32 + assert_preimage_range( + ScalarValue::Int32(Some(100)), + ScalarValue::Int32(Some(100)), + ScalarValue::Int32(Some(101)), + ); + // Negative values + assert_preimage_range( + ScalarValue::Float64(Some(-5.0)), + ScalarValue::Float64(Some(-5.0)), + ScalarValue::Float64(Some(-4.0)), + ); + // Zero + assert_preimage_range( + ScalarValue::Float64(Some(0.0)), + ScalarValue::Float64(Some(0.0)), + ScalarValue::Float64(Some(1.0)), + ); + } + + #[test] + fn test_floor_preimage_non_integer_float() { + // floor(x) = 1.3 has NO SOLUTION because floor always returns an integer + // Therefore preimage should return None for non-integer literals + assert_preimage_none(ScalarValue::Float64(Some(1.3))); + assert_preimage_none(ScalarValue::Float64(Some(-2.5))); + assert_preimage_none(ScalarValue::Float32(Some(3.7))); + } + + #[test] + fn test_floor_preimage_integer_overflow() { + // All integer types at MAX value should return None + assert_preimage_none(ScalarValue::Int64(Some(i64::MAX))); + assert_preimage_none(ScalarValue::Int32(Some(i32::MAX))); + assert_preimage_none(ScalarValue::Int16(Some(i16::MAX))); + assert_preimage_none(ScalarValue::Int8(Some(i8::MAX))); + } + + #[test] + fn test_floor_preimage_float_edge_cases() { + // Float64 edge cases + assert_preimage_none(ScalarValue::Float64(Some(f64::INFINITY))); + assert_preimage_none(ScalarValue::Float64(Some(f64::NEG_INFINITY))); + assert_preimage_none(ScalarValue::Float64(Some(f64::NAN))); + assert_preimage_none(ScalarValue::Float64(Some(f64::MAX))); // precision loss + + // Float32 edge cases + assert_preimage_none(ScalarValue::Float32(Some(f32::INFINITY))); + assert_preimage_none(ScalarValue::Float32(Some(f32::NEG_INFINITY))); + assert_preimage_none(ScalarValue::Float32(Some(f32::NAN))); + assert_preimage_none(ScalarValue::Float32(Some(f32::MAX))); // precision loss + } + + #[test] + fn test_floor_preimage_null_values() { + assert_preimage_none(ScalarValue::Float64(None)); + assert_preimage_none(ScalarValue::Float32(None)); + assert_preimage_none(ScalarValue::Int64(None)); + } + + // ============ Decimal32 Tests (mirrors float/int tests) ============ + + #[test] + fn test_floor_preimage_decimal_valid_cases() { + // ===== Decimal32 ===== + // Positive integer decimal: 100.00 (scale=2, so raw=10000) + // floor(x) = 100.00 -> x in [100.00, 101.00) + assert_preimage_range( + ScalarValue::Decimal32(Some(10000), 9, 2), + ScalarValue::Decimal32(Some(10000), 9, 2), // 100.00 + ScalarValue::Decimal32(Some(10100), 9, 2), // 101.00 + ); + + // Smaller positive: 50.00 + assert_preimage_range( + ScalarValue::Decimal32(Some(5000), 9, 2), + ScalarValue::Decimal32(Some(5000), 9, 2), // 50.00 + ScalarValue::Decimal32(Some(5100), 9, 2), // 51.00 + ); + + // Negative integer decimal: -5.00 + assert_preimage_range( + ScalarValue::Decimal32(Some(-500), 9, 2), + ScalarValue::Decimal32(Some(-500), 9, 2), // -5.00 + ScalarValue::Decimal32(Some(-400), 9, 2), // -4.00 + ); + + // Zero: 0.00 + assert_preimage_range( + ScalarValue::Decimal32(Some(0), 9, 2), + ScalarValue::Decimal32(Some(0), 9, 2), // 0.00 + ScalarValue::Decimal32(Some(100), 9, 2), // 1.00 + ); + + // Scale 0 (pure integer): 42 + assert_preimage_range( + ScalarValue::Decimal32(Some(42), 9, 0), + ScalarValue::Decimal32(Some(42), 9, 0), + ScalarValue::Decimal32(Some(43), 9, 0), + ); + + // ===== Decimal64 ===== + assert_preimage_range( + ScalarValue::Decimal64(Some(10000), 18, 2), + ScalarValue::Decimal64(Some(10000), 18, 2), // 100.00 + ScalarValue::Decimal64(Some(10100), 18, 2), // 101.00 + ); + + // Negative + assert_preimage_range( + ScalarValue::Decimal64(Some(-500), 18, 2), + ScalarValue::Decimal64(Some(-500), 18, 2), // -5.00 + ScalarValue::Decimal64(Some(-400), 18, 2), // -4.00 + ); + + // Zero + assert_preimage_range( + ScalarValue::Decimal64(Some(0), 18, 2), + ScalarValue::Decimal64(Some(0), 18, 2), + ScalarValue::Decimal64(Some(100), 18, 2), + ); + + // ===== Decimal128 ===== + assert_preimage_range( + ScalarValue::Decimal128(Some(10000), 38, 2), + ScalarValue::Decimal128(Some(10000), 38, 2), // 100.00 + ScalarValue::Decimal128(Some(10100), 38, 2), // 101.00 + ); + + // Negative + assert_preimage_range( + ScalarValue::Decimal128(Some(-500), 38, 2), + ScalarValue::Decimal128(Some(-500), 38, 2), // -5.00 + ScalarValue::Decimal128(Some(-400), 38, 2), // -4.00 + ); + + // Zero + assert_preimage_range( + ScalarValue::Decimal128(Some(0), 38, 2), + ScalarValue::Decimal128(Some(0), 38, 2), + ScalarValue::Decimal128(Some(100), 38, 2), + ); + + // ===== Decimal256 ===== + assert_preimage_range( + ScalarValue::Decimal256(Some(i256::from(10000)), 76, 2), + ScalarValue::Decimal256(Some(i256::from(10000)), 76, 2), // 100.00 + ScalarValue::Decimal256(Some(i256::from(10100)), 76, 2), // 101.00 + ); + + // Negative + assert_preimage_range( + ScalarValue::Decimal256(Some(i256::from(-500)), 76, 2), + ScalarValue::Decimal256(Some(i256::from(-500)), 76, 2), // -5.00 + ScalarValue::Decimal256(Some(i256::from(-400)), 76, 2), // -4.00 + ); + + // Zero + assert_preimage_range( + ScalarValue::Decimal256(Some(i256::ZERO), 76, 2), + ScalarValue::Decimal256(Some(i256::ZERO), 76, 2), + ScalarValue::Decimal256(Some(i256::from(100)), 76, 2), + ); + } + + #[test] + fn test_floor_preimage_decimal_non_integer() { + // floor(x) = 1.30 has NO SOLUTION because floor always returns an integer + // Therefore preimage should return None for non-integer decimals + + // Decimal32 + assert_preimage_none(ScalarValue::Decimal32(Some(130), 9, 2)); // 1.30 + assert_preimage_none(ScalarValue::Decimal32(Some(-250), 9, 2)); // -2.50 + assert_preimage_none(ScalarValue::Decimal32(Some(370), 9, 2)); // 3.70 + assert_preimage_none(ScalarValue::Decimal32(Some(1), 9, 2)); // 0.01 + + // Decimal64 + assert_preimage_none(ScalarValue::Decimal64(Some(130), 18, 2)); // 1.30 + assert_preimage_none(ScalarValue::Decimal64(Some(-250), 18, 2)); // -2.50 + + // Decimal128 + assert_preimage_none(ScalarValue::Decimal128(Some(130), 38, 2)); // 1.30 + assert_preimage_none(ScalarValue::Decimal128(Some(-250), 38, 2)); // -2.50 + + // Decimal256 + assert_preimage_none(ScalarValue::Decimal256(Some(i256::from(130)), 76, 2)); // 1.30 + assert_preimage_none(ScalarValue::Decimal256(Some(i256::from(-250)), 76, 2)); // -2.50 + + // Decimal32: i32::MAX - 50 + // This return None because the value is not an integer, not because it is out of range. + assert_preimage_none(ScalarValue::Decimal32(Some(i32::MAX - 50), 10, 2)); + + // Decimal64: i64::MAX - 50 + // This return None because the value is not an integer, not because it is out of range. + assert_preimage_none(ScalarValue::Decimal64(Some(i64::MAX - 50), 19, 2)); + } + + #[test] + fn test_floor_preimage_decimal_overflow() { + // Test near MAX where adding scale_factor would overflow + + // Decimal32: i32::MAX + assert_preimage_none(ScalarValue::Decimal32(Some(i32::MAX), 10, 0)); + + // Decimal64: i64::MAX + assert_preimage_none(ScalarValue::Decimal64(Some(i64::MAX), 19, 0)); + } + + #[test] + fn test_floor_preimage_decimal_edge_cases() { + // ===== Decimal32 ===== + // Large value that doesn't overflow + // Decimal(9,2) max value is 9,999,999.99 (stored as 999,999,999) + // Use a large value that fits Decimal(9,2) and is divisible by 100 + let safe_max_aligned_32 = 999_999_900; // 9,999,999.00 + assert_preimage_range( + ScalarValue::Decimal32(Some(safe_max_aligned_32), 9, 2), + ScalarValue::Decimal32(Some(safe_max_aligned_32), 9, 2), + ScalarValue::Decimal32(Some(safe_max_aligned_32 + 100), 9, 2), + ); + + // Negative edge: use a large negative value that fits Decimal(9,2) + // Decimal(9,2) min value is -9,999,999.99 (stored as -999,999,999) + let min_aligned_32 = -999_999_900; // -9,999,999.00 + assert_preimage_range( + ScalarValue::Decimal32(Some(min_aligned_32), 9, 2), + ScalarValue::Decimal32(Some(min_aligned_32), 9, 2), + ScalarValue::Decimal32(Some(min_aligned_32 + 100), 9, 2), + ); + } + + #[test] + fn test_floor_preimage_decimal_null() { + assert_preimage_none(ScalarValue::Decimal32(None, 9, 2)); + assert_preimage_none(ScalarValue::Decimal64(None, 18, 2)); + assert_preimage_none(ScalarValue::Decimal128(None, 38, 2)); + assert_preimage_none(ScalarValue::Decimal256(None, 76, 2)); + } +} diff --git a/datafusion/functions/src/math/iszero.rs b/datafusion/functions/src/math/iszero.rs index ba4afc5622eb3..aa93d797eb7b3 100644 --- a/datafusion/functions/src/math/iszero.rs +++ b/datafusion/functions/src/math/iszero.rs @@ -18,12 +18,19 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, ArrowNativeTypeOp, AsArray, BooleanArray}; -use arrow::datatypes::DataType::{Boolean, Float16, Float32, Float64}; -use arrow::datatypes::{DataType, Float16Type, Float32Type, Float64Type}; +use arrow::array::{ArrowNativeTypeOp, AsArray, BooleanArray}; +use arrow::datatypes::DataType::{ + Boolean, Decimal32, Decimal64, Decimal128, Decimal256, Float16, Float32, Float64, + Int8, Int16, Int32, Int64, Null, UInt8, UInt16, UInt32, UInt64, +}; +use arrow::datatypes::{ + DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float16Type, + Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type, + UInt16Type, UInt32Type, UInt64Type, +}; -use datafusion_common::types::NativeType; -use datafusion_common::{Result, ScalarValue, exec_err}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, internal_err}; use datafusion_expr::{Coercion, TypeSignatureClass}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, @@ -31,8 +38,6 @@ use datafusion_expr::{ }; use datafusion_macros::user_doc; -use crate::utils::make_scalar_function; - #[user_doc( doc_section(label = "Math Functions"), description = "Returns true if a given number is +0.0 or -0.0 otherwise returns false.", @@ -60,14 +65,10 @@ impl Default for IsZeroFunc { impl IsZeroFunc { pub fn new() -> Self { - // Accept any numeric type and coerce to float - let float = Coercion::new_implicit( - TypeSignatureClass::Float, - vec![TypeSignatureClass::Numeric], - NativeType::Float64, - ); + // Accept any numeric type (ints, uints, floats, decimals) without implicit casts. + let numeric = Coercion::new_exact(TypeSignatureClass::Numeric); Self { - signature: Signature::coercible(vec![float], Volatility::Immutable), + signature: Signature::coercible(vec![numeric], Volatility::Immutable), } } } @@ -90,79 +91,155 @@ impl ScalarUDFImpl for IsZeroFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - // Handle NULL input - if args.args[0].data_type().is_null() { - return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))); + let [arg] = take_function_args(self.name(), args.args)?; + + match arg { + ColumnarValue::Scalar(scalar) => { + if scalar.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))); + } + + match scalar { + ScalarValue::Float64(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0.0)))) + } + ScalarValue::Float32(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0.0)))) + } + ScalarValue::Float16(Some(v)) => Ok(ColumnarValue::Scalar( + ScalarValue::Boolean(Some(v.is_zero())), + )), + + ScalarValue::Int8(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::Int16(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::Int32(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::Int64(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::UInt8(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::UInt16(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::UInt32(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::UInt64(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + + ScalarValue::Decimal32(Some(v), ..) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::Decimal64(Some(v), ..) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::Decimal128(Some(v), ..) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::Decimal256(Some(v), ..) => Ok(ColumnarValue::Scalar( + ScalarValue::Boolean(Some(v.is_zero())), + )), + + _ => { + internal_err!( + "Unexpected scalar type for iszero: {:?}", + scalar.data_type() + ) + } + } + } + ColumnarValue::Array(array) => match array.data_type() { + Null => Ok(ColumnarValue::Array(Arc::new(BooleanArray::new_null( + array.len(), + )))), + + Float64 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0.0, + )))), + Float32 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0.0, + )))), + Float16 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x.is_zero(), + )))), + + Int8 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + Int16 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + Int32 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + Int64 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + UInt8 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + UInt16 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + UInt32 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + UInt64 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + + Decimal32(_, _) => { + Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))) + } + Decimal64(_, _) => { + Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))) + } + Decimal128(_, _) => { + Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))) + } + Decimal256(_, _) => { + Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x.is_zero(), + )))) + } + + other => { + internal_err!("Unexpected data type {other:?} for function iszero") + } + }, } - make_scalar_function(iszero, vec![])(&args.args) } fn documentation(&self) -> Option<&Documentation> { self.doc() } } - -/// Iszero SQL function -fn iszero(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - Float64 => Ok(Arc::new(BooleanArray::from_unary( - args[0].as_primitive::(), - |x| x == 0.0, - )) as ArrayRef), - - Float32 => Ok(Arc::new(BooleanArray::from_unary( - args[0].as_primitive::(), - |x| x == 0.0, - )) as ArrayRef), - - Float16 => Ok(Arc::new(BooleanArray::from_unary( - args[0].as_primitive::(), - |x| x.is_zero(), - )) as ArrayRef), - - other => exec_err!("Unsupported data type {other:?} for function iszero"), - } -} - -#[cfg(test)] -mod test { - use std::sync::Arc; - - use arrow::array::{ArrayRef, Float32Array, Float64Array}; - - use datafusion_common::cast::as_boolean_array; - - use crate::math::iszero::iszero; - - #[test] - fn test_iszero_f64() { - let args: Vec = - vec![Arc::new(Float64Array::from(vec![1.0, 0.0, 3.0, -0.0]))]; - - let result = iszero(&args).expect("failed to initialize function iszero"); - let booleans = - as_boolean_array(&result).expect("failed to initialize function iszero"); - - assert_eq!(booleans.len(), 4); - assert!(!booleans.value(0)); - assert!(booleans.value(1)); - assert!(!booleans.value(2)); - assert!(booleans.value(3)); - } - - #[test] - fn test_iszero_f32() { - let args: Vec = - vec![Arc::new(Float32Array::from(vec![1.0, 0.0, 3.0, -0.0]))]; - - let result = iszero(&args).expect("failed to initialize function iszero"); - let booleans = - as_boolean_array(&result).expect("failed to initialize function iszero"); - - assert_eq!(booleans.len(), 4); - assert!(!booleans.value(0)); - assert!(booleans.value(1)); - assert!(!booleans.value(2)); - assert!(booleans.value(3)); - } -} diff --git a/datafusion/functions/src/math/nans.rs b/datafusion/functions/src/math/nans.rs index 03f246c28be19..632eafe1e009a 100644 --- a/datafusion/functions/src/math/nans.rs +++ b/datafusion/functions/src/math/nans.rs @@ -17,13 +17,21 @@ //! Math function: `isnan()`. -use arrow::datatypes::{DataType, Float16Type, Float32Type, Float64Type}; -use datafusion_common::types::NativeType; -use datafusion_common::{Result, ScalarValue, exec_err}; -use datafusion_expr::{Coercion, ColumnarValue, ScalarFunctionArgs, TypeSignatureClass}; - use arrow::array::{ArrayRef, AsArray, BooleanArray}; -use datafusion_expr::{Documentation, ScalarUDFImpl, Signature, Volatility}; +use arrow::datatypes::DataType::{ + Decimal32, Decimal64, Decimal128, Decimal256, Float16, Float32, Float64, Int8, Int16, + Int32, Int64, Null, UInt8, UInt16, UInt32, UInt64, +}; +use arrow::datatypes::{ + DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float16Type, + Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type, + UInt16Type, UInt32Type, UInt64Type, +}; +use datafusion_common::{Result, ScalarValue, exec_err, utils::take_function_args}; +use datafusion_expr::{ + Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignatureClass, Volatility, +}; use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; @@ -55,14 +63,10 @@ impl Default for IsNanFunc { impl IsNanFunc { pub fn new() -> Self { - // Accept any numeric type and coerce to float - let float = Coercion::new_implicit( - TypeSignatureClass::Float, - vec![TypeSignatureClass::Numeric], - NativeType::Float64, - ); + // Accept any numeric type (ints, uints, floats, decimals) without implicit casts. + let numeric = Coercion::new_exact(TypeSignatureClass::Numeric); Self { - signature: Signature::coercible(vec![float], Volatility::Immutable), + signature: Signature::coercible(vec![numeric], Volatility::Immutable), } } } @@ -84,36 +88,123 @@ impl ScalarUDFImpl for IsNanFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - // Handle NULL input - if args.args[0].data_type().is_null() { - return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))); - } + let [arg] = take_function_args(self.name(), args.args)?; + + match arg { + ColumnarValue::Scalar(scalar) => { + if scalar.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))); + } + + let result = match scalar { + ScalarValue::Float64(Some(v)) => Some(v.is_nan()), + ScalarValue::Float32(Some(v)) => Some(v.is_nan()), + ScalarValue::Float16(Some(v)) => Some(v.is_nan()), - let args = ColumnarValue::values_to_arrays(&args.args)?; - - let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => Arc::new(BooleanArray::from_unary( - args[0].as_primitive::(), - f64::is_nan, - )) as ArrayRef, - - DataType::Float32 => Arc::new(BooleanArray::from_unary( - args[0].as_primitive::(), - f32::is_nan, - )) as ArrayRef, - - DataType::Float16 => Arc::new(BooleanArray::from_unary( - args[0].as_primitive::(), - |x| x.is_nan(), - )) as ArrayRef, - other => { - return exec_err!( - "Unsupported data type {other:?} for function {}", - self.name() - ); + // Non-float numeric inputs are never NaN + ScalarValue::Int8(_) + | ScalarValue::Int16(_) + | ScalarValue::Int32(_) + | ScalarValue::Int64(_) + | ScalarValue::UInt8(_) + | ScalarValue::UInt16(_) + | ScalarValue::UInt32(_) + | ScalarValue::UInt64(_) + | ScalarValue::Decimal32(_, _, _) + | ScalarValue::Decimal64(_, _, _) + | ScalarValue::Decimal128(_, _, _) + | ScalarValue::Decimal256(_, _, _) => Some(false), + + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ); + } + }; + + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(result))) } - }; - Ok(ColumnarValue::Array(arr)) + ColumnarValue::Array(array) => { + // NOTE: BooleanArray::from_unary preserves nulls. + let arr: ArrayRef = match array.data_type() { + Null => Arc::new(BooleanArray::new_null(array.len())) as ArrayRef, + + Float64 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + f64::is_nan, + )) as ArrayRef, + Float32 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + f32::is_nan, + )) as ArrayRef, + Float16 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x.is_nan(), + )) as ArrayRef, + + // Non-float numeric arrays are never NaN + Decimal32(_, _) => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + Decimal64(_, _) => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + Decimal128(_, _) => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + Decimal256(_, _) => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + + Int8 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + Int16 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + Int32 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + Int64 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + UInt8 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + UInt16 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + UInt32 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + UInt64 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ); + } + }; + + Ok(ColumnarValue::Array(arr)) + } + } } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/math/nanvl.rs b/datafusion/functions/src/math/nanvl.rs index 6daf476e250d3..2bdc3fbbc64ac 100644 --- a/datafusion/functions/src/math/nanvl.rs +++ b/datafusion/functions/src/math/nanvl.rs @@ -18,12 +18,10 @@ use std::any::Any; use std::sync::Arc; -use crate::utils::make_scalar_function; - use arrow::array::{ArrayRef, AsArray, Float16Array, Float32Array, Float64Array}; use arrow::datatypes::DataType::{Float16, Float32, Float64}; use arrow::datatypes::{DataType, Float16Type, Float32Type, Float64Type}; -use datafusion_common::{DataFusionError, Result, exec_err}; +use datafusion_common::{Result, ScalarValue, exec_err, utils::take_function_args}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, @@ -101,7 +99,24 @@ impl ScalarUDFImpl for NanvlFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(nanvl, vec![])(&args.args) + let [x, y] = take_function_args(self.name(), args.args)?; + + match (x, y) { + (ColumnarValue::Scalar(ScalarValue::Float16(Some(v))), y) if v.is_nan() => { + Ok(y) + } + (ColumnarValue::Scalar(ScalarValue::Float32(Some(v))), y) if v.is_nan() => { + Ok(y) + } + (ColumnarValue::Scalar(ScalarValue::Float64(Some(v))), y) if v.is_nan() => { + Ok(y) + } + (x @ ColumnarValue::Scalar(_), _) => Ok(x), + (x, y) => { + let args = ColumnarValue::values_to_arrays(&[x, y])?; + Ok(ColumnarValue::Array(nanvl(&args)?)) + } + } } fn documentation(&self) -> Option<&Documentation> { @@ -110,42 +125,49 @@ impl ScalarUDFImpl for NanvlFunc { } /// Nanvl SQL function +/// +/// - x is NaN -> output is y (which may itself be NULL) +/// - otherwise -> output is x (which may itself be NULL) fn nanvl(args: &[ArrayRef]) -> Result { match args[0].data_type() { Float64 => { - let compute_nanvl = |x: f64, y: f64| { - if x.is_nan() { y } else { x } - }; - - let x = args[0].as_primitive() as &Float64Array; - let y = args[1].as_primitive() as &Float64Array; - arrow::compute::binary::<_, _, _, Float64Type>(x, y, compute_nanvl) - .map(|res| Arc::new(res) as _) - .map_err(DataFusionError::from) + let x = args[0].as_primitive::(); + let y = args[1].as_primitive::(); + let result: Float64Array = x + .iter() + .zip(y.iter()) + .map(|(x_value, y_value)| match x_value { + Some(x_value) if x_value.is_nan() => y_value, + _ => x_value, + }) + .collect(); + Ok(Arc::new(result) as ArrayRef) } Float32 => { - let compute_nanvl = |x: f32, y: f32| { - if x.is_nan() { y } else { x } - }; - - let x = args[0].as_primitive() as &Float32Array; - let y = args[1].as_primitive() as &Float32Array; - arrow::compute::binary::<_, _, _, Float32Type>(x, y, compute_nanvl) - .map(|res| Arc::new(res) as _) - .map_err(DataFusionError::from) + let x = args[0].as_primitive::(); + let y = args[1].as_primitive::(); + let result: Float32Array = x + .iter() + .zip(y.iter()) + .map(|(x_value, y_value)| match x_value { + Some(x_value) if x_value.is_nan() => y_value, + _ => x_value, + }) + .collect(); + Ok(Arc::new(result) as ArrayRef) } Float16 => { - let compute_nanvl = - |x: ::Native, - y: ::Native| { - if x.is_nan() { y } else { x } - }; - - let x = args[0].as_primitive() as &Float16Array; - let y = args[1].as_primitive() as &Float16Array; - arrow::compute::binary::<_, _, _, Float16Type>(x, y, compute_nanvl) - .map(|res| Arc::new(res) as _) - .map_err(DataFusionError::from) + let x = args[0].as_primitive::(); + let y = args[1].as_primitive::(); + let result: Float16Array = x + .iter() + .zip(y.iter()) + .map(|(x_value, y_value)| match x_value { + Some(x_value) if x_value.is_nan() => y_value, + _ => x_value, + }) + .collect(); + Ok(Arc::new(result) as ArrayRef) } other => exec_err!("Unsupported data type {other:?} for function nanvl"), } @@ -163,8 +185,8 @@ mod test { #[test] fn test_nanvl_f64() { let args: Vec = vec![ - Arc::new(Float64Array::from(vec![1.0, f64::NAN, 3.0, f64::NAN])), // y - Arc::new(Float64Array::from(vec![5.0, 6.0, f64::NAN, f64::NAN])), // x + Arc::new(Float64Array::from(vec![1.0, f64::NAN, 3.0, f64::NAN])), // x + Arc::new(Float64Array::from(vec![5.0, 6.0, f64::NAN, f64::NAN])), // y ]; let result = nanvl(&args).expect("failed to initialize function nanvl"); @@ -181,8 +203,8 @@ mod test { #[test] fn test_nanvl_f32() { let args: Vec = vec![ - Arc::new(Float32Array::from(vec![1.0, f32::NAN, 3.0, f32::NAN])), // y - Arc::new(Float32Array::from(vec![5.0, 6.0, f32::NAN, f32::NAN])), // x + Arc::new(Float32Array::from(vec![1.0, f32::NAN, 3.0, f32::NAN])), // x + Arc::new(Float32Array::from(vec![5.0, 6.0, f32::NAN, f32::NAN])), // y ]; let result = nanvl(&args).expect("failed to initialize function nanvl"); diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index de70788128b88..8c25c57740d5f 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -31,7 +31,7 @@ use arrow::error::ArrowError; use datafusion_common::types::{ NativeType, logical_float32, logical_float64, logical_int32, }; -use datafusion_common::{Result, ScalarValue, exec_err}; +use datafusion_common::{Result, ScalarValue, exec_err, internal_err}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, @@ -141,7 +141,67 @@ impl ScalarUDFImpl for RoundFunc { &default_decimal_places }; - round_columnar(&args.args[0], decimal_places, args.number_rows) + // Scalar fast path for float and decimal types - avoid array conversion overhead + if let (ColumnarValue::Scalar(value_scalar), ColumnarValue::Scalar(dp_scalar)) = + (&args.args[0], decimal_places) + { + if value_scalar.is_null() || dp_scalar.is_null() { + return ColumnarValue::Scalar(ScalarValue::Null) + .cast_to(args.return_type(), None); + } + + let dp = if let ScalarValue::Int32(Some(dp)) = dp_scalar { + *dp + } else { + return internal_err!( + "Unexpected datatype for decimal_places: {}", + dp_scalar.data_type() + ); + }; + + match value_scalar { + ScalarValue::Float32(Some(v)) => { + let rounded = round_float(*v, dp)?; + Ok(ColumnarValue::Scalar(ScalarValue::from(rounded))) + } + ScalarValue::Float64(Some(v)) => { + let rounded = round_float(*v, dp)?; + Ok(ColumnarValue::Scalar(ScalarValue::from(rounded))) + } + ScalarValue::Decimal128(Some(v), precision, scale) => { + let rounded = round_decimal(*v, *scale, dp)?; + let scalar = + ScalarValue::Decimal128(Some(rounded), *precision, *scale); + Ok(ColumnarValue::Scalar(scalar)) + } + ScalarValue::Decimal256(Some(v), precision, scale) => { + let rounded = round_decimal(*v, *scale, dp)?; + let scalar = + ScalarValue::Decimal256(Some(rounded), *precision, *scale); + Ok(ColumnarValue::Scalar(scalar)) + } + ScalarValue::Decimal64(Some(v), precision, scale) => { + let rounded = round_decimal(*v, *scale, dp)?; + let scalar = + ScalarValue::Decimal64(Some(rounded), *precision, *scale); + Ok(ColumnarValue::Scalar(scalar)) + } + ScalarValue::Decimal32(Some(v), precision, scale) => { + let rounded = round_decimal(*v, *scale, dp)?; + let scalar = + ScalarValue::Decimal32(Some(rounded), *precision, *scale); + Ok(ColumnarValue::Scalar(scalar)) + } + _ => { + internal_err!( + "Unexpected datatype for value: {}", + value_scalar.data_type() + ) + } + } + } else { + round_columnar(&args.args[0], decimal_places, args.number_rows) + } } fn output_ordering(&self, input: &[ExprProperties]) -> Result { diff --git a/datafusion/functions/src/math/signum.rs b/datafusion/functions/src/math/signum.rs index e217088c64c2e..8a3769a12f294 100644 --- a/datafusion/functions/src/math/signum.rs +++ b/datafusion/functions/src/math/signum.rs @@ -18,11 +18,12 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, AsArray}; +use arrow::array::AsArray; use arrow::datatypes::DataType::{Float32, Float64}; use arrow::datatypes::{DataType, Float32Type, Float64Type}; -use datafusion_common::{Result, exec_err}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, internal_err}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, @@ -30,8 +31,6 @@ use datafusion_expr::{ }; use datafusion_macros::user_doc; -use crate::utils::make_scalar_function; - #[user_doc( doc_section(label = "Math Functions"), description = r#"Returns the sign of a number. @@ -98,7 +97,53 @@ impl ScalarUDFImpl for SignumFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(signum, vec![])(&args.args) + let return_type = args.return_type().clone(); + let [arg] = take_function_args(self.name(), args.args)?; + + match arg { + ColumnarValue::Scalar(scalar) => { + if scalar.is_null() { + return ColumnarValue::Scalar(ScalarValue::Null) + .cast_to(&return_type, None); + } + + match scalar { + ScalarValue::Float64(Some(v)) => { + let result = if v == 0.0 { 0.0 } else { v.signum() }; + Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(result)))) + } + ScalarValue::Float32(Some(v)) => { + let result = if v == 0.0 { 0.0 } else { v.signum() }; + Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(result)))) + } + _ => { + internal_err!( + "Unexpected scalar type for signum: {:?}", + scalar.data_type() + ) + } + } + } + ColumnarValue::Array(array) => match array.data_type() { + Float64 => Ok(ColumnarValue::Array(Arc::new( + array.as_primitive::().unary::<_, Float64Type>( + |x: f64| { + if x == 0.0 { 0.0 } else { x.signum() } + }, + ), + ))), + Float32 => Ok(ColumnarValue::Array(Arc::new( + array.as_primitive::().unary::<_, Float32Type>( + |x: f32| { + if x == 0.0 { 0.0 } else { x.signum() } + }, + ), + ))), + other => { + internal_err!("Unsupported data type {other:?} for function signum") + } + }, + } } fn documentation(&self) -> Option<&Documentation> { @@ -106,33 +151,6 @@ impl ScalarUDFImpl for SignumFunc { } } -/// signum SQL function -fn signum(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - Float64 => Ok(Arc::new( - args[0] - .as_primitive::() - .unary::<_, Float64Type>( - |x: f64| { - if x == 0_f64 { 0_f64 } else { x.signum() } - }, - ), - ) as ArrayRef), - - Float32 => Ok(Arc::new( - args[0] - .as_primitive::() - .unary::<_, Float32Type>( - |x: f32| { - if x == 0_f32 { 0_f32 } else { x.signum() } - }, - ), - ) as ArrayRef), - - other => exec_err!("Unsupported data type {other:?} for function signum"), - } -} - #[cfg(test)] mod test { use std::sync::Arc; diff --git a/datafusion/functions/src/math/trunc.rs b/datafusion/functions/src/math/trunc.rs index 6727ba8fbdf08..ecdad22e8af11 100644 --- a/datafusion/functions/src/math/trunc.rs +++ b/datafusion/functions/src/math/trunc.rs @@ -24,7 +24,7 @@ use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; use arrow::datatypes::DataType::{Float32, Float64}; use arrow::datatypes::{DataType, Float32Type, Float64Type, Int64Type}; use datafusion_common::ScalarValue::Int64; -use datafusion_common::{Result, exec_err}; +use datafusion_common::{Result, ScalarValue, exec_err}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ @@ -110,7 +110,50 @@ impl ScalarUDFImpl for TruncFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(trunc, vec![])(&args.args) + // Extract precision from second argument (default 0) + let precision = match args.args.get(1) { + Some(ColumnarValue::Scalar(Int64(Some(p)))) => Some(*p), + Some(ColumnarValue::Scalar(Int64(None))) => None, // null precision + Some(ColumnarValue::Array(_)) => { + // Precision is an array - use array path + return make_scalar_function(trunc, vec![])(&args.args); + } + None => Some(0), // default precision + Some(cv) => { + return exec_err!( + "trunc function requires precision to be Int64, got {:?}", + cv.data_type() + ); + } + }; + + // Scalar fast path using tuple matching for (value, precision) + match (&args.args[0], precision) { + // Null cases + (ColumnarValue::Scalar(sv), _) if sv.is_null() => { + ColumnarValue::Scalar(ScalarValue::Null).cast_to(args.return_type(), None) + } + (_, None) => { + ColumnarValue::Scalar(ScalarValue::Null).cast_to(args.return_type(), None) + } + // Scalar cases + (ColumnarValue::Scalar(ScalarValue::Float64(Some(v))), Some(p)) => Ok( + ColumnarValue::Scalar(ScalarValue::Float64(Some(if p == 0 { + v.trunc() + } else { + compute_truncate64(*v, p) + }))), + ), + (ColumnarValue::Scalar(ScalarValue::Float32(Some(v))), Some(p)) => Ok( + ColumnarValue::Scalar(ScalarValue::Float32(Some(if p == 0 { + v.trunc() + } else { + compute_truncate32(*v, p) + }))), + ), + // Array path for everything else + _ => make_scalar_function(trunc, vec![])(&args.args), + } } fn output_ordering(&self, input: &[ExprProperties]) -> Result { @@ -202,12 +245,12 @@ fn trunc(args: &[ArrayRef]) -> Result { fn compute_truncate32(x: f32, y: i64) -> f32 { let factor = 10.0_f32.powi(y as i32); - (x * factor).round() / factor + (x * factor).trunc() / factor } fn compute_truncate64(x: f64, y: i64) -> f64 { let factor = 10.0_f64.powi(y as i32); - (x * factor).round() / factor + (x * factor).trunc() / factor } #[cfg(test)] @@ -238,9 +281,9 @@ mod test { assert_eq!(floats.len(), 5); assert_eq!(floats.value(0), 15.0); - assert_eq!(floats.value(1), 1_234.268); + assert_eq!(floats.value(1), 1_234.267); assert_eq!(floats.value(2), 1_233.12); - assert_eq!(floats.value(3), 3.312_98); + assert_eq!(floats.value(3), 3.312_97); assert_eq!(floats.value(4), -21.123_4); } @@ -263,9 +306,9 @@ mod test { assert_eq!(floats.len(), 5); assert_eq!(floats.value(0), 5.0); - assert_eq!(floats.value(1), 234.268); + assert_eq!(floats.value(1), 234.267); assert_eq!(floats.value(2), 123.12); - assert_eq!(floats.value(3), 123.312_98); + assert_eq!(floats.value(3), 123.312_97); assert_eq!(floats.value(4), -321.123_1); } diff --git a/datafusion/functions/src/regex/regexplike.rs b/datafusion/functions/src/regex/regexplike.rs index bc06d462c04bd..439a2dba06954 100644 --- a/datafusion/functions/src/regex/regexplike.rs +++ b/datafusion/functions/src/regex/regexplike.rs @@ -356,7 +356,7 @@ fn handle_regexp_like( .map_err(|e| arrow_datafusion_err!(e))? } (Utf8, LargeUtf8) => { - let value = values.as_string_view(); + let value = values.as_string::(); let pattern = patterns.as_string::(); regexp::regexp_is_match(value, pattern, flags) diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index fe3c508edea07..bfd035ed3c0db 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -15,12 +15,12 @@ // specific language governing permissions and limitations // under the License. -use crate::utils::make_scalar_function; use arrow::array::{ArrayRef, AsArray, Int32Array, StringArrayType}; use arrow::datatypes::DataType; use arrow::error::ArrowError; use datafusion_common::types::logical_string; -use datafusion_common::{Result, internal_err}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, internal_err}; use datafusion_expr::{ColumnarValue, Documentation, TypeSignatureClass}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use datafusion_expr_common::signature::Coercion; @@ -91,7 +91,31 @@ impl ScalarUDFImpl for AsciiFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(ascii, vec![])(&args.args) + let [arg] = take_function_args(self.name(), args.args)?; + + match arg { + ColumnarValue::Scalar(scalar) => { + if scalar.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Int32(None))); + } + + match scalar { + ScalarValue::Utf8(Some(s)) + | ScalarValue::LargeUtf8(Some(s)) + | ScalarValue::Utf8View(Some(s)) => { + let result = s.chars().next().map_or(0, |c| c as i32); + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(result)))) + } + _ => { + internal_err!( + "Unexpected data type {:?} for function ascii", + scalar.data_type() + ) + } + } + } + ColumnarValue::Array(array) => Ok(ColumnarValue::Array(ascii(&[array])?)), + } } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/string/chr.rs b/datafusion/functions/src/string/chr.rs index ba011b94367e3..2f432c838e010 100644 --- a/datafusion/functions/src/string/chr.rs +++ b/datafusion/functions/src/string/chr.rs @@ -18,24 +18,21 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::ArrayRef; -use arrow::array::GenericStringBuilder; +use arrow::array::{ArrayRef, GenericStringBuilder, Int64Array}; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Int64; use arrow::datatypes::DataType::Utf8; -use crate::utils::make_scalar_function; use datafusion_common::cast::as_int64_array; -use datafusion_common::{Result, exec_err}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, exec_err, internal_err}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; /// Returns the character with the given code. /// chr(65) = 'A' -fn chr(args: &[ArrayRef]) -> Result { - let integer_array = as_int64_array(&args[0])?; - +fn chr_array(integer_array: &Int64Array) -> Result { let mut builder = GenericStringBuilder::::with_capacity( integer_array.len(), // 1 byte per character, assuming that is the common case @@ -56,15 +53,11 @@ fn chr(args: &[ArrayRef]) -> Result { return exec_err!("invalid Unicode scalar value: {integer}"); } - None => { - builder.append_null(); - } + None => builder.append_null(), } } - let result = builder.finish(); - - Ok(Arc::new(result) as ArrayRef) + Ok(Arc::new(builder.finish()) as ArrayRef) } #[user_doc( @@ -119,7 +112,32 @@ impl ScalarUDFImpl for ChrFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(chr, vec![])(&args.args) + let [arg] = take_function_args(self.name(), args.args)?; + + match arg { + ColumnarValue::Scalar(ScalarValue::Int64(Some(code_point))) => { + if let Ok(u) = u32::try_from(code_point) + && let Some(c) = core::char::from_u32(u) + { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + c.to_string(), + )))) + } else { + exec_err!("invalid Unicode scalar value: {code_point}") + } + } + ColumnarValue::Scalar(ScalarValue::Int64(None)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) + } + ColumnarValue::Array(array) => { + let integer_array = as_int64_array(&array)?; + Ok(ColumnarValue::Array(chr_array(integer_array)?)) + } + other => internal_err!( + "Unexpected data type {:?} for function chr", + other.data_type() + ), + } } fn documentation(&self) -> Option<&Documentation> { @@ -130,13 +148,27 @@ impl ScalarUDFImpl for ChrFunc { #[cfg(test)] mod tests { use super::*; + use arrow::array::{Array, Int64Array, StringArray}; + use arrow::datatypes::Field; use datafusion_common::assert_contains; + use datafusion_common::config::ConfigOptions; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; + + fn invoke_chr(arg: ColumnarValue, number_rows: usize) -> Result { + ChrFunc::new().invoke_with_args(ScalarFunctionArgs { + args: vec![arg], + arg_fields: vec![Field::new("a", Int64, true).into()], + number_rows, + return_field: Field::new("f", Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }) + } #[test] fn test_chr_normal() { let input = Arc::new(Int64Array::from(vec![ - Some(0), // null + Some(0), // \u{0000} Some(65), // A Some(66), // B Some(67), // C @@ -149,8 +181,13 @@ mod tests { Some(9), // tab Some(0x10FFFF), // 0x10FFFF, the largest Unicode code point ])); - let result = chr(&[input]).unwrap(); - let string_array = result.as_any().downcast_ref::().unwrap(); + + let result = invoke_chr(ColumnarValue::Array(input), 12).unwrap(); + let ColumnarValue::Array(arr) = result else { + panic!("Expected array"); + }; + let string_array = arr.as_any().downcast_ref::().unwrap(); + let expected = [ "\u{0000}", "A", @@ -174,55 +211,48 @@ mod tests { #[test] fn test_chr_error() { - // invalid Unicode code points (too large) let input = Arc::new(Int64Array::from(vec![i64::MAX])); - let result = chr(&[input]); + let result = invoke_chr(ColumnarValue::Array(input), 1); assert!(result.is_err()); assert_contains!( result.err().unwrap().to_string(), "invalid Unicode scalar value: 9223372036854775807" ); - // invalid Unicode code points (too large) case 2 let input = Arc::new(Int64Array::from(vec![0x10FFFF + 1])); - let result = chr(&[input]); + let result = invoke_chr(ColumnarValue::Array(input), 1); assert!(result.is_err()); assert_contains!( result.err().unwrap().to_string(), "invalid Unicode scalar value: 1114112" ); - // invalid Unicode code points (surrogate code point) - // link: let input = Arc::new(Int64Array::from(vec![0xD800 + 1])); - let result = chr(&[input]); + let result = invoke_chr(ColumnarValue::Array(input), 1); assert!(result.is_err()); assert_contains!( result.err().unwrap().to_string(), "invalid Unicode scalar value: 55297" ); - // negative input - let input = Arc::new(Int64Array::from(vec![i64::MIN + 2i64])); // will be 2 if cast to u32 - let result = chr(&[input]); + let input = Arc::new(Int64Array::from(vec![i64::MIN + 2i64])); + let result = invoke_chr(ColumnarValue::Array(input), 1); assert!(result.is_err()); assert_contains!( result.err().unwrap().to_string(), "invalid Unicode scalar value: -9223372036854775806" ); - // negative input case 2 let input = Arc::new(Int64Array::from(vec![-1])); - let result = chr(&[input]); + let result = invoke_chr(ColumnarValue::Array(input), 1); assert!(result.is_err()); assert_contains!( result.err().unwrap().to_string(), "invalid Unicode scalar value: -1" ); - // one error with valid values after - let input = Arc::new(Int64Array::from(vec![65, -1, 66])); // A, -1, B - let result = chr(&[input]); + let input = Arc::new(Int64Array::from(vec![65, -1, 66])); + let result = invoke_chr(ColumnarValue::Array(input), 3); assert!(result.is_err()); assert_contains!( result.err().unwrap().to_string(), @@ -232,10 +262,36 @@ mod tests { #[test] fn test_chr_empty() { - // empty input array let input = Arc::new(Int64Array::from(Vec::::new())); - let result = chr(&[input]).unwrap(); - let string_array = result.as_any().downcast_ref::().unwrap(); + let result = invoke_chr(ColumnarValue::Array(input), 0).unwrap(); + let ColumnarValue::Array(arr) = result else { + panic!("Expected array"); + }; + let string_array = arr.as_any().downcast_ref::().unwrap(); assert_eq!(string_array.len(), 0); } + + #[test] + fn test_chr_scalar() { + let result = + invoke_chr(ColumnarValue::Scalar(ScalarValue::Int64(Some(65))), 1).unwrap(); + + match result { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { + assert_eq!(s, "A"); + } + other => panic!("Unexpected result: {other:?}"), + } + } + + #[test] + fn test_chr_scalar_null() { + let result = + invoke_chr(ColumnarValue::Scalar(ScalarValue::Int64(None)), 1).unwrap(); + + match result { + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {} + other => panic!("Unexpected result: {other:?}"), + } + } } diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 9e565342bafbc..e674541253288 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -120,13 +120,10 @@ impl ScalarUDFImpl for ConcatFunc { } }); - let array_len = args - .iter() - .filter_map(|x| match x { - ColumnarValue::Array(array) => Some(array.len()), - _ => None, - }) - .next(); + let array_len = args.iter().find_map(|x| match x { + ColumnarValue::Array(array) => Some(array.len()), + _ => None, + }); // Scalar if array_len.is_none() { @@ -207,7 +204,9 @@ impl ScalarUDFImpl for ConcatFunc { DataType::Utf8View => { let string_array = as_string_view_array(array)?; - data_size += string_array.len(); + // This is an estimate; in particular, it will + // undercount arrays of short strings (<= 12 bytes). + data_size += string_array.total_buffer_bytes_used(); let column = if array.is_nullable() { ColumnarValueRef::NullableStringViewArray(string_array) } else { diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index b08799f434aa6..9d3b32eedf8fd 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, StringArray, as_largestring_array}; +use arrow::array::Array; use std::any::Any; use std::sync::Arc; @@ -25,7 +25,9 @@ use crate::string::concat; use crate::string::concat::simplify_concat; use crate::string::concat_ws; use crate::strings::{ColumnarValueRef, StringArrayBuilder}; -use datafusion_common::cast::{as_string_array, as_string_view_array}; +use datafusion_common::cast::{ + as_large_string_array, as_string_array, as_string_view_array, +}; use datafusion_common::{Result, ScalarValue, exec_err, internal_err, plan_err}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; @@ -105,7 +107,6 @@ impl ScalarUDFImpl for ConcatWsFunc { fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let ScalarFunctionArgs { args, .. } = args; - // do not accept 0 arguments. if args.len() < 2 { return exec_err!( "concat_ws was called with {} arguments. It requires at least 2.", @@ -113,18 +114,14 @@ impl ScalarUDFImpl for ConcatWsFunc { ); } - let array_len = args - .iter() - .filter_map(|x| match x { - ColumnarValue::Array(array) => Some(array.len()), - _ => None, - }) - .next(); + let array_len = args.iter().find_map(|x| match x { + ColumnarValue::Array(array) => Some(array.len()), + _ => None, + }); // Scalar if array_len.is_none() { let ColumnarValue::Scalar(scalar) = &args[0] else { - // loop above checks for all args being scalar unreachable!() }; let sep = match scalar.try_as_str() { @@ -139,7 +136,6 @@ impl ScalarUDFImpl for ConcatWsFunc { let mut values = Vec::with_capacity(args.len() - 1); for arg in &args[1..] { let ColumnarValue::Scalar(scalar) = arg else { - // loop above checks for all args being scalar unreachable!() }; @@ -162,23 +158,53 @@ impl ScalarUDFImpl for ConcatWsFunc { // parse sep let sep = match &args[0] { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { - data_size += s.len() * len * (args.len() - 2); // estimate - ColumnarValueRef::Scalar(s.as_bytes()) - } - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { - return Ok(ColumnarValue::Array(Arc::new(StringArray::new_null(len)))); - } - ColumnarValue::Array(array) => { - let string_array = as_string_array(array)?; - data_size += string_array.values().len() * (args.len() - 2); // estimate - if array.is_nullable() { - ColumnarValueRef::NullableArray(string_array) - } else { - ColumnarValueRef::NonNullableArray(string_array) + ColumnarValue::Scalar(scalar) => match scalar.try_as_str() { + Some(Some(s)) => { + data_size += s.len() * len * (args.len() - 2); // estimate + ColumnarValueRef::Scalar(s.as_bytes()) } - } - _ => unreachable!("concat ws"), + Some(None) => { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } + None => { + return internal_err!("Expected string separator, got {scalar:?}"); + } + }, + ColumnarValue::Array(array) => match array.data_type() { + DataType::Utf8 => { + let string_array = as_string_array(array)?; + data_size += string_array.values().len() * (args.len() - 2); + if array.is_nullable() { + ColumnarValueRef::NullableArray(string_array) + } else { + ColumnarValueRef::NonNullableArray(string_array) + } + } + DataType::LargeUtf8 => { + let string_array = as_large_string_array(array)?; + data_size += string_array.values().len() * (args.len() - 2); + if array.is_nullable() { + ColumnarValueRef::NullableLargeStringArray(string_array) + } else { + ColumnarValueRef::NonNullableLargeStringArray(string_array) + } + } + DataType::Utf8View => { + let string_array = as_string_view_array(array)?; + data_size += + string_array.total_buffer_bytes_used() * (args.len() - 2); + if array.is_nullable() { + ColumnarValueRef::NullableStringViewArray(string_array) + } else { + ColumnarValueRef::NonNullableStringViewArray(string_array) + } + } + other => { + return plan_err!( + "Input was {other} which is not a supported datatype for concat_ws separator" + ); + } + }, }; let mut columns = Vec::with_capacity(args.len() - 1); @@ -206,7 +232,7 @@ impl ScalarUDFImpl for ConcatWsFunc { columns.push(column); } DataType::LargeUtf8 => { - let string_array = as_largestring_array(array); + let string_array = as_large_string_array(array)?; data_size += string_array.values().len(); let column = if array.is_nullable() { @@ -221,11 +247,9 @@ impl ScalarUDFImpl for ConcatWsFunc { DataType::Utf8View => { let string_array = as_string_view_array(array)?; - data_size += string_array - .data_buffers() - .iter() - .map(|buf| buf.len()) - .sum::(); + // This is an estimate; in particular, it will + // undercount arrays of short strings (<= 12 bytes). + data_size += string_array.total_buffer_bytes_used(); let column = if array.is_nullable() { ColumnarValueRef::NullableStringViewArray(string_array) } else { @@ -251,18 +275,14 @@ impl ScalarUDFImpl for ConcatWsFunc { continue; } - let mut iter = columns.iter(); - for column in iter.by_ref() { + let mut first = true; + for column in &columns { if column.is_valid(i) { + if !first { + builder.write::(&sep, i); + } builder.write::(column, i); - break; - } - } - - for column in iter { - if column.is_valid(i) { - builder.write::(&sep, i); - builder.write::(column, i); + first = false; } } @@ -546,4 +566,78 @@ mod tests { Ok(()) } + + #[test] + fn concat_ws_utf8view_scalar_separator() -> Result<()> { + let c0 = ColumnarValue::Scalar(ScalarValue::Utf8View(Some(",".to_string()))); + let c1 = + ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); + let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("x"), + None, + Some("z"), + ]))); + + let arg_fields = vec![ + Field::new("a", Utf8, true).into(), + Field::new("a", Utf8, true).into(), + Field::new("a", Utf8, true).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![c0, c1, c2], + arg_fields, + number_rows: 3, + return_field: Field::new("f", Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = ConcatWsFunc::new().invoke_with_args(args)?; + let expected = + Arc::new(StringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef; + match &result { + ColumnarValue::Array(array) => { + assert_eq!(&expected, array); + } + _ => panic!("Expected array result"), + } + + Ok(()) + } + + #[test] + fn concat_ws_largeutf8_scalar_separator() -> Result<()> { + let c0 = ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(",".to_string()))); + let c1 = + ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); + let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("x"), + None, + Some("z"), + ]))); + + let arg_fields = vec![ + Field::new("a", Utf8, true).into(), + Field::new("a", Utf8, true).into(), + Field::new("a", Utf8, true).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![c0, c1, c2], + arg_fields, + number_rows: 3, + return_field: Field::new("f", Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = ConcatWsFunc::new().invoke_with_args(args)?; + let expected = + Arc::new(StringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef; + match &result { + ColumnarValue::Array(array) => { + assert_eq!(&expected, array); + } + _ => panic!("Expected array result"), + } + + Ok(()) + } } diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index 2ca5e190c6e02..65f320c4f9f13 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -18,16 +18,17 @@ use std::any::Any; use std::sync::Arc; -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use crate::utils::utf8_to_str_type; use arrow::array::{ - ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array, + Array, ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array, OffsetSizeTrait, StringArrayType, StringViewArray, }; use arrow::datatypes::DataType; use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View}; use datafusion_common::cast::as_int64_array; use datafusion_common::types::{NativeType, logical_int64, logical_string}; -use datafusion_common::{DataFusionError, Result, exec_err}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{DataFusionError, Result, ScalarValue, exec_err, internal_err}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; @@ -99,7 +100,63 @@ impl ScalarUDFImpl for RepeatFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(repeat, vec![])(&args.args) + let return_type = args.return_field.data_type().clone(); + let [string_arg, count_arg] = take_function_args(self.name(), args.args)?; + + // Early return if either argument is a scalar null + if let ColumnarValue::Scalar(s) = &string_arg + && s.is_null() + { + return Ok(ColumnarValue::Scalar(ScalarValue::try_from(&return_type)?)); + } + if let ColumnarValue::Scalar(c) = &count_arg + && c.is_null() + { + return Ok(ColumnarValue::Scalar(ScalarValue::try_from(&return_type)?)); + } + + match (&string_arg, &count_arg) { + ( + ColumnarValue::Scalar(string_scalar), + ColumnarValue::Scalar(count_scalar), + ) => { + let count = match count_scalar { + ScalarValue::Int64(Some(n)) => *n, + _ => { + return internal_err!( + "Unexpected data type {:?} for repeat count", + count_scalar.data_type() + ); + } + }; + + let result = match string_scalar { + ScalarValue::Utf8(Some(s)) | ScalarValue::Utf8View(Some(s)) => { + ScalarValue::Utf8(Some(compute_repeat( + s, + count, + i32::MAX as usize, + )?)) + } + ScalarValue::LargeUtf8(Some(s)) => ScalarValue::LargeUtf8(Some( + compute_repeat(s, count, i64::MAX as usize)?, + )), + _ => { + return internal_err!( + "Unexpected data type {:?} for function repeat", + string_scalar.data_type() + ); + } + }; + + Ok(ColumnarValue::Scalar(result)) + } + _ => { + let string_array = string_arg.to_array(args.number_rows)?; + let count_array = count_arg.to_array(args.number_rows)?; + Ok(ColumnarValue::Array(repeat(&string_array, &count_array)?)) + } + } } fn documentation(&self) -> Option<&Documentation> { @@ -107,13 +164,30 @@ impl ScalarUDFImpl for RepeatFunc { } } +/// Computes repeat for a single string value with max size check +#[inline] +fn compute_repeat(s: &str, count: i64, max_size: usize) -> Result { + if count <= 0 { + return Ok(String::new()); + } + let result_len = s.len().saturating_mul(count as usize); + if result_len > max_size { + return exec_err!( + "string size overflow on repeat, max size is {}, but got {}", + max_size, + result_len + ); + } + Ok(s.repeat(count as usize)) +} + /// Repeats string the specified number of times. /// repeat('Pg', 4) = 'PgPgPgPg' -fn repeat(args: &[ArrayRef]) -> Result { - let number_array = as_int64_array(&args[1])?; - match args[0].data_type() { +fn repeat(string_array: &ArrayRef, count_array: &ArrayRef) -> Result { + let number_array = as_int64_array(count_array)?; + match string_array.data_type() { Utf8View => { - let string_view_array = args[0].as_string_view(); + let string_view_array = string_array.as_string_view(); repeat_impl::( &string_view_array, number_array, @@ -121,17 +195,17 @@ fn repeat(args: &[ArrayRef]) -> Result { ) } Utf8 => { - let string_array = args[0].as_string::(); + let string_arr = string_array.as_string::(); repeat_impl::>( - &string_array, + &string_arr, number_array, i32::MAX as usize, ) } LargeUtf8 => { - let string_array = args[0].as_string::(); + let string_arr = string_array.as_string::(); repeat_impl::>( - &string_array, + &string_arr, number_array, i64::MAX as usize, ) @@ -150,7 +224,7 @@ fn repeat_impl<'a, T, S>( ) -> Result where T: OffsetSizeTrait, - S: StringArrayType<'a>, + S: StringArrayType<'a> + 'a, { let mut total_capacity = 0; let mut max_item_capacity = 0; @@ -181,37 +255,55 @@ where // Reusable buffer to avoid allocations in string.repeat() let mut buffer = Vec::::with_capacity(max_item_capacity); - string_array - .iter() - .zip(number_array.iter()) - .for_each(|(string, number)| { + // Helper function to repeat a string into a buffer using doubling strategy + // count must be > 0 + #[inline] + fn repeat_to_buffer(buffer: &mut Vec, string: &str, count: usize) { + buffer.clear(); + if !string.is_empty() { + let src = string.as_bytes(); + // Initial copy + buffer.extend_from_slice(src); + // Doubling strategy: copy what we have so far until we reach the target + while buffer.len() < src.len() * count { + let copy_len = buffer.len().min(src.len() * count - buffer.len()); + // SAFETY: we're copying valid UTF-8 bytes that we already verified + buffer.extend_from_within(..copy_len); + } + } + } + + // Fast path: no nulls in either array + if string_array.null_count() == 0 && number_array.null_count() == 0 { + for i in 0..string_array.len() { + // SAFETY: i is within bounds (0..len) and null_count() == 0 guarantees valid value + let string = unsafe { string_array.value_unchecked(i) }; + let count = number_array.value(i); + if count > 0 { + repeat_to_buffer(&mut buffer, string, count as usize); + // SAFETY: buffer contains valid UTF-8 since we only copy from a valid &str + builder.append_value(unsafe { std::str::from_utf8_unchecked(&buffer) }); + } else { + builder.append_value(""); + } + } + } else { + // Slow path: handle nulls + for (string, number) in string_array.iter().zip(number_array.iter()) { match (string, number) { - (Some(string), Some(number)) if number >= 0 => { - buffer.clear(); - let count = number as usize; - if count > 0 && !string.is_empty() { - let src = string.as_bytes(); - // Initial copy - buffer.extend_from_slice(src); - // Doubling strategy: copy what we have so far until we reach the target - while buffer.len() < src.len() * count { - let copy_len = - buffer.len().min(src.len() * count - buffer.len()); - // SAFETY: we're copying valid UTF-8 bytes that we already verified - buffer.extend_from_within(..copy_len); - } - } - // SAFETY: buffer contains valid UTF-8 since we only ever copy from a valid &str + (Some(string), Some(count)) if count > 0 => { + repeat_to_buffer(&mut buffer, string, count as usize); + // SAFETY: buffer contains valid UTF-8 since we only copy from a valid &str builder .append_value(unsafe { std::str::from_utf8_unchecked(&buffer) }); } (Some(_), Some(_)) => builder.append_value(""), _ => builder.append_null(), } - }); - let array = builder.finish(); + } + } - Ok(Arc::new(array) as ArrayRef) + Ok(Arc::new(builder.finish()) as ArrayRef) } #[cfg(test)] diff --git a/datafusion/functions/src/string/replace.rs b/datafusion/functions/src/string/replace.rs index 165e0634a6b80..458b86d0c6fb0 100644 --- a/datafusion/functions/src/string/replace.rs +++ b/datafusion/functions/src/string/replace.rs @@ -228,19 +228,21 @@ fn replace_into_string(buffer: &mut String, string: &str, from: &str, to: &str) return; } - // Fast path for replacing a single ASCII character with another single ASCII character - // This matches Rust's str::replace() optimization and enables vectorization + // Fast path for replacing a single ASCII character with another single ASCII character. + // Extends the buffer's underlying Vec directly, for performance. if let ([from_byte], [to_byte]) = (from.as_bytes(), to.as_bytes()) && from_byte.is_ascii() && to_byte.is_ascii() { - // SAFETY: We're replacing ASCII with ASCII, which preserves UTF-8 validity - let replaced: Vec = string - .as_bytes() - .iter() - .map(|b| if *b == *from_byte { *to_byte } else { *b }) - .collect(); - buffer.push_str(unsafe { std::str::from_utf8_unchecked(&replaced) }); + // SAFETY: Replacing an ASCII byte with another ASCII byte preserves UTF-8 validity. + unsafe { + buffer.as_mut_vec().extend( + string + .as_bytes() + .iter() + .map(|&b| if b == *from_byte { *to_byte } else { b }), + ); + } return; } diff --git a/datafusion/functions/src/string/split_part.rs b/datafusion/functions/src/string/split_part.rs index 74bf7c16c43a1..e24dbd63d147d 100644 --- a/datafusion/functions/src/string/split_part.rs +++ b/datafusion/functions/src/string/split_part.rs @@ -48,7 +48,10 @@ use std::sync::Arc; ```"#, standard_argument(name = "str", prefix = "String"), argument(name = "delimiter", description = "String or character to split on."), - argument(name = "pos", description = "Position of the part to return.") + argument( + name = "pos", + description = "Position of the part to return (counting from 1). Negative values count backward from the end of the string." + ) )] #[derive(Debug, PartialEq, Eq, Hash)] pub struct SplitPartFunc { @@ -233,7 +236,7 @@ where std::cmp::Ordering::Less => { // Negative index: use rsplit().nth() to efficiently get from the end // rsplit iterates in reverse, so -1 means first from rsplit (index 0) - let idx: usize = (-n - 1).try_into().map_err(|_| { + let idx: usize = (n.unsigned_abs() - 1).try_into().map_err(|_| { exec_datafusion_err!( "split_part index {n} exceeds minimum supported value" ) @@ -324,6 +327,20 @@ mod tests { Utf8, StringArray ); + test_function!( + SplitPartFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( + "abc~@~def~@~ghi" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(i64::MIN))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); Ok(()) } diff --git a/datafusion/functions/src/string/to_hex.rs b/datafusion/functions/src/string/to_hex.rs index 891cbe2549579..ed8ce07b876d5 100644 --- a/datafusion/functions/src/string/to_hex.rs +++ b/datafusion/functions/src/string/to_hex.rs @@ -18,7 +18,6 @@ use std::any::Any; use std::sync::Arc; -use crate::utils::make_scalar_function; use arrow::array::{Array, ArrayRef, StringArray}; use arrow::buffer::{Buffer, OffsetBuffer}; use arrow::datatypes::{ @@ -26,7 +25,7 @@ use arrow::datatypes::{ Int64Type, UInt8Type, UInt16Type, UInt32Type, UInt64Type, }; use datafusion_common::cast::as_primitive_array; -use datafusion_common::{Result, ScalarValue, exec_err}; +use datafusion_common::{Result, ScalarValue, exec_err, internal_err}; use datafusion_expr::{ Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, @@ -38,11 +37,11 @@ const HEX_CHARS: &[u8; 16] = b"0123456789abcdef"; /// Converts the number to its equivalent hexadecimal representation. /// to_hex(2147483647) = '7fffffff' -fn to_hex(args: &[ArrayRef]) -> Result +fn to_hex_array(array: &ArrayRef) -> Result where T::Native: ToHex, { - let integer_array = as_primitive_array::(&args[0])?; + let integer_array = as_primitive_array::(array)?; let len = integer_array.len(); // Max hex string length: 16 chars for u64/i64 @@ -78,6 +77,14 @@ where Ok(Arc::new(result) as ArrayRef) } +#[inline] +fn to_hex_scalar(value: T) -> String { + let mut hex_buffer = [0u8; 16]; + let hex_len = value.write_hex_to_buffer(&mut hex_buffer); + // SAFETY: hex_buffer is ASCII hex digits + unsafe { std::str::from_utf8_unchecked(&hex_buffer[16 - hex_len..]).to_string() } +} + /// Trait for converting integer types to hexadecimal in a buffer trait ToHex: ArrowNativeType { /// Write hex representation to buffer and return the number of hex digits written. @@ -223,33 +230,71 @@ impl ScalarUDFImpl for ToHexFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - match args.args[0].data_type() { - DataType::Null => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), - DataType::Int64 => { - make_scalar_function(to_hex::, vec![])(&args.args) - } - DataType::UInt64 => { - make_scalar_function(to_hex::, vec![])(&args.args) - } - DataType::Int32 => { - make_scalar_function(to_hex::, vec![])(&args.args) - } - DataType::UInt32 => { - make_scalar_function(to_hex::, vec![])(&args.args) - } - DataType::Int16 => { - make_scalar_function(to_hex::, vec![])(&args.args) - } - DataType::UInt16 => { - make_scalar_function(to_hex::, vec![])(&args.args) - } - DataType::Int8 => { - make_scalar_function(to_hex::, vec![])(&args.args) - } - DataType::UInt8 => { - make_scalar_function(to_hex::, vec![])(&args.args) + let arg = &args.args[0]; + + match arg { + ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) => Ok( + ColumnarValue::Scalar(ScalarValue::Utf8(Some(to_hex_scalar(*v)))), + ), + ColumnarValue::Scalar(ScalarValue::UInt64(Some(v))) => Ok( + ColumnarValue::Scalar(ScalarValue::Utf8(Some(to_hex_scalar(*v)))), + ), + ColumnarValue::Scalar(ScalarValue::Int32(Some(v))) => Ok( + ColumnarValue::Scalar(ScalarValue::Utf8(Some(to_hex_scalar(*v)))), + ), + ColumnarValue::Scalar(ScalarValue::UInt32(Some(v))) => Ok( + ColumnarValue::Scalar(ScalarValue::Utf8(Some(to_hex_scalar(*v)))), + ), + ColumnarValue::Scalar(ScalarValue::Int16(Some(v))) => Ok( + ColumnarValue::Scalar(ScalarValue::Utf8(Some(to_hex_scalar(*v)))), + ), + ColumnarValue::Scalar(ScalarValue::UInt16(Some(v))) => Ok( + ColumnarValue::Scalar(ScalarValue::Utf8(Some(to_hex_scalar(*v)))), + ), + ColumnarValue::Scalar(ScalarValue::Int8(Some(v))) => Ok( + ColumnarValue::Scalar(ScalarValue::Utf8(Some(to_hex_scalar(*v)))), + ), + ColumnarValue::Scalar(ScalarValue::UInt8(Some(v))) => Ok( + ColumnarValue::Scalar(ScalarValue::Utf8(Some(to_hex_scalar(*v)))), + ), + + // NULL scalars + ColumnarValue::Scalar(s) if s.is_null() => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) } - other => exec_err!("Unsupported data type {other:?} for function to_hex"), + + ColumnarValue::Array(array) => match array.data_type() { + DataType::Int64 => { + Ok(ColumnarValue::Array(to_hex_array::(array)?)) + } + DataType::UInt64 => { + Ok(ColumnarValue::Array(to_hex_array::(array)?)) + } + DataType::Int32 => { + Ok(ColumnarValue::Array(to_hex_array::(array)?)) + } + DataType::UInt32 => { + Ok(ColumnarValue::Array(to_hex_array::(array)?)) + } + DataType::Int16 => { + Ok(ColumnarValue::Array(to_hex_array::(array)?)) + } + DataType::UInt16 => { + Ok(ColumnarValue::Array(to_hex_array::(array)?)) + } + DataType::Int8 => { + Ok(ColumnarValue::Array(to_hex_array::(array)?)) + } + DataType::UInt8 => { + Ok(ColumnarValue::Array(to_hex_array::(array)?)) + } + other => exec_err!("Unsupported data type {other:?} for function to_hex"), + }, + + other => internal_err!( + "Unexpected argument type {:?} for function to_hex", + other.data_type() + ), } } @@ -288,8 +333,8 @@ mod tests { let expected = $expected; let array = <$array_type>::from(input); - let array_ref = Arc::new(array); - let hex_result = to_hex::<$arrow_type>(&[array_ref])?; + let array_ref: ArrayRef = Arc::new(array); + let hex_result = to_hex_array::<$arrow_type>(&array_ref)?; let hex_array = as_string_array(&hex_result)?; let expected_array = StringArray::from(expected); diff --git a/datafusion/functions/src/strings.rs b/datafusion/functions/src/strings.rs index a7be3ef792994..cfddf57b094b5 100644 --- a/datafusion/functions/src/strings.rs +++ b/datafusion/functions/src/strings.rs @@ -152,43 +152,34 @@ impl StringViewArrayBuilder { } ColumnarValueRef::NullableArray(array) => { if !CHECK_VALID || array.is_valid(i) { - self.block.push_str( - std::str::from_utf8(array.value(i).as_bytes()).unwrap(), - ); + self.block.push_str(array.value(i)); } } ColumnarValueRef::NullableLargeStringArray(array) => { if !CHECK_VALID || array.is_valid(i) { - self.block.push_str( - std::str::from_utf8(array.value(i).as_bytes()).unwrap(), - ); + self.block.push_str(array.value(i)); } } ColumnarValueRef::NullableStringViewArray(array) => { if !CHECK_VALID || array.is_valid(i) { - self.block.push_str( - std::str::from_utf8(array.value(i).as_bytes()).unwrap(), - ); + self.block.push_str(array.value(i)); } } ColumnarValueRef::NonNullableArray(array) => { - self.block - .push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap()); + self.block.push_str(array.value(i)); } ColumnarValueRef::NonNullableLargeStringArray(array) => { - self.block - .push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap()); + self.block.push_str(array.value(i)); } ColumnarValueRef::NonNullableStringViewArray(array) => { - self.block - .push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap()); + self.block.push_str(array.value(i)); } } } pub fn append_offset(&mut self) { self.builder.append_value(&self.block); - self.block = String::new(); + self.block.clear(); } pub fn finish(mut self) -> StringViewArray { diff --git a/datafusion/functions/src/unicode/common.rs b/datafusion/functions/src/unicode/common.rs new file mode 100644 index 0000000000000..93f0c7900961e --- /dev/null +++ b/datafusion/functions/src/unicode/common.rs @@ -0,0 +1,183 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Common utilities for implementing unicode functions + +use arrow::array::{ + Array, ArrayAccessor, ArrayIter, ArrayRef, ByteView, GenericStringArray, Int64Array, + OffsetSizeTrait, StringViewArray, make_view, +}; +use arrow::datatypes::DataType; +use arrow_buffer::{NullBuffer, ScalarBuffer}; +use datafusion_common::cast::{ + as_generic_string_array, as_int64_array, as_string_view_array, +}; +use datafusion_common::exec_err; +use std::cmp::Ordering; +use std::ops::Range; +use std::sync::Arc; + +/// A trait for `left` and `right` byte slicing operations +pub(crate) trait LeftRightSlicer { + fn slice(string: &str, n: i64) -> Range; +} + +pub(crate) struct LeftSlicer {} + +impl LeftRightSlicer for LeftSlicer { + fn slice(string: &str, n: i64) -> Range { + 0..left_right_byte_length(string, n) + } +} + +pub(crate) struct RightSlicer {} + +impl LeftRightSlicer for RightSlicer { + fn slice(string: &str, n: i64) -> Range { + if n == 0 { + // Return nothing for `n=0` + 0..0 + } else if n == i64::MIN { + // Special case for i64::MIN overflow + 0..0 + } else { + left_right_byte_length(string, -n)..string.len() + } + } +} + +/// Calculate the byte length of the substring of `n` chars from string `string` +#[inline] +fn left_right_byte_length(string: &str, n: i64) -> usize { + match n.cmp(&0) { + Ordering::Less => string + .char_indices() + .nth_back((n.unsigned_abs().min(usize::MAX as u64) - 1) as usize) + .map(|(index, _)| index) + .unwrap_or(0), + Ordering::Equal => 0, + Ordering::Greater => string + .char_indices() + .nth(n.unsigned_abs().min(usize::MAX as u64) as usize) + .map(|(index, _)| index) + .unwrap_or(string.len()), + } +} + +/// General implementation for `left` and `right` functions +pub(crate) fn general_left_right( + args: &[ArrayRef], +) -> datafusion_common::Result { + let n_array = as_int64_array(&args[1])?; + + match args[0].data_type() { + DataType::Utf8 => { + let string_array = as_generic_string_array::(&args[0])?; + general_left_right_array::(string_array, n_array) + } + DataType::LargeUtf8 => { + let string_array = as_generic_string_array::(&args[0])?; + general_left_right_array::(string_array, n_array) + } + DataType::Utf8View => { + let string_view_array = as_string_view_array(&args[0])?; + general_left_right_view::(string_view_array, n_array) + } + _ => exec_err!("Not supported"), + } +} + +/// `general_left_right` implementation for strings +fn general_left_right_array< + 'a, + T: OffsetSizeTrait, + V: ArrayAccessor, + F: LeftRightSlicer, +>( + string_array: V, + n_array: &Int64Array, +) -> datafusion_common::Result { + let iter = ArrayIter::new(string_array); + let result = iter + .zip(n_array.iter()) + .map(|(string, n)| match (string, n) { + (Some(string), Some(n)) => { + let range = F::slice(string, n); + // Extract a given range from a byte-indexed slice + Some(&string[range]) + } + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +/// `general_left_right` implementation for StringViewArray +fn general_left_right_view( + string_view_array: &StringViewArray, + n_array: &Int64Array, +) -> datafusion_common::Result { + let len = n_array.len(); + + let views = string_view_array.views(); + // Every string in StringViewArray has one corresponding view in `views` + debug_assert!(views.len() == string_view_array.len()); + + // Compose null buffer at once + let string_nulls = string_view_array.nulls(); + let n_nulls = n_array.nulls(); + let new_nulls = NullBuffer::union(string_nulls, n_nulls); + + let new_views = (0..len) + .map(|idx| { + let view = views[idx]; + + let is_valid = match &new_nulls { + Some(nulls_buf) => nulls_buf.is_valid(idx), + None => true, + }; + + if is_valid { + let string: &str = string_view_array.value(idx); + let n = n_array.value(idx); + + // Input string comes from StringViewArray, so it should fit in 32-bit length + let range = F::slice(string, n); + let result_bytes = &string.as_bytes()[range.clone()]; + + let byte_view = ByteView::from(view); + // New offset starts at 0 for left, and at `range.start` for right, + // which is encoded in the given range + let new_offset = byte_view.offset + (range.start as u32); + // Reuse buffer + make_view(result_bytes, byte_view.buffer_index, new_offset) + } else { + // For nulls, keep the original view + view + } + }) + .collect::>(); + + // Buffers are unchanged + let result = StringViewArray::try_new( + ScalarBuffer::from(new_views), + Vec::from(string_view_array.data_buffers()), + new_nulls, + )?; + Ok(Arc::new(result) as ArrayRef) +} diff --git a/datafusion/functions/src/unicode/find_in_set.rs b/datafusion/functions/src/unicode/find_in_set.rs index a25c37266c2ca..0feb637924264 100644 --- a/datafusion/functions/src/unicode/find_in_set.rs +++ b/datafusion/functions/src/unicode/find_in_set.rs @@ -98,9 +98,8 @@ impl ScalarUDFImpl for FindInSetFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let ScalarFunctionArgs { args, .. } = args; - - let [string, str_list] = take_function_args(self.name(), args)?; + let return_field = args.return_field; + let [string, str_list] = take_function_args(self.name(), args.args)?; match (string, str_list) { // both inputs are scalars @@ -141,7 +140,7 @@ impl ScalarUDFImpl for FindInSetFunc { ) => { let result_array = match str_list_literal { // find_in_set(column_a, null) = null - None => new_null_array(str_array.data_type(), str_array.len()), + None => new_null_array(return_field.data_type(), str_array.len()), Some(str_list_literal) => { let str_list = str_list_literal.split(',').collect::>(); let result = match str_array.data_type() { @@ -190,7 +189,7 @@ impl ScalarUDFImpl for FindInSetFunc { let res = match string_literal { // find_in_set(null, column_b) = null None => { - new_null_array(str_list_array.data_type(), str_list_array.len()) + new_null_array(return_field.data_type(), str_list_array.len()) } Some(string) => { let result = match str_list_array.data_type() { diff --git a/datafusion/functions/src/unicode/initcap.rs b/datafusion/functions/src/unicode/initcap.rs index 929b0c316951b..e2fc9130992db 100644 --- a/datafusion/functions/src/unicode/initcap.rs +++ b/datafusion/functions/src/unicode/initcap.rs @@ -26,7 +26,7 @@ use arrow::datatypes::DataType; use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::types::logical_string; -use datafusion_common::{Result, exec_err}; +use datafusion_common::{Result, ScalarValue, exec_err}; use datafusion_expr::{ Coercion, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, @@ -99,6 +99,39 @@ impl ScalarUDFImpl for InitcapFunc { &self, args: datafusion_expr::ScalarFunctionArgs, ) -> Result { + let arg = &args.args[0]; + + // Scalar fast path - handle directly without array conversion + if let ColumnarValue::Scalar(scalar) = arg { + return match scalar { + ScalarValue::Utf8(None) + | ScalarValue::LargeUtf8(None) + | ScalarValue::Utf8View(None) => Ok(arg.clone()), + ScalarValue::Utf8(Some(s)) => { + let mut result = String::new(); + initcap_string(s, &mut result); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))) + } + ScalarValue::LargeUtf8(Some(s)) => { + let mut result = String::new(); + initcap_string(s, &mut result); + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(result)))) + } + ScalarValue::Utf8View(Some(s)) => { + let mut result = String::new(); + initcap_string(s, &mut result); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(result)))) + } + other => { + exec_err!( + "Unsupported data type {:?} for function `initcap`", + other.data_type() + ) + } + }; + } + + // Array path let args = &args.args; match args[0].data_type() { DataType::Utf8 => make_scalar_function(initcap::, vec![])(args), diff --git a/datafusion/functions/src/unicode/left.rs b/datafusion/functions/src/unicode/left.rs index db27d900b6828..76873e7f5d3e1 100644 --- a/datafusion/functions/src/unicode/left.rs +++ b/datafusion/functions/src/unicode/left.rs @@ -16,20 +16,11 @@ // under the License. use std::any::Any; -use std::cmp::Ordering; -use std::sync::Arc; -use arrow::array::{ - Array, ArrayAccessor, ArrayIter, ArrayRef, GenericStringArray, Int64Array, - OffsetSizeTrait, -}; +use crate::unicode::common::{LeftSlicer, general_left_right}; +use crate::utils::make_scalar_function; use arrow::datatypes::DataType; - -use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::Result; -use datafusion_common::cast::{ - as_generic_string_array, as_int64_array, as_string_view_array, -}; use datafusion_common::exec_err; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ @@ -94,22 +85,26 @@ impl ScalarUDFImpl for LeftFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - utf8_to_str_type(&arg_types[0], "left") + Ok(arg_types[0].clone()) } + /// Returns first n characters in the string, or when n is negative, returns all but last |n| characters. + /// left('abcde', 2) = 'ab' + /// left('abcde', -2) = 'abc' + /// The implementation uses UTF-8 code points as characters fn invoke_with_args( &self, args: datafusion_expr::ScalarFunctionArgs, ) -> Result { let args = &args.args; match args[0].data_type() { - DataType::Utf8 | DataType::Utf8View => { - make_scalar_function(left::, vec![])(args) + DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 => { + make_scalar_function(general_left_right::, vec![])(args) } - DataType::LargeUtf8 => make_scalar_function(left::, vec![])(args), other => exec_err!( - "Unsupported data type {other:?} for function left,\ - expected Utf8View, Utf8 or LargeUtf8." + "Unsupported data type {other:?} for function {},\ + expected Utf8View, Utf8 or LargeUtf8.", + self.name() ), } } @@ -119,63 +114,10 @@ impl ScalarUDFImpl for LeftFunc { } } -/// Returns first n characters in the string, or when n is negative, returns all but last |n| characters. -/// left('abcde', 2) = 'ab' -/// The implementation uses UTF-8 code points as characters -fn left(args: &[ArrayRef]) -> Result { - let n_array = as_int64_array(&args[1])?; - - if args[0].data_type() == &DataType::Utf8View { - let string_array = as_string_view_array(&args[0])?; - left_impl::(string_array, n_array) - } else { - let string_array = as_generic_string_array::(&args[0])?; - left_impl::(string_array, n_array) - } -} - -fn left_impl<'a, T: OffsetSizeTrait, V: ArrayAccessor>( - string_array: V, - n_array: &Int64Array, -) -> Result { - let iter = ArrayIter::new(string_array); - let mut chars_buf = Vec::new(); - let result = iter - .zip(n_array.iter()) - .map(|(string, n)| match (string, n) { - (Some(string), Some(n)) => match n.cmp(&0) { - Ordering::Less => { - // Collect chars once and reuse for both count and take - chars_buf.clear(); - chars_buf.extend(string.chars()); - let len = chars_buf.len() as i64; - - // For negative n, take (len + n) chars if n > -len (avoiding abs() which panics on i64::MIN) - Some(if n > -len { - chars_buf - .iter() - .take((len + n) as usize) - .collect::() - } else { - "".to_string() - }) - } - Ordering::Equal => Some("".to_string()), - Ordering::Greater => { - Some(string.chars().take(n as usize).collect::()) - } - }, - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - #[cfg(test)] mod tests { - use arrow::array::{Array, StringArray}; - use arrow::datatypes::DataType::Utf8; + use arrow::array::{Array, StringArray, StringViewArray}; + use arrow::datatypes::DataType::{Utf8, Utf8View}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -218,6 +160,17 @@ mod tests { Utf8, StringArray ); + test_function!( + LeftFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::from(i64::MIN)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); test_function!( LeftFunc::new(), vec![ @@ -299,6 +252,74 @@ mod tests { StringArray ); + // StringView cases + test_function!( + LeftFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("abcde".to_string()))), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + Ok(Some("ab")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + LeftFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("abcde".to_string()))), + ColumnarValue::Scalar(ScalarValue::from(200i64)), + ], + Ok(Some("abcde")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + LeftFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("".to_string()))), + ColumnarValue::Scalar(ScalarValue::from(200i64)), + ], + Ok(Some("")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + LeftFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + "joséésoj".to_string() + ))), + ColumnarValue::Scalar(ScalarValue::from(-3i64)), + ], + Ok(Some("joséé")), + &str, + Utf8View, + StringViewArray + ); + + // Unicode indexing case + let input = "joé楽s𐀀so↓j"; + for n in 1..=input.chars().count() { + let expected = input + .chars() + .take(input.chars().count() - n) + .collect::(); + test_function!( + LeftFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::from(input)), + ColumnarValue::Scalar(ScalarValue::from(-(n as i64))), + ], + Ok(Some(expected.as_str())), + &str, + Utf8, + StringArray + ); + } + Ok(()) } } diff --git a/datafusion/functions/src/unicode/mod.rs b/datafusion/functions/src/unicode/mod.rs index 4a0dd21d749af..7250b3915fb5c 100644 --- a/datafusion/functions/src/unicode/mod.rs +++ b/datafusion/functions/src/unicode/mod.rs @@ -22,6 +22,7 @@ use std::sync::Arc; use datafusion_expr::ScalarUDF; pub mod character_length; +pub mod common; pub mod find_in_set; pub mod initcap; pub mod left; diff --git a/datafusion/functions/src/unicode/right.rs b/datafusion/functions/src/unicode/right.rs index ac98a3f202a5b..a97e242b73f9e 100644 --- a/datafusion/functions/src/unicode/right.rs +++ b/datafusion/functions/src/unicode/right.rs @@ -16,20 +16,11 @@ // under the License. use std::any::Any; -use std::cmp::{Ordering, max}; -use std::sync::Arc; -use arrow::array::{ - Array, ArrayAccessor, ArrayIter, ArrayRef, GenericStringArray, Int64Array, - OffsetSizeTrait, -}; +use crate::unicode::common::{RightSlicer, general_left_right}; +use crate::utils::make_scalar_function; use arrow::datatypes::DataType; - -use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::Result; -use datafusion_common::cast::{ - as_generic_string_array, as_int64_array, as_string_view_array, -}; use datafusion_common::exec_err; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ @@ -94,22 +85,26 @@ impl ScalarUDFImpl for RightFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - utf8_to_str_type(&arg_types[0], "right") + Ok(arg_types[0].clone()) } + /// Returns right n characters in the string, or when n is negative, returns all but first |n| characters. + /// right('abcde', 2) = 'de' + /// right('abcde', -2) = 'cde' + /// The implementation uses UTF-8 code points as characters fn invoke_with_args( &self, args: datafusion_expr::ScalarFunctionArgs, ) -> Result { let args = &args.args; match args[0].data_type() { - DataType::Utf8 | DataType::Utf8View => { - make_scalar_function(right::, vec![])(args) + DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 => { + make_scalar_function(general_left_right::, vec![])(args) } - DataType::LargeUtf8 => make_scalar_function(right::, vec![])(args), other => exec_err!( - "Unsupported data type {other:?} for function right,\ - expected Utf8View, Utf8 or LargeUtf8." + "Unsupported data type {other:?} for function {},\ + expected Utf8View, Utf8 or LargeUtf8.", + self.name() ), } } @@ -119,58 +114,10 @@ impl ScalarUDFImpl for RightFunc { } } -/// Returns last n characters in the string, or when n is negative, returns all but first |n| characters. -/// right('abcde', 2) = 'de' -/// The implementation uses UTF-8 code points as characters -fn right(args: &[ArrayRef]) -> Result { - let n_array = as_int64_array(&args[1])?; - if args[0].data_type() == &DataType::Utf8View { - // string_view_right(args) - let string_array = as_string_view_array(&args[0])?; - right_impl::(&mut string_array.iter(), n_array) - } else { - // string_right::(args) - let string_array = &as_generic_string_array::(&args[0])?; - right_impl::(&mut string_array.iter(), n_array) - } -} - -// Currently the return type can only be Utf8 or LargeUtf8, to reach fully support, we need -// to edit the `get_optimal_return_type` in utils.rs to make the udfs be able to return Utf8View -// See https://github.com/apache/datafusion/issues/11790#issuecomment-2283777166 -fn right_impl<'a, T: OffsetSizeTrait, V: ArrayAccessor>( - string_array_iter: &mut ArrayIter, - n_array: &Int64Array, -) -> Result { - let result = string_array_iter - .zip(n_array.iter()) - .map(|(string, n)| match (string, n) { - (Some(string), Some(n)) => match n.cmp(&0) { - Ordering::Less => Some( - string - .chars() - .skip(n.unsigned_abs() as usize) - .collect::(), - ), - Ordering::Equal => Some("".to_string()), - Ordering::Greater => Some( - string - .chars() - .skip(max(string.chars().count() as i64 - n, 0) as usize) - .collect::(), - ), - }, - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - #[cfg(test)] mod tests { - use arrow::array::{Array, StringArray}; - use arrow::datatypes::DataType::Utf8; + use arrow::array::{Array, StringArray, StringViewArray}; + use arrow::datatypes::DataType::{Utf8, Utf8View}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -213,6 +160,17 @@ mod tests { Utf8, StringArray ); + test_function!( + RightFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::from(i64::MIN)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); test_function!( RightFunc::new(), vec![ @@ -260,10 +218,10 @@ mod tests { test_function!( RightFunc::new(), vec![ - ColumnarValue::Scalar(ScalarValue::from("joséésoj")), + ColumnarValue::Scalar(ScalarValue::from("joséérend")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ], - Ok(Some("éésoj")), + Ok(Some("érend")), &str, Utf8, StringArray @@ -271,10 +229,10 @@ mod tests { test_function!( RightFunc::new(), vec![ - ColumnarValue::Scalar(ScalarValue::from("joséésoj")), + ColumnarValue::Scalar(ScalarValue::from("joséérend")), ColumnarValue::Scalar(ScalarValue::from(-3i64)), ], - Ok(Some("éésoj")), + Ok(Some("éérend")), &str, Utf8, StringArray @@ -294,6 +252,71 @@ mod tests { StringArray ); + // StringView cases + test_function!( + RightFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("abcde".to_string()))), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + Ok(Some("de")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + RightFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("abcde".to_string()))), + ColumnarValue::Scalar(ScalarValue::from(200i64)), + ], + Ok(Some("abcde")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + RightFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("".to_string()))), + ColumnarValue::Scalar(ScalarValue::from(200i64)), + ], + Ok(Some("")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + RightFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + "joséérend".to_string() + ))), + ColumnarValue::Scalar(ScalarValue::from(-3i64)), + ], + Ok(Some("éérend")), + &str, + Utf8View, + StringViewArray + ); + + // Unicode indexing case + let input = "joé楽s𐀀so↓j"; + for n in 1..=input.chars().count() { + let expected = input.chars().skip(n).collect::(); + test_function!( + RightFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::from(input)), + ColumnarValue::Scalar(ScalarValue::from(-(n as i64))), + ], + Ok(Some(expected.as_str())), + &str, + Utf8, + StringArray + ); + } + Ok(()) } } diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index 9be086c4cf5fc..c1d6ecffe5510 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -32,6 +32,7 @@ use datafusion_expr::{ Volatility, }; use datafusion_macros::user_doc; +use memchr::memchr; #[user_doc( doc_section(label = "String Functions"), @@ -179,6 +180,31 @@ fn strpos(args: &[ArrayRef]) -> Result { } } +/// Find `needle` in `haystack` using `memchr` to quickly skip to positions +/// where the first byte matches, then verify the remaining bytes. Using +/// string::find is slower because it has significant per-call overhead that +/// `memchr` does not, and strpos is often invoked many times on short inputs. +/// Returns a 1-based position, or 0 if not found. +/// Both inputs must be ASCII-only. +fn find_ascii_substring(haystack: &[u8], needle: &[u8]) -> usize { + let needle_len = needle.len(); + let first_byte = needle[0]; + let mut offset = 0; + + while let Some(pos) = memchr(first_byte, &haystack[offset..]) { + let start = offset + pos; + if start + needle_len > haystack.len() { + return 0; + } + if haystack[start..start + needle_len] == *needle { + return start + 1; + } + offset = start + 1; + } + + 0 +} + /// Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.) /// strpos('high', 'ig') = 2 /// The implementation uses UTF-8 code points as characters @@ -198,37 +224,25 @@ where .zip(substring_iter) .map(|(string, substring)| match (string, substring) { (Some(string), Some(substring)) => { - // If only ASCII characters are present, we can use the slide window method to find - // the sub vector in the main vector. This is faster than string.find() method. + if substring.is_empty() { + return T::Native::from_usize(1); + } + + let substring_bytes = substring.as_bytes(); + let string_bytes = string.as_bytes(); + + if substring_bytes.len() > string_bytes.len() { + return T::Native::from_usize(0); + } + if ascii_only { - // If the substring is empty, the result is 1. - if substring.is_empty() { - T::Native::from_usize(1) - } else { - T::Native::from_usize( - string - .as_bytes() - .windows(substring.len()) - .position(|w| w == substring.as_bytes()) - .map(|x| x + 1) - .unwrap_or(0), - ) - } + T::Native::from_usize(find_ascii_substring( + string_bytes, + substring_bytes, + )) } else { // For non-ASCII, use a single-pass search that tracks both // byte position and character position simultaneously - if substring.is_empty() { - return T::Native::from_usize(1); - } - - let substring_bytes = substring.as_bytes(); - let string_bytes = string.as_bytes(); - - if substring_bytes.len() > string_bytes.len() { - return T::Native::from_usize(0); - } - - // Single pass: find substring while counting characters let mut char_pos = 0; for (byte_idx, _) in string.char_indices() { char_pos += 1; diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index cc1d53b3aad67..505388089f198 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -176,7 +176,7 @@ fn substr(args: &[ArrayRef]) -> Result { // `get_true_start_end('Hi🌏', 1, None) -> (0, 6)` // `get_true_start_end('Hi🌏', 1, 1) -> (0, 1)` // `get_true_start_end('Hi🌏', -10, 2) -> (0, 0)` -fn get_true_start_end( +pub fn get_true_start_end( input: &str, start: i64, count: Option, @@ -185,7 +185,10 @@ fn get_true_start_end( let start = start.checked_sub(1).unwrap_or(start); let end = match count { - Some(count) => start + count as i64, + Some(count) => { + let count_i64 = i64::try_from(count).unwrap_or(i64::MAX); + start.saturating_add(count_i64) + } None => input.len() as i64, }; let count_to_end = count.is_some(); @@ -235,7 +238,7 @@ fn get_true_start_end( // string, such as `substr(long_str_with_1k_chars, 1, 32)`. // In such case the overhead of ASCII-validation may not be worth it, so // skip the validation for short prefix for now. -fn enable_ascii_fast_path<'a, V: StringArrayType<'a>>( +pub fn enable_ascii_fast_path<'a, V: StringArrayType<'a>>( string_array: &V, start: &Int64Array, count: Option<&Int64Array>, @@ -247,7 +250,7 @@ fn enable_ascii_fast_path<'a, V: StringArrayType<'a>>( // HACK: can be simplified if function has specialized // implementation for `ScalarValue` (implement without `make_scalar_function()`) - let avg_prefix_len = start + let total_prefix_len = start .iter() .zip(count.iter()) .take(n_sample) @@ -255,11 +258,11 @@ fn enable_ascii_fast_path<'a, V: StringArrayType<'a>>( let start = start.unwrap_or(0); let count = count.unwrap_or(0); // To get substring, need to decode from 0 to start+count instead of start to start+count - start + count + start.saturating_add(count) }) - .sum::(); + .fold(0i64, |acc, val| acc.saturating_add(val)); - avg_prefix_len as f64 / n_sample as f64 <= short_prefix_threshold + (total_prefix_len as f64 / n_sample as f64) <= short_prefix_threshold } None => false, }; @@ -810,7 +813,7 @@ mod tests { SubstrFunc::new(), vec![ ColumnarValue::Scalar(ScalarValue::from("abc")), - ColumnarValue::Scalar(ScalarValue::from(-9223372036854775808i64)), + ColumnarValue::Scalar(ScalarValue::from(i64::MIN)), ], Ok(Some("abc")), &str, @@ -821,7 +824,7 @@ mod tests { SubstrFunc::new(), vec![ ColumnarValue::Scalar(ScalarValue::from("overflow")), - ColumnarValue::Scalar(ScalarValue::from(-9223372036854775808i64)), + ColumnarValue::Scalar(ScalarValue::from(i64::MIN)), ColumnarValue::Scalar(ScalarValue::from(1i64)), ], exec_err!("negative overflow when calculating skip value"), @@ -829,6 +832,18 @@ mod tests { Utf8View, StringViewArray ); + test_function!( + SubstrFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::from("large count")), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ColumnarValue::Scalar(ScalarValue::from(i64::MAX)), + ], + Ok(Some("arge count")), + &str, + Utf8View, + StringViewArray + ); Ok(()) } diff --git a/datafusion/functions/src/unicode/translate.rs b/datafusion/functions/src/unicode/translate.rs index f97c0ed5c2990..e86eaf8111b1c 100644 --- a/datafusion/functions/src/unicode/translate.rs +++ b/datafusion/functions/src/unicode/translate.rs @@ -35,8 +35,8 @@ use datafusion_macros::user_doc; #[user_doc( doc_section(label = "String Functions"), - description = "Translates characters in a string to specified translation characters.", - syntax_example = "translate(str, chars, translation)", + description = "Performs character-wise substitution based on a mapping.", + syntax_example = "translate(str, from, to)", sql_example = r#"```sql > select translate('twice', 'wic', 'her'); +--------------------------------------------------+ @@ -46,10 +46,10 @@ use datafusion_macros::user_doc; +--------------------------------------------------+ ```"#, standard_argument(name = "str", prefix = "String"), - argument(name = "chars", description = "Characters to translate."), + argument(name = "from", description = "The characters to be replaced."), argument( - name = "translation", - description = "Translation characters. Translation characters replace only characters at the same position in the **chars** string." + name = "to", + description = "The characters to replace them with. Each character in **from** that is found in **str** is replaced by the character at the same index in **to**. Any characters in **from** that don't have a corresponding character in **to** are removed. If a character appears more than once in **from**, the first occurrence determines the mapping." ) )] #[derive(Debug, PartialEq, Eq, Hash)] @@ -71,6 +71,7 @@ impl TranslateFunc { vec![ Exact(vec![Utf8View, Utf8, Utf8]), Exact(vec![Utf8, Utf8, Utf8]), + Exact(vec![LargeUtf8, Utf8, Utf8]), ], Volatility::Immutable, ), @@ -99,6 +100,61 @@ impl ScalarUDFImpl for TranslateFunc { &self, args: datafusion_expr::ScalarFunctionArgs, ) -> Result { + // When from and to are scalars, pre-build the translation map once + if let (Some(from_str), Some(to_str)) = ( + try_as_scalar_str(&args.args[1]), + try_as_scalar_str(&args.args[2]), + ) { + let to_graphemes: Vec<&str> = to_str.graphemes(true).collect(); + + let mut from_map: HashMap<&str, usize> = HashMap::new(); + for (index, c) in from_str.graphemes(true).enumerate() { + // Ignore characters that already exist in from_map + from_map.entry(c).or_insert(index); + } + + let ascii_table = build_ascii_translate_table(from_str, to_str); + + let string_array = args.args[0].to_array_of_size(args.number_rows)?; + + let result = match string_array.data_type() { + DataType::Utf8View => { + let arr = string_array.as_string_view(); + translate_with_map::( + arr, + &from_map, + &to_graphemes, + ascii_table.as_ref(), + ) + } + DataType::Utf8 => { + let arr = string_array.as_string::(); + translate_with_map::( + arr, + &from_map, + &to_graphemes, + ascii_table.as_ref(), + ) + } + DataType::LargeUtf8 => { + let arr = string_array.as_string::(); + translate_with_map::( + arr, + &from_map, + &to_graphemes, + ascii_table.as_ref(), + ) + } + other => { + return exec_err!( + "Unsupported data type {other:?} for function translate" + ); + } + }?; + + return Ok(ColumnarValue::Array(result)); + } + make_scalar_function(invoke_translate, vec![])(&args.args) } @@ -107,6 +163,14 @@ impl ScalarUDFImpl for TranslateFunc { } } +/// If `cv` is a non-null scalar string, return its value. +fn try_as_scalar_str(cv: &ColumnarValue) -> Option<&str> { + match cv { + ColumnarValue::Scalar(s) => s.try_as_str().flatten(), + _ => None, + } +} + fn invoke_translate(args: &[ArrayRef]) -> Result { match args[0].data_type() { DataType::Utf8View => { @@ -123,8 +187,8 @@ fn invoke_translate(args: &[ArrayRef]) -> Result { } DataType::LargeUtf8 => { let string_array = args[0].as_string::(); - let from_array = args[1].as_string::(); - let to_array = args[2].as_string::(); + let from_array = args[1].as_string::(); + let to_array = args[2].as_string::(); translate::(string_array, from_array, to_array) } other => { @@ -170,7 +234,7 @@ where // Build from_map using reusable buffer from_graphemes.extend(from.graphemes(true)); for (index, c) in from_graphemes.iter().enumerate() { - // Ignore characters that already exist in from_map, else insert + // Ignore characters that already exist in from_map from_map.entry(*c).or_insert(index); } @@ -199,6 +263,97 @@ where Ok(Arc::new(result) as ArrayRef) } +/// Sentinel value in the ASCII translate table indicating the character should +/// be deleted (the `from` character has no corresponding `to` character). Any +/// value > 127 works since valid ASCII is 0–127. +const ASCII_DELETE: u8 = 0xFF; + +/// If `from` and `to` are both ASCII, build a fixed-size lookup table for +/// translation. Each entry maps an input byte to its replacement byte, or to +/// [`ASCII_DELETE`] if the character should be removed. Returns `None` if +/// either string contains non-ASCII characters. +fn build_ascii_translate_table(from: &str, to: &str) -> Option<[u8; 128]> { + if !from.is_ascii() || !to.is_ascii() { + return None; + } + let mut table = [0u8; 128]; + for i in 0..128u8 { + table[i as usize] = i; + } + let to_bytes = to.as_bytes(); + let mut seen = [false; 128]; + for (i, from_byte) in from.bytes().enumerate() { + let idx = from_byte as usize; + if !seen[idx] { + seen[idx] = true; + if i < to_bytes.len() { + table[idx] = to_bytes[i]; + } else { + table[idx] = ASCII_DELETE; + } + } + } + Some(table) +} + +/// Optimized translate for constant `from` and `to` arguments: uses a pre-built +/// translation map instead of rebuilding it for every row. When an ASCII byte +/// lookup table is provided, ASCII input rows use the lookup table; non-ASCII +/// inputs fallback to using the map. +fn translate_with_map<'a, T: OffsetSizeTrait, V>( + string_array: V, + from_map: &HashMap<&str, usize>, + to_graphemes: &[&str], + ascii_table: Option<&[u8; 128]>, +) -> Result +where + V: ArrayAccessor, +{ + let mut result_graphemes: Vec<&str> = Vec::new(); + let mut ascii_buf: Vec = Vec::new(); + + let result = ArrayIter::new(string_array) + .map(|string| { + string.map(|s| { + // Fast path: byte-level table lookup for ASCII strings + if let Some(table) = ascii_table + && s.is_ascii() + { + ascii_buf.clear(); + for &b in s.as_bytes() { + let mapped = table[b as usize]; + if mapped != ASCII_DELETE { + ascii_buf.push(mapped); + } + } + // SAFETY: all bytes are ASCII, hence valid UTF-8. + return unsafe { + std::str::from_utf8_unchecked(&ascii_buf).to_owned() + }; + } + + // Slow path: grapheme-based translation + result_graphemes.clear(); + + for c in s.graphemes(true) { + match from_map.get(c) { + Some(n) => { + if let Some(replacement) = to_graphemes.get(*n) { + result_graphemes.push(*replacement); + } + } + None => result_graphemes.push(c), + } + } + + result_graphemes.concat() + }) + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + #[cfg(test)] mod tests { use arrow::array::{Array, StringArray}; @@ -284,6 +439,21 @@ mod tests { Utf8, StringArray ); + // Non-ASCII input with ASCII scalar from/to: exercises the + // grapheme fallback within translate_with_map. + test_function!( + TranslateFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::from("café")), + ColumnarValue::Scalar(ScalarValue::from("ae")), + ColumnarValue::Scalar(ScalarValue::from("AE")) + ], + Ok(Some("cAfé")), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "unicode_expressions"))] test_function!( TranslateFunc::new(), diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index e4980728b18a0..b9bde1454994c 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -147,7 +147,7 @@ where if scalar.is_null() { // Null scalar is castable to any numeric, creating a non-null expression. // Provide null array explicitly to make result null - PrimitiveArray::::new_null(1) + PrimitiveArray::::new_null(left.len()) } else { let right = R::Native::try_from(scalar.clone()).map_err(|_| { DataFusionError::NotImplemented(format!( @@ -363,12 +363,30 @@ pub mod test { }; } - use arrow::datatypes::DataType; + use arrow::{ + array::Int32Array, + datatypes::{DataType, Int32Type}, + }; use itertools::Either; pub(crate) use test_function; use super::*; + #[test] + fn test_calculate_binary_math_scalar_null() { + let left = Int32Array::from(vec![1, 2]); + let right = ColumnarValue::Scalar(ScalarValue::Int32(None)); + let result = calculate_binary_math::( + &left, + &right, + |x, y| Ok(x + y), + ) + .unwrap(); + + assert_eq!(result.len(), 2); + assert_eq!(result.null_count(), 2); + } + #[test] fn string_to_int_type() { let v = utf8_to_int_type(&DataType::Utf8, "test").unwrap(); diff --git a/datafusion/macros/Cargo.toml b/datafusion/macros/Cargo.toml index 85833bf11649d..da26de7fe2174 100644 --- a/datafusion/macros/Cargo.toml +++ b/datafusion/macros/Cargo.toml @@ -45,5 +45,5 @@ proc-macro = true [dependencies] datafusion-doc = { workspace = true } -quote = "1.0.41" -syn = { version = "2.0.113", features = ["full"] } +quote = "1.0.44" +syn = { version = "2.0.116", features = ["full"] } diff --git a/datafusion/macros/src/user_doc.rs b/datafusion/macros/src/user_doc.rs index 27f73fd955380..ce9e7d55ef103 100644 --- a/datafusion/macros/src/user_doc.rs +++ b/datafusion/macros/src/user_doc.rs @@ -20,7 +20,6 @@ html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] #![cfg_attr(docsrs, feature(doc_cfg))] -#![deny(clippy::allow_attributes)] extern crate proc_macro; use datafusion_doc::scalar_doc_sections::doc_sections_const; diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 15d3261ca5132..76d3f73f68767 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -55,7 +55,7 @@ itertools = { workspace = true } log = { workspace = true } recursive = { workspace = true, optional = true } regex = { workspace = true } -regex-syntax = "0.8.6" +regex-syntax = "0.8.9" [dev-dependencies] async-trait = { workspace = true } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 02395c76bdd92..a98678f7cf9c4 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -36,22 +36,22 @@ use datafusion_common::{ }; use datafusion_expr::expr::{ self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists, InList, - InSubquery, Like, ScalarFunction, Sort, WindowFunction, + InSubquery, Like, ScalarFunction, SetComparison, Sort, WindowFunction, }; use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; use datafusion_expr::expr_schema::cast_subquery; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::type_coercion::binary::{comparison_coercion, like_coercion}; -use datafusion_expr::type_coercion::functions::fields_with_udf; +use datafusion_expr::type_coercion::functions::{UDFCoercionExt, fields_with_udf}; +use datafusion_expr::type_coercion::is_datetime; use datafusion_expr::type_coercion::other::{ get_coerce_type_for_case_expression, get_coerce_type_for_list, }; -use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_utf8view_or_large_utf8}; use datafusion_expr::utils::merge_schema; use datafusion_expr::{ - AggregateUDF, Cast, Expr, ExprSchemable, Join, Limit, LogicalPlan, Operator, - Projection, ScalarUDF, Union, WindowFrame, WindowFrameBound, WindowFrameUnits, - is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, lit, not, + Cast, Expr, ExprSchemable, Join, Limit, LogicalPlan, Operator, Projection, Union, + WindowFrame, WindowFrameBound, WindowFrameUnits, is_false, is_not_false, is_not_true, + is_not_unknown, is_true, is_unknown, lit, not, }; /// Performs type coercion by determining the schema @@ -500,6 +500,43 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { negated, )))) } + Expr::SetComparison(SetComparison { + expr, + subquery, + op, + quantifier, + }) => { + let new_plan = analyze_internal( + self.schema, + Arc::unwrap_or_clone(subquery.subquery), + )? + .data; + let expr_type = expr.get_type(self.schema)?; + let subquery_type = new_plan.schema().field(0).data_type(); + if (expr_type.is_numeric() && subquery_type.is_string()) + || (subquery_type.is_numeric() && expr_type.is_string()) + { + return plan_err!( + "expr type {expr_type} can't cast to {subquery_type} in SetComparison" + ); + } + let common_type = comparison_coercion(&expr_type, subquery_type).ok_or( + plan_datafusion_err!( + "expr type {expr_type} can't cast to {subquery_type} in SetComparison" + ), + )?; + let new_subquery = Subquery { + subquery: Arc::new(new_plan), + outer_ref_columns: subquery.outer_ref_columns, + spans: subquery.spans, + }; + Ok(Transformed::yes(Expr::SetComparison(SetComparison::new( + Box::new(expr.cast_to(&common_type, self.schema)?), + cast_subquery(new_subquery, &common_type)?, + op, + quantifier, + )))) + } Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op( *expr, self.schema, @@ -637,11 +674,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { Ok(Transformed::yes(Expr::Case(case))) } Expr::ScalarFunction(ScalarFunction { func, args }) => { - let new_expr = coerce_arguments_for_signature_with_scalar_udf( - args, - self.schema, - &func, - )?; + let new_expr = + coerce_arguments_for_signature(args, self.schema, func.as_ref())?; Ok(Transformed::yes(Expr::ScalarFunction( ScalarFunction::new_udf(func, new_expr), ))) @@ -657,11 +691,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { null_treatment, }, }) => { - let new_expr = coerce_arguments_for_signature_with_aggregate_udf( - args, - self.schema, - &func, - )?; + let new_expr = + coerce_arguments_for_signature(args, self.schema, func.as_ref())?; Ok(Transformed::yes(Expr::AggregateFunction( expr::AggregateFunction::new_udf( func, @@ -692,13 +723,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { let args = match &fun { expr::WindowFunctionDefinition::AggregateUDF(udf) => { - coerce_arguments_for_signature_with_aggregate_udf( - args, - self.schema, - udf, - )? + coerce_arguments_for_signature(args, self.schema, udf.as_ref())? + } + expr::WindowFunctionDefinition::WindowUDF(udf) => { + coerce_arguments_for_signature(args, self.schema, udf.as_ref())? } - _ => args, }; let new_expr = Expr::from(WindowFunction { @@ -859,12 +888,15 @@ fn coerce_frame_bound( fn extract_window_frame_target_type(col_type: &DataType) -> Result { if col_type.is_numeric() - || is_utf8_or_utf8view_or_large_utf8(col_type) - || matches!(col_type, DataType::List(_)) - || matches!(col_type, DataType::LargeList(_)) - || matches!(col_type, DataType::FixedSizeList(_, _)) - || matches!(col_type, DataType::Null) - || matches!(col_type, DataType::Boolean) + || col_type.is_string() + || col_type.is_null() + || matches!( + col_type, + DataType::List(_) + | DataType::LargeList(_) + | DataType::FixedSizeList(_, _) + | DataType::Boolean + ) { Ok(col_type.clone()) } else if is_datetime(col_type) { @@ -917,40 +949,10 @@ fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchema) -> Result { /// `signature`, if possible. /// /// See the module level documentation for more detail on coercion. -fn coerce_arguments_for_signature_with_scalar_udf( - expressions: Vec, - schema: &DFSchema, - func: &ScalarUDF, -) -> Result> { - if expressions.is_empty() { - return Ok(expressions); - } - - let current_fields = expressions - .iter() - .map(|e| e.to_field(schema).map(|(_, f)| f)) - .collect::>>()?; - - let coerced_types = fields_with_udf(¤t_fields, func)? - .into_iter() - .map(|f| f.data_type().clone()) - .collect::>(); - - expressions - .into_iter() - .enumerate() - .map(|(i, expr)| expr.cast_to(&coerced_types[i], schema)) - .collect() -} - -/// Returns `expressions` coerced to types compatible with -/// `signature`, if possible. -/// -/// See the module level documentation for more detail on coercion. -fn coerce_arguments_for_signature_with_aggregate_udf( +fn coerce_arguments_for_signature( expressions: Vec, schema: &DFSchema, - func: &AggregateUDF, + func: &F, ) -> Result> { if expressions.is_empty() { return Ok(expressions); @@ -1890,7 +1892,7 @@ mod test { .err() .unwrap() .strip_backtrace(); - assert!(err.starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'avg' function: coercion from Utf8 to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64]) failed")); + assert!(err.starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'avg' function: coercion from Utf8 to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float16, Float32, Float64]) failed")); Ok(()) } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index d9273a8f60fb2..2096c42770315 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -34,7 +34,9 @@ use datafusion_expr::expr::{Alias, ScalarFunction}; use datafusion_expr::logical_plan::{ Aggregate, Filter, LogicalPlan, Projection, Sort, Window, }; -use datafusion_expr::{BinaryExpr, Case, Expr, Operator, SortExpr, col}; +use datafusion_expr::{ + BinaryExpr, Case, Expr, ExpressionPlacement, Operator, SortExpr, col, +}; const CSE_PREFIX: &str = "__common_expr"; @@ -698,10 +700,27 @@ impl CSEController for ExprCSEController<'_> { } fn is_ignored(&self, node: &Expr) -> bool { + // MoveTowardsLeafNodes expressions (e.g. get_field) are cheap struct + // field accesses that the ExtractLeafExpressions / PushDownLeafProjections + // rules deliberately duplicate when needed (one copy for a filter + // predicate, another for an output column). CSE deduplicating them + // creates intermediate projections that fight with those rules, + // causing optimizer instability — ExtractLeafExpressions will undo + // the dedup, creating an infinite loop that runs until the iteration + // limit is hit. Skip them. + if node.placement() == ExpressionPlacement::MoveTowardsLeafNodes { + return true; + } + // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] let is_normal_minus_aggregates = matches!( node, + // TODO: there's an argument for removing `Literal` from here, + // maybe using `Expr::placemement().should_push_to_leaves()` instead + // so that we extract common literals and don't broadcast them to num_batch_rows multiple times. + // However that currently breaks things like `percentile_cont()` which expect literal arguments + // (and would instead be getting `col(__common_expr_n)`). Expr::Literal(..) | Expr::Column(..) | Expr::ScalarVariable(..) @@ -825,6 +844,7 @@ mod test { use super::*; use crate::assert_optimized_plan_eq_snapshot; use crate::optimizer::OptimizerContext; + use crate::test::udfs::leaf_udf_expr; use crate::test::*; use datafusion_expr::test::function_stub::{avg, sum}; @@ -1826,4 +1846,56 @@ mod test { panic!("dummy - not implemented") } } + + /// Identical MoveTowardsLeafNodes expressions should NOT be deduplicated + /// by CSE — they are cheap (e.g. struct field access) and the extraction + /// rules deliberately duplicate them. Deduplicating causes optimizer + /// instability where one optimizer rule will undo the work of another, + /// resulting in an infinite optimization loop until the + /// we hit the max iteration limit and then give up. + #[test] + fn test_leaf_expression_not_extracted() -> Result<()> { + let table_scan = test_table_scan()?; + + let leaf = leaf_udf_expr(col("a")); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![leaf.clone().alias("c1"), leaf.alias("c2")])? + .build()?; + + // Plan should be unchanged — no __common_expr introduced + assert_optimized_plan_equal!( + plan, + @r" + Projection: leaf_udf(test.a) AS c1, leaf_udf(test.a) AS c2 + TableScan: test + " + ) + } + + /// When a MoveTowardsLeafNodes expression appears as a sub-expression of + /// a larger expression that IS duplicated, only the outer expression gets + /// deduplicated; the leaf sub-expression stays inline. + #[test] + fn test_leaf_subexpression_not_extracted() -> Result<()> { + let table_scan = test_table_scan()?; + + // leaf_udf(a) + b appears twice — the outer `+` is a common + // sub-expression, but leaf_udf(a) by itself is MoveTowardsLeafNodes + // and should not be extracted separately. + let common = leaf_udf_expr(col("a")) + col("b"); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![common.clone().alias("c1"), common.alias("c2")])? + .build()?; + + // The whole `leaf_udf(a) + b` gets deduplicated as __common_expr_1, + // but leaf_udf(a) alone is NOT pulled out. + assert_optimized_plan_equal!( + plan, + @r" + Projection: __common_expr_1 AS c1, __common_expr_1 AS c2 + Projection: leaf_udf(test.a) + test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) + } } diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index c8acb044876c4..b9d160d55589f 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -27,7 +27,10 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{Column, Result, assert_or_internal_err, plan_err}; +use datafusion_common::{ + Column, DFSchemaRef, ExprSchema, NullEquality, Result, assert_or_internal_err, + plan_err, +}; use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; @@ -310,6 +313,39 @@ fn mark_join( ) } +/// Check if join keys in the join filter may contain NULL values +/// +/// Returns true if any join key column is nullable on either side. +/// This is used to optimize null-aware anti joins: if all join keys are non-nullable, +/// we can use a regular anti join instead of the more expensive null-aware variant. +fn join_keys_may_be_null( + join_filter: &Expr, + left_schema: &DFSchemaRef, + right_schema: &DFSchemaRef, +) -> Result { + // Extract columns from the join filter + let mut columns = std::collections::HashSet::new(); + expr_to_columns(join_filter, &mut columns)?; + + // Check if any column is nullable + for col in columns { + // Check in left schema + if let Ok(field) = left_schema.field_from_column(&col) + && field.as_ref().is_nullable() + { + return Ok(true); + } + // Check in right schema + if let Ok(field) = right_schema.field_from_column(&col) + && field.as_ref().is_nullable() + { + return Ok(true); + } + } + + Ok(false) +} + fn build_join( left: &LogicalPlan, subquery: &LogicalPlan, @@ -403,6 +439,8 @@ fn build_join( // Degenerate case: no right columns referenced by the predicate(s) sub_query_alias.clone() }; + + // Mark joins don't use null-aware semantics (they use three-valued logic with mark column) let new_plan = LogicalPlanBuilder::from(left.clone()) .join_on(right_projected, join_type, Some(join_filter))? .build()?; @@ -415,10 +453,36 @@ fn build_join( return Ok(Some(new_plan)); } + // Determine if this should be a null-aware anti join + // Null-aware semantics are only needed for NOT IN subqueries, not NOT EXISTS: + // - NOT IN: Uses three-valued logic, requires null-aware handling + // - NOT EXISTS: Uses two-valued logic, regular anti join is correct + // We can distinguish them: NOT IN has in_predicate_opt, NOT EXISTS does not + // + // Additionally, if the join keys are non-nullable on both sides, we don't need + // null-aware semantics because NULLs cannot exist in the data. + let null_aware = matches!(join_type, JoinType::LeftAnti) + && in_predicate_opt.is_some() + && join_keys_may_be_null(&join_filter, left.schema(), sub_query_alias.schema())?; + // join our sub query into the main plan - let new_plan = LogicalPlanBuilder::from(left.clone()) - .join_on(sub_query_alias, join_type, Some(join_filter))? - .build()?; + let new_plan = if null_aware { + // Use join_detailed_with_options to set null_aware flag + LogicalPlanBuilder::from(left.clone()) + .join_detailed_with_options( + sub_query_alias, + join_type, + (Vec::::new(), Vec::::new()), // No equijoin keys, filter-based join + Some(join_filter), + NullEquality::NullEqualsNothing, + true, // null_aware + )? + .build()? + } else { + LogicalPlanBuilder::from(left.clone()) + .join_on(sub_query_alias, join_type, Some(join_filter))? + .build()? + }; debug!( "predicate subquery optimized:\n{}", new_plan.display_indent() @@ -1977,7 +2041,7 @@ mod tests { TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [arr:Int32;N] Unnest: lists[sq.arr|depth=1] structs[] [arr:Int32;N] - TableScan: sq [arr:List(Field { data_type: Int32, nullable: true });N] + TableScan: sq [arr:List(Int32);N] " ) } @@ -2012,7 +2076,7 @@ mod tests { TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [a:UInt32;N] Unnest: lists[sq.a|depth=1] structs[] [a:UInt32;N] - TableScan: sq [a:List(Field { data_type: UInt32, nullable: true });N] + TableScan: sq [a:List(UInt32);N] " ) } diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 770291566346c..3cb0516a6d296 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -341,6 +341,7 @@ fn find_inner_join( filter: None, schema: join_schema, null_equality, + null_aware: false, })); } } @@ -363,6 +364,7 @@ fn find_inner_join( join_type: JoinType::Inner, join_constraint: JoinConstraint::On, null_equality, + null_aware: false, })) } @@ -522,7 +524,7 @@ mod tests { plan, @ r" Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] " @@ -608,7 +610,7 @@ mod tests { plan, @ r" Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] " @@ -634,7 +636,7 @@ mod tests { plan, @ r" Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] " @@ -856,7 +858,7 @@ mod tests { plan, @ r" Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] @@ -936,7 +938,7 @@ mod tests { TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] " @@ -1010,7 +1012,7 @@ mod tests { Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] @@ -1246,7 +1248,7 @@ mod tests { plan, @ r" Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] " @@ -1367,6 +1369,7 @@ mod tests { filter: None, schema: join_schema, null_equality: NullEquality::NullEqualsNull, // Test preservation + null_aware: false, }); // Apply filter that can create join conditions diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index 2c78051c14134..58abe38d04bc7 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -119,6 +119,7 @@ impl OptimizerRule for EliminateOuterJoin { filter: join.filter.clone(), schema: Arc::clone(&join.schema), null_equality: join.null_equality, + null_aware: join.null_aware, })); Filter::try_new(filter.predicate, new_join) .map(|f| Transformed::yes(LogicalPlan::Filter(f))) diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index a623faf8a2ff0..0a50761e8a9f7 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -76,6 +76,7 @@ impl OptimizerRule for ExtractEquijoinPredicate { join_constraint, schema, null_equality, + null_aware, }) => { let left_schema = left.schema(); let right_schema = right.schema(); @@ -117,6 +118,7 @@ impl OptimizerRule for ExtractEquijoinPredicate { // According to `is not distinct from`'s semantics, it's // safe to override it null_equality: NullEquality::NullEqualsNull, + null_aware, }))); } } @@ -132,6 +134,7 @@ impl OptimizerRule for ExtractEquijoinPredicate { join_constraint, schema, null_equality, + null_aware, }))) } else { Ok(Transformed::no(LogicalPlan::Join(Join { @@ -143,6 +146,7 @@ impl OptimizerRule for ExtractEquijoinPredicate { join_constraint, schema, null_equality, + null_aware, }))) } } diff --git a/datafusion/optimizer/src/extract_leaf_expressions.rs b/datafusion/optimizer/src/extract_leaf_expressions.rs new file mode 100644 index 0000000000000..f5f4982e38c65 --- /dev/null +++ b/datafusion/optimizer/src/extract_leaf_expressions.rs @@ -0,0 +1,3022 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Two-pass optimizer pipeline that pushes cheap expressions (like struct field +//! access `user['status']`) closer to data sources, enabling early data reduction +//! and source-level optimizations (e.g., Parquet column pruning). See +//! [`ExtractLeafExpressions`] (pass 1) and [`PushDownLeafProjections`] (pass 2). + +use indexmap::{IndexMap, IndexSet}; +use std::collections::HashMap; +use std::sync::Arc; + +use datafusion_common::alias::AliasGenerator; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion_common::{Column, DFSchema, Result, qualified_name}; +use datafusion_expr::logical_plan::LogicalPlan; +use datafusion_expr::{Expr, ExpressionPlacement, Projection}; + +use crate::optimizer::ApplyOrder; +use crate::push_down_filter::replace_cols_by_name; +use crate::utils::has_all_column_refs; +use crate::{OptimizerConfig, OptimizerRule}; + +/// Prefix for aliases generated by the extraction optimizer passes. +/// +/// This prefix is **reserved for internal optimizer use**. User-defined aliases +/// starting with this prefix may be misidentified as optimizer-generated +/// extraction aliases, leading to unexpected behavior. Do not use this prefix +/// in user queries. +const EXTRACTED_EXPR_PREFIX: &str = "__datafusion_extracted"; + +/// Returns `true` if any sub-expression in `exprs` has +/// [`ExpressionPlacement::MoveTowardsLeafNodes`] placement. +/// +/// This is a lightweight pre-check that short-circuits as soon as one +/// extractable expression is found, avoiding the expensive allocations +/// (column HashSets, extractors, expression rewrites) that the full +/// extraction pipeline requires. +fn has_extractable_expr(exprs: &[Expr]) -> bool { + exprs.iter().any(|expr| { + expr.exists(|e| Ok(e.placement() == ExpressionPlacement::MoveTowardsLeafNodes)) + .unwrap_or(false) + }) +} + +/// Extracts `MoveTowardsLeafNodes` sub-expressions from non-projection nodes +/// into **extraction projections** (pass 1 of 2). +/// +/// This handles Filter, Sort, Limit, Aggregate, and Join nodes. For Projection +/// nodes, extraction and pushdown are handled by [`PushDownLeafProjections`]. +/// +/// # Key Concepts +/// +/// **Extraction projection**: a projection inserted *below* a node that +/// pre-computes a cheap expression and exposes it under an alias +/// (`__datafusion_extracted_N`). The parent node then references the alias +/// instead of the original expression. +/// +/// **Recovery projection**: a projection inserted *above* a node to restore +/// the original output schema when extraction changes it. +/// Schema-preserving nodes (Filter, Sort, Limit) gain extra columns from +/// the extraction projection that bubble up; the recovery projection selects +/// only the original columns to hide the extras. +/// +/// # Example +/// +/// Given a filter with a struct field access: +/// +/// ```text +/// Filter: user['status'] = 'active' +/// TableScan: t [id, user] +/// ``` +/// +/// This rule: +/// 1. Inserts an **extraction projection** below the filter: +/// 2. Adds a **recovery projection** above to hide the extra column: +/// +/// ```text +/// Projection: id, user <-- recovery projection +/// Filter: __datafusion_extracted_1 = 'active' +/// Projection: user['status'] AS __datafusion_extracted_1, id, user <-- extraction projection +/// TableScan: t [id, user] +/// ``` +/// +/// **Important:** The `PushDownFilter` rule is aware of projections created by this rule +/// and will not push filters through them. It uses `ExpressionPlacement` to detect +/// `MoveTowardsLeafNodes` expressions and skip filter pushdown past them. +#[derive(Default, Debug)] +pub struct ExtractLeafExpressions {} + +impl ExtractLeafExpressions { + /// Create a new [`ExtractLeafExpressions`] + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for ExtractLeafExpressions { + fn name(&self) -> &str { + "extract_leaf_expressions" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } + + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + if !config.options().optimizer.enable_leaf_expression_pushdown { + return Ok(Transformed::no(plan)); + } + let alias_generator = config.alias_generator(); + extract_from_plan(plan, alias_generator) + } +} + +/// Extracts `MoveTowardsLeafNodes` sub-expressions from a plan node. +/// +/// Works for any number of inputs (0, 1, 2, …N). For multi-input nodes +/// like Join, each extracted sub-expression is routed to the correct input +/// by checking which input's schema contains all of the expression's column +/// references. +fn extract_from_plan( + plan: LogicalPlan, + alias_generator: &Arc, +) -> Result> { + // Only extract from plan types whose output schema is predictable after + // expression rewriting. Nodes like Window derive column names from + // their expressions, so rewriting `get_field` inside a window function + // changes the output schema and breaks the recovery projection. + if !matches!( + &plan, + LogicalPlan::Aggregate(_) + | LogicalPlan::Filter(_) + | LogicalPlan::Sort(_) + | LogicalPlan::Limit(_) + | LogicalPlan::Join(_) + ) { + return Ok(Transformed::no(plan)); + } + + let inputs = plan.inputs(); + if inputs.is_empty() { + return Ok(Transformed::no(plan)); + } + + // Fast pre-check: skip all allocations if no extractable expressions exist + if !has_extractable_expr(&plan.expressions()) { + return Ok(Transformed::no(plan)); + } + + // Save original output schema before any transformation + let original_schema = Arc::clone(plan.schema()); + + // Build per-input schemas from borrowed inputs (before plan is consumed + // by map_expressions). We only need schemas and column sets for routing; + // the actual inputs are cloned later only if extraction succeeds. + let input_schemas: Vec> = + inputs.iter().map(|i| Arc::clone(i.schema())).collect(); + + // Build per-input extractors + let mut extractors: Vec = input_schemas + .iter() + .map(|schema| LeafExpressionExtractor::new(schema.as_ref(), alias_generator)) + .collect(); + + // Build per-input column sets for routing expressions to the correct input + let input_column_sets: Vec> = input_schemas + .iter() + .map(|schema| schema_columns(schema.as_ref())) + .collect(); + + // Transform expressions via map_expressions with routing + let transformed = plan.map_expressions(|expr| { + routing_extract(expr, &mut extractors, &input_column_sets) + })?; + + // If no expressions were rewritten, nothing was extracted + if !transformed.transformed { + return Ok(transformed); + } + + // Clone inputs now that we know extraction succeeded. Wrap in Arc + // upfront since build_extraction_projection expects &Arc. + let owned_inputs: Vec> = transformed + .data + .inputs() + .into_iter() + .map(|i| Arc::new(i.clone())) + .collect(); + + // Build per-input extraction projections (None means no extractions for that input) + let new_inputs: Vec = owned_inputs + .into_iter() + .zip(extractors.iter()) + .map(|(input_arc, extractor)| { + match extractor.build_extraction_projection(&input_arc)? { + Some(plan) => Ok(plan), + // No extractions for this input — recover the LogicalPlan + // without cloning (refcount is 1 since build returned None). + None => { + Ok(Arc::try_unwrap(input_arc).unwrap_or_else(|arc| (*arc).clone())) + } + } + }) + .collect::>>()?; + + // Rebuild the plan keeping its rewritten expressions but replacing + // inputs with the new extraction projections. + let new_plan = transformed + .data + .with_new_exprs(transformed.data.expressions(), new_inputs)?; + + // Add recovery projection if the output schema changed + let recovered = build_recovery_projection(original_schema.as_ref(), new_plan)?; + + Ok(Transformed::yes(recovered)) +} + +/// Given an expression, returns the index of the input whose columns fully +/// cover the expression's column references. +/// Returns `None` if the expression references columns from multiple inputs +/// or if multiple inputs match (ambiguous, e.g. unqualified columns present +/// in both sides of a join). +fn find_owning_input( + expr: &Expr, + input_column_sets: &[std::collections::HashSet], +) -> Option { + let mut found = None; + for (idx, cols) in input_column_sets.iter().enumerate() { + if has_all_column_refs(expr, cols) { + if found.is_some() { + // Ambiguous — multiple inputs match + return None; + } + found = Some(idx); + } + } + found +} + +/// Walks an expression tree top-down, extracting `MoveTowardsLeafNodes` +/// sub-expressions and routing each to the correct per-input extractor. +fn routing_extract( + expr: Expr, + extractors: &mut [LeafExpressionExtractor], + input_column_sets: &[std::collections::HashSet], +) -> Result> { + expr.transform_down(|e| { + // Skip expressions already aliased with extracted expression pattern + if let Expr::Alias(alias) = &e + && alias.name.starts_with(EXTRACTED_EXPR_PREFIX) + { + return Ok(Transformed { + data: e, + transformed: false, + tnr: TreeNodeRecursion::Jump, + }); + } + + // Don't extract Alias nodes directly — preserve the alias and let + // transform_down recurse into the inner expression + if matches!(&e, Expr::Alias(_)) { + return Ok(Transformed::no(e)); + } + + match e.placement() { + ExpressionPlacement::MoveTowardsLeafNodes => { + if let Some(idx) = find_owning_input(&e, input_column_sets) { + let col_ref = extractors[idx].add_extracted(e)?; + Ok(Transformed::yes(col_ref)) + } else { + // References columns from multiple inputs — cannot extract + Ok(Transformed::no(e)) + } + } + ExpressionPlacement::Column => { + // Track columns that the parent node references so the + // extraction projection includes them as pass-through. + // Without this, the extraction projection would only + // contain __datafusion_extracted_N aliases, and the parent couldn't + // resolve its other column references. + if let Expr::Column(col) = &e + && let Some(idx) = find_owning_input(&e, input_column_sets) + { + extractors[idx].columns_needed.insert(col.clone()); + } + Ok(Transformed::no(e)) + } + _ => Ok(Transformed::no(e)), + } + }) +} + +/// Returns all columns in the schema (both qualified and unqualified forms) +fn schema_columns(schema: &DFSchema) -> std::collections::HashSet { + schema + .iter() + .flat_map(|(qualifier, field)| { + [ + Column::new(qualifier.cloned(), field.name()), + Column::new_unqualified(field.name()), + ] + }) + .collect() +} + +/// Rewrites extraction pairs and column references from one qualifier +/// space to another. +/// +/// Builds a replacement map by zipping `from_schema` (whose qualifiers +/// currently appear in `pairs` / `columns`) with `to_schema` (the +/// qualifiers we want), then applies `replace_cols_by_name`. +/// +/// Used for SubqueryAlias (alias-space -> input-space) and Union +/// (union output-space -> per-branch input-space). +fn remap_pairs_and_columns( + pairs: &[(Expr, String)], + columns: &IndexSet, + from_schema: &DFSchema, + to_schema: &DFSchema, +) -> Result { + let mut replace_map = HashMap::new(); + for ((from_q, from_f), (to_q, to_f)) in from_schema.iter().zip(to_schema.iter()) { + replace_map.insert( + qualified_name(from_q, from_f.name()), + Expr::Column(Column::new(to_q.cloned(), to_f.name())), + ); + } + let remapped_pairs: Vec<(Expr, String)> = pairs + .iter() + .map(|(expr, alias)| { + Ok(( + replace_cols_by_name(expr.clone(), &replace_map)?, + alias.clone(), + )) + }) + .collect::>()?; + let remapped_columns: IndexSet = columns + .iter() + .filter_map(|col| { + let rewritten = + replace_cols_by_name(Expr::Column(col.clone()), &replace_map).ok()?; + if let Expr::Column(c) = rewritten { + Some(c) + } else { + Some(col.clone()) + } + }) + .collect(); + Ok(ExtractionTarget { + pairs: remapped_pairs, + columns: remapped_columns, + }) +} + +// ============================================================================= +// Helper Types & Functions for Extraction Targeting +// ============================================================================= + +/// A bundle of extraction pairs (expression + alias) and standalone columns +/// that need to be pushed through a plan node. +struct ExtractionTarget { + /// Extracted expressions paired with their generated aliases. + pairs: Vec<(Expr, String)>, + /// Standalone column references needed by the parent node. + columns: IndexSet, +} + +/// Build a replacement map from a projection: output_column_name -> underlying_expr. +/// +/// This is used to resolve column references through a renaming projection. +/// For example, if a projection has `user AS x`, this maps `x` -> `col("user")`. +fn build_projection_replace_map(projection: &Projection) -> HashMap { + projection + .schema + .iter() + .zip(projection.expr.iter()) + .map(|((qualifier, field), expr)| { + let key = Column::from((qualifier, field)).flat_name(); + (key, expr.clone().unalias()) + }) + .collect() +} + +/// Build a recovery projection to restore the original output schema. +/// +/// After extraction, a node's output schema may differ from the original: +/// +/// - **Schema-preserving nodes** (Filter/Sort/Limit): the extraction projection +/// below adds extra `__datafusion_extracted_N` columns that bubble up through +/// the node. Recovery selects only the original columns to hide the extras. +/// ```text +/// Original schema: [id, user] +/// After extraction: [__datafusion_extracted_1, id, user] ← extra column leaked through +/// Recovery: SELECT id, user FROM ... ← hides __datafusion_extracted_1 +/// ``` +/// +/// - **Schema-defining nodes** (Aggregate): same number of columns but names +/// may differ because extracted aliases replaced the original expressions. +/// Recovery maps positionally, aliasing where names changed. +/// ```text +/// Original: [SUM(user['balance'])] +/// After: [SUM(__datafusion_extracted_1)] ← name changed +/// Recovery: SUM(__datafusion_extracted_1) AS "SUM(user['balance'])" +/// ``` +/// +/// - **Schemas identical** → no recovery projection needed. +fn build_recovery_projection( + original_schema: &DFSchema, + input: LogicalPlan, +) -> Result { + let new_schema = input.schema(); + let orig_len = original_schema.fields().len(); + let new_len = new_schema.fields().len(); + + if orig_len == new_len { + // Same number of fields — check if schemas are identical + let schemas_match = original_schema.iter().zip(new_schema.iter()).all( + |((orig_q, orig_f), (new_q, new_f))| { + orig_f.name() == new_f.name() && orig_q == new_q + }, + ); + if schemas_match { + return Ok(input); + } + + // Schema-defining nodes (Aggregate, Join): names may differ at some + // positions because extracted aliases replaced the original expressions. + // Map positionally, aliasing where the name changed. + // + // Invariant: `with_new_exprs` on all supported node types (Aggregate, + // Filter, Sort, Limit, Join) preserves column order, so positional + // mapping is safe here. + debug_assert!( + orig_len == new_len, + "build_recovery_projection: positional mapping requires same field count, \ + got original={orig_len} vs new={new_len}" + ); + let mut proj_exprs = Vec::with_capacity(orig_len); + for (i, (orig_qualifier, orig_field)) in original_schema.iter().enumerate() { + let (new_qualifier, new_field) = new_schema.qualified_field(i); + if orig_field.name() == new_field.name() && orig_qualifier == new_qualifier { + proj_exprs.push(Expr::from((orig_qualifier, orig_field))); + } else { + let new_col = Expr::Column(Column::from((new_qualifier, new_field))); + proj_exprs.push( + new_col.alias_qualified(orig_qualifier.cloned(), orig_field.name()), + ); + } + } + let projection = Projection::try_new(proj_exprs, Arc::new(input))?; + Ok(LogicalPlan::Projection(projection)) + } else { + // Schema-preserving nodes: new schema has extra extraction columns. + // Original columns still exist by name; select them to hide extras. + let col_exprs: Vec = original_schema.iter().map(Expr::from).collect(); + let projection = Projection::try_new(col_exprs, Arc::new(input))?; + Ok(LogicalPlan::Projection(projection)) + } +} + +/// Collects `MoveTowardsLeafNodes` sub-expressions found during expression +/// tree traversal and can build an extraction projection from them. +/// +/// # Example +/// +/// Given `Filter: user['status'] = 'active' AND user['name'] IS NOT NULL`: +/// - `add_extracted(user['status'])` → stores it, returns `col("__datafusion_extracted_1")` +/// - `add_extracted(user['name'])` → stores it, returns `col("__datafusion_extracted_2")` +/// - `build_extraction_projection()` produces: +/// `Projection: user['status'] AS __datafusion_extracted_1, user['name'] AS __datafusion_extracted_2, ` +struct LeafExpressionExtractor<'a> { + /// Extracted expressions: maps expression -> alias + extracted: IndexMap, + /// Columns referenced by extracted expressions or the parent node, + /// included as pass-through in the extraction projection. + columns_needed: IndexSet, + /// Input schema + input_schema: &'a DFSchema, + /// Alias generator + alias_generator: &'a Arc, +} + +impl<'a> LeafExpressionExtractor<'a> { + fn new(input_schema: &'a DFSchema, alias_generator: &'a Arc) -> Self { + Self { + extracted: IndexMap::new(), + columns_needed: IndexSet::new(), + input_schema, + alias_generator, + } + } + + /// Adds an expression to extracted set, returns column reference. + fn add_extracted(&mut self, expr: Expr) -> Result { + // Deduplication: reuse existing alias if same expression + if let Some(alias) = self.extracted.get(&expr) { + return Ok(Expr::Column(Column::new_unqualified(alias))); + } + + // Track columns referenced by this expression + for col in expr.column_refs() { + self.columns_needed.insert(col.clone()); + } + + // Generate unique alias + let alias = self.alias_generator.next(EXTRACTED_EXPR_PREFIX); + self.extracted.insert(expr, alias.clone()); + + Ok(Expr::Column(Column::new_unqualified(&alias))) + } + + /// Builds an extraction projection above the given input, or merges into + /// it if the input is already a projection. Delegates to + /// [`build_extraction_projection_impl`]. + /// + /// Returns `None` if there are no extractions. + fn build_extraction_projection( + &self, + input: &Arc, + ) -> Result> { + if self.extracted.is_empty() { + return Ok(None); + } + let pairs: Vec<(Expr, String)> = self + .extracted + .iter() + .map(|(e, a)| (e.clone(), a.clone())) + .collect(); + let proj = build_extraction_projection_impl( + &pairs, + &self.columns_needed, + input, + self.input_schema, + )?; + Ok(Some(LogicalPlan::Projection(proj))) + } +} + +/// Build an extraction projection above the target node (shared by both passes). +/// +/// If the target is an existing projection, merges into it. This requires +/// resolving column references through the projection's rename mapping: +/// if the projection has `user AS u`, and an extracted expression references +/// `u['name']`, we must rewrite it to `user['name']` since the merged +/// projection reads from the same input as the original. +/// +/// Deduplicates by resolved expression equality and adds pass-through +/// columns as needed. Otherwise builds a fresh projection with extracted +/// expressions + ALL input schema columns. +fn build_extraction_projection_impl( + extracted_exprs: &[(Expr, String)], + columns_needed: &IndexSet, + target: &Arc, + target_schema: &DFSchema, +) -> Result { + if let LogicalPlan::Projection(existing) = target.as_ref() { + // Merge into existing projection + let mut proj_exprs = existing.expr.clone(); + + // Build a map of existing expressions (by Expr equality) to their aliases + let existing_extractions: IndexMap = existing + .expr + .iter() + .filter_map(|e| { + if let Expr::Alias(alias) = e + && alias.name.starts_with(EXTRACTED_EXPR_PREFIX) + { + return Some((*alias.expr.clone(), alias.name.clone())); + } + None + }) + .collect(); + + // Resolve column references through the projection's rename mapping + let replace_map = build_projection_replace_map(existing); + + // Add new extracted expressions, resolving column refs through the projection + for (expr, alias) in extracted_exprs { + let resolved = replace_cols_by_name(expr.clone().alias(alias), &replace_map)?; + let resolved_inner = if let Expr::Alias(a) = &resolved { + a.expr.as_ref() + } else { + &resolved + }; + if let Some(existing_alias) = existing_extractions.get(resolved_inner) { + // Same expression already extracted under a different alias — + // add the expression with the new alias so both names are + // available in the output. We can't reference the existing alias + // as a column within the same projection, so we duplicate the + // computation. + if existing_alias != alias { + proj_exprs.push(resolved); + } + } else { + proj_exprs.push(resolved); + } + } + + // Add any new pass-through columns that aren't already in the projection. + // We check against existing.input.schema() (the projection's source) rather + // than target_schema (the projection's output) because columns produced + // by alias expressions (e.g., CSE's __common_expr_N) exist in the output but + // not the input, and cannot be added as pass-through Column references. + let existing_cols: IndexSet = existing + .expr + .iter() + .filter_map(|e| { + if let Expr::Column(c) = e { + Some(c.clone()) + } else { + None + } + }) + .collect(); + + let input_schema = existing.input.schema(); + for col in columns_needed { + let col_expr = Expr::Column(col.clone()); + let resolved = replace_cols_by_name(col_expr, &replace_map)?; + if let Expr::Column(resolved_col) = &resolved + && !existing_cols.contains(resolved_col) + && input_schema.has_column(resolved_col) + { + proj_exprs.push(Expr::Column(resolved_col.clone())); + } + // If resolved to non-column expr, it's already computed by existing projection + } + + Projection::try_new(proj_exprs, Arc::clone(&existing.input)) + } else { + // Build new projection with extracted expressions + all input columns + let mut proj_exprs = Vec::new(); + for (expr, alias) in extracted_exprs { + proj_exprs.push(expr.clone().alias(alias)); + } + for (qualifier, field) in target_schema.iter() { + proj_exprs.push(Expr::from((qualifier, field))); + } + Projection::try_new(proj_exprs, Arc::clone(target)) + } +} + +// ============================================================================= +// Pass 2: PushDownLeafProjections +// ============================================================================= + +/// Pushes extraction projections down through schema-preserving nodes towards +/// leaf nodes (pass 2 of 2, after [`ExtractLeafExpressions`]). +/// +/// Handles two types of projections: +/// - **Pure extraction projections** (all `__datafusion_extracted` aliases + columns): +/// pushes through Filter/Sort/Limit, merges into existing projections, or routes +/// into multi-input node inputs (Join, SubqueryAlias, etc.) +/// - **Mixed projections** (user projections containing `MoveTowardsLeafNodes` +/// sub-expressions): splits into a recovery projection + extraction projection, +/// then pushes the extraction projection down. +/// +/// # Example: Pushing through a Filter +/// +/// After pass 1, the extraction projection sits directly below the filter: +/// ```text +/// Projection: id, user <-- recovery +/// Filter: __datafusion_extracted_1 = 'active' +/// Projection: user['status'] AS __datafusion_extracted_1, id, user <-- extraction +/// TableScan: t [id, user] +/// ``` +/// +/// Pass 2 pushes the extraction projection through the recovery and filter, +/// and a subsequent `OptimizeProjections` pass removes the (now-redundant) +/// recovery projection: +/// ```text +/// Filter: __datafusion_extracted_1 = 'active' +/// Projection: user['status'] AS __datafusion_extracted_1, id, user <-- extraction (pushed down) +/// TableScan: t [id, user] +/// ``` +#[derive(Default, Debug)] +pub struct PushDownLeafProjections {} + +impl PushDownLeafProjections { + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for PushDownLeafProjections { + fn name(&self) -> &str { + "push_down_leaf_projections" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } + + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + if !config.options().optimizer.enable_leaf_expression_pushdown { + return Ok(Transformed::no(plan)); + } + let alias_generator = config.alias_generator(); + match try_push_input(&plan, alias_generator)? { + Some(new_plan) => Ok(Transformed::yes(new_plan)), + None => Ok(Transformed::no(plan)), + } + } +} + +/// Attempts to push a projection's extractable expressions further down. +/// +/// Returns `Some(new_subtree)` if the projection was pushed down or merged, +/// `None` if there is nothing to push or the projection sits above a barrier. +fn try_push_input( + input: &LogicalPlan, + alias_generator: &Arc, +) -> Result> { + let LogicalPlan::Projection(proj) = input else { + return Ok(None); + }; + split_and_push_projection(proj, alias_generator) +} + +/// Splits a projection into extractable pieces, pushes them towards leaf +/// nodes, and adds a recovery projection if needed. +/// +/// Handles both: +/// - **Pure extraction projections** (all `__datafusion_extracted` aliases + columns) +/// - **Mixed projections** (containing `MoveTowardsLeafNodes` sub-expressions) +/// +/// Returns `Some(new_subtree)` if extractions were pushed down, +/// `None` if there is nothing to extract or push. +/// +/// # Example: Mixed Projection +/// +/// ```text +/// Input plan: +/// Projection: user['name'] IS NOT NULL AS has_name, id +/// Filter: ... +/// TableScan +/// +/// Phase 1 (Split): +/// extraction_pairs: [(user['name'], "__datafusion_extracted_1")] +/// recovery_exprs: [__datafusion_extracted_1 IS NOT NULL AS has_name, id] +/// +/// Phase 2 (Push): +/// Push extraction projection through Filter toward TableScan +/// +/// Phase 3 (Recovery): +/// Projection: __datafusion_extracted_1 IS NOT NULL AS has_name, id <-- recovery +/// Filter: ... +/// Projection: user['name'] AS __datafusion_extracted_1, id <-- extraction (pushed) +/// TableScan +/// ``` +fn split_and_push_projection( + proj: &Projection, + alias_generator: &Arc, +) -> Result> { + // Fast pre-check: skip if there are no pre-existing extracted aliases + // and no new extractable expressions. + let has_existing_extracted = proj.expr.iter().any(|e| { + matches!(e, Expr::Alias(alias) if alias.name.starts_with(EXTRACTED_EXPR_PREFIX)) + }); + if !has_existing_extracted && !has_extractable_expr(&proj.expr) { + return Ok(None); + } + + let input = &proj.input; + let input_schema = input.schema(); + + // ── Phase 1: Split ────────────────────────────────────────────────── + // For each projection expression, collect extraction pairs and build + // recovery expressions. + // + // Pre-existing `__datafusion_extracted` aliases are inserted into the + // extractor's `IndexMap` with the **full** `Expr::Alias(…)` as the key, + // so the alias name participates in equality. This prevents collisions + // when CSE rewrites produce the same inner expression under different + // alias names (e.g. `__common_expr_4 AS __datafusion_extracted_1` and + // `__common_expr_4 AS __datafusion_extracted_3`). New extractions from + // `routing_extract` use bare (non-Alias) keys and get normal dedup. + // + // When building the final `extraction_pairs`, the Alias wrapper is + // stripped so consumers see the usual `(inner_expr, alias_name)` tuples. + + let mut extractors = vec![LeafExpressionExtractor::new( + input_schema.as_ref(), + alias_generator, + )]; + let input_column_sets = vec![schema_columns(input_schema.as_ref())]; + + let original_schema = proj.schema.as_ref(); + let mut recovery_exprs: Vec = Vec::with_capacity(proj.expr.len()); + let mut needs_recovery = false; + let mut has_new_extractions = false; + let mut proj_exprs_captured: usize = 0; + // Track standalone column expressions (Case B) to detect column refs + // from extracted aliases (Case A) that aren't also standalone expressions. + let mut standalone_columns: IndexSet = IndexSet::new(); + + for (expr, (qualifier, field)) in proj.expr.iter().zip(original_schema.iter()) { + if let Expr::Alias(alias) = expr + && alias.name.starts_with(EXTRACTED_EXPR_PREFIX) + { + // Insert the full Alias expression as the key so that + // distinct alias names don't collide in the IndexMap. + let alias_name = alias.name.clone(); + + for col_ref in alias.expr.column_refs() { + extractors[0].columns_needed.insert(col_ref.clone()); + } + + extractors[0] + .extracted + .insert(expr.clone(), alias_name.clone()); + recovery_exprs.push(Expr::Column(Column::new_unqualified(&alias_name))); + proj_exprs_captured += 1; + } else if let Expr::Column(col) = expr { + // Plain column pass-through — track it in the extractor + extractors[0].columns_needed.insert(col.clone()); + standalone_columns.insert(col.clone()); + recovery_exprs.push(expr.clone()); + proj_exprs_captured += 1; + } else { + // Everything else: run through routing_extract + let transformed = + routing_extract(expr.clone(), &mut extractors, &input_column_sets)?; + if transformed.transformed { + has_new_extractions = true; + } + let transformed_expr = transformed.data; + + // Build recovery expression, aliasing back to original name if needed + let original_name = field.name(); + let needs_alias = if let Expr::Column(col) = &transformed_expr { + col.name.as_str() != original_name + } else { + let expr_name = transformed_expr.schema_name().to_string(); + original_name != &expr_name + }; + let recovery_expr = if needs_alias { + needs_recovery = true; + transformed_expr + .clone() + .alias_qualified(qualifier.cloned(), original_name) + } else { + transformed_expr.clone() + }; + + // If the expression was transformed (i.e., has extracted sub-parts), + // it differs from what the pushed projection outputs → needs recovery. + // Also, any non-column, non-__datafusion_extracted expression needs recovery + // because the pushed extraction projection won't output it directly. + if transformed.transformed || !matches!(expr, Expr::Column(_)) { + needs_recovery = true; + } + + recovery_exprs.push(recovery_expr); + } + } + + // Build extraction_pairs, stripping the Alias wrapper from pre-existing + // entries (they used the full Alias as the map key to avoid dedup). + let extractor = &extractors[0]; + let extraction_pairs: Vec<(Expr, String)> = extractor + .extracted + .iter() + .map(|(e, a)| match e { + Expr::Alias(alias) => (*alias.expr.clone(), a.clone()), + _ => (e.clone(), a.clone()), + }) + .collect(); + let columns_needed = &extractor.columns_needed; + + // If no extractions found, nothing to do + if extraction_pairs.is_empty() { + return Ok(None); + } + + // If columns_needed has entries that aren't standalone projection columns + // (i.e., they came from column refs inside extracted aliases), a merge + // into an inner projection will widen the schema with those extra columns, + // requiring a recovery projection to restore the original schema. + if columns_needed + .iter() + .any(|c| !standalone_columns.contains(c)) + { + needs_recovery = true; + } + + // ── Phase 2: Push down ────────────────────────────────────────────── + let proj_input = Arc::clone(&proj.input); + let pushed = push_extraction_pairs( + &extraction_pairs, + columns_needed, + proj, + &proj_input, + alias_generator, + proj_exprs_captured, + )?; + + // ── Phase 3: Recovery ─────────────────────────────────────────────── + // Determine the base plan: either the pushed result or an in-place extraction. + let base_plan = match pushed { + Some(plan) => plan, + None => { + if !has_new_extractions { + // Only pre-existing __datafusion_extracted aliases and columns, no new + // extractions from routing_extract. The original projection is + // already an extraction projection that couldn't be pushed + // further. Return None. + return Ok(None); + } + // Build extraction projection in-place (couldn't push down) + let input_arc = Arc::clone(input); + let extraction = build_extraction_projection_impl( + &extraction_pairs, + columns_needed, + &input_arc, + input_schema.as_ref(), + )?; + LogicalPlan::Projection(extraction) + } + }; + + // Wrap with recovery projection if the output schema changed + if needs_recovery { + let recovery = LogicalPlan::Projection(Projection::try_new( + recovery_exprs, + Arc::new(base_plan), + )?); + Ok(Some(recovery)) + } else { + Ok(Some(base_plan)) + } +} + +/// Returns true if the plan is a Projection where ALL expressions are either +/// `Alias(EXTRACTED_EXPR_PREFIX, ...)` or `Column`, with at least one extraction. +/// Such projections can safely be pushed further without re-extraction. +fn is_pure_extraction_projection(plan: &LogicalPlan) -> bool { + let LogicalPlan::Projection(proj) = plan else { + return false; + }; + let mut has_extraction = false; + for expr in &proj.expr { + match expr { + Expr::Alias(alias) if alias.name.starts_with(EXTRACTED_EXPR_PREFIX) => { + has_extraction = true; + } + Expr::Column(_) => {} + _ => return false, + } + } + has_extraction +} + +/// Pushes extraction pairs down through the projection's input node, +/// dispatching to the appropriate handler based on the input node type. +fn push_extraction_pairs( + pairs: &[(Expr, String)], + columns_needed: &IndexSet, + proj: &Projection, + proj_input: &Arc, + alias_generator: &Arc, + proj_exprs_captured: usize, +) -> Result> { + match proj_input.as_ref() { + // Merge into existing projection, then try to push the result further down. + // Only merge when every expression in the outer projection is fully + // captured as either an extraction pair (Case A: __datafusion_extracted + // alias) or a plain column (Case B). Uncaptured expressions (e.g. + // `col AS __common_expr_1` from CSE, or complex expressions with + // extracted sub-parts) would be lost during the merge. + LogicalPlan::Projection(_) if proj_exprs_captured == proj.expr.len() => { + let target_schema = Arc::clone(proj_input.schema()); + let merged = build_extraction_projection_impl( + pairs, + columns_needed, + proj_input, + target_schema.as_ref(), + )?; + let merged_plan = LogicalPlan::Projection(merged); + + // After merging, try to push the result further down, but ONLY + // if the merged result is still a pure extraction projection + // (all __datafusion_extracted aliases + columns). If the merge inherited + // bare MoveTowardsLeafNodes expressions from the inner projection, + // pushing would re-extract them into new aliases and fail when + // the (None, true) fallback can't find the original aliases. + // This handles: Extraction → Recovery(cols) → Filter → ... → TableScan + // by pushing through the recovery projection AND the filter in one pass. + if is_pure_extraction_projection(&merged_plan) + && let Some(pushed) = try_push_input(&merged_plan, alias_generator)? + { + return Ok(Some(pushed)); + } + Ok(Some(merged_plan)) + } + // Generic: handles Filter/Sort/Limit (via recursion), + // SubqueryAlias (with qualifier remap in try_push_into_inputs), + // Join, and anything else. + // Safely bails out for nodes that don't pass through extracted + // columns (Aggregate, Window) via the output schema check. + _ => try_push_into_inputs( + pairs, + columns_needed, + proj_input.as_ref(), + alias_generator, + ), + } +} + +/// Routes extraction pairs and columns to the appropriate inputs. +/// +/// - **Union**: broadcasts to every input via [`remap_pairs_and_columns`]. +/// - **Other nodes**: routes each expression to the one input that owns +/// all of its column references (via [`find_owning_input`]). +/// +/// Returns `None` if any expression can't be routed or no input has pairs. +fn route_to_inputs( + pairs: &[(Expr, String)], + columns: &IndexSet, + node: &LogicalPlan, + input_column_sets: &[std::collections::HashSet], + input_schemas: &[Arc], +) -> Result>> { + let num_inputs = input_schemas.len(); + let mut per_input: Vec = (0..num_inputs) + .map(|_| ExtractionTarget { + pairs: vec![], + columns: IndexSet::new(), + }) + .collect(); + + if matches!(node, LogicalPlan::Union(_)) { + // Union output schema and each input schema have the same fields by + // index but may differ in qualifiers (e.g. output `s` vs input + // `simple_struct.s`). Remap pairs/columns to each input's space. + let union_schema = node.schema(); + for (idx, input_schema) in input_schemas.iter().enumerate() { + per_input[idx] = + remap_pairs_and_columns(pairs, columns, union_schema, input_schema)?; + } + } else { + for (expr, alias) in pairs { + match find_owning_input(expr, input_column_sets) { + Some(idx) => per_input[idx].pairs.push((expr.clone(), alias.clone())), + None => return Ok(None), // Cross-input expression — bail out + } + } + for col in columns { + let col_expr = Expr::Column(col.clone()); + match find_owning_input(&col_expr, input_column_sets) { + Some(idx) => { + per_input[idx].columns.insert(col.clone()); + } + None => return Ok(None), // Ambiguous column — bail out + } + } + } + + // Check at least one input has extractions to push + if per_input.iter().all(|t| t.pairs.is_empty()) { + return Ok(None); + } + + Ok(Some(per_input)) +} + +/// Pushes extraction expressions into a node's inputs by routing each +/// expression to the input that owns all of its column references. +/// +/// Works for any number of inputs (1, 2, …N). For single-input nodes, +/// all expressions trivially route to that input. For multi-input nodes +/// (Join, etc.), each expression is routed to the side that owns its columns. +/// +/// Returns `Some(new_node)` if all expressions could be routed AND the +/// rebuilt node's output schema contains all extracted aliases. +/// Returns `None` if any expression references columns from multiple inputs +/// or the node doesn't pass through the extracted columns. +/// +/// # Example: Join with expressions from both sides +/// +/// ```text +/// Extraction projection above a Join: +/// Projection: left.user['name'] AS __datafusion_extracted_1, right.order['total'] AS __datafusion_extracted_2, ... +/// Join: left.id = right.user_id +/// TableScan: left [id, user] +/// TableScan: right [user_id, order] +/// +/// After routing each expression to its owning input: +/// Join: left.id = right.user_id +/// Projection: user['name'] AS __datafusion_extracted_1, id, user <-- left-side extraction +/// TableScan: left [id, user] +/// Projection: order['total'] AS __datafusion_extracted_2, user_id, order <-- right-side extraction +/// TableScan: right [user_id, order] +/// ``` +fn try_push_into_inputs( + pairs: &[(Expr, String)], + columns_needed: &IndexSet, + node: &LogicalPlan, + alias_generator: &Arc, +) -> Result> { + let inputs = node.inputs(); + if inputs.is_empty() { + return Ok(None); + } + + // SubqueryAlias remaps qualifiers between input and output. + // Rewrite pairs/columns from alias-space to input-space before routing. + let remapped = if let LogicalPlan::SubqueryAlias(sa) = node { + remap_pairs_and_columns(pairs, columns_needed, &sa.schema, sa.input.schema())? + } else { + ExtractionTarget { + pairs: pairs.to_vec(), + columns: columns_needed.clone(), + } + }; + let pairs = &remapped.pairs[..]; + let columns_needed = &remapped.columns; + + // Build per-input schemas and column sets for routing + let input_schemas: Vec> = + inputs.iter().map(|i| Arc::clone(i.schema())).collect(); + let input_column_sets: Vec> = + input_schemas.iter().map(|s| schema_columns(s)).collect(); + + // Route pairs and columns to the appropriate inputs + let per_input = match route_to_inputs( + pairs, + columns_needed, + node, + &input_column_sets, + &input_schemas, + )? { + Some(routed) => routed, + None => return Ok(None), + }; + + let num_inputs = inputs.len(); + + // Build per-input extraction projections and push them as far as possible + // immediately. This is critical because map_children preserves cached schemas, + // so if the TopDown pass later pushes a child further (changing its output + // schema), the parent node's schema becomes stale. + let mut new_inputs: Vec = Vec::with_capacity(num_inputs); + for (idx, input) in inputs.into_iter().enumerate() { + if per_input[idx].pairs.is_empty() { + new_inputs.push(input.clone()); + } else { + let input_arc = Arc::new(input.clone()); + let target_schema = Arc::clone(input.schema()); + let proj = build_extraction_projection_impl( + &per_input[idx].pairs, + &per_input[idx].columns, + &input_arc, + target_schema.as_ref(), + )?; + // Verify all requested aliases appear in the projection's output. + // A merge may deduplicate if the same expression already exists + // under a different alias, leaving the requested alias missing. + let proj_schema = proj.schema.as_ref(); + for (_expr, alias) in &per_input[idx].pairs { + if !proj_schema.fields().iter().any(|f| f.name() == alias) { + return Ok(None); + } + } + let proj_plan = LogicalPlan::Projection(proj); + // Try to push the extraction projection further down within + // this input (e.g., through Filter → existing extraction projection). + // This ensures the input's output schema is stable and won't change + // when the TopDown pass later visits children. + match try_push_input(&proj_plan, alias_generator)? { + Some(pushed) => new_inputs.push(pushed), + None => new_inputs.push(proj_plan), + } + } + } + + // Rebuild the node with new inputs + let new_node = node.with_new_exprs(node.expressions(), new_inputs)?; + + // Safety check: verify all extracted aliases appear in the rebuilt + // node's output schema. Nodes like Aggregate define their own output + // and won't pass through extracted columns — bail out for those. + let output_schema = new_node.schema(); + for (_expr, alias) in pairs { + if !output_schema.fields().iter().any(|f| f.name() == alias) { + return Ok(None); + } + } + + Ok(Some(new_node)) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::*; + use crate::optimize_projections::OptimizeProjections; + use crate::test::udfs::PlacementTestUDF; + use crate::test::*; + use crate::{Optimizer, OptimizerContext}; + use datafusion_common::Result; + use datafusion_expr::expr::ScalarFunction; + use datafusion_expr::{Expr, ExpressionPlacement}; + use datafusion_expr::{ + ScalarUDF, col, lit, logical_plan::builder::LogicalPlanBuilder, + }; + + fn leaf_udf(expr: Expr, name: &str) -> Expr { + Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::new(ScalarUDF::new_from_impl( + PlacementTestUDF::new() + .with_placement(ExpressionPlacement::MoveTowardsLeafNodes), + )), + vec![expr, lit(name)], + )) + } + + // ========================================================================= + // Combined optimization stage formatter + // ========================================================================= + + /// Runs all 4 optimization stages and returns a single formatted string. + /// Stages that produce the same plan as the previous stage show + /// "(same as )" to reduce noise. + /// + /// Stages: + /// 1. **Original** - OptimizeProjections only (baseline) + /// 2. **After Extraction** - + ExtractLeafExpressions + /// 3. **After Pushdown** - + PushDownLeafProjections + /// 4. **Optimized** - + final OptimizeProjections + fn format_optimization_stages(plan: &LogicalPlan) -> Result { + let run = |rules: Vec>| -> Result { + let ctx = OptimizerContext::new().with_max_passes(1); + let optimizer = Optimizer::with_rules(rules); + let optimized = optimizer.optimize(plan.clone(), &ctx, |_, _| {})?; + Ok(format!("{optimized}")) + }; + + let original = run(vec![Arc::new(OptimizeProjections::new())])?; + + let after_extract = run(vec![ + Arc::new(OptimizeProjections::new()), + Arc::new(ExtractLeafExpressions::new()), + ])?; + + let after_pushdown = run(vec![ + Arc::new(OptimizeProjections::new()), + Arc::new(ExtractLeafExpressions::new()), + Arc::new(PushDownLeafProjections::new()), + ])?; + + let optimized = run(vec![ + Arc::new(OptimizeProjections::new()), + Arc::new(ExtractLeafExpressions::new()), + Arc::new(PushDownLeafProjections::new()), + Arc::new(OptimizeProjections::new()), + ])?; + + let mut out = format!("## Original Plan\n{original}"); + + out.push_str("\n\n## After Extraction\n"); + if after_extract == original { + out.push_str("(same as original)"); + } else { + out.push_str(&after_extract); + } + + out.push_str("\n\n## After Pushdown\n"); + if after_pushdown == after_extract { + out.push_str("(same as after extraction)"); + } else { + out.push_str(&after_pushdown); + } + + out.push_str("\n\n## Optimized\n"); + if optimized == after_pushdown { + out.push_str("(same as after pushdown)"); + } else { + out.push_str(&optimized); + } + + Ok(out) + } + + /// Assert all optimization stages for a plan in a single insta snapshot. + macro_rules! assert_stages { + ($plan:expr, @ $expected:literal $(,)?) => {{ + let result = format_optimization_stages(&$plan)?; + insta::assert_snapshot!(result, @ $expected); + Ok::<(), datafusion_common::DataFusionError>(()) + }}; + } + + #[test] + fn test_extract_from_filter() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan.clone()) + .filter(leaf_udf(col("user"), "status").eq(lit("active")))? + .select(vec![ + table_scan + .schema() + .index_of_column_by_name(None, "id") + .unwrap(), + ])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: test.id + Filter: leaf_udf(test.user, Utf8("status")) = Utf8("active") + TableScan: test projection=[id, user] + + ## After Extraction + Projection: test.id + Projection: test.id, test.user + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + Projection: test.id + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id + TableScan: test projection=[id, user] + "#) + } + + #[test] + fn test_no_extraction_for_column() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(col("a").eq(lit(1)))? + .build()?; + + assert_stages!(plan, @" + ## Original Plan + Filter: test.a = Int32(1) + TableScan: test projection=[a, b, c] + + ## After Extraction + (same as original) + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + ") + } + + #[test] + fn test_extract_from_projection() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![leaf_udf(col("user"), "name")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("name")) + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("name")) + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: leaf_udf(test.user, Utf8("name")) + TableScan: test projection=[user] + "#) + } + + #[test] + fn test_extract_from_projection_with_subexpression() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![ + leaf_udf(col("user"), "name") + .is_not_null() + .alias("has_name"), + ])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("name")) IS NOT NULL AS has_name + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 IS NOT NULL AS has_name + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: leaf_udf(test.user, Utf8("name")) IS NOT NULL AS has_name + TableScan: test projection=[user] + "#) + } + + #[test] + fn test_projection_no_extraction_for_column() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b")])? + .build()?; + + assert_stages!(plan, @" + ## Original Plan + TableScan: test projection=[a, b] + + ## After Extraction + (same as original) + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + ") + } + + #[test] + fn test_filter_with_deduplication() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let field_access = leaf_udf(col("user"), "name"); + // Filter with the same expression used twice + let plan = LogicalPlanBuilder::from(table_scan) + .filter( + field_access + .clone() + .is_not_null() + .and(field_access.is_null()), + )? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Filter: leaf_udf(test.user, Utf8("name")) IS NOT NULL AND leaf_udf(test.user, Utf8("name")) IS NULL + TableScan: test projection=[id, user] + + ## After Extraction + Projection: test.id, test.user + Filter: __datafusion_extracted_1 IS NOT NULL AND __datafusion_extracted_1 IS NULL + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + "#) + } + + #[test] + fn test_already_leaf_expression_in_filter() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(leaf_udf(col("user"), "name").eq(lit("test")))? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Filter: leaf_udf(test.user, Utf8("name")) = Utf8("test") + TableScan: test projection=[id, user] + + ## After Extraction + Projection: test.id, test.user + Filter: __datafusion_extracted_1 = Utf8("test") + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + "#) + } + + #[test] + fn test_extract_from_aggregate_group_by() -> Result<()> { + use datafusion_expr::test::function_stub::count; + + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![leaf_udf(col("user"), "status")], vec![count(lit(1))])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Aggregate: groupBy=[[leaf_udf(test.user, Utf8("status"))]], aggr=[[COUNT(Int32(1))]] + TableScan: test projection=[user] + + ## After Extraction + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("status")), COUNT(Int32(1)) + Aggregate: groupBy=[[__datafusion_extracted_1]], aggr=[[COUNT(Int32(1))]] + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("status")), COUNT(Int32(1)) + Aggregate: groupBy=[[__datafusion_extracted_1]], aggr=[[COUNT(Int32(1))]] + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1 + TableScan: test projection=[user] + "#) + } + + #[test] + fn test_extract_from_aggregate_args() -> Result<()> { + use datafusion_expr::test::function_stub::count; + + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("user")], + vec![count(leaf_udf(col("user"), "value"))], + )? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Aggregate: groupBy=[[test.user]], aggr=[[COUNT(leaf_udf(test.user, Utf8("value")))]] + TableScan: test projection=[user] + + ## After Extraction + Projection: test.user, COUNT(__datafusion_extracted_1) AS COUNT(leaf_udf(test.user,Utf8("value"))) + Aggregate: groupBy=[[test.user]], aggr=[[COUNT(__datafusion_extracted_1)]] + Projection: leaf_udf(test.user, Utf8("value")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + "#) + } + + #[test] + fn test_projection_with_filter_combined() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(leaf_udf(col("user"), "status").eq(lit("active")))? + .project(vec![leaf_udf(col("user"), "name")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("name")) + Filter: leaf_udf(test.user, Utf8("status")) = Utf8("active") + TableScan: test projection=[user] + + ## After Extraction + Projection: leaf_udf(test.user, Utf8("name")) + Projection: test.user + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## After Pushdown + Projection: __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("name")) + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.user, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_2 + TableScan: test projection=[user] + + ## Optimized + Projection: __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("name")) + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_2 + TableScan: test projection=[user] + "#) + } + + #[test] + fn test_projection_preserves_alias() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![leaf_udf(col("user"), "name").alias("username")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("name")) AS username + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 AS username + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: leaf_udf(test.user, Utf8("name")) AS username + TableScan: test projection=[user] + "#) + } + + /// Test: Projection with different field than Filter + /// SELECT id, s['label'] FROM t WHERE s['value'] > 150 + /// Both s['label'] and s['value'] should be in a single extraction projection. + #[test] + fn test_projection_different_field_from_filter() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(leaf_udf(col("user"), "value").gt(lit(150)))? + .project(vec![col("user"), leaf_udf(col("user"), "label")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: test.user, leaf_udf(test.user, Utf8("label")) + Filter: leaf_udf(test.user, Utf8("value")) > Int32(150) + TableScan: test projection=[user] + + ## After Extraction + Projection: test.user, leaf_udf(test.user, Utf8("label")) + Projection: test.user + Filter: __datafusion_extracted_1 > Int32(150) + Projection: leaf_udf(test.user, Utf8("value")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## After Pushdown + Projection: test.user, __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("label")) + Filter: __datafusion_extracted_1 > Int32(150) + Projection: leaf_udf(test.user, Utf8("value")) AS __datafusion_extracted_1, test.user, leaf_udf(test.user, Utf8("label")) AS __datafusion_extracted_2 + TableScan: test projection=[user] + + ## Optimized + (same as after pushdown) + "#) + } + + #[test] + fn test_projection_deduplication() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let field = leaf_udf(col("user"), "name"); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![field.clone(), field.clone().alias("name2")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("name")), leaf_udf(test.user, Utf8("name")) AS name2 + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("name")), __datafusion_extracted_1 AS name2 + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: leaf_udf(test.user, Utf8("name")), leaf_udf(test.user, Utf8("name")) AS name2 + TableScan: test projection=[user] + "#) + } + + // ========================================================================= + // Additional tests for code coverage + // ========================================================================= + + /// Extractions push through Sort nodes to reach the TableScan. + #[test] + fn test_extract_through_sort() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .sort(vec![col("user").sort(true, true)])? + .project(vec![leaf_udf(col("user"), "name")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("name")) + Sort: test.user ASC NULLS FIRST + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("name")) + Sort: test.user ASC NULLS FIRST + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + (same as after pushdown) + "#) + } + + /// Extractions push through Limit nodes to reach the TableScan. + #[test] + fn test_extract_through_limit() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .limit(0, Some(10))? + .project(vec![leaf_udf(col("user"), "name")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("name")) + Limit: skip=0, fetch=10 + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("name")) + Limit: skip=0, fetch=10 + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("name")) + Limit: skip=0, fetch=10 + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1 + TableScan: test projection=[user] + "#) + } + + /// Aliased aggregate functions like count(...).alias("cnt") are handled. + #[test] + fn test_extract_from_aliased_aggregate() -> Result<()> { + use datafusion_expr::test::function_stub::count; + + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("user")], + vec![count(leaf_udf(col("user"), "value")).alias("cnt")], + )? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Aggregate: groupBy=[[test.user]], aggr=[[COUNT(leaf_udf(test.user, Utf8("value"))) AS cnt]] + TableScan: test projection=[user] + + ## After Extraction + Aggregate: groupBy=[[test.user]], aggr=[[COUNT(__datafusion_extracted_1) AS cnt]] + Projection: leaf_udf(test.user, Utf8("value")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + "#) + } + + /// Aggregates with no MoveTowardsLeafNodes expressions return unchanged. + #[test] + fn test_aggregate_no_extraction() -> Result<()> { + use datafusion_expr::test::function_stub::count; + + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("a")], vec![count(col("b"))])? + .build()?; + + assert_stages!(plan, @" + ## Original Plan + Aggregate: groupBy=[[test.a]], aggr=[[COUNT(test.b)]] + TableScan: test projection=[a, b] + + ## After Extraction + (same as original) + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + ") + } + + /// Projections containing extracted expression aliases are skipped (already extracted). + #[test] + fn test_skip_extracted_projection() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![ + leaf_udf(col("user"), "name").alias("__datafusion_extracted_manual"), + col("user"), + ])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_manual, test.user + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + "#) + } + + /// Multiple extractions merge into a single extracted expression projection. + #[test] + fn test_merge_into_existing_extracted_projection() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(leaf_udf(col("user"), "status").eq(lit("active")))? + .filter(leaf_udf(col("user"), "name").is_not_null())? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Filter: leaf_udf(test.user, Utf8("name")) IS NOT NULL + Filter: leaf_udf(test.user, Utf8("status")) = Utf8("active") + TableScan: test projection=[id, user] + + ## After Extraction + Projection: test.id, test.user + Filter: __datafusion_extracted_1 IS NOT NULL + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.id, test.user + Projection: test.id, test.user + Filter: __datafusion_extracted_2 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_2, test.id, test.user + TableScan: test projection=[id, user] + + ## After Pushdown + Projection: test.id, test.user + Filter: __datafusion_extracted_1 IS NOT NULL + Filter: __datafusion_extracted_2 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_2, test.id, test.user, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1 + TableScan: test projection=[id, user] + + ## Optimized + Projection: test.id, test.user + Filter: __datafusion_extracted_1 IS NOT NULL + Projection: test.id, test.user, __datafusion_extracted_1 + Filter: __datafusion_extracted_2 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_2, test.id, test.user, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1 + TableScan: test projection=[id, user] + "#) + } + + /// Extractions push through passthrough projections (columns only). + #[test] + fn test_extract_through_passthrough_projection() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("user")])? + .project(vec![leaf_udf(col("user"), "name")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("name")) + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("name")) + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: leaf_udf(test.user, Utf8("name")) + TableScan: test projection=[user] + "#) + } + + /// Projections with aliased columns (nothing to extract) return unchanged. + #[test] + fn test_projection_early_return_no_extraction() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").alias("x"), col("b")])? + .build()?; + + assert_stages!(plan, @" + ## Original Plan + Projection: test.a AS x, test.b + TableScan: test projection=[a, b] + + ## After Extraction + (same as original) + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + ") + } + + /// Projections with arithmetic expressions but no MoveTowardsLeafNodes return unchanged. + #[test] + fn test_projection_with_arithmetic_no_extraction() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![(col("a") + col("b")).alias("sum")])? + .build()?; + + assert_stages!(plan, @" + ## Original Plan + Projection: test.a + test.b AS sum + TableScan: test projection=[a, b] + + ## After Extraction + (same as original) + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + ") + } + + /// Aggregate extractions merge into existing extracted projection created by Filter. + #[test] + fn test_aggregate_merge_into_extracted_projection() -> Result<()> { + use datafusion_expr::test::function_stub::count; + + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(leaf_udf(col("user"), "status").eq(lit("active")))? + .aggregate(vec![leaf_udf(col("user"), "name")], vec![count(lit(1))])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Aggregate: groupBy=[[leaf_udf(test.user, Utf8("name"))]], aggr=[[COUNT(Int32(1))]] + Filter: leaf_udf(test.user, Utf8("status")) = Utf8("active") + TableScan: test projection=[user] + + ## After Extraction + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("name")), COUNT(Int32(1)) + Aggregate: groupBy=[[__datafusion_extracted_1]], aggr=[[COUNT(Int32(1))]] + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user + Projection: test.user + Filter: __datafusion_extracted_2 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_2, test.user + TableScan: test projection=[user] + + ## After Pushdown + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("name")), COUNT(Int32(1)) + Aggregate: groupBy=[[__datafusion_extracted_1]], aggr=[[COUNT(Int32(1))]] + Filter: __datafusion_extracted_2 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_2, test.user, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1 + TableScan: test projection=[user] + + ## Optimized + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("name")), COUNT(Int32(1)) + Aggregate: groupBy=[[__datafusion_extracted_1]], aggr=[[COUNT(Int32(1))]] + Projection: __datafusion_extracted_1 + Filter: __datafusion_extracted_2 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_2, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1 + TableScan: test projection=[user] + "#) + } + + /// Projection containing a MoveTowardsLeafNodes sub-expression above an + /// Aggregate. Aggregate blocks pushdown, so the (None, true) recovery + /// fallback path fires: in-place extraction + recovery projection. + #[test] + fn test_projection_with_leaf_expr_above_aggregate() -> Result<()> { + use datafusion_expr::test::function_stub::count; + + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("user")], vec![count(lit(1))])? + .project(vec![ + leaf_udf(col("user"), "name") + .is_not_null() + .alias("has_name"), + col("COUNT(Int32(1))"), + ])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("name")) IS NOT NULL AS has_name, COUNT(Int32(1)) + Aggregate: groupBy=[[test.user]], aggr=[[COUNT(Int32(1))]] + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 IS NOT NULL AS has_name, COUNT(Int32(1)) + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user, COUNT(Int32(1)) + Aggregate: groupBy=[[test.user]], aggr=[[COUNT(Int32(1))]] + TableScan: test projection=[user] + + ## Optimized + Projection: leaf_udf(test.user, Utf8("name")) IS NOT NULL AS has_name, COUNT(Int32(1)) + Aggregate: groupBy=[[test.user]], aggr=[[COUNT(Int32(1))]] + TableScan: test projection=[user] + "#) + } + + /// Merging adds new pass-through columns not in the existing extracted projection. + #[test] + fn test_merge_with_new_columns() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(leaf_udf(col("a"), "x").eq(lit(1)))? + .filter(leaf_udf(col("b"), "y").eq(lit(2)))? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Filter: leaf_udf(test.b, Utf8("y")) = Int32(2) + Filter: leaf_udf(test.a, Utf8("x")) = Int32(1) + TableScan: test projection=[a, b, c] + + ## After Extraction + Projection: test.a, test.b, test.c + Filter: __datafusion_extracted_1 = Int32(2) + Projection: leaf_udf(test.b, Utf8("y")) AS __datafusion_extracted_1, test.a, test.b, test.c + Projection: test.a, test.b, test.c + Filter: __datafusion_extracted_2 = Int32(1) + Projection: leaf_udf(test.a, Utf8("x")) AS __datafusion_extracted_2, test.a, test.b, test.c + TableScan: test projection=[a, b, c] + + ## After Pushdown + Projection: test.a, test.b, test.c + Filter: __datafusion_extracted_1 = Int32(2) + Filter: __datafusion_extracted_2 = Int32(1) + Projection: leaf_udf(test.a, Utf8("x")) AS __datafusion_extracted_2, test.a, test.b, test.c, leaf_udf(test.b, Utf8("y")) AS __datafusion_extracted_1 + TableScan: test projection=[a, b, c] + + ## Optimized + Projection: test.a, test.b, test.c + Filter: __datafusion_extracted_1 = Int32(2) + Projection: test.a, test.b, test.c, __datafusion_extracted_1 + Filter: __datafusion_extracted_2 = Int32(1) + Projection: leaf_udf(test.a, Utf8("x")) AS __datafusion_extracted_2, test.a, test.b, test.c, leaf_udf(test.b, Utf8("y")) AS __datafusion_extracted_1 + TableScan: test projection=[a, b, c] + "#) + } + + // ========================================================================= + // Join extraction tests + // ========================================================================= + + /// Create a second table scan with struct field for join tests + fn test_table_scan_with_struct_named(name: &str) -> Result { + use arrow::datatypes::Schema; + let schema = Schema::new(test_table_scan_with_struct_fields()); + datafusion_expr::logical_plan::table_scan(Some(name), &schema, None)?.build() + } + + /// Extraction from equijoin keys (`on` expressions). + #[test] + fn test_extract_from_join_on() -> Result<()> { + use datafusion_expr::JoinType; + + let left = test_table_scan_with_struct()?; + let right = test_table_scan_with_struct_named("right")?; + + let plan = LogicalPlanBuilder::from(left) + .join_with_expr_keys( + right, + JoinType::Inner, + ( + vec![leaf_udf(col("user"), "id")], + vec![leaf_udf(col("user"), "id")], + ), + None, + )? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Inner Join: leaf_udf(test.user, Utf8("id")) = leaf_udf(right.user, Utf8("id")) + TableScan: test projection=[id, user] + TableScan: right projection=[id, user] + + ## After Extraction + Projection: test.id, test.user, right.id, right.user + Inner Join: __datafusion_extracted_1 = __datafusion_extracted_2 + Projection: leaf_udf(test.user, Utf8("id")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("id")) AS __datafusion_extracted_2, right.id, right.user + TableScan: right projection=[id, user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + "#) + } + + /// Extraction from non-equi join filter. + #[test] + fn test_extract_from_join_filter() -> Result<()> { + use datafusion_expr::JoinType; + + let left = test_table_scan_with_struct()?; + let right = test_table_scan_with_struct_named("right")?; + + let plan = LogicalPlanBuilder::from(left) + .join_on( + right, + JoinType::Inner, + vec![ + col("test.user").eq(col("right.user")), + leaf_udf(col("test.user"), "status").eq(lit("active")), + ], + )? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Inner Join: Filter: test.user = right.user AND leaf_udf(test.user, Utf8("status")) = Utf8("active") + TableScan: test projection=[id, user] + TableScan: right projection=[id, user] + + ## After Extraction + Projection: test.id, test.user, right.id, right.user + Inner Join: Filter: test.user = right.user AND __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + TableScan: right projection=[id, user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + "#) + } + + /// Extraction from both left and right sides of a join. + #[test] + fn test_extract_from_join_both_sides() -> Result<()> { + use datafusion_expr::JoinType; + + let left = test_table_scan_with_struct()?; + let right = test_table_scan_with_struct_named("right")?; + + let plan = LogicalPlanBuilder::from(left) + .join_on( + right, + JoinType::Inner, + vec![ + col("test.user").eq(col("right.user")), + leaf_udf(col("test.user"), "status").eq(lit("active")), + leaf_udf(col("right.user"), "role").eq(lit("admin")), + ], + )? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Inner Join: Filter: test.user = right.user AND leaf_udf(test.user, Utf8("status")) = Utf8("active") AND leaf_udf(right.user, Utf8("role")) = Utf8("admin") + TableScan: test projection=[id, user] + TableScan: right projection=[id, user] + + ## After Extraction + Projection: test.id, test.user, right.id, right.user + Inner Join: Filter: test.user = right.user AND __datafusion_extracted_1 = Utf8("active") AND __datafusion_extracted_2 = Utf8("admin") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("role")) AS __datafusion_extracted_2, right.id, right.user + TableScan: right projection=[id, user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + "#) + } + + /// Join with no MoveTowardsLeafNodes expressions returns unchanged. + #[test] + fn test_extract_from_join_no_extraction() -> Result<()> { + use datafusion_expr::JoinType; + + let left = test_table_scan()?; + let right = test_table_scan_with_name("right")?; + + let plan = LogicalPlanBuilder::from(left) + .join(right, JoinType::Inner, (vec!["a"], vec!["a"]), None)? + .build()?; + + assert_stages!(plan, @" + ## Original Plan + Inner Join: test.a = right.a + TableScan: test projection=[a, b, c] + TableScan: right projection=[a, b, c] + + ## After Extraction + (same as original) + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + ") + } + + /// Join followed by filter with extraction. + #[test] + fn test_extract_from_filter_above_join() -> Result<()> { + use datafusion_expr::JoinType; + + let left = test_table_scan_with_struct()?; + let right = test_table_scan_with_struct_named("right")?; + + let plan = LogicalPlanBuilder::from(left) + .join_with_expr_keys( + right, + JoinType::Inner, + ( + vec![leaf_udf(col("user"), "id")], + vec![leaf_udf(col("user"), "id")], + ), + None, + )? + .filter(leaf_udf(col("test.user"), "status").eq(lit("active")))? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Filter: leaf_udf(test.user, Utf8("status")) = Utf8("active") + Inner Join: leaf_udf(test.user, Utf8("id")) = leaf_udf(right.user, Utf8("id")) + TableScan: test projection=[id, user] + TableScan: right projection=[id, user] + + ## After Extraction + Projection: test.id, test.user, right.id, right.user + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user, right.id, right.user + Projection: test.id, test.user, right.id, right.user + Inner Join: __datafusion_extracted_2 = __datafusion_extracted_3 + Projection: leaf_udf(test.user, Utf8("id")) AS __datafusion_extracted_2, test.id, test.user + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("id")) AS __datafusion_extracted_3, right.id, right.user + TableScan: right projection=[id, user] + + ## After Pushdown + Projection: test.id, test.user, right.id, right.user + Filter: __datafusion_extracted_1 = Utf8("active") + Inner Join: __datafusion_extracted_2 = __datafusion_extracted_3 + Projection: leaf_udf(test.user, Utf8("id")) AS __datafusion_extracted_2, test.id, test.user, leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1 + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("id")) AS __datafusion_extracted_3, right.id, right.user + TableScan: right projection=[id, user] + + ## Optimized + Projection: test.id, test.user, right.id, right.user + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: test.id, test.user, __datafusion_extracted_1, right.id, right.user + Inner Join: __datafusion_extracted_2 = __datafusion_extracted_3 + Projection: leaf_udf(test.user, Utf8("id")) AS __datafusion_extracted_2, test.id, test.user, leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1 + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("id")) AS __datafusion_extracted_3, right.id, right.user + TableScan: right projection=[id, user] + "#) + } + + /// Extraction projection (get_field in SELECT) above a Join pushes into + /// the correct input side. + #[test] + fn test_extract_projection_above_join() -> Result<()> { + use datafusion_expr::JoinType; + + let left = test_table_scan_with_struct()?; + let right = test_table_scan_with_struct_named("right")?; + + let plan = LogicalPlanBuilder::from(left) + .join(right, JoinType::Inner, (vec!["id"], vec!["id"]), None)? + .project(vec![ + leaf_udf(col("test.user"), "status"), + leaf_udf(col("right.user"), "role"), + ])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("status")), leaf_udf(right.user, Utf8("role")) + Inner Join: test.id = right.id + TableScan: test projection=[id, user] + TableScan: right projection=[id, user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("status")), __datafusion_extracted_2 AS leaf_udf(right.user,Utf8("role")) + Inner Join: test.id = right.id + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("role")) AS __datafusion_extracted_2, right.id, right.user + TableScan: right projection=[id, user] + + ## Optimized + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("status")), __datafusion_extracted_2 AS leaf_udf(right.user,Utf8("role")) + Inner Join: test.id = right.id + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("role")) AS __datafusion_extracted_2, right.id + TableScan: right projection=[id, user] + "#) + } + + /// Join where both sides have same-named columns: a qualified reference + /// to the right side must be routed to the right input, not the left. + #[test] + fn test_extract_from_join_qualified_right_side() -> Result<()> { + use datafusion_expr::JoinType; + + let left = test_table_scan_with_struct()?; + let right = test_table_scan_with_struct_named("right")?; + + // Filter references right.user explicitly — must route to right side + let plan = LogicalPlanBuilder::from(left) + .join_on( + right, + JoinType::Inner, + vec![ + col("test.id").eq(col("right.id")), + leaf_udf(col("right.user"), "status").eq(lit("active")), + ], + )? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Inner Join: Filter: test.id = right.id AND leaf_udf(right.user, Utf8("status")) = Utf8("active") + TableScan: test projection=[id, user] + TableScan: right projection=[id, user] + + ## After Extraction + Projection: test.id, test.user, right.id, right.user + Inner Join: Filter: test.id = right.id AND __datafusion_extracted_1 = Utf8("active") + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("status")) AS __datafusion_extracted_1, right.id, right.user + TableScan: right projection=[id, user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + "#) + } + + /// When both inputs contain the same unqualified column, an unqualified + /// column reference is ambiguous and `find_owning_input` must return + /// `None` rather than always returning 0 (the left side). + #[test] + fn test_find_owning_input_ambiguous_unqualified_column() { + use std::collections::HashSet; + + // Simulate schema_columns output for two sides of a join where both + // have a "user" column — each set contains the qualified and + // unqualified form. + let left_cols: HashSet = [ + Column::new(Some("test"), "user"), + Column::new_unqualified("user"), + ] + .into_iter() + .collect(); + + let right_cols: HashSet = [ + Column::new(Some("right"), "user"), + Column::new_unqualified("user"), + ] + .into_iter() + .collect(); + + let input_column_sets = vec![left_cols, right_cols]; + + // Unqualified "user" matches both sets — must return None (ambiguous) + let unqualified = Expr::Column(Column::new_unqualified("user")); + assert_eq!(find_owning_input(&unqualified, &input_column_sets), None); + + // Qualified "right.user" matches only the right set — must return Some(1) + let qualified_right = Expr::Column(Column::new(Some("right"), "user")); + assert_eq!( + find_owning_input(&qualified_right, &input_column_sets), + Some(1) + ); + + // Qualified "test.user" matches only the left set — must return Some(0) + let qualified_left = Expr::Column(Column::new(Some("test"), "user")); + assert_eq!( + find_owning_input(&qualified_left, &input_column_sets), + Some(0) + ); + } + + /// Two leaf_udf expressions from different sides of a Join in a Filter. + /// Each is routed to its respective input side independently. + #[test] + fn test_extract_from_join_cross_input_expression() -> Result<()> { + let left = test_table_scan_with_struct()?; + let right = test_table_scan_with_struct_named("right")?; + + let plan = LogicalPlanBuilder::from(left) + .join_on( + right, + datafusion_expr::JoinType::Inner, + vec![col("test.id").eq(col("right.id"))], + )? + .filter( + leaf_udf(col("test.user"), "status") + .eq(leaf_udf(col("right.user"), "status")), + )? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Filter: leaf_udf(test.user, Utf8("status")) = leaf_udf(right.user, Utf8("status")) + Inner Join: Filter: test.id = right.id + TableScan: test projection=[id, user] + TableScan: right projection=[id, user] + + ## After Extraction + Projection: test.id, test.user, right.id, right.user + Filter: __datafusion_extracted_1 = __datafusion_extracted_2 + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, leaf_udf(right.user, Utf8("status")) AS __datafusion_extracted_2, test.id, test.user, right.id, right.user + Inner Join: Filter: test.id = right.id + TableScan: test projection=[id, user] + TableScan: right projection=[id, user] + + ## After Pushdown + Projection: test.id, test.user, right.id, right.user + Filter: __datafusion_extracted_1 = __datafusion_extracted_2 + Inner Join: Filter: test.id = right.id + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("status")) AS __datafusion_extracted_2, right.id, right.user + TableScan: right projection=[id, user] + + ## Optimized + (same as after pushdown) + "#) + } + + // ========================================================================= + // Column-rename through intermediate node tests + // ========================================================================= + + /// Projection with leaf expr above Filter above renaming Projection. + #[test] + fn test_extract_through_filter_with_column_rename() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("user").alias("x")])? + .filter(col("x").is_not_null())? + .project(vec![leaf_udf(col("x"), "a")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(x, Utf8("a")) + Filter: x IS NOT NULL + Projection: test.user AS x + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 AS leaf_udf(x,Utf8("a")) + Filter: x IS NOT NULL + Projection: test.user AS x, leaf_udf(test.user, Utf8("a")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: __datafusion_extracted_1 AS leaf_udf(x,Utf8("a")) + Filter: x IS NOT NULL + Projection: test.user AS x, leaf_udf(test.user, Utf8("a")) AS __datafusion_extracted_1 + TableScan: test projection=[user] + "#) + } + + /// Same as above but with a partial extraction (leaf + arithmetic). + #[test] + fn test_extract_partial_through_filter_with_column_rename() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("user").alias("x")])? + .filter(col("x").is_not_null())? + .project(vec![leaf_udf(col("x"), "a").is_not_null()])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(x, Utf8("a")) IS NOT NULL + Filter: x IS NOT NULL + Projection: test.user AS x + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 IS NOT NULL AS leaf_udf(x,Utf8("a")) IS NOT NULL + Filter: x IS NOT NULL + Projection: test.user AS x, leaf_udf(test.user, Utf8("a")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: __datafusion_extracted_1 IS NOT NULL AS leaf_udf(x,Utf8("a")) IS NOT NULL + Filter: x IS NOT NULL + Projection: test.user AS x, leaf_udf(test.user, Utf8("a")) AS __datafusion_extracted_1 + TableScan: test projection=[user] + "#) + } + + /// Tests merge_into_extracted_projection path through a renaming projection. + #[test] + fn test_extract_from_filter_above_renaming_projection() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("user").alias("x")])? + .filter(leaf_udf(col("x"), "a").eq(lit("active")))? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Filter: leaf_udf(x, Utf8("a")) = Utf8("active") + Projection: test.user AS x + TableScan: test projection=[user] + + ## After Extraction + Projection: x + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: test.user AS x, leaf_udf(test.user, Utf8("a")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + Projection: x + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: test.user AS x, leaf_udf(test.user, Utf8("a")) AS __datafusion_extracted_1 + TableScan: test projection=[user] + "#) + } + + // ========================================================================= + // SubqueryAlias extraction tests + // ========================================================================= + + /// Extraction projection pushes through SubqueryAlias. + #[test] + fn test_extract_through_subquery_alias() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .alias("sub")? + .project(vec![leaf_udf(col("sub.user"), "name")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(sub.user, Utf8("name")) + SubqueryAlias: sub + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 AS leaf_udf(sub.user,Utf8("name")) + SubqueryAlias: sub + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: __datafusion_extracted_1 AS leaf_udf(sub.user,Utf8("name")) + SubqueryAlias: sub + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1 + TableScan: test projection=[user] + "#) + } + + /// Extraction projection pushes through SubqueryAlias + Filter. + #[test] + fn test_extract_through_subquery_alias_with_filter() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .alias("sub")? + .filter(leaf_udf(col("sub.user"), "status").eq(lit("active")))? + .project(vec![leaf_udf(col("sub.user"), "name")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(sub.user, Utf8("name")) + Filter: leaf_udf(sub.user, Utf8("status")) = Utf8("active") + SubqueryAlias: sub + TableScan: test projection=[user] + + ## After Extraction + Projection: leaf_udf(sub.user, Utf8("name")) + Projection: sub.user + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(sub.user, Utf8("status")) AS __datafusion_extracted_1, sub.user + SubqueryAlias: sub + TableScan: test projection=[user] + + ## After Pushdown + Projection: __datafusion_extracted_2 AS leaf_udf(sub.user,Utf8("name")) + Filter: __datafusion_extracted_1 = Utf8("active") + SubqueryAlias: sub + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_2, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: __datafusion_extracted_2 AS leaf_udf(sub.user,Utf8("name")) + Filter: __datafusion_extracted_1 = Utf8("active") + SubqueryAlias: sub + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_2 + TableScan: test projection=[user] + "#) + } + + /// Two layers of SubqueryAlias: extraction pushes through both. + #[test] + fn test_extract_through_nested_subquery_alias() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .alias("inner_sub")? + .alias("outer_sub")? + .project(vec![leaf_udf(col("outer_sub.user"), "name")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(outer_sub.user, Utf8("name")) + SubqueryAlias: outer_sub + SubqueryAlias: inner_sub + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 AS leaf_udf(outer_sub.user,Utf8("name")) + SubqueryAlias: outer_sub + SubqueryAlias: inner_sub + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: __datafusion_extracted_1 AS leaf_udf(outer_sub.user,Utf8("name")) + SubqueryAlias: outer_sub + SubqueryAlias: inner_sub + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1 + TableScan: test projection=[user] + "#) + } + + /// Plain columns through SubqueryAlias -- no extraction needed. + #[test] + fn test_subquery_alias_no_extraction() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .alias("sub")? + .project(vec![col("sub.a"), col("sub.b")])? + .build()?; + + assert_stages!(plan, @" + ## Original Plan + SubqueryAlias: sub + TableScan: test projection=[a, b] + + ## After Extraction + (same as original) + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + ") + } + + /// Two UDFs with the same `name()` but different concrete types should NOT be + /// deduplicated -- they are semantically different expressions that happen to + /// collide on `schema_name()`. + #[test] + fn test_different_udfs_same_schema_name_not_deduplicated() -> Result<()> { + let udf_a = Arc::new(ScalarUDF::new_from_impl( + PlacementTestUDF::new() + .with_placement(ExpressionPlacement::MoveTowardsLeafNodes) + .with_id(1), + )); + let udf_b = Arc::new(ScalarUDF::new_from_impl( + PlacementTestUDF::new() + .with_placement(ExpressionPlacement::MoveTowardsLeafNodes) + .with_id(2), + )); + + let expr_a = Expr::ScalarFunction(ScalarFunction::new_udf( + udf_a, + vec![col("user"), lit("field")], + )); + let expr_b = Expr::ScalarFunction(ScalarFunction::new_udf( + udf_b, + vec![col("user"), lit("field")], + )); + + // Verify preconditions: same schema_name but different Expr + assert_eq!( + expr_a.schema_name().to_string(), + expr_b.schema_name().to_string(), + "Both expressions should have the same schema_name" + ); + assert_ne!( + expr_a, expr_b, + "Expressions should NOT be equal (different UDF instances)" + ); + + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan.clone()) + .filter(expr_a.clone().eq(lit("a")).and(expr_b.clone().eq(lit("b"))))? + .select(vec![ + table_scan + .schema() + .index_of_column_by_name(None, "id") + .unwrap(), + ])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: test.id + Filter: leaf_udf(test.user, Utf8("field")) = Utf8("a") AND leaf_udf(test.user, Utf8("field")) = Utf8("b") + TableScan: test projection=[id, user] + + ## After Extraction + Projection: test.id + Projection: test.id, test.user + Filter: __datafusion_extracted_1 = Utf8("a") AND __datafusion_extracted_2 = Utf8("b") + Projection: leaf_udf(test.user, Utf8("field")) AS __datafusion_extracted_1, leaf_udf(test.user, Utf8("field")) AS __datafusion_extracted_2, test.id, test.user + TableScan: test projection=[id, user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + Projection: test.id + Filter: __datafusion_extracted_1 = Utf8("a") AND __datafusion_extracted_2 = Utf8("b") + Projection: leaf_udf(test.user, Utf8("field")) AS __datafusion_extracted_1, leaf_udf(test.user, Utf8("field")) AS __datafusion_extracted_2, test.id + TableScan: test projection=[id, user] + "#) + } + + // ========================================================================= + // Filter pushdown interaction tests + // ========================================================================= + + /// Extraction pushdown through a filter that already had its own + /// `leaf_udf` extracted. + #[test] + fn test_extraction_pushdown_through_filter_with_extracted_predicate() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(leaf_udf(col("user"), "status").eq(lit("active")))? + .project(vec![col("id"), leaf_udf(col("user"), "name")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: test.id, leaf_udf(test.user, Utf8("name")) + Filter: leaf_udf(test.user, Utf8("status")) = Utf8("active") + TableScan: test projection=[id, user] + + ## After Extraction + Projection: test.id, leaf_udf(test.user, Utf8("name")) + Projection: test.id, test.user + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + + ## After Pushdown + Projection: test.id, __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("name")) + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_2 + TableScan: test projection=[id, user] + + ## Optimized + Projection: test.id, __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("name")) + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_2 + TableScan: test projection=[id, user] + "#) + } + + /// Same expression in filter predicate and projection output. + #[test] + fn test_extraction_pushdown_same_expr_in_filter_and_projection() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let field_expr = leaf_udf(col("user"), "status"); + let plan = LogicalPlanBuilder::from(table_scan) + .filter(field_expr.clone().gt(lit(5)))? + .project(vec![col("id"), field_expr])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: test.id, leaf_udf(test.user, Utf8("status")) + Filter: leaf_udf(test.user, Utf8("status")) > Int32(5) + TableScan: test projection=[id, user] + + ## After Extraction + Projection: test.id, leaf_udf(test.user, Utf8("status")) + Projection: test.id, test.user + Filter: __datafusion_extracted_1 > Int32(5) + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + + ## After Pushdown + Projection: test.id, __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("status")) + Filter: __datafusion_extracted_1 > Int32(5) + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user, leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_2 + TableScan: test projection=[id, user] + + ## Optimized + Projection: test.id, __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("status")) + Filter: __datafusion_extracted_1 > Int32(5) + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_2 + TableScan: test projection=[id, user] + "#) + } + + /// Left join with a `leaf_udf` filter on the right side AND + /// the projection also selects `leaf_udf` from the right side. + #[test] + fn test_left_join_with_filter_and_projection_extraction() -> Result<()> { + use datafusion_expr::JoinType; + + let left = test_table_scan_with_struct()?; + let right = test_table_scan_with_struct_named("right")?; + + let plan = LogicalPlanBuilder::from(left) + .join_on( + right, + JoinType::Left, + vec![ + col("test.id").eq(col("right.id")), + leaf_udf(col("right.user"), "status").gt(lit(5)), + ], + )? + .project(vec![ + col("test.id"), + leaf_udf(col("test.user"), "name"), + leaf_udf(col("right.user"), "status"), + ])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: test.id, leaf_udf(test.user, Utf8("name")), leaf_udf(right.user, Utf8("status")) + Left Join: Filter: test.id = right.id AND leaf_udf(right.user, Utf8("status")) > Int32(5) + TableScan: test projection=[id, user] + TableScan: right projection=[id, user] + + ## After Extraction + Projection: test.id, leaf_udf(test.user, Utf8("name")), leaf_udf(right.user, Utf8("status")) + Projection: test.id, test.user, right.id, right.user + Left Join: Filter: test.id = right.id AND __datafusion_extracted_1 > Int32(5) + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("status")) AS __datafusion_extracted_1, right.id, right.user + TableScan: right projection=[id, user] + + ## After Pushdown + Projection: test.id, __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("name")), __datafusion_extracted_3 AS leaf_udf(right.user,Utf8("status")) + Left Join: Filter: test.id = right.id AND __datafusion_extracted_1 > Int32(5) + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_2, test.id, test.user + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("status")) AS __datafusion_extracted_1, right.id, right.user, leaf_udf(right.user, Utf8("status")) AS __datafusion_extracted_3 + TableScan: right projection=[id, user] + + ## Optimized + Projection: test.id, __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("name")), __datafusion_extracted_3 AS leaf_udf(right.user,Utf8("status")) + Left Join: Filter: test.id = right.id AND __datafusion_extracted_1 > Int32(5) + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_2, test.id + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("status")) AS __datafusion_extracted_1, right.id, leaf_udf(right.user, Utf8("status")) AS __datafusion_extracted_3 + TableScan: right projection=[id, user] + "#) + } + + /// Extraction projection pushed through a filter whose predicate + /// references a different extracted expression. + #[test] + fn test_pure_extraction_proj_push_through_filter() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(leaf_udf(col("user"), "status").gt(lit(5)))? + .project(vec![ + col("id"), + leaf_udf(col("user"), "name"), + leaf_udf(col("user"), "status"), + ])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: test.id, leaf_udf(test.user, Utf8("name")), leaf_udf(test.user, Utf8("status")) + Filter: leaf_udf(test.user, Utf8("status")) > Int32(5) + TableScan: test projection=[id, user] + + ## After Extraction + Projection: test.id, leaf_udf(test.user, Utf8("name")), leaf_udf(test.user, Utf8("status")) + Projection: test.id, test.user + Filter: __datafusion_extracted_1 > Int32(5) + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + + ## After Pushdown + Projection: test.id, __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("name")), __datafusion_extracted_3 AS leaf_udf(test.user,Utf8("status")) + Filter: __datafusion_extracted_1 > Int32(5) + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_2, leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_3 + TableScan: test projection=[id, user] + + ## Optimized + Projection: test.id, __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("name")), __datafusion_extracted_3 AS leaf_udf(test.user,Utf8("status")) + Filter: __datafusion_extracted_1 > Int32(5) + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_2, leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_3 + TableScan: test projection=[id, user] + "#) + } + + /// When an extraction projection's __extracted alias references a column + /// (e.g. `user`) that is NOT a standalone expression in the projection, + /// the merge into the inner projection should still succeed. + #[test] + fn test_merge_extraction_into_projection_with_column_ref_inflation() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + + // Inner projection (simulates a trimmed projection) + let inner = LogicalPlanBuilder::from(table_scan) + .project(vec![col("user"), col("id")])? + .build()?; + + // Outer projection: __extracted alias + id (but NOT user as standalone). + // The alias references `user` internally, inflating columns_needed. + let plan = LogicalPlanBuilder::from(inner) + .project(vec![ + leaf_udf(col("user"), "status") + .alias(format!("{EXTRACTED_EXPR_PREFIX}_1")), + col("id"), + ])? + .build()?; + + // Run only PushDownLeafProjections + let ctx = OptimizerContext::new().with_max_passes(1); + let optimizer = + Optimizer::with_rules(vec![Arc::new(PushDownLeafProjections::new())]); + let result = optimizer.optimize(plan, &ctx, |_, _| {})?; + + // With the fix: merge succeeds → extraction merged into inner projection. + // Without the fix: merge rejected → two separate projections remain. + insta::assert_snapshot!(format!("{result}"), @r#" + Projection: __datafusion_extracted_1, test.id + Projection: test.user, test.id, leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1 + TableScan: test + "#); + + Ok(()) + } +} diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index a1a59cb348876..e610091824092 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -23,7 +23,6 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] -#![deny(clippy::allow_attributes)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] //! # DataFusion Optimizer @@ -58,6 +57,7 @@ pub mod eliminate_nested_union { } pub mod eliminate_outer_join; pub mod extract_equijoin_predicate; +pub mod extract_leaf_expressions; pub mod filter_null_join_keys; pub mod optimize_projections; pub mod optimize_unions; @@ -66,6 +66,7 @@ pub mod propagate_empty_relation; pub mod push_down_filter; pub mod push_down_limit; pub mod replace_distinct_aggregate; +pub mod rewrite_set_comparison; pub mod scalar_subquery_to_join; pub mod simplify_expressions; pub mod single_distinct_to_groupby; diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 548eadffa242e..93df300bb50b4 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -268,15 +268,10 @@ fn optimize_projections( Some(projection) => indices.into_mapped_indices(|idx| projection[idx]), None => indices.into_inner(), }; - return TableScan::try_new( - table_name, - source, - Some(projection), - filters, - fetch, - ) - .map(LogicalPlan::TableScan) - .map(Transformed::yes); + let new_scan = + TableScan::try_new(table_name, source, Some(projection), filters, fetch)?; + + return Ok(Transformed::yes(LogicalPlan::TableScan(new_scan))); } // Other node types are handled below _ => {} @@ -530,15 +525,14 @@ fn merge_consecutive_projections(proj: Projection) -> Result 1 - && !is_expr_trivial( - &prev_projection.expr - [prev_projection.schema.index_of_column(col).unwrap()], - ) + && !prev_projection.expr[prev_projection.schema.index_of_column(col).unwrap()] + .placement() + .should_push_to_leaves() }) { // no change return Projection::try_new_with_schema(expr, input, schema).map(Transformed::no); @@ -565,7 +559,19 @@ fn merge_consecutive_projections(proj: Projection) -> Result rewrite_expr(*expr, &prev_projection).map(|result| { result.update_data(|expr| { - Expr::Alias(Alias::new(expr, relation, name).with_metadata(metadata)) + // After substitution, the inner expression may now have the + // same schema_name as the alias (e.g. when an extraction + // alias like `__extracted_1 AS f(x)` is resolved back to + // `f(x)`). Wrapping in a redundant self-alias causes a + // cosmetic `f(x) AS f(x)` due to Display vs schema_name + // formatting differences. Drop the alias when it matches. + if metadata.is_none() && expr.schema_name().to_string() == name { + expr + } else { + Expr::Alias( + Alias::new(expr, relation, name).with_metadata(metadata), + ) + } }) }), e => rewrite_expr(e, &prev_projection), @@ -591,11 +597,6 @@ fn merge_consecutive_projections(proj: Projection) -> Result bool { - matches!(expr, Expr::Column(_) | Expr::Literal(_, _)) -} - /// Rewrites a projection expression using the projection before it (i.e. its input) /// This is a subroutine to the `merge_consecutive_projections` function. /// diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 8740ab072a1f5..118ddef49b7e7 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -43,6 +43,7 @@ use crate::eliminate_join::EliminateJoin; use crate::eliminate_limit::EliminateLimit; use crate::eliminate_outer_join::EliminateOuterJoin; use crate::extract_equijoin_predicate::ExtractEquijoinPredicate; +use crate::extract_leaf_expressions::{ExtractLeafExpressions, PushDownLeafProjections}; use crate::filter_null_join_keys::FilterNullJoinKeys; use crate::optimize_projections::OptimizeProjections; use crate::optimize_unions::OptimizeUnions; @@ -51,6 +52,7 @@ use crate::propagate_empty_relation::PropagateEmptyRelation; use crate::push_down_filter::PushDownFilter; use crate::push_down_limit::PushDownLimit; use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; +use crate::rewrite_set_comparison::RewriteSetComparison; use crate::scalar_subquery_to_join::ScalarSubqueryToJoin; use crate::simplify_expressions::SimplifyExpressions; use crate::single_distinct_to_groupby::SingleDistinctToGroupBy; @@ -235,6 +237,7 @@ impl Optimizer { /// Create a new optimizer using the recommended list of rules pub fn new() -> Self { let rules: Vec> = vec![ + Arc::new(RewriteSetComparison::new()), Arc::new(OptimizeUnions::new()), Arc::new(SimplifyExpressions::new()), Arc::new(ReplaceDistinctWithAggregate::new()), @@ -258,6 +261,8 @@ impl Optimizer { // that might benefit from the following rules Arc::new(EliminateGroupByConstant::new()), Arc::new(CommonSubexprEliminate::new()), + Arc::new(ExtractLeafExpressions::new()), + Arc::new(PushDownLeafProjections::new()), Arc::new(OptimizeProjections::new()), ]; diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 755ffdbafc869..15bb5db07d2c2 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -45,6 +45,7 @@ use crate::optimizer::ApplyOrder; use crate::simplify_expressions::simplify_predicates; use crate::utils::{has_all_column_refs, is_restrict_null_predicate}; use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_expr::ExpressionPlacement; /// Optimizer rule for pushing (moving) filter expressions down in a plan so /// they are applied as early as possible. @@ -263,6 +264,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump), Expr::Exists { .. } | Expr::InSubquery(_) + | Expr::SetComparison(_) | Expr::ScalarSubquery(_) | Expr::OuterReferenceColumn(_, _) | Expr::Unnest(_) => { @@ -454,11 +456,11 @@ fn push_down_all_join( } } - // For infer predicates, if they can not push through join, just drop them + // Push predicates inferred from the join expression for predicate in inferred_join_predicates { - if left_preserved && checker.is_left_only(&predicate) { + if checker.is_left_only(&predicate) { left_push.push(predicate); - } else if right_preserved && checker.is_right_only(&predicate) { + } else if checker.is_right_only(&predicate) { right_push.push(predicate); } } @@ -1294,10 +1296,13 @@ fn rewrite_projection( predicates: Vec, mut projection: Projection, ) -> Result<(Transformed, Option)> { - // A projection is filter-commutable if it do not contain volatile predicates or contain volatile - // predicates that are not used in the filter. However, we should re-writes all predicate expressions. - // collect projection. - let (volatile_map, non_volatile_map): (HashMap<_, _>, HashMap<_, _>) = projection + // Partition projection expressions into non-pushable vs pushable. + // Non-pushable expressions are volatile (must not be duplicated) or + // MoveTowardsLeafNodes (cheap expressions like get_field where re-inlining + // into a filter causes optimizer instability — ExtractLeafExpressions will + // undo the push-down, creating an infinite loop that runs until the + // iteration limit is hit). + let (non_pushable_map, pushable_map): (HashMap<_, _>, HashMap<_, _>) = projection .schema .iter() .zip(projection.expr.iter()) @@ -1307,12 +1312,15 @@ fn rewrite_projection( (qualified_name(qualifier, field.name()), expr) }) - .partition(|(_, value)| value.is_volatile()); + .partition(|(_, value)| { + value.is_volatile() + || value.placement() == ExpressionPlacement::MoveTowardsLeafNodes + }); let mut push_predicates = vec![]; let mut keep_predicates = vec![]; for expr in predicates { - if contain(&expr, &volatile_map) { + if contain(&expr, &non_pushable_map) { keep_predicates.push(expr); } else { push_predicates.push(expr); @@ -1324,7 +1332,7 @@ fn rewrite_projection( // re-write all filters based on this projection // E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1" let new_filter = LogicalPlan::Filter(Filter::try_new( - replace_cols_by_name(expr, &non_volatile_map)?, + replace_cols_by_name(expr, &pushable_map)?, std::mem::take(&mut projection.input), )?); @@ -1335,7 +1343,10 @@ fn rewrite_projection( conjunction(keep_predicates), )) } - None => Ok((Transformed::no(LogicalPlan::Projection(projection)), None)), + None => Ok(( + Transformed::no(LogicalPlan::Projection(projection)), + conjunction(keep_predicates), + )), } } @@ -1445,6 +1456,7 @@ mod tests { use crate::assert_optimized_plan_eq_snapshot; use crate::optimizer::Optimizer; use crate::simplify_expressions::SimplifyExpressions; + use crate::test::udfs::leaf_udf_expr; use crate::test::*; use datafusion_expr::test::function_stub::sum; use insta::assert_snapshot; @@ -2331,7 +2343,7 @@ mod tests { plan, @r" Projection: test.a, test1.d - Cross Join: + Cross Join: Projection: test.a, test.b, test.c TableScan: test, full_filters=[test.a = Int32(1)] Projection: test1.d, test1.e, test1.f @@ -2361,7 +2373,7 @@ mod tests { plan, @r" Projection: test.a, test1.a - Cross Join: + Cross Join: Projection: test.a, test.b, test.c TableScan: test, full_filters=[test.a = Int32(1)] Projection: test1.a, test1.b, test1.c @@ -2720,8 +2732,7 @@ mod tests { ) } - /// post-left-join predicate on a column common to both sides is only pushed to the left side - /// i.e. - not duplicated to the right side + /// post-left-join predicate on a column common to both sides is pushed to both sides #[test] fn filter_using_left_join_on_common() -> Result<()> { let table_scan = test_table_scan()?; @@ -2749,20 +2760,19 @@ mod tests { TableScan: test2 ", ); - // filter sent to left side of the join, not the right + // filter sent to left side of the join and to the right assert_optimized_plan_equal!( plan, @r" Left Join: Using test.a = test2.a TableScan: test, full_filters=[test.a <= Int64(1)] Projection: test2.a - TableScan: test2 + TableScan: test2, full_filters=[test2.a <= Int64(1)] " ) } - /// post-right-join predicate on a column common to both sides is only pushed to the right side - /// i.e. - not duplicated to the left side. + /// post-right-join predicate on a column common to both sides is pushed to both sides #[test] fn filter_using_right_join_on_common() -> Result<()> { let table_scan = test_table_scan()?; @@ -2790,12 +2800,12 @@ mod tests { TableScan: test2 ", ); - // filter sent to right side of join, not duplicated to the left + // filter sent to right side of join, sent to the left as well assert_optimized_plan_equal!( plan, @r" Right Join: Using test.a = test2.a - TableScan: test + TableScan: test, full_filters=[test.a <= Int64(1)] Projection: test2.a TableScan: test2, full_filters=[test2.a <= Int64(1)] " @@ -2977,7 +2987,7 @@ mod tests { Projection: test.a, test.b, test.c TableScan: test Projection: test2.a, test2.b, test2.c - TableScan: test2, full_filters=[test2.c > UInt32(4)] + TableScan: test2, full_filters=[test2.a > UInt32(1), test2.c > UInt32(4)] " ) } @@ -4222,4 +4232,68 @@ mod tests { " ) } + + /// Test that filters are NOT pushed through MoveTowardsLeafNodes projections. + /// These are cheap expressions (like get_field) where re-inlining into a filter + /// has no benefit and causes optimizer instability — ExtractLeafExpressions will + /// undo the push-down, creating an infinite loop that runs until the iteration + /// limit is hit. + #[test] + fn filter_not_pushed_through_move_towards_leaves_projection() -> Result<()> { + let table_scan = test_table_scan()?; + + // Create a projection with a MoveTowardsLeafNodes expression + let proj = LogicalPlanBuilder::from(table_scan) + .project(vec![ + leaf_udf_expr(col("a")).alias("val"), + col("b"), + col("c"), + ])? + .build()?; + + // Put a filter on the MoveTowardsLeafNodes column + let plan = LogicalPlanBuilder::from(proj) + .filter(col("val").gt(lit(150i64)))? + .build()?; + + // Filter should NOT be pushed through — val maps to a MoveTowardsLeafNodes expr + assert_optimized_plan_equal!( + plan, + @r" + Filter: val > Int64(150) + Projection: leaf_udf(test.a) AS val, test.b, test.c + TableScan: test + " + ) + } + + /// Test mixed predicates: Column predicate pushed, MoveTowardsLeafNodes kept. + #[test] + fn filter_mixed_predicates_partial_push() -> Result<()> { + let table_scan = test_table_scan()?; + + // Create a projection with both MoveTowardsLeafNodes and Column expressions + let proj = LogicalPlanBuilder::from(table_scan) + .project(vec![ + leaf_udf_expr(col("a")).alias("val"), + col("b"), + col("c"), + ])? + .build()?; + + // Filter with both: val > 150 (MoveTowardsLeafNodes) AND b > 5 (Column) + let plan = LogicalPlanBuilder::from(proj) + .filter(col("val").gt(lit(150i64)).and(col("b").gt(lit(5i64))))? + .build()?; + + // val > 150 should be kept above, b > 5 should be pushed through + assert_optimized_plan_equal!( + plan, + @r" + Filter: val > Int64(150) + Projection: leaf_udf(test.a) AS val, test.b, test.c + TableScan: test, full_filters=[test.b > Int64(5)] + " + ) + } } diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 7b302adf22acc..755e192e340d9 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -1044,7 +1044,7 @@ mod test { plan, @r" Limit: skip=0, fetch=1000 - Cross Join: + Cross Join: Limit: skip=0, fetch=1000 TableScan: test, fetch=1000 Limit: skip=0, fetch=1000 @@ -1067,7 +1067,7 @@ mod test { plan, @r" Limit: skip=1000, fetch=1000 - Cross Join: + Cross Join: Limit: skip=0, fetch=2000 TableScan: test, fetch=2000 Limit: skip=0, fetch=2000 diff --git a/datafusion/optimizer/src/rewrite_set_comparison.rs b/datafusion/optimizer/src/rewrite_set_comparison.rs new file mode 100644 index 0000000000000..c8c35b518743a --- /dev/null +++ b/datafusion/optimizer/src/rewrite_set_comparison.rs @@ -0,0 +1,171 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Optimizer rule rewriting `SetComparison` subqueries (e.g. `= ANY`, +//! `> ALL`) into boolean expressions built from `EXISTS` subqueries +//! that capture SQL three-valued logic. + +use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{Column, DFSchema, ExprSchema, Result, ScalarValue, plan_err}; +use datafusion_expr::expr::{self, Exists, SetComparison, SetQuantifier}; +use datafusion_expr::logical_plan::Subquery; +use datafusion_expr::logical_plan::builder::LogicalPlanBuilder; +use datafusion_expr::{Expr, LogicalPlan, lit}; +use std::sync::Arc; + +use datafusion_expr::utils::merge_schema; + +/// Rewrite `SetComparison` expressions to scalar subqueries that return the +/// correct boolean value (including SQL NULL semantics). After this rule +/// runs, later rules such as `ScalarSubqueryToJoin` can decorrelate and +/// remove the remaining subquery. +#[derive(Debug, Default)] +pub struct RewriteSetComparison; + +impl RewriteSetComparison { + /// Create a new `RewriteSetComparison` optimizer rule. + pub fn new() -> Self { + Self + } + + fn rewrite_plan(&self, plan: LogicalPlan) -> Result> { + let schema = merge_schema(&plan.inputs()); + plan.map_expressions(|expr| { + expr.transform_up(|expr| rewrite_set_comparison(expr, &schema)) + }) + } +} + +impl OptimizerRule for RewriteSetComparison { + fn name(&self) -> &str { + "rewrite_set_comparison" + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + plan.transform_up_with_subqueries(|plan| self.rewrite_plan(plan)) + } +} + +fn rewrite_set_comparison( + expr: Expr, + outer_schema: &DFSchema, +) -> Result> { + match expr { + Expr::SetComparison(set_comparison) => { + let rewritten = build_set_comparison_subquery(set_comparison, outer_schema)?; + Ok(Transformed::yes(rewritten)) + } + _ => Ok(Transformed::no(expr)), + } +} + +fn build_set_comparison_subquery( + set_comparison: SetComparison, + outer_schema: &DFSchema, +) -> Result { + let SetComparison { + expr, + subquery, + op, + quantifier, + } = set_comparison; + + let left_expr = to_outer_reference(*expr, outer_schema)?; + let subquery_schema = subquery.subquery.schema(); + if subquery_schema.fields().is_empty() { + return plan_err!("single expression required."); + } + // avoid `head_output_expr` for aggr/window plan, it will gives group-by expr if exists + let right_expr = Expr::Column(Column::from(subquery_schema.qualified_field(0))); + + let comparison = Expr::BinaryExpr(expr::BinaryExpr::new( + Box::new(left_expr), + op, + Box::new(right_expr), + )); + + let true_exists = + exists_subquery(&subquery, Expr::IsTrue(Box::new(comparison.clone())))?; + let null_exists = + exists_subquery(&subquery, Expr::IsNull(Box::new(comparison.clone())))?; + + let result_expr = match quantifier { + SetQuantifier::Any => Expr::Case(expr::Case { + expr: None, + when_then_expr: vec![ + (Box::new(true_exists), Box::new(lit(true))), + ( + Box::new(null_exists), + Box::new(Expr::Literal(ScalarValue::Boolean(None), None)), + ), + ], + else_expr: Some(Box::new(lit(false))), + }), + SetQuantifier::All => { + let false_exists = + exists_subquery(&subquery, Expr::IsFalse(Box::new(comparison.clone())))?; + Expr::Case(expr::Case { + expr: None, + when_then_expr: vec![ + (Box::new(false_exists), Box::new(lit(false))), + ( + Box::new(null_exists), + Box::new(Expr::Literal(ScalarValue::Boolean(None), None)), + ), + ], + else_expr: Some(Box::new(lit(true))), + }) + } + }; + + Ok(result_expr) +} + +fn exists_subquery(subquery: &Subquery, filter: Expr) -> Result { + let plan = LogicalPlanBuilder::from(subquery.subquery.as_ref().clone()) + .filter(filter)? + .build()?; + let outer_ref_columns = plan.all_out_ref_exprs(); + Ok(Expr::Exists(Exists { + subquery: Subquery { + subquery: Arc::new(plan), + outer_ref_columns, + spans: subquery.spans.clone(), + }, + negated: false, + })) +} + +fn to_outer_reference(expr: Expr, outer_schema: &DFSchema) -> Result { + expr.transform_up(|expr| match expr { + Expr::Column(col) => { + let field = outer_schema.field_from_column(&col)?; + Ok(Transformed::yes(Expr::OuterReferenceColumn( + Arc::clone(field), + col, + ))) + } + Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)), + _ => Ok(Transformed::no(expr)), + }) + .map(|t| t.data) +} diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 55bff5849c5cb..c6644e008645a 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -18,7 +18,7 @@ //! Expression simplification API use arrow::{ - array::{AsArray, new_null_array}, + array::{Array, AsArray, new_null_array}, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; @@ -28,6 +28,7 @@ use std::ops::Not; use std::sync::Arc; use datafusion_common::config::ConfigOptions; +use datafusion_common::nested_struct::has_one_of_more_common_fields; use datafusion_common::{ DFSchema, DataFusionError, Result, ScalarValue, exec_datafusion_err, internal_err, }; @@ -38,8 +39,8 @@ use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; use datafusion_expr::{ - BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility, and, - binary::BinaryTypeCoercer, lit, or, + BinaryExpr, Case, ColumnarValue, Expr, ExprSchemable, Like, Operator, Volatility, + and, binary::BinaryTypeCoercer, lit, or, preimage::PreimageResult, }; use datafusion_expr::{Cast, TryCast, simplify::ExprSimplifyResult}; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; @@ -51,7 +52,6 @@ use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionP use super::inlist_simplifier::ShortenInListSimplifier; use super::utils::*; -use crate::analyzer::type_coercion::TypeCoercionRewriter; use crate::simplify_expressions::SimplifyContext; use crate::simplify_expressions::regex::simplify_regex_expr; use crate::simplify_expressions::unwrap_cast::{ @@ -59,6 +59,10 @@ use crate::simplify_expressions::unwrap_cast::{ is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist, unwrap_cast_in_comparison_for_binary, }; +use crate::{ + analyzer::type_coercion::TypeCoercionRewriter, + simplify_expressions::udf_preimage::rewrite_with_preimage, +}; use datafusion_expr::expr_rewriter::rewrite_with_guarantees_map; use datafusion_expr_common::casts::try_cast_literal_to_type; use indexmap::IndexSet; @@ -633,6 +637,7 @@ impl ConstEvaluator { | Expr::OuterReferenceColumn(_, _) | Expr::Exists { .. } | Expr::InSubquery(_) + | Expr::SetComparison(_) | Expr::ScalarSubquery(_) | Expr::WindowFunction { .. } | Expr::GroupingSet(_) @@ -641,6 +646,35 @@ impl ConstEvaluator { Expr::ScalarFunction(ScalarFunction { func, .. }) => { Self::volatility_ok(func.signature().volatility) } + Expr::Cast(Cast { expr, data_type }) + | Expr::TryCast(TryCast { expr, data_type }) => { + if let ( + Ok(DataType::Struct(source_fields)), + DataType::Struct(target_fields), + ) = (expr.get_type(&DFSchema::empty()), data_type) + { + // Don't const-fold struct casts with different field counts + if source_fields.len() != target_fields.len() { + return false; + } + + // Skip const-folding when there is no field name overlap + if !has_one_of_more_common_fields(&source_fields, target_fields) { + return false; + } + + // Don't const-fold struct casts with empty (0-row) literals + // The simplifier uses a 1-row input batch, which causes dimension mismatches + // when evaluating 0-row struct literals + if let Expr::Literal(ScalarValue::Struct(struct_array), _) = + expr.as_ref() + && struct_array.len() == 0 + { + return false; + } + } + true + } Expr::Literal(_, _) | Expr::Alias(..) | Expr::Unnest(_) @@ -659,8 +693,6 @@ impl ConstEvaluator { | Expr::Like { .. } | Expr::SimilarTo { .. } | Expr::Case(_) - | Expr::Cast { .. } - | Expr::TryCast { .. } | Expr::InList { .. } => true, } } @@ -1045,6 +1077,22 @@ impl TreeNodeRewriter for Simplifier<'_> { ); } } + // A = L1 AND A != L2 --> A = L1 (when L1 != L2) + Expr::BinaryExpr(BinaryExpr { + left, + op: And, + right, + }) if is_eq_and_ne_with_different_literal(&left, &right) => { + Transformed::yes(*left) + } + // A != L2 AND A = L1 --> A = L1 (when L1 != L2) + Expr::BinaryExpr(BinaryExpr { + left, + op: And, + right, + }) if is_eq_and_ne_with_different_literal(&right, &left) => { + Transformed::yes(*right) + } // // Rules for Multiply @@ -1952,12 +2000,132 @@ impl TreeNodeRewriter for Simplifier<'_> { })) } + // ======================================= + // preimage_in_comparison + // ======================================= + // + // For case: + // date_part('YEAR', expr) op literal + // + // For details see datafusion_expr::ScalarUDFImpl::preimage + Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + use datafusion_expr::Operator::*; + let is_preimage_op = matches!( + op, + Eq | NotEq + | Lt + | LtEq + | Gt + | GtEq + | IsDistinctFrom + | IsNotDistinctFrom + ); + if !is_preimage_op || is_null(&right) { + return Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr { + left, + op, + right, + }))); + } + + if let PreimageResult::Range { interval, expr } = + get_preimage(left.as_ref(), right.as_ref(), info)? + { + rewrite_with_preimage(*interval, op, expr)? + } else if let Some(swapped) = op.swap() { + if let PreimageResult::Range { interval, expr } = + get_preimage(right.as_ref(), left.as_ref(), info)? + { + rewrite_with_preimage(*interval, swapped, expr)? + } else { + Transformed::no(Expr::BinaryExpr(BinaryExpr { left, op, right })) + } + } else { + Transformed::no(Expr::BinaryExpr(BinaryExpr { left, op, right })) + } + } + // For case: + // date_part('YEAR', expr) IN (literal1, literal2, ...) + Expr::InList(InList { + expr, + list, + negated, + }) => { + if list.len() > THRESHOLD_INLINE_INLIST || list.iter().any(is_null) { + return Ok(Transformed::no(Expr::InList(InList { + expr, + list, + negated, + }))); + } + + let (op, combiner): (Operator, fn(Expr, Expr) -> Expr) = + if negated { (NotEq, and) } else { (Eq, or) }; + + let mut rewritten: Option = None; + for item in &list { + let PreimageResult::Range { interval, expr } = + get_preimage(expr.as_ref(), item, info)? + else { + return Ok(Transformed::no(Expr::InList(InList { + expr, + list, + negated, + }))); + }; + + let range_expr = rewrite_with_preimage(*interval, op, expr)?.data; + rewritten = Some(match rewritten { + None => range_expr, + Some(acc) => combiner(acc, range_expr), + }); + } + + if let Some(rewritten) = rewritten { + Transformed::yes(rewritten) + } else { + Transformed::no(Expr::InList(InList { + expr, + list, + negated, + })) + } + } + // no additional rewrites possible expr => Transformed::no(expr), }) } } +fn get_preimage( + left_expr: &Expr, + right_expr: &Expr, + info: &SimplifyContext, +) -> Result { + let Expr::ScalarFunction(ScalarFunction { func, args }) = left_expr else { + return Ok(PreimageResult::None); + }; + if !is_literal_or_literal_cast(right_expr) { + return Ok(PreimageResult::None); + } + if func.signature().volatility != Volatility::Immutable { + return Ok(PreimageResult::None); + } + func.preimage(args, right_expr, info) +} + +fn is_literal_or_literal_cast(expr: &Expr) -> bool { + match expr { + Expr::Literal(_, _) => true, + Expr::Cast(Cast { expr, .. }) => matches!(expr.as_ref(), Expr::Literal(_, _)), + Expr::TryCast(TryCast { expr, .. }) => { + matches!(expr.as_ref(), Expr::Literal(_, _)) + } + _ => false, + } +} + fn as_string_scalar(expr: &Expr) -> Option<(DataType, &Option)> { match expr { Expr::Literal(ScalarValue::Utf8(s), _) => Some((DataType::Utf8, s)), @@ -2150,7 +2318,10 @@ mod tests { use super::*; use crate::simplify_expressions::SimplifyContext; use crate::test::test_table_scan_with_name; - use arrow::datatypes::FieldRef; + use arrow::{ + array::{Int32Array, StructArray}, + datatypes::{FieldRef, Fields}, + }; use datafusion_common::{DFSchemaRef, ToDFSchema, assert_contains}; use datafusion_expr::{ expr::WindowFunction, @@ -2398,6 +2569,27 @@ mod tests { assert_eq!(simplify(expr_b), expected); } + #[test] + fn test_simplify_eq_and_neq_with_different_literals() { + // A = 1 AND A != 0 --> A = 1 (when 1 != 0) + let expr = col("c2").eq(lit(1)).and(col("c2").not_eq(lit(0))); + let expected = col("c2").eq(lit(1)); + assert_eq!(simplify(expr), expected); + + // A != 0 AND A = 1 --> A = 1 (when 1 != 0) + let expr = col("c2").not_eq(lit(0)).and(col("c2").eq(lit(1))); + let expected = col("c2").eq(lit(1)); + assert_eq!(simplify(expr), expected); + + // Should NOT simplify when literals are the same (A = 1 AND A != 1) + // This is a contradiction but handled by other rules + let expr = col("c2").eq(lit(1)).and(col("c2").not_eq(lit(1))); + // Should not be simplified by this rule (left unchanged or handled elsewhere) + let result = simplify(expr.clone()); + // The expression should not have been simplified + assert_eq!(result, expr); + } + #[test] fn test_simplify_multiply_by_one() { let expr_a = col("c2") * lit(1); @@ -4995,4 +5187,156 @@ mod tests { else_expr: None, }) } + + // -------------------------------- + // --- Struct Cast Tests ----- + // -------------------------------- + + /// Helper to create a `Struct` literal cast expression from `source_fields` and `target_fields`. + fn make_struct_cast_expr(source_fields: Fields, target_fields: Fields) -> Expr { + // Create 1-row struct array (not 0-row) so it can be evaluated by simplifier + let arrays: Vec> = vec![ + Arc::new(Int32Array::from(vec![Some(1)])), + Arc::new(Int32Array::from(vec![Some(2)])), + ]; + let struct_array = StructArray::try_new(source_fields, arrays, None).unwrap(); + + Expr::Cast(Cast::new( + Box::new(Expr::Literal( + ScalarValue::Struct(Arc::new(struct_array)), + None, + )), + DataType::Struct(target_fields), + )) + } + + #[test] + fn test_struct_cast_different_field_counts_not_foldable() { + // Test that struct casts with different field counts are NOT marked as foldable + // When field counts differ, const-folding should not be attempted + + let source_fields = Fields::from(vec![ + Arc::new(Field::new("a", DataType::Int32, true)), + Arc::new(Field::new("b", DataType::Int32, true)), + ]); + + let target_fields = Fields::from(vec![ + Arc::new(Field::new("x", DataType::Int32, true)), + Arc::new(Field::new("y", DataType::Int32, true)), + Arc::new(Field::new("z", DataType::Int32, true)), + ]); + + let expr = make_struct_cast_expr(source_fields, target_fields); + + let simplifier = + ExprSimplifier::new(SimplifyContext::default().with_schema(test_schema())); + + // The cast should remain unchanged since field counts differ + let result = simplifier.simplify(expr.clone()).unwrap(); + // Ensure const-folding was not attempted (the expression remains exactly the same) + assert_eq!( + result, expr, + "Struct cast with different field counts should remain unchanged (no const-folding)" + ); + } + + #[test] + fn test_struct_cast_same_field_count_foldable() { + // Test that struct casts with same field counts can be considered for const-folding + + let source_fields = Fields::from(vec![ + Arc::new(Field::new("a", DataType::Int32, true)), + Arc::new(Field::new("b", DataType::Int32, true)), + ]); + + let target_fields = Fields::from(vec![ + Arc::new(Field::new("a", DataType::Int32, true)), + Arc::new(Field::new("b", DataType::Int32, true)), + ]); + + let expr = make_struct_cast_expr(source_fields, target_fields); + + let simplifier = + ExprSimplifier::new(SimplifyContext::default().with_schema(test_schema())); + + // The cast should be simplified + let result = simplifier.simplify(expr.clone()).unwrap(); + // Struct casts with same field count should be const-folded to a literal + assert!(matches!(result, Expr::Literal(_, _))); + // Ensure the simplifier made a change (not identical to original) + assert_ne!( + result, expr, + "Struct cast with same field count should be simplified (not identical to input)" + ); + } + + #[test] + fn test_struct_cast_different_names_same_count() { + // Test struct cast with same field count but different names + // Field count matches; simplification should be skipped because names do not overlap + + let source_fields = Fields::from(vec![ + Arc::new(Field::new("a", DataType::Int32, true)), + Arc::new(Field::new("b", DataType::Int32, true)), + ]); + + let target_fields = Fields::from(vec![ + Arc::new(Field::new("x", DataType::Int32, true)), + Arc::new(Field::new("y", DataType::Int32, true)), + ]); + + let expr = make_struct_cast_expr(source_fields, target_fields); + + let simplifier = + ExprSimplifier::new(SimplifyContext::default().with_schema(test_schema())); + + // The cast should remain unchanged because there is no name overlap + let result = simplifier.simplify(expr.clone()).unwrap(); + assert_eq!( + result, expr, + "Struct cast with different names but same field count should not be simplified" + ); + } + + #[test] + fn test_struct_cast_empty_array_not_foldable() { + // Test that struct casts with 0-row (empty) struct arrays are NOT const-folded + // The simplifier uses a 1-row input batch, which causes dimension mismatches + // when evaluating 0-row struct literals + + let source_fields = Fields::from(vec![ + Arc::new(Field::new("a", DataType::Int32, true)), + Arc::new(Field::new("b", DataType::Int32, true)), + ]); + + let target_fields = Fields::from(vec![ + Arc::new(Field::new("a", DataType::Int32, true)), + Arc::new(Field::new("b", DataType::Int32, true)), + ]); + + // Create a 0-row (empty) struct array + let arrays: Vec> = vec![ + Arc::new(Int32Array::new(vec![].into(), None)), + Arc::new(Int32Array::new(vec![].into(), None)), + ]; + let struct_array = StructArray::try_new(source_fields, arrays, None).unwrap(); + + let expr = Expr::Cast(Cast::new( + Box::new(Expr::Literal( + ScalarValue::Struct(Arc::new(struct_array)), + None, + )), + DataType::Struct(target_fields), + )); + + let simplifier = + ExprSimplifier::new(SimplifyContext::default().with_schema(test_schema())); + + // The cast should remain unchanged since the struct array is empty (0-row) + let result = simplifier.simplify(expr.clone()).unwrap(); + assert_eq!( + result, expr, + "Struct cast with empty (0-row) array should remain unchanged" + ); + } } diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs b/datafusion/optimizer/src/simplify_expressions/mod.rs index 3ab76119cca84..b85b000821ad8 100644 --- a/datafusion/optimizer/src/simplify_expressions/mod.rs +++ b/datafusion/optimizer/src/simplify_expressions/mod.rs @@ -24,6 +24,7 @@ mod regex; pub mod simplify_exprs; pub mod simplify_literal; mod simplify_predicates; +mod udf_preimage; mod unwrap_cast; mod utils; diff --git a/datafusion/optimizer/src/simplify_expressions/udf_preimage.rs b/datafusion/optimizer/src/simplify_expressions/udf_preimage.rs new file mode 100644 index 0000000000000..da2716d13cb47 --- /dev/null +++ b/datafusion/optimizer/src/simplify_expressions/udf_preimage.rs @@ -0,0 +1,404 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::{Result, internal_err, tree_node::Transformed}; +use datafusion_expr::{Expr, Operator, and, lit, or}; +use datafusion_expr_common::interval_arithmetic::Interval; + +/// Rewrites a binary expression using its "preimage" +/// +/// Specifically it rewrites expressions of the form ` OP x` (e.g. ` = +/// x`) where `` is known to have a pre-image (aka the entire single +/// range for which it is valid) and `x` is not `NULL` +/// +/// For details see [`datafusion_expr::ScalarUDFImpl::preimage`] +pub(super) fn rewrite_with_preimage( + preimage_interval: Interval, + op: Operator, + expr: Expr, +) -> Result> { + let (lower, upper) = preimage_interval.into_bounds(); + let (lower, upper) = (lit(lower), lit(upper)); + + let rewritten_expr = match op { + // < x ==> < lower + Operator::Lt => expr.lt(lower), + // >= x ==> >= lower + Operator::GtEq => expr.gt_eq(lower), + // > x ==> >= upper + Operator::Gt => expr.gt_eq(upper), + // <= x ==> < upper + Operator::LtEq => expr.lt(upper), + // = x ==> ( >= lower) and ( < upper) + Operator::Eq => and(expr.clone().gt_eq(lower), expr.lt(upper)), + // != x ==> ( < lower) or ( >= upper) + Operator::NotEq => or(expr.clone().lt(lower), expr.gt_eq(upper)), + // is not distinct from x ==> ( is NULL and x is NULL) or (( >= lower) and ( < upper)) + // but since x is always not NULL => ( is not NULL) and ( >= lower) and ( < upper) + Operator::IsNotDistinctFrom => expr + .clone() + .is_not_null() + .and(expr.clone().gt_eq(lower)) + .and(expr.lt(upper)), + // is distinct from x ==> ( < lower) or ( >= upper) or ( is NULL and x is not NULL) or ( is not NULL and x is NULL) + // but given that x is always not NULL => ( < lower) or ( >= upper) or ( is NULL) + Operator::IsDistinctFrom => expr + .clone() + .lt(lower) + .or(expr.clone().gt_eq(upper)) + .or(expr.is_null()), + _ => return internal_err!("Expect comparison operators"), + }; + Ok(Transformed::yes(rewritten_expr)) +} + +#[cfg(test)] +mod test { + use std::any::Any; + use std::sync::Arc; + + use arrow::datatypes::{DataType, Field}; + use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue}; + use datafusion_expr::{ + ColumnarValue, Expr, Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, + Signature, Volatility, and, binary_expr, col, lit, or, preimage::PreimageResult, + simplify::SimplifyContext, + }; + + use super::Interval; + use crate::simplify_expressions::ExprSimplifier; + + fn is_distinct_from(left: Expr, right: Expr) -> Expr { + binary_expr(left, Operator::IsDistinctFrom, right) + } + + fn is_not_distinct_from(left: Expr, right: Expr) -> Expr { + binary_expr(left, Operator::IsNotDistinctFrom, right) + } + + #[derive(Debug, PartialEq, Eq, Hash)] + struct PreimageUdf { + /// Defaults to an exact signature with one Int32 argument and Immutable volatility + signature: Signature, + /// If true, returns a preimage; otherwise, returns None + enabled: bool, + } + + impl PreimageUdf { + fn new() -> Self { + Self { + signature: Signature::exact(vec![DataType::Int32], Volatility::Immutable), + enabled: true, + } + } + + /// Set the enabled flag + fn with_enabled(mut self, enabled: bool) -> Self { + self.enabled = enabled; + self + } + + /// Set the volatility + fn with_volatility(mut self, volatility: Volatility) -> Self { + self.signature.volatility = volatility; + self + } + } + + impl ScalarUDFImpl for PreimageUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "preimage_func" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(500)))) + } + + fn preimage( + &self, + args: &[Expr], + lit_expr: &Expr, + _info: &SimplifyContext, + ) -> Result { + if !self.enabled { + return Ok(PreimageResult::None); + } + if args.len() != 1 { + return Ok(PreimageResult::None); + } + + let expr = args.first().cloned().expect("Should be column expression"); + match lit_expr { + Expr::Literal(ScalarValue::Int32(Some(500)), _) => { + Ok(PreimageResult::Range { + expr, + interval: Box::new(Interval::try_new( + ScalarValue::Int32(Some(100)), + ScalarValue::Int32(Some(200)), + )?), + }) + } + Expr::Literal(ScalarValue::Int32(Some(600)), _) => { + Ok(PreimageResult::Range { + expr, + interval: Box::new(Interval::try_new( + ScalarValue::Int32(Some(300)), + ScalarValue::Int32(Some(400)), + )?), + }) + } + _ => Ok(PreimageResult::None), + } + } + } + + fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr { + let simplify_context = SimplifyContext::default().with_schema(Arc::clone(schema)); + ExprSimplifier::new(simplify_context) + .simplify(expr) + .unwrap() + } + + fn preimage_udf_expr() -> Expr { + ScalarUDF::new_from_impl(PreimageUdf::new()).call(vec![col("x")]) + } + + fn non_immutable_udf_expr() -> Expr { + ScalarUDF::new_from_impl(PreimageUdf::new().with_volatility(Volatility::Volatile)) + .call(vec![col("x")]) + } + + fn no_preimage_udf_expr() -> Expr { + ScalarUDF::new_from_impl(PreimageUdf::new().with_enabled(false)) + .call(vec![col("x")]) + } + + fn test_schema() -> DFSchemaRef { + Arc::new( + DFSchema::from_unqualified_fields( + vec![Field::new("x", DataType::Int32, true)].into(), + Default::default(), + ) + .unwrap(), + ) + } + + fn test_schema_xy() -> DFSchemaRef { + Arc::new( + DFSchema::from_unqualified_fields( + vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Int32, false), + ] + .into(), + Default::default(), + ) + .unwrap(), + ) + } + + #[test] + fn test_preimage_eq_rewrite() { + // Equality rewrite when preimage and column expression are available. + let schema = test_schema(); + let expr = preimage_udf_expr().eq(lit(500)); + let expected = and(col("x").gt_eq(lit(100)), col("x").lt(lit(200))); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_noteq_rewrite() { + // Inequality rewrite expands to disjoint ranges. + let schema = test_schema(); + let expr = preimage_udf_expr().not_eq(lit(500)); + let expected = col("x").lt(lit(100)).or(col("x").gt_eq(lit(200))); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_eq_rewrite_swapped() { + // Equality rewrite works when the literal appears on the left. + let schema = test_schema(); + let expr = lit(500).eq(preimage_udf_expr()); + let expected = and(col("x").gt_eq(lit(100)), col("x").lt(lit(200))); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_lt_rewrite() { + // Less-than comparison rewrites to the lower bound. + let schema = test_schema(); + let expr = preimage_udf_expr().lt(lit(500)); + let expected = col("x").lt(lit(100)); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_lteq_rewrite() { + // Less-than-or-equal comparison rewrites to the upper bound. + let schema = test_schema(); + let expr = preimage_udf_expr().lt_eq(lit(500)); + let expected = col("x").lt(lit(200)); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_gt_rewrite() { + // Greater-than comparison rewrites to the upper bound (inclusive). + let schema = test_schema(); + let expr = preimage_udf_expr().gt(lit(500)); + let expected = col("x").gt_eq(lit(200)); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_gteq_rewrite() { + // Greater-than-or-equal comparison rewrites to the lower bound. + let schema = test_schema(); + let expr = preimage_udf_expr().gt_eq(lit(500)); + let expected = col("x").gt_eq(lit(100)); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_is_not_distinct_from_rewrite() { + // IS NOT DISTINCT FROM rewrites to equality plus expression not-null check + // for non-null literal RHS. + let schema = test_schema(); + let expr = is_not_distinct_from(preimage_udf_expr(), lit(500)); + let expected = col("x") + .is_not_null() + .and(col("x").gt_eq(lit(100))) + .and(col("x").lt(lit(200))); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_is_distinct_from_rewrite() { + // IS DISTINCT FROM adds an explicit NULL branch for the column. + let schema = test_schema(); + let expr = is_distinct_from(preimage_udf_expr(), lit(500)); + let expected = col("x") + .lt(lit(100)) + .or(col("x").gt_eq(lit(200))) + .or(col("x").is_null()); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_in_list_rewrite() { + let schema = test_schema(); + let expr = preimage_udf_expr().in_list(vec![lit(500), lit(600)], false); + let expected = or( + and(col("x").gt_eq(lit(100)), col("x").lt(lit(200))), + and(col("x").gt_eq(lit(300)), col("x").lt(lit(400))), + ); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_not_in_list_rewrite() { + let schema = test_schema(); + let expr = preimage_udf_expr().in_list(vec![lit(500), lit(600)], true); + let expected = and( + or(col("x").lt(lit(100)), col("x").gt_eq(lit(200))), + or(col("x").lt(lit(300)), col("x").gt_eq(lit(400))), + ); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_in_list_long_list_no_rewrite() { + let schema = test_schema(); + let expr = preimage_udf_expr().in_list((1..100).map(lit).collect(), false); + + assert_eq!(optimize_test(expr.clone(), &schema), expr); + } + + #[test] + fn test_preimage_non_literal_rhs_no_rewrite() { + // Non-literal RHS should not be rewritten. + let schema = test_schema_xy(); + let expr = preimage_udf_expr().eq(col("y")); + let expected = expr.clone(); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_null_literal_no_rewrite_distinct_ops() { + // NULL literal RHS should not be rewritten for DISTINCTness operators: + // - `expr IS DISTINCT FROM NULL` <=> `NOT (expr IS NULL)` + // - `expr IS NOT DISTINCT FROM NULL` <=> `expr IS NULL` + // + // For normal comparisons (=, !=, <, <=, >, >=), `expr OP NULL` evaluates to NULL + // under SQL tri-state logic, and DataFusion's simplifier constant-folds it. + // https://docs.rs/datafusion/latest/datafusion/physical_optimizer/pruning/struct.PruningPredicate.html#boolean-tri-state-logic + + let schema = test_schema(); + + let expr = is_distinct_from(preimage_udf_expr(), lit(ScalarValue::Int32(None))); + assert_eq!(optimize_test(expr.clone(), &schema), expr); + + let expr = + is_not_distinct_from(preimage_udf_expr(), lit(ScalarValue::Int32(None))); + assert_eq!(optimize_test(expr.clone(), &schema), expr); + } + + #[test] + fn test_preimage_non_immutable_no_rewrite() { + // Non-immutable UDFs should not participate in preimage rewrites. + let schema = test_schema(); + let expr = non_immutable_udf_expr().eq(lit(500)); + let expected = expr.clone(); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_no_preimage_no_rewrite() { + // If the UDF provides no preimage, the expression should remain unchanged. + let schema = test_schema(); + let expr = no_preimage_udf_expr().eq(lit(500)); + let expected = expr.clone(); + + assert_eq!(optimize_test(expr, &schema), expected); + } +} diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index 1f214e3d365c9..b0908b47602f7 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -290,6 +290,54 @@ pub fn is_lit(expr: &Expr) -> bool { matches!(expr, Expr::Literal(_, _)) } +/// Checks if `eq_expr` is `A = L1` and `ne_expr` is `A != L2` where L1 != L2. +/// This pattern can be simplified to just `A = L1` since if A equals L1 +/// and L1 is different from L2, then A is automatically not equal to L2. +pub fn is_eq_and_ne_with_different_literal(eq_expr: &Expr, ne_expr: &Expr) -> bool { + fn extract_var_and_literal(expr: &Expr) -> Option<(&Expr, &Expr)> { + match expr { + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right, + }) + | Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::NotEq, + right, + }) => match (left.as_ref(), right.as_ref()) { + (Expr::Literal(_, _), var) => Some((var, left)), + (var, Expr::Literal(_, _)) => Some((var, right)), + _ => None, + }, + _ => None, + } + } + match (eq_expr, ne_expr) { + ( + Expr::BinaryExpr(BinaryExpr { + op: Operator::Eq, .. + }), + Expr::BinaryExpr(BinaryExpr { + op: Operator::NotEq, + .. + }), + ) => { + // Check if both compare the same expression against different literals + if let (Some((var1, lit1)), Some((var2, lit2))) = ( + extract_var_and_literal(eq_expr), + extract_var_and_literal(ne_expr), + ) && var1 == var2 + && lit1 != lit2 + { + return true; + } + false + } + _ => false, + } +} + /// negate a Not clause /// input is the clause to be negated.(args of Not clause) /// For BinaryExpr, use the negation of op instead. diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 05edd230daccb..00c8fab228117 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -184,7 +184,11 @@ impl OptimizerRule for SingleDistinctToGroupBy { func, params: AggregateFunctionParams { - mut args, distinct, .. + mut args, + distinct, + filter, + order_by, + null_treatment, }, }) => { if distinct { @@ -204,9 +208,9 @@ impl OptimizerRule for SingleDistinctToGroupBy { func, vec![col(SINGLE_DISTINCT_ALIAS)], false, // intentional to remove distinct here - None, - vec![], - None, + filter, + order_by, + null_treatment, ))) // if the aggregate function is not distinct, we need to rewrite it like two phase aggregation } else { @@ -217,9 +221,9 @@ impl OptimizerRule for SingleDistinctToGroupBy { Arc::clone(&func), args, false, - None, - vec![], - None, + filter, + order_by, + null_treatment, )) .alias(&alias_str), ); diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index a45983950496d..2915e77be2e12 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -24,6 +24,7 @@ use datafusion_common::{Result, assert_contains}; use datafusion_expr::{LogicalPlan, LogicalPlanBuilder, logical_plan::table_scan}; use std::sync::Arc; +pub mod udfs; pub mod user_defined; pub fn test_table_scan_fields() -> Vec { @@ -34,6 +35,28 @@ pub fn test_table_scan_fields() -> Vec { ] } +pub fn test_table_scan_with_struct_fields() -> Vec { + vec![ + Field::new("id", DataType::UInt32, false), + Field::new( + "user", + DataType::Struct( + vec![ + Field::new("name", DataType::Utf8, true), + Field::new("status", DataType::Utf8, true), + ] + .into(), + ), + true, + ), + ] +} + +pub fn test_table_scan_with_struct() -> Result { + let schema = Schema::new(test_table_scan_with_struct_fields()); + table_scan(Some("test"), &schema, None)?.build() +} + /// some tests share a common table with different names pub fn test_table_scan_with_name(name: &str) -> Result { let schema = Schema::new(test_table_scan_fields()); diff --git a/datafusion/optimizer/src/test/udfs.rs b/datafusion/optimizer/src/test/udfs.rs new file mode 100644 index 0000000000000..9164603dba3d5 --- /dev/null +++ b/datafusion/optimizer/src/test/udfs.rs @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; + +use arrow::datatypes::DataType; +use datafusion_common::Result; +use datafusion_expr::{ + ColumnarValue, Expr, ExpressionPlacement, ScalarFunctionArgs, ScalarUDF, + ScalarUDFImpl, Signature, TypeSignature, +}; + +/// A configurable test UDF for optimizer tests. +/// Defaults to `MoveTowardsLeafNodes` placement. Use `with_placement()` to override. +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct PlacementTestUDF { + signature: Signature, + placement: ExpressionPlacement, + id: usize, +} + +impl Default for PlacementTestUDF { + fn default() -> Self { + Self::new() + } +} + +impl PlacementTestUDF { + pub fn new() -> Self { + Self { + // Accept any one or two arguments and return UInt32 for testing purposes. + // The actual types don't matter since this UDF is not intended for execution. + signature: Signature::new( + TypeSignature::OneOf(vec![TypeSignature::Any(1), TypeSignature::Any(2)]), + datafusion_expr::Volatility::Immutable, + ), + placement: ExpressionPlacement::MoveTowardsLeafNodes, + id: 0, + } + } + + /// Set the expression placement for this UDF, which is used by optimizer rules to determine where in the plan the expression should be placed. + /// This also resets the name of the UDF to a default based on the placement. + pub fn with_placement(mut self, placement: ExpressionPlacement) -> Self { + self.placement = placement; + self + } + + /// Set the id of the UDF. + /// This is an arbitrary made up field to allow creating multiple distinct UDFs with the same placement. + pub fn with_id(mut self, id: usize) -> Self { + self.id = id; + self + } +} + +impl ScalarUDFImpl for PlacementTestUDF { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + match self.placement { + ExpressionPlacement::MoveTowardsLeafNodes => "leaf_udf", + ExpressionPlacement::KeepInPlace => "keep_in_place_udf", + ExpressionPlacement::Column => "column_udf", + ExpressionPlacement::Literal => "literal_udf", + } + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::UInt32) + } + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + panic!("PlacementTestUDF: not intended for execution") + } + fn placement(&self, _args: &[ExpressionPlacement]) -> ExpressionPlacement { + self.placement + } +} + +/// Create a `leaf_udf(arg)` expression with `MoveTowardsLeafNodes` placement. +pub fn leaf_udf_expr(arg: Expr) -> Expr { + let udf = ScalarUDF::new_from_impl( + PlacementTestUDF::new().with_placement(ExpressionPlacement::MoveTowardsLeafNodes), + ); + udf.call(vec![arg]) +} diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 36a6df54ddaf0..fd4991c24413f 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -543,7 +543,7 @@ fn recursive_cte_projection_pushdown() -> Result<()> { RecursiveQuery: is_distinct=false Projection: test.col_int32 AS id TableScan: test projection=[col_int32] - Projection: CAST(CAST(nodes.id AS Int64) + Int64(1) AS Int32) AS id + Projection: CAST(CAST(nodes.id AS Int64) + Int64(1) AS Int32) Filter: nodes.id < Int32(3) TableScan: nodes projection=[id] " @@ -567,7 +567,7 @@ fn recursive_cte_with_aliased_self_reference() -> Result<()> { RecursiveQuery: is_distinct=false Projection: test.col_int32 AS id TableScan: test projection=[col_int32] - Projection: CAST(CAST(child.id AS Int64) + Int64(1) AS Int32) AS id + Projection: CAST(CAST(child.id AS Int64) + Int64(1) AS Int32) SubqueryAlias: child Filter: nodes.id < Int32(3) TableScan: nodes projection=[id] @@ -630,7 +630,7 @@ fn recursive_cte_projection_pushdown_baseline() -> Result<()> { Projection: test.col_int32 AS n Filter: test.col_int32 = Int32(5) TableScan: test projection=[col_int32] - Projection: CAST(CAST(countdown.n AS Int64) - Int64(1) AS Int32) AS n + Projection: CAST(CAST(countdown.n AS Int64) - Int64(1) AS Int32) Filter: countdown.n > Int32(1) TableScan: countdown projection=[n] " diff --git a/datafusion/physical-expr-adapter/LICENSE.txt b/datafusion/physical-expr-adapter/LICENSE.txt new file mode 120000 index 0000000000000..1ef648f64b34f --- /dev/null +++ b/datafusion/physical-expr-adapter/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/physical-expr-adapter/NOTICE.txt b/datafusion/physical-expr-adapter/NOTICE.txt new file mode 120000 index 0000000000000..fb051c92b10b2 --- /dev/null +++ b/datafusion/physical-expr-adapter/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/physical-expr-adapter/src/lib.rs b/datafusion/physical-expr-adapter/src/lib.rs index d7c750e4a1a1c..ea4db19ee110e 100644 --- a/datafusion/physical-expr-adapter/src/lib.rs +++ b/datafusion/physical-expr-adapter/src/lib.rs @@ -21,14 +21,13 @@ html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] #![cfg_attr(docsrs, feature(doc_cfg))] -// https://github.com/apache/datafusion/issues/18881 -#![deny(clippy::allow_attributes)] //! Physical expression schema adaptation utilities for DataFusion pub mod schema_rewriter; pub use schema_rewriter::{ - DefaultPhysicalExprAdapter, DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, - PhysicalExprAdapterFactory, replace_columns_with_literals, + BatchAdapter, BatchAdapterFactory, DefaultPhysicalExprAdapter, + DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, + replace_columns_with_literals, }; diff --git a/datafusion/physical-expr-adapter/src/schema_rewriter.rs b/datafusion/physical-expr-adapter/src/schema_rewriter.rs index 83727ac092044..5a9ee8502eaa9 100644 --- a/datafusion/physical-expr-adapter/src/schema_rewriter.rs +++ b/datafusion/physical-expr-adapter/src/schema_rewriter.rs @@ -24,20 +24,24 @@ use std::collections::HashMap; use std::hash::Hash; use std::sync::Arc; +use arrow::array::RecordBatch; use arrow::compute::can_cast_types; -use arrow::datatypes::{DataType, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Field, SchemaRef}; use datafusion_common::{ Result, ScalarValue, exec_err, nested_struct::validate_struct_compatibility, tree_node::{Transformed, TransformedResult, TreeNode}, }; use datafusion_functions::core::getfield::GetFieldFunc; +use datafusion_physical_expr::PhysicalExprSimplifier; use datafusion_physical_expr::expressions::CastColumnExpr; +use datafusion_physical_expr::projection::{ProjectionExprs, Projector}; use datafusion_physical_expr::{ ScalarFunctionExpr, expressions::{self, Column}, }; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use itertools::Itertools; /// Replace column references in the given physical expression with literal values. /// @@ -137,11 +141,11 @@ where /// &self, /// logical_file_schema: SchemaRef, /// physical_file_schema: SchemaRef, -/// ) -> Arc { -/// Arc::new(CustomPhysicalExprAdapter { +/// ) -> Result> { +/// Ok(Arc::new(CustomPhysicalExprAdapter { /// logical_file_schema, /// physical_file_schema, -/// }) +/// })) /// } /// } /// ``` @@ -174,7 +178,7 @@ pub trait PhysicalExprAdapterFactory: Send + Sync + std::fmt::Debug { &self, logical_file_schema: SchemaRef, physical_file_schema: SchemaRef, - ) -> Arc; + ) -> Result>; } #[derive(Debug, Clone)] @@ -185,11 +189,11 @@ impl PhysicalExprAdapterFactory for DefaultPhysicalExprAdapterFactory { &self, logical_file_schema: SchemaRef, physical_file_schema: SchemaRef, - ) -> Arc { - Arc::new(DefaultPhysicalExprAdapter { + ) -> Result> { + Ok(Arc::new(DefaultPhysicalExprAdapter { logical_file_schema, physical_file_schema, - }) + })) } } @@ -228,7 +232,8 @@ impl PhysicalExprAdapterFactory for DefaultPhysicalExprAdapterFactory { /// # logical_file_schema: &Schema, /// # ) -> datafusion_common::Result<()> { /// let factory = DefaultPhysicalExprAdapterFactory; -/// let adapter = factory.create(Arc::new(logical_file_schema.clone()), Arc::new(physical_file_schema.clone())); +/// let adapter = +/// factory.create(Arc::new(logical_file_schema.clone()), Arc::new(physical_file_schema.clone()))?; /// let adapted_predicate = adapter.rewrite(predicate)?; /// # Ok(()) /// # } @@ -255,20 +260,20 @@ impl DefaultPhysicalExprAdapter { impl PhysicalExprAdapter for DefaultPhysicalExprAdapter { fn rewrite(&self, expr: Arc) -> Result> { let rewriter = DefaultPhysicalExprAdapterRewriter { - logical_file_schema: &self.logical_file_schema, - physical_file_schema: &self.physical_file_schema, + logical_file_schema: Arc::clone(&self.logical_file_schema), + physical_file_schema: Arc::clone(&self.physical_file_schema), }; expr.transform(|expr| rewriter.rewrite_expr(Arc::clone(&expr))) .data() } } -struct DefaultPhysicalExprAdapterRewriter<'a> { - logical_file_schema: &'a Schema, - physical_file_schema: &'a Schema, +struct DefaultPhysicalExprAdapterRewriter { + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, } -impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { +impl DefaultPhysicalExprAdapterRewriter { fn rewrite_expr( &self, expr: Arc, @@ -416,18 +421,13 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { }; let physical_field = self.physical_file_schema.field(physical_column_index); - let column = match ( - column.index() == physical_column_index, - logical_field.data_type() == physical_field.data_type(), - ) { - // If the column index matches and the data types match, we can use the column as is - (true, true) => return Ok(Transformed::no(expr)), - // If the indexes or data types do not match, we need to create a new column expression - (true, _) => column.clone(), - (false, _) => { - Column::new_with_schema(logical_field.name(), self.physical_file_schema)? - } - }; + if column.index() == physical_column_index + && logical_field.data_type() == physical_field.data_type() + { + return Ok(Transformed::no(expr)); + } + + let column = self.resolve_column(column, physical_column_index)?; if logical_field.data_type() == physical_field.data_type() { // If the data types match, we can use the column as is @@ -438,24 +438,60 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { // TODO: add optimization to move the cast from the column to literal expressions in the case of `col = 123` // since that's much cheaper to evalaute. // See https://github.com/apache/datafusion/issues/15780#issuecomment-2824716928 - // + self.create_cast_column_expr(column, logical_field) + } + + /// Resolves a column expression, handling index and type mismatches. + /// + /// Returns the appropriate Column expression when the column's index or data type + /// don't match the physical schema. Assumes that the early-exit case (both index + /// and type match) has already been checked by the caller. + fn resolve_column( + &self, + column: &Column, + physical_column_index: usize, + ) -> Result { + if column.index() == physical_column_index { + Ok(column.clone()) + } else { + Column::new_with_schema(column.name(), self.physical_file_schema.as_ref()) + } + } + + /// Validates type compatibility and creates a CastColumnExpr if needed. + /// + /// Checks whether the physical field can be cast to the logical field type, + /// handling both struct and scalar types. Returns a CastColumnExpr with the + /// appropriate configuration. + fn create_cast_column_expr( + &self, + column: Column, + logical_field: &Field, + ) -> Result>> { + let actual_physical_field = self.physical_file_schema.field(column.index()); + // For struct types, use validate_struct_compatibility which handles: // - Missing fields in source (filled with nulls) // - Extra fields in source (ignored) // - Recursive validation of nested structs // For non-struct types, use Arrow's can_cast_types - match (physical_field.data_type(), logical_field.data_type()) { + match (actual_physical_field.data_type(), logical_field.data_type()) { (DataType::Struct(physical_fields), DataType::Struct(logical_fields)) => { - validate_struct_compatibility(physical_fields, logical_fields)?; + validate_struct_compatibility( + physical_fields.as_ref(), + logical_fields.as_ref(), + )?; } _ => { - let is_compatible = - can_cast_types(physical_field.data_type(), logical_field.data_type()); + let is_compatible = can_cast_types( + actual_physical_field.data_type(), + logical_field.data_type(), + ); if !is_compatible { return exec_err!( "Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type)", column.name(), - physical_field.data_type(), + actual_physical_field.data_type(), logical_field.data_type() ); } @@ -464,7 +500,7 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { let cast_expr = Arc::new(CastColumnExpr::new( Arc::new(column), - Arc::new(physical_field.clone()), + Arc::new(actual_physical_field.clone()), Arc::new(logical_field.clone()), None, )); @@ -473,6 +509,141 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { } } +/// Factory for creating [`BatchAdapter`] instances to adapt record batches +/// to a target schema. +/// +/// This binds a target schema and allows creating adapters for different source schemas. +/// It handles: +/// - **Column reordering**: Columns are reordered to match the target schema +/// - **Type casting**: Automatic type conversion (e.g., Int32 to Int64) +/// - **Missing columns**: Nullable columns missing from source are filled with nulls +/// - **Struct field adaptation**: Nested struct fields are recursively adapted +/// +/// ## Examples +/// +/// ```rust +/// use arrow::array::{Int32Array, Int64Array, StringArray, RecordBatch}; +/// use arrow::datatypes::{DataType, Field, Schema}; +/// use datafusion_physical_expr_adapter::BatchAdapterFactory; +/// use std::sync::Arc; +/// +/// // Target schema has different column order and types +/// let target_schema = Arc::new(Schema::new(vec![ +/// Field::new("name", DataType::Utf8, true), +/// Field::new("id", DataType::Int64, false), // Int64 in target +/// Field::new("score", DataType::Float64, true), // Missing from source +/// ])); +/// +/// // Source schema has different column order and Int32 for id +/// let source_schema = Arc::new(Schema::new(vec![ +/// Field::new("id", DataType::Int32, false), // Int32 in source +/// Field::new("name", DataType::Utf8, true), +/// // Note: 'score' column is missing from source +/// ])); +/// +/// // Create factory with target schema +/// let factory = BatchAdapterFactory::new(Arc::clone(&target_schema)); +/// +/// // Create adapter for this specific source schema +/// let adapter = factory.make_adapter(Arc::clone(&source_schema)).unwrap(); +/// +/// // Create a source batch +/// let source_batch = RecordBatch::try_new( +/// source_schema, +/// vec![ +/// Arc::new(Int32Array::from(vec![1, 2, 3])), +/// Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol"])), +/// ], +/// ).unwrap(); +/// +/// // Adapt the batch to match target schema +/// let adapted = adapter.adapt_batch(&source_batch).unwrap(); +/// +/// assert_eq!(adapted.num_columns(), 3); +/// assert_eq!(adapted.column(0).data_type(), &DataType::Utf8); // name +/// assert_eq!(adapted.column(1).data_type(), &DataType::Int64); // id (cast from Int32) +/// assert_eq!(adapted.column(2).data_type(), &DataType::Float64); // score (filled with nulls) +/// ``` +#[derive(Debug)] +pub struct BatchAdapterFactory { + target_schema: SchemaRef, + expr_adapter_factory: Arc, +} + +impl BatchAdapterFactory { + /// Create a new [`BatchAdapterFactory`] with the given target schema. + pub fn new(target_schema: SchemaRef) -> Self { + let expr_adapter_factory = Arc::new(DefaultPhysicalExprAdapterFactory); + Self { + target_schema, + expr_adapter_factory, + } + } + + /// Set a custom [`PhysicalExprAdapterFactory`] to use when adapting expressions. + /// + /// Use this to customize behavior when adapting batches, e.g. to fill in missing values + /// with defaults instead of nulls. + /// + /// See [`PhysicalExprAdapter`] for more details. + pub fn with_adapter_factory( + self, + factory: Arc, + ) -> Self { + Self { + expr_adapter_factory: factory, + ..self + } + } + + /// Create a new [`BatchAdapter`] for the given source schema. + /// + /// Batches fed into this [`BatchAdapter`] *must* conform to the source schema, + /// no validation is performed at runtime to minimize overheads. + pub fn make_adapter(&self, source_schema: SchemaRef) -> Result { + let expr_adapter = self + .expr_adapter_factory + .create(Arc::clone(&self.target_schema), Arc::clone(&source_schema))?; + + let simplifier = PhysicalExprSimplifier::new(&self.target_schema); + + let projection = ProjectionExprs::from_indices( + &(0..self.target_schema.fields().len()).collect_vec(), + &self.target_schema, + ); + + let adapted = projection + .try_map_exprs(|e| simplifier.simplify(expr_adapter.rewrite(e)?))?; + let projector = adapted.make_projector(&source_schema)?; + + Ok(BatchAdapter { projector }) + } +} + +/// Adapter for transforming record batches to match a target schema. +/// +/// Create instances via [`BatchAdapterFactory`]. +/// +/// ## Performance +/// +/// The adapter pre-computes the projection expressions during creation, +/// so the [`adapt_batch`](BatchAdapter::adapt_batch) call is efficient and suitable +/// for use in hot paths like streaming file scans. +#[derive(Debug)] +pub struct BatchAdapter { + projector: Projector, +} + +impl BatchAdapter { + /// Adapt the given record batch to match the target schema. + /// + /// The input batch *must* conform to the source schema used when + /// creating this adapter. + pub fn adapt_batch(&self, batch: &RecordBatch) -> Result { + self.projector.project_batch(batch) + } +} + #[cfg(test)] mod tests { use super::*; @@ -508,7 +679,9 @@ mod tests { let (physical_schema, logical_schema) = create_test_schema(); let factory = DefaultPhysicalExprAdapterFactory; - let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); + let adapter = factory + .create(Arc::new(logical_schema), Arc::new(physical_schema)) + .unwrap(); let column_expr = Arc::new(Column::new("a", 0)); let result = adapter.rewrite(column_expr).unwrap(); @@ -521,7 +694,9 @@ mod tests { fn test_rewrite_multi_column_expr_with_type_cast() { let (physical_schema, logical_schema) = create_test_schema(); let factory = DefaultPhysicalExprAdapterFactory; - let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); + let adapter = factory + .create(Arc::new(logical_schema), Arc::new(physical_schema)) + .unwrap(); // Create a complex expression: (a + 5) OR (c > 0.0) that tests the recursive case of the rewriter let column_a = Arc::new(Column::new("a", 0)) as Arc; @@ -586,7 +761,9 @@ mod tests { )]); let factory = DefaultPhysicalExprAdapterFactory; - let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); + let adapter = factory + .create(Arc::new(logical_schema), Arc::new(physical_schema)) + .unwrap(); let column_expr = Arc::new(Column::new("data", 0)); let error_msg = adapter.rewrite(column_expr).unwrap_err().to_string(); @@ -624,35 +801,39 @@ mod tests { )]); let factory = DefaultPhysicalExprAdapterFactory; - let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); + let adapter = factory + .create(Arc::new(logical_schema), Arc::new(physical_schema)) + .unwrap(); let column_expr = Arc::new(Column::new("data", 0)); let result = adapter.rewrite(column_expr).unwrap(); + let physical_struct_fields: Fields = vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ] + .into(); + let physical_field = Arc::new(Field::new( + "data", + DataType::Struct(physical_struct_fields), + false, + )); + + let logical_struct_fields: Fields = vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8View, true), + ] + .into(); + let logical_field = Arc::new(Field::new( + "data", + DataType::Struct(logical_struct_fields), + false, + )); + let expected = Arc::new(CastColumnExpr::new( Arc::new(Column::new("data", 0)), - Arc::new(Field::new( - "data", - DataType::Struct( - vec![ - Field::new("id", DataType::Int32, false), - Field::new("name", DataType::Utf8, true), - ] - .into(), - ), - false, - )), - Arc::new(Field::new( - "data", - DataType::Struct( - vec![ - Field::new("id", DataType::Int64, false), - Field::new("name", DataType::Utf8View, true), - ] - .into(), - ), - false, - )), + physical_field, + logical_field, None, )) as Arc; @@ -664,7 +845,9 @@ mod tests { let (physical_schema, logical_schema) = create_test_schema(); let factory = DefaultPhysicalExprAdapterFactory; - let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); + let adapter = factory + .create(Arc::new(logical_schema), Arc::new(physical_schema)) + .unwrap(); let column_expr = Arc::new(Column::new("c", 2)); let result = adapter.rewrite(column_expr)?; @@ -688,7 +871,9 @@ mod tests { ]); let factory = DefaultPhysicalExprAdapterFactory; - let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); + let adapter = factory + .create(Arc::new(logical_schema), Arc::new(physical_schema)) + .unwrap(); let column_expr = Arc::new(Column::new("b", 1)); let error_msg = adapter.rewrite(column_expr).unwrap_err().to_string(); @@ -704,7 +889,9 @@ mod tests { ]); let factory = DefaultPhysicalExprAdapterFactory; - let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); + let adapter = factory + .create(Arc::new(logical_schema), Arc::new(physical_schema)) + .unwrap(); let column_expr = Arc::new(Column::new("b", 1)); let result = adapter.rewrite(column_expr).unwrap(); @@ -770,7 +957,9 @@ mod tests { let (physical_schema, logical_schema) = create_test_schema(); let factory = DefaultPhysicalExprAdapterFactory; - let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); + let adapter = factory + .create(Arc::new(logical_schema), Arc::new(physical_schema)) + .unwrap(); let column_expr = Arc::new(Column::new("b", 1)) as Arc; let result = adapter.rewrite(Arc::clone(&column_expr))?; @@ -794,7 +983,9 @@ mod tests { ]); let factory = DefaultPhysicalExprAdapterFactory; - let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); + let adapter = factory + .create(Arc::new(logical_schema), Arc::new(physical_schema)) + .unwrap(); let column_expr = Arc::new(Column::new("b", 1)); let result = adapter.rewrite(column_expr); @@ -852,8 +1043,9 @@ mod tests { ]; let factory = DefaultPhysicalExprAdapterFactory; - let adapter = - factory.create(Arc::clone(&logical_schema), Arc::clone(&physical_schema)); + let adapter = factory + .create(Arc::clone(&logical_schema), Arc::clone(&physical_schema)) + .unwrap(); let adapted_projection = projection .into_iter() @@ -954,8 +1146,9 @@ mod tests { let projection = vec![col("data", &logical_schema).unwrap()]; let factory = DefaultPhysicalExprAdapterFactory; - let adapter = - factory.create(Arc::clone(&logical_schema), Arc::clone(&physical_schema)); + let adapter = factory + .create(Arc::clone(&logical_schema), Arc::clone(&physical_schema)) + .unwrap(); let adapted_projection = projection .into_iter() @@ -1033,8 +1226,8 @@ mod tests { )]); let rewriter = DefaultPhysicalExprAdapterRewriter { - logical_file_schema: &logical_schema, - physical_file_schema: &physical_schema, + logical_file_schema: Arc::new(logical_schema), + physical_file_schema: Arc::new(physical_schema), }; // Test that when a field exists in physical schema, it returns None @@ -1046,4 +1239,257 @@ mod tests { // with ScalarUDF, which is complex to set up in a unit test. The integration tests in // datafusion/core/tests/parquet/schema_adapter.rs provide better coverage for this functionality. } + + // ============================================================================ + // BatchAdapterFactory and BatchAdapter tests + // ============================================================================ + + #[test] + fn test_batch_adapter_factory_basic() { + // Target schema + let target_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Utf8, true), + ])); + + // Source schema with different column order and type + let source_schema = Arc::new(Schema::new(vec![ + Field::new("b", DataType::Utf8, true), + Field::new("a", DataType::Int32, false), // Int32 -> Int64 + ])); + + let factory = BatchAdapterFactory::new(Arc::clone(&target_schema)); + let adapter = factory.make_adapter(Arc::clone(&source_schema)).unwrap(); + + // Create source batch + let source_batch = RecordBatch::try_new( + Arc::clone(&source_schema), + vec![ + Arc::new(StringArray::from(vec![Some("hello"), None, Some("world")])), + Arc::new(Int32Array::from(vec![1, 2, 3])), + ], + ) + .unwrap(); + + let adapted = adapter.adapt_batch(&source_batch).unwrap(); + + // Verify schema matches target + assert_eq!(adapted.num_columns(), 2); + assert_eq!(adapted.schema().field(0).name(), "a"); + assert_eq!(adapted.schema().field(0).data_type(), &DataType::Int64); + assert_eq!(adapted.schema().field(1).name(), "b"); + assert_eq!(adapted.schema().field(1).data_type(), &DataType::Utf8); + + // Verify data + let col_a = adapted + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col_a.iter().collect_vec(), vec![Some(1), Some(2), Some(3)]); + + let col_b = adapted + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + col_b.iter().collect_vec(), + vec![Some("hello"), None, Some("world")] + ); + } + + #[test] + fn test_batch_adapter_factory_missing_column() { + // Target schema with a column missing from source + let target_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, true), // exists in source + Field::new("c", DataType::Float64, true), // missing from source + ])); + + let source_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, true), + ])); + + let factory = BatchAdapterFactory::new(Arc::clone(&target_schema)); + let adapter = factory.make_adapter(Arc::clone(&source_schema)).unwrap(); + + let source_batch = RecordBatch::try_new( + Arc::clone(&source_schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2])), + Arc::new(StringArray::from(vec!["x", "y"])), + ], + ) + .unwrap(); + + let adapted = adapter.adapt_batch(&source_batch).unwrap(); + + assert_eq!(adapted.num_columns(), 3); + + // Missing column should be filled with nulls + let col_c = adapted.column(2); + assert_eq!(col_c.data_type(), &DataType::Float64); + assert_eq!(col_c.null_count(), 2); // All nulls + } + + #[test] + fn test_batch_adapter_factory_with_struct() { + // Target has struct with Int64 id + let target_struct_fields: Fields = vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, true), + ] + .into(); + let target_schema = Arc::new(Schema::new(vec![Field::new( + "data", + DataType::Struct(target_struct_fields), + false, + )])); + + // Source has struct with Int32 id + let source_struct_fields: Fields = vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ] + .into(); + let source_schema = Arc::new(Schema::new(vec![Field::new( + "data", + DataType::Struct(source_struct_fields.clone()), + false, + )])); + + let struct_array = StructArray::new( + source_struct_fields, + vec![ + Arc::new(Int32Array::from(vec![10, 20])) as _, + Arc::new(StringArray::from(vec!["a", "b"])) as _, + ], + None, + ); + + let source_batch = RecordBatch::try_new( + Arc::clone(&source_schema), + vec![Arc::new(struct_array)], + ) + .unwrap(); + + let factory = BatchAdapterFactory::new(Arc::clone(&target_schema)); + let adapter = factory.make_adapter(source_schema).unwrap(); + let adapted = adapter.adapt_batch(&source_batch).unwrap(); + + let result_struct = adapted + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + // Verify id was cast to Int64 + let id_col = result_struct.column_by_name("id").unwrap(); + assert_eq!(id_col.data_type(), &DataType::Int64); + let id_values = id_col.as_any().downcast_ref::().unwrap(); + assert_eq!(id_values.iter().collect_vec(), vec![Some(10), Some(20)]); + } + + #[test] + fn test_batch_adapter_factory_identity() { + // When source and target schemas are identical, should pass through efficiently + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, true), + ])); + + let factory = BatchAdapterFactory::new(Arc::clone(&schema)); + let adapter = factory.make_adapter(Arc::clone(&schema)).unwrap(); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["a", "b", "c"])), + ], + ) + .unwrap(); + + let adapted = adapter.adapt_batch(&batch).unwrap(); + + assert_eq!(adapted.num_columns(), 2); + assert_eq!(adapted.schema().field(0).data_type(), &DataType::Int32); + assert_eq!(adapted.schema().field(1).data_type(), &DataType::Utf8); + } + + #[test] + fn test_batch_adapter_factory_reuse() { + // Factory can create multiple adapters for different source schemas + let target_schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Int64, false), + Field::new("y", DataType::Utf8, true), + ])); + + let factory = BatchAdapterFactory::new(Arc::clone(&target_schema)); + + // First source schema + let source1 = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Utf8, true), + ])); + let adapter1 = factory.make_adapter(source1).unwrap(); + + // Second source schema (different order) + let source2 = Arc::new(Schema::new(vec![ + Field::new("y", DataType::Utf8, true), + Field::new("x", DataType::Int64, false), + ])); + let adapter2 = factory.make_adapter(source2).unwrap(); + + // Both should work correctly + assert!(format!("{:?}", adapter1).contains("BatchAdapter")); + assert!(format!("{:?}", adapter2).contains("BatchAdapter")); + } + + #[test] + fn test_rewrite_column_index_and_type_mismatch() { + let physical_schema = Schema::new(vec![ + Field::new("b", DataType::Utf8, true), + Field::new("a", DataType::Int32, false), // Index 1 + ]); + + let logical_schema = Schema::new(vec![ + Field::new("a", DataType::Int64, false), // Index 0, Different Type + Field::new("b", DataType::Utf8, true), + ]); + + let factory = DefaultPhysicalExprAdapterFactory; + let adapter = factory + .create(Arc::new(logical_schema), Arc::new(physical_schema)) + .unwrap(); + + // Logical column "a" is at index 0 + let column_expr = Arc::new(Column::new("a", 0)); + + let result = adapter.rewrite(column_expr).unwrap(); + + // Should be a CastColumnExpr + let cast_expr = result + .as_any() + .downcast_ref::() + .expect("Expected CastColumnExpr"); + + // Verify the inner column points to the correct physical index (1) + let inner_col = cast_expr + .expr() + .as_any() + .downcast_ref::() + .expect("Expected inner Column"); + assert_eq!(inner_col.name(), "a"); + assert_eq!(inner_col.index(), 1); // Physical index is 1 + + // Verify cast types + assert_eq!( + cast_expr.data_type(&Schema::empty()).unwrap(), + DataType::Int64 + ); + } } diff --git a/datafusion/physical-expr-common/src/binary_map.rs b/datafusion/physical-expr-common/src/binary_map.rs index ab95302bbb046..95d085ddfdb6c 100644 --- a/datafusion/physical-expr-common/src/binary_map.rs +++ b/datafusion/physical-expr-common/src/binary_map.rs @@ -389,7 +389,7 @@ where // is value is already present in the set? let entry = self.map.find_mut(hash, |header| { // compare value if hashes match - if header.len != value_len { + if header.hash != hash || header.len != value_len { return false; } // value is stored inline so no need to consult buffer @@ -427,7 +427,7 @@ where // Check if the value is already present in the set let entry = self.map.find_mut(hash, |header| { // compare value if hashes match - if header.len != value_len { + if header.hash != hash { return false; } // Need to compare the bytes in the buffer diff --git a/datafusion/physical-expr-common/src/binary_view_map.rs b/datafusion/physical-expr-common/src/binary_view_map.rs index 2de563472c789..ff0b7c71eec82 100644 --- a/datafusion/physical-expr-common/src/binary_view_map.rs +++ b/datafusion/physical-expr-common/src/binary_view_map.rs @@ -17,16 +17,17 @@ //! [`ArrowBytesViewMap`] and [`ArrowBytesViewSet`] for storing maps/sets of values from //! `StringViewArray`/`BinaryViewArray`. -//! Much of the code is from `binary_map.rs`, but with simpler implementation because we directly use the -//! [`GenericByteViewBuilder`]. use crate::binary_map::OutputType; use ahash::RandomState; +use arrow::array::NullBufferBuilder; use arrow::array::cast::AsArray; -use arrow::array::{Array, ArrayBuilder, ArrayRef, GenericByteViewBuilder}; +use arrow::array::{Array, ArrayRef, BinaryViewArray, ByteView, make_view}; +use arrow::buffer::{Buffer, ScalarBuffer}; use arrow::datatypes::{BinaryViewType, ByteViewType, DataType, StringViewType}; use datafusion_common::hash_utils::create_hashes; use datafusion_common::utils::proxy::{HashTableAllocExt, VecAllocExt}; use std::fmt::Debug; +use std::mem::size_of; use std::sync::Arc; /// HashSet optimized for storing string or binary values that can produce that @@ -113,6 +114,9 @@ impl ArrowBytesViewSet { /// This map is used by the special `COUNT DISTINCT` aggregate function to /// store the distinct values, and by the `GROUP BY` operator to store /// group values when they are a single string array. +/// Max size of the in-progress buffer before flushing to completed buffers +const BYTE_VIEW_MAX_BLOCK_SIZE: usize = 2 * 1024 * 1024; + pub struct ArrowBytesViewMap where V: Debug + PartialEq + Eq + Clone + Copy + Default, @@ -124,8 +128,15 @@ where /// Total size of the map in bytes map_size: usize, - /// Builder for output array - builder: GenericByteViewBuilder, + /// Views for all stored values (in insertion order) + views: Vec, + /// In-progress buffer for out-of-line string data + in_progress: Vec, + /// Completed buffers containing string data + completed: Vec, + /// Tracks null values (true = null) + nulls: NullBufferBuilder, + /// random state used to generate hashes random_state: RandomState, /// buffer that stores hash values (reused across batches to save allocations) @@ -148,7 +159,10 @@ where output_type, map: hashbrown::hash_table::HashTable::with_capacity(INITIAL_MAP_CAPACITY), map_size: 0, - builder: GenericByteViewBuilder::new(), + views: Vec::new(), + in_progress: Vec::new(), + completed: Vec::new(), + nulls: NullBufferBuilder::new(0), random_state: RandomState::new(), hashes_buffer: vec![], null: None, @@ -250,53 +264,92 @@ where // step 2: insert each value into the set, if not already present let values = values.as_byte_view::(); + // Get raw views buffer for direct comparison + let input_views = values.views(); + // Ensure lengths are equivalent - assert_eq!(values.len(), batch_hashes.len()); + assert_eq!(values.len(), self.hashes_buffer.len()); + + for i in 0..values.len() { + let view_u128 = input_views[i]; + let hash = self.hashes_buffer[i]; - for (value, &hash) in values.iter().zip(batch_hashes.iter()) { - // handle null value - let Some(value) = value else { + // handle null value via validity bitmap check + if !values.is_valid(i) { let payload = if let Some(&(payload, _offset)) = self.null.as_ref() { payload } else { let payload = make_payload_fn(None); - let null_index = self.builder.len(); - self.builder.append_null(); + let null_index = self.views.len(); + self.views.push(0); + self.nulls.append_null(); self.null = Some((payload, null_index)); payload }; observe_payload_fn(payload); continue; - }; - - // get the value as bytes - let value: &[u8] = value.as_ref(); + } - let entry = self.map.find_mut(hash, |header| { - let v = self.builder.get_value(header.view_idx); + // Extract length from the view (first 4 bytes of u128 in little-endian) + let len = view_u128 as u32; - if v.len() != value.len() { - return false; - } + // Check if value already exists + let maybe_payload = { + // Borrow completed and in_progress for comparison + let completed = &self.completed; + let in_progress = &self.in_progress; - v == value - }); + self.map + .find(hash, |header| { + if header.hash != hash { + return false; + } + + // Fast path: inline strings can be compared directly + if len <= 12 { + return header.view == view_u128; + } + + // For larger strings: first compare the 4-byte prefix + let stored_prefix = (header.view >> 32) as u32; + let input_prefix = (view_u128 >> 32) as u32; + if stored_prefix != input_prefix { + return false; + } + + // Prefix matched - compare full bytes + let byte_view = ByteView::from(header.view); + let stored_len = byte_view.length as usize; + let buffer_index = byte_view.buffer_index as usize; + let offset = byte_view.offset as usize; + + let stored_value = if buffer_index < completed.len() { + &completed[buffer_index].as_slice() + [offset..offset + stored_len] + } else { + &in_progress[offset..offset + stored_len] + }; + let input_value: &[u8] = values.value(i).as_ref(); + stored_value == input_value + }) + .map(|entry| entry.payload) + }; - let payload = if let Some(entry) = entry { - entry.payload + let payload = if let Some(payload) = maybe_payload { + payload } else { - // no existing value, make a new one. + // no existing value, make a new one + let value: &[u8] = values.value(i).as_ref(); let payload = make_payload_fn(Some(value)); - let inner_view_idx = self.builder.len(); + // Create view pointing to our buffers + let new_view = self.append_value(value); let new_header = Entry { - view_idx: inner_view_idx, + view: new_view, hash, payload, }; - self.builder.append_value(value); - self.map .insert_accounted(new_header, |h| h.hash, &mut self.map_size); payload @@ -311,29 +364,58 @@ where /// /// The values are guaranteed to be returned in the same order in which /// they were first seen. - pub fn into_state(self) -> ArrayRef { - let mut builder = self.builder; - match self.output_type { - OutputType::BinaryView => { - let array = builder.finish(); + pub fn into_state(mut self) -> ArrayRef { + // Flush any remaining in-progress buffer + if !self.in_progress.is_empty() { + let flushed = std::mem::take(&mut self.in_progress); + self.completed.push(Buffer::from_vec(flushed)); + } - Arc::new(array) - } + // Build null buffer if we have any nulls + let null_buffer = self.nulls.finish(); + + let views = ScalarBuffer::from(self.views); + let array = + unsafe { BinaryViewArray::new_unchecked(views, self.completed, null_buffer) }; + + match self.output_type { + OutputType::BinaryView => Arc::new(array), OutputType::Utf8View => { - // SAFETY: - // we asserted the input arrays were all the correct type and - // thus since all the values that went in were valid (e.g. utf8) - // so are all the values that come out - let array = builder.finish(); + // SAFETY: all input was valid utf8 let array = unsafe { array.to_string_view_unchecked() }; Arc::new(array) } - _ => { - unreachable!("Utf8/Binary should use `ArrowBytesMap`") - } + _ => unreachable!("Utf8/Binary should use `ArrowBytesMap`"), } } + /// Append a value to our buffers and return the view pointing to it + fn append_value(&mut self, value: &[u8]) -> u128 { + let len = value.len(); + let view = if len <= 12 { + make_view(value, 0, 0) + } else { + // Ensure buffer is big enough + if self.in_progress.len() + len > BYTE_VIEW_MAX_BLOCK_SIZE { + let flushed = std::mem::replace( + &mut self.in_progress, + Vec::with_capacity(BYTE_VIEW_MAX_BLOCK_SIZE), + ); + self.completed.push(Buffer::from_vec(flushed)); + } + + let buffer_index = self.completed.len() as u32; + let offset = self.in_progress.len() as u32; + self.in_progress.extend_from_slice(value); + + make_view(value, buffer_index, offset) + }; + + self.views.push(view); + self.nulls.append_non_null(); + view + } + /// Total number of entries (including null, if present) pub fn len(&self) -> usize { self.non_null_len() + self.null.map(|_| 1).unwrap_or(0) @@ -352,8 +434,16 @@ where /// Return the total size, in bytes, of memory used to store the data in /// this set, not including `self` pub fn size(&self) -> usize { + let views_size = self.views.len() * size_of::(); + let in_progress_size = self.in_progress.capacity(); + let completed_size: usize = self.completed.iter().map(|b| b.len()).sum(); + let nulls_size = self.nulls.allocated_size(); + self.map_size - + self.builder.allocated_size() + + views_size + + in_progress_size + + completed_size + + nulls_size + self.hashes_buffer.allocated_size() } } @@ -366,7 +456,8 @@ where f.debug_struct("ArrowBytesMap") .field("map", &"") .field("map_size", &self.map_size) - .field("view_builder", &self.builder) + .field("views_len", &self.views.len()) + .field("completed_buffers", &self.completed.len()) .field("random_state", &self.random_state) .field("hashes_buffer", &self.hashes_buffer) .finish() @@ -374,13 +465,20 @@ where } /// Entry in the hash table -- see [`ArrowBytesViewMap`] for more details +/// +/// Stores the view pointing to our internal buffers, eliminating the need +/// for a separate builder index. For inline strings (<=12 bytes), the view +/// contains the entire value. For out-of-line strings, the view contains +/// buffer_index and offset pointing directly to our storage. #[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] struct Entry where V: Debug + PartialEq + Eq + Clone + Copy + Default, { - /// The idx into the views array - view_idx: usize, + /// The u128 view pointing to our internal buffers. For inline strings, + /// this contains the complete value. For larger strings, this contains + /// the buffer_index/offset into our completed/in_progress buffers. + view: u128, hash: u64, diff --git a/datafusion/physical-expr-common/src/lib.rs b/datafusion/physical-expr-common/src/lib.rs index 84378a3d26eee..b6eaacdca2505 100644 --- a/datafusion/physical-expr-common/src/lib.rs +++ b/datafusion/physical-expr-common/src/lib.rs @@ -24,8 +24,6 @@ // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -// https://github.com/apache/datafusion/issues/18881 -#![deny(clippy::allow_attributes)] //! Physical Expr Common packages for [DataFusion] //! This package contains high level PhysicalExpr trait diff --git a/datafusion/physical-expr-common/src/metrics/value.rs b/datafusion/physical-expr-common/src/metrics/value.rs index 9a14b804a20b5..26f68980bad8e 100644 --- a/datafusion/physical-expr-common/src/metrics/value.rs +++ b/datafusion/physical-expr-common/src/metrics/value.rs @@ -372,19 +372,31 @@ impl Drop for ScopedTimerGuard<'_> { pub struct PruningMetrics { pruned: Arc, matched: Arc, + fully_matched: Arc, } impl Display for PruningMetrics { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { let matched = self.matched.load(Ordering::Relaxed); let total = self.pruned.load(Ordering::Relaxed) + matched; + let fully_matched = self.fully_matched.load(Ordering::Relaxed); - write!( - f, - "{} total → {} matched", - human_readable_count(total), - human_readable_count(matched) - ) + if fully_matched != 0 { + write!( + f, + "{} total → {} matched -> {} fully matched", + human_readable_count(total), + human_readable_count(matched), + human_readable_count(fully_matched) + ) + } else { + write!( + f, + "{} total → {} matched", + human_readable_count(total), + human_readable_count(matched) + ) + } } } @@ -400,6 +412,7 @@ impl PruningMetrics { Self { pruned: Arc::new(AtomicUsize::new(0)), matched: Arc::new(AtomicUsize::new(0)), + fully_matched: Arc::new(AtomicUsize::new(0)), } } @@ -417,6 +430,13 @@ impl PruningMetrics { self.matched.fetch_add(n, Ordering::Relaxed); } + /// Add `n` to the metric's fully matched value + pub fn add_fully_matched(&self, n: usize) { + // relaxed ordering for operations on `value` poses no issues + // we're purely using atomic ops with no associated memory ops + self.fully_matched.fetch_add(n, Ordering::Relaxed); + } + /// Subtract `n` to the metric's matched value. pub fn subtract_matched(&self, n: usize) { // relaxed ordering for operations on `value` poses no issues @@ -433,6 +453,11 @@ impl PruningMetrics { pub fn matched(&self) -> usize { self.matched.load(Ordering::Relaxed) } + + /// Number of items fully matched + pub fn fully_matched(&self) -> usize { + self.fully_matched.load(Ordering::Relaxed) + } } /// Counters tracking ratio metrics (e.g. matched vs total) @@ -906,8 +931,11 @@ impl MetricValue { ) => { let pruned = other_pruning_metrics.pruned.load(Ordering::Relaxed); let matched = other_pruning_metrics.matched.load(Ordering::Relaxed); + let fully_matched = + other_pruning_metrics.fully_matched.load(Ordering::Relaxed); pruning_metrics.add_pruned(pruned); pruning_metrics.add_matched(matched); + pruning_metrics.add_fully_matched(fully_matched); } ( Self::Ratio { ratio_metrics, .. }, @@ -956,20 +984,21 @@ impl MetricValue { "files_ranges_pruned_statistics" => 4, "row_groups_pruned_statistics" => 5, "row_groups_pruned_bloom_filter" => 6, - "page_index_rows_pruned" => 7, - _ => 8, + "page_index_pages_pruned" => 7, + "page_index_rows_pruned" => 8, + _ => 9, }, - Self::SpillCount(_) => 9, - Self::SpilledBytes(_) => 10, - Self::SpilledRows(_) => 11, - Self::CurrentMemoryUsage(_) => 12, - Self::Count { .. } => 13, - Self::Gauge { .. } => 14, - Self::Time { .. } => 15, - Self::Ratio { .. } => 16, - Self::StartTimestamp(_) => 17, // show timestamps last - Self::EndTimestamp(_) => 18, - Self::Custom { .. } => 19, + Self::SpillCount(_) => 10, + Self::SpilledBytes(_) => 11, + Self::SpilledRows(_) => 12, + Self::CurrentMemoryUsage(_) => 13, + Self::Count { .. } => 14, + Self::Gauge { .. } => 15, + Self::Time { .. } => 16, + Self::Ratio { .. } => 17, + Self::StartTimestamp(_) => 18, // show timestamps last + Self::EndTimestamp(_) => 19, + Self::Custom { .. } => 20, } } diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index 2358a21940912..7107b0a9004d3 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -35,6 +35,7 @@ use datafusion_common::{ }; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; +use datafusion_expr_common::placement::ExpressionPlacement; use datafusion_expr_common::sort_properties::ExprProperties; use datafusion_expr_common::statistics::Distribution; @@ -430,6 +431,16 @@ pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash { fn is_volatile_node(&self) -> bool { false } + + /// Returns placement information for this expression. + /// + /// This is used by optimizers to make decisions about expression placement, + /// such as whether to push expressions down through projections. + /// + /// The default implementation returns [`ExpressionPlacement::KeepInPlace`]. + fn placement(&self) -> ExpressionPlacement { + ExpressionPlacement::KeepInPlace + } } #[deprecated( diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 1b23beeaa37cc..7e61be3a16aec 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -85,5 +85,9 @@ name = "is_null" harness = false name = "binary_op" +[[bench]] +harness = false +name = "simplify" + [package.metadata.cargo-machete] ignored = ["half"] diff --git a/datafusion/physical-expr/benches/case_when.rs b/datafusion/physical-expr/benches/case_when.rs index eb0886a31e8df..33931a2ba98e4 100644 --- a/datafusion/physical-expr/benches/case_when.rs +++ b/datafusion/physical-expr/benches/case_when.rs @@ -20,6 +20,7 @@ use arrow::datatypes::{ArrowNativeTypeOp, Field, Schema}; use arrow::record_batch::RecordBatch; use arrow::util::test_util::seedable_rng; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, case, col, lit}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -93,6 +94,7 @@ fn criterion_benchmark(c: &mut Criterion) { run_benchmarks(c, &make_batch(8192, 100)); benchmark_lookup_table_case_when(c, 8192); + benchmark_divide_by_zero_protection(c, 8192); } fn run_benchmarks(c: &mut Criterion, batch: &RecordBatch) { @@ -517,5 +519,83 @@ fn benchmark_lookup_table_case_when(c: &mut Criterion, batch_size: usize) { } } +fn benchmark_divide_by_zero_protection(c: &mut Criterion, batch_size: usize) { + let mut group = c.benchmark_group("divide_by_zero_protection"); + + for zero_percentage in [0.0, 0.1, 0.5, 0.9] { + let rng = &mut seedable_rng(); + + let numerator: Int32Array = + (0..batch_size).map(|_| Some(rng.random::())).collect(); + + let divisor_values: Vec> = (0..batch_size) + .map(|_| { + let roll: f32 = rng.random(); + if roll < zero_percentage { + Some(0) + } else { + let mut val = rng.random::(); + while val == 0 { + val = rng.random::(); + } + Some(val) + } + }) + .collect(); + + let divisor: Int32Array = divisor_values.iter().cloned().collect(); + let divisor_copy: Int32Array = divisor_values.iter().cloned().collect(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("numerator", numerator.data_type().clone(), true), + Field::new("divisor", divisor.data_type().clone(), true), + Field::new("divisor_copy", divisor_copy.data_type().clone(), true), + ])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(numerator), + Arc::new(divisor), + Arc::new(divisor_copy), + ], + ) + .unwrap(); + + let numerator_col = col("numerator", &batch.schema()).unwrap(); + let divisor_col = col("divisor", &batch.schema()).unwrap(); + + // DivideByZeroProtection: WHEN condition checks `divisor_col > 0` and division + // uses `divisor_col` as divisor. Since the checked column matches the divisor, + // this triggers the DivideByZeroProtection optimization. + group.bench_function( + format!( + "{} rows, {}% zeros: DivideByZeroProtection", + batch_size, + (zero_percentage * 100.0) as i32 + ), + |b| { + let when = Arc::new(BinaryExpr::new( + Arc::clone(&divisor_col), + Operator::NotEq, + lit(0i32), + )); + let then = Arc::new(BinaryExpr::new( + Arc::clone(&numerator_col), + Operator::Divide, + Arc::clone(&divisor_col), + )); + let else_null: Arc = lit(ScalarValue::Int32(None)); + let expr = + Arc::new(case(None, vec![(when, then)], Some(else_null)).unwrap()); + + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }, + ); + } + + group.finish(); +} + criterion_group!(benches, criterion_benchmark); criterion_main!(benches); diff --git a/datafusion/physical-expr/benches/simplify.rs b/datafusion/physical-expr/benches/simplify.rs new file mode 100644 index 0000000000000..cc00c710004e8 --- /dev/null +++ b/datafusion/physical-expr/benches/simplify.rs @@ -0,0 +1,299 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This is an attempt at reproducing some predicates generated by TPC-DS query #76, +//! and trying to figure out how long it takes to simplify them. + +use arrow::datatypes::{DataType, Field, Schema}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr::simplifier::PhysicalExprSimplifier; +use std::hint::black_box; +use std::sync::Arc; + +use datafusion_common::ScalarValue; +use datafusion_expr::Operator; + +use datafusion_physical_expr::expressions::{ + BinaryExpr, CaseExpr, Column, IsNullExpr, Literal, +}; + +fn catalog_sales_schema() -> Schema { + Schema::new(vec![ + Field::new("cs_sold_date_sk", DataType::Int64, true), // 0 + Field::new("cs_sold_time_sk", DataType::Int64, true), // 1 + Field::new("cs_ship_date_sk", DataType::Int64, true), // 2 + Field::new("cs_bill_customer_sk", DataType::Int64, true), // 3 + Field::new("cs_bill_cdemo_sk", DataType::Int64, true), // 4 + Field::new("cs_bill_hdemo_sk", DataType::Int64, true), // 5 + Field::new("cs_bill_addr_sk", DataType::Int64, true), // 6 + Field::new("cs_ship_customer_sk", DataType::Int64, true), // 7 + Field::new("cs_ship_cdemo_sk", DataType::Int64, true), // 8 + Field::new("cs_ship_hdemo_sk", DataType::Int64, true), // 9 + Field::new("cs_ship_addr_sk", DataType::Int64, true), // 10 + Field::new("cs_call_center_sk", DataType::Int64, true), // 11 + Field::new("cs_catalog_page_sk", DataType::Int64, true), // 12 + Field::new("cs_ship_mode_sk", DataType::Int64, true), // 13 + Field::new("cs_warehouse_sk", DataType::Int64, true), // 14 + Field::new("cs_item_sk", DataType::Int64, true), // 15 + Field::new("cs_promo_sk", DataType::Int64, true), // 16 + Field::new("cs_order_number", DataType::Int64, true), // 17 + Field::new("cs_quantity", DataType::Int64, true), // 18 + Field::new("cs_wholesale_cost", DataType::Decimal128(7, 2), true), + Field::new("cs_list_price", DataType::Decimal128(7, 2), true), + Field::new("cs_sales_price", DataType::Decimal128(7, 2), true), + Field::new("cs_ext_discount_amt", DataType::Decimal128(7, 2), true), + Field::new("cs_ext_sales_price", DataType::Decimal128(7, 2), true), + Field::new("cs_ext_wholesale_cost", DataType::Decimal128(7, 2), true), + Field::new("cs_ext_list_price", DataType::Decimal128(7, 2), true), + Field::new("cs_ext_tax", DataType::Decimal128(7, 2), true), + Field::new("cs_coupon_amt", DataType::Decimal128(7, 2), true), + Field::new("cs_ext_ship_cost", DataType::Decimal128(7, 2), true), + Field::new("cs_net_paid", DataType::Decimal128(7, 2), true), + Field::new("cs_net_paid_inc_tax", DataType::Decimal128(7, 2), true), + Field::new("cs_net_paid_inc_ship", DataType::Decimal128(7, 2), true), + Field::new("cs_net_paid_inc_ship_tax", DataType::Decimal128(7, 2), true), + Field::new("cs_net_profit", DataType::Decimal128(7, 2), true), + ]) +} + +fn web_sales_schema() -> Schema { + Schema::new(vec![ + Field::new("ws_sold_date_sk", DataType::Int64, true), + Field::new("ws_sold_time_sk", DataType::Int64, true), + Field::new("ws_ship_date_sk", DataType::Int64, true), + Field::new("ws_item_sk", DataType::Int64, true), + Field::new("ws_bill_customer_sk", DataType::Int64, true), + Field::new("ws_bill_cdemo_sk", DataType::Int64, true), + Field::new("ws_bill_hdemo_sk", DataType::Int64, true), + Field::new("ws_bill_addr_sk", DataType::Int64, true), + Field::new("ws_ship_customer_sk", DataType::Int64, true), + Field::new("ws_ship_cdemo_sk", DataType::Int64, true), + Field::new("ws_ship_hdemo_sk", DataType::Int64, true), + Field::new("ws_ship_addr_sk", DataType::Int64, true), + Field::new("ws_web_page_sk", DataType::Int64, true), + Field::new("ws_web_site_sk", DataType::Int64, true), + Field::new("ws_ship_mode_sk", DataType::Int64, true), + Field::new("ws_warehouse_sk", DataType::Int64, true), + Field::new("ws_promo_sk", DataType::Int64, true), + Field::new("ws_order_number", DataType::Int64, true), + Field::new("ws_quantity", DataType::Int64, true), + Field::new("ws_wholesale_cost", DataType::Decimal128(7, 2), true), + Field::new("ws_list_price", DataType::Decimal128(7, 2), true), + Field::new("ws_sales_price", DataType::Decimal128(7, 2), true), + Field::new("ws_ext_discount_amt", DataType::Decimal128(7, 2), true), + Field::new("ws_ext_sales_price", DataType::Decimal128(7, 2), true), + Field::new("ws_ext_wholesale_cost", DataType::Decimal128(7, 2), true), + Field::new("ws_ext_list_price", DataType::Decimal128(7, 2), true), + Field::new("ws_ext_tax", DataType::Decimal128(7, 2), true), + Field::new("ws_coupon_amt", DataType::Decimal128(7, 2), true), + Field::new("ws_ext_ship_cost", DataType::Decimal128(7, 2), true), + Field::new("ws_net_paid", DataType::Decimal128(7, 2), true), + Field::new("ws_net_paid_inc_tax", DataType::Decimal128(7, 2), true), + Field::new("ws_net_paid_inc_ship", DataType::Decimal128(7, 2), true), + Field::new("ws_net_paid_inc_ship_tax", DataType::Decimal128(7, 2), true), + Field::new("ws_net_profit", DataType::Decimal128(7, 2), true), + ]) +} + +// Helper to create a literal +fn lit_i64(val: i64) -> Arc { + Arc::new(Literal::new(ScalarValue::Int64(Some(val)))) +} + +fn lit_i32(val: i32) -> Arc { + Arc::new(Literal::new(ScalarValue::Int32(Some(val)))) +} + +fn lit_bool(val: bool) -> Arc { + Arc::new(Literal::new(ScalarValue::Boolean(Some(val)))) +} + +// Helper to create binary expressions +fn and( + left: Arc, + right: Arc, +) -> Arc { + Arc::new(BinaryExpr::new(left, Operator::And, right)) +} + +fn gte( + left: Arc, + right: Arc, +) -> Arc { + Arc::new(BinaryExpr::new(left, Operator::GtEq, right)) +} + +fn lte( + left: Arc, + right: Arc, +) -> Arc { + Arc::new(BinaryExpr::new(left, Operator::LtEq, right)) +} + +fn modulo( + left: Arc, + right: Arc, +) -> Arc { + Arc::new(BinaryExpr::new(left, Operator::Modulo, right)) +} + +fn eq( + left: Arc, + right: Arc, +) -> Arc { + Arc::new(BinaryExpr::new(left, Operator::Eq, right)) +} + +/// Build a predicate similar to TPC-DS q76 catalog_sales filter. +/// Uses placeholder columns instead of hash expressions. +pub fn catalog_sales_predicate(num_partitions: usize) -> Arc { + let cs_sold_date_sk: Arc = + Arc::new(Column::new("cs_sold_date_sk", 0)); + let cs_ship_addr_sk: Arc = + Arc::new(Column::new("cs_ship_addr_sk", 10)); + let cs_item_sk: Arc = Arc::new(Column::new("cs_item_sk", 15)); + + // Use a simple modulo expression as placeholder for hash + let item_hash_mod = modulo(cs_item_sk.clone(), lit_i64(num_partitions as i64)); + let date_hash_mod = modulo(cs_sold_date_sk.clone(), lit_i64(num_partitions as i64)); + + // cs_ship_addr_sk IS NULL + let is_null_expr: Arc = Arc::new(IsNullExpr::new(cs_ship_addr_sk)); + + // Build item_sk CASE expression with num_partitions branches + let item_when_then: Vec<(Arc, Arc)> = (0 + ..num_partitions) + .map(|partition| { + let when_expr = eq(item_hash_mod.clone(), lit_i32(partition as i32)); + let then_expr = and( + gte(cs_item_sk.clone(), lit_i64(partition as i64)), + lte(cs_item_sk.clone(), lit_i64(18000)), + ); + (when_expr, then_expr) + }) + .collect(); + + let item_case_expr: Arc = + Arc::new(CaseExpr::try_new(None, item_when_then, Some(lit_bool(false))).unwrap()); + + // Build sold_date_sk CASE expression with num_partitions branches + let date_when_then: Vec<(Arc, Arc)> = (0 + ..num_partitions) + .map(|partition| { + let when_expr = eq(date_hash_mod.clone(), lit_i32(partition as i32)); + let then_expr = and( + gte(cs_sold_date_sk.clone(), lit_i64(2415000 + partition as i64)), + lte(cs_sold_date_sk.clone(), lit_i64(2488070)), + ); + (when_expr, then_expr) + }) + .collect(); + + let date_case_expr: Arc = + Arc::new(CaseExpr::try_new(None, date_when_then, Some(lit_bool(false))).unwrap()); + + // Final: is_null AND item_case AND date_case + and(and(is_null_expr, item_case_expr), date_case_expr) +} +/// Build a predicate similar to TPC-DS q76 web_sales filter. +/// Uses placeholder columns instead of hash expressions. +fn web_sales_predicate(num_partitions: usize) -> Arc { + let ws_sold_date_sk: Arc = + Arc::new(Column::new("ws_sold_date_sk", 0)); + let ws_item_sk: Arc = Arc::new(Column::new("ws_item_sk", 3)); + let ws_ship_customer_sk: Arc = + Arc::new(Column::new("ws_ship_customer_sk", 8)); + + // Use simple modulo expression as placeholder for hash + let item_hash_mod = modulo(ws_item_sk.clone(), lit_i64(num_partitions as i64)); + let date_hash_mod = modulo(ws_sold_date_sk.clone(), lit_i64(num_partitions as i64)); + + // ws_ship_customer_sk IS NULL + let is_null_expr: Arc = + Arc::new(IsNullExpr::new(ws_ship_customer_sk)); + + // Build item_sk CASE expression with num_partitions branches + let item_when_then: Vec<(Arc, Arc)> = (0 + ..num_partitions) + .map(|partition| { + let when_expr = eq(item_hash_mod.clone(), lit_i32(partition as i32)); + let then_expr = and( + gte(ws_item_sk.clone(), lit_i64(partition as i64)), + lte(ws_item_sk.clone(), lit_i64(18000)), + ); + (when_expr, then_expr) + }) + .collect(); + + let item_case_expr: Arc = + Arc::new(CaseExpr::try_new(None, item_when_then, Some(lit_bool(false))).unwrap()); + + // Build sold_date_sk CASE expression with num_partitions branches + let date_when_then: Vec<(Arc, Arc)> = (0 + ..num_partitions) + .map(|partition| { + let when_expr = eq(date_hash_mod.clone(), lit_i32(partition as i32)); + let then_expr = and( + gte(ws_sold_date_sk.clone(), lit_i64(2415000 + partition as i64)), + lte(ws_sold_date_sk.clone(), lit_i64(2488070)), + ); + (when_expr, then_expr) + }) + .collect(); + + let date_case_expr: Arc = + Arc::new(CaseExpr::try_new(None, date_when_then, Some(lit_bool(false))).unwrap()); + + and(and(is_null_expr, item_case_expr), date_case_expr) +} + +/// Measures how long `PhysicalExprSimplifier::simplify` takes for a given expression. +fn bench_simplify( + c: &mut Criterion, + name: &str, + schema: &Schema, + expr: &Arc, +) { + let simplifier = PhysicalExprSimplifier::new(schema); + c.bench_function(name, |b| { + b.iter(|| black_box(simplifier.simplify(black_box(Arc::clone(expr))).unwrap())) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + let cs_schema = catalog_sales_schema(); + let ws_schema = web_sales_schema(); + + for num_partitions in [16, 128] { + bench_simplify( + c, + &format!("tpc-ds/q76/cs/{num_partitions}"), + &cs_schema, + &catalog_sales_predicate(num_partitions), + ); + bench_simplify( + c, + &format!("tpc-ds/q76/ws/{num_partitions}"), + &ws_schema, + &web_sales_predicate(num_partitions), + ); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/physical-expr/src/analysis.rs b/datafusion/physical-expr/src/analysis.rs index d734c86726f1d..11a60afc90a10 100644 --- a/datafusion/physical-expr/src/analysis.rs +++ b/datafusion/physical-expr/src/analysis.rs @@ -178,7 +178,7 @@ pub fn analyze( "ExprBoundaries has a non-zero distinct count although it represents an empty table" ); assert_or_internal_err!( - context.selectivity == Some(0.0), + context.selectivity.unwrap_or(0.0) == 0.0, "AnalysisContext has a non-zero selectivity although it represents an empty table" ); Ok(context) diff --git a/datafusion/physical-expr/src/equivalence/properties/mod.rs b/datafusion/physical-expr/src/equivalence/properties/mod.rs index 70f97139f8af4..996bc4b08fcd2 100644 --- a/datafusion/physical-expr/src/equivalence/properties/mod.rs +++ b/datafusion/physical-expr/src/equivalence/properties/mod.rs @@ -1277,7 +1277,7 @@ impl EquivalenceProperties { // Rewriting equivalence properties in terms of new schema is not // safe when schemas are not aligned: return plan_err!( - "Schemas have to be aligned to rewrite equivalences:\n Old schema: {:?}\n New schema: {:?}", + "Schemas have to be aligned to rewrite equivalences:\n Old schema: {}\n New schema: {}", self.schema, schema ); diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 758317d3d2798..dac208be534cd 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -42,6 +42,7 @@ use crate::expressions::case::literal_lookup_table::LiteralLookupTable; use arrow::compute::kernels::merge::{MergeIndex, merge, merge_n}; use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_physical_expr_common::datum::compare_with_eq; +use datafusion_physical_expr_common::utils::scatter; use itertools::Itertools; use std::fmt::{Debug, Formatter}; @@ -64,7 +65,7 @@ enum EvalMethod { /// for expressions that are infallible and can be cheaply computed for the entire /// record batch rather than just for the rows where the predicate is true. /// - /// CASE WHEN condition THEN column [ELSE NULL] END + /// CASE WHEN condition THEN infallible_expression [ELSE NULL] END InfallibleExprOrNull, /// This is a specialization for a specific use case where we can take a fast path /// if there is just one when/then pair and both the `then` and `else` expressions @@ -72,9 +73,13 @@ enum EvalMethod { /// CASE WHEN condition THEN literal ELSE literal END ScalarOrScalar, /// This is a specialization for a specific use case where we can take a fast path - /// if there is just one when/then pair and both the `then` and `else` are expressions + /// if there is just one when/then pair, the `then` is an expression, and `else` is either + /// an expression, literal NULL or absent. /// - /// CASE WHEN condition THEN expression ELSE expression END + /// In contrast to [`EvalMethod::InfallibleExprOrNull`], this specialization can handle fallible + /// `then` expressions. + /// + /// CASE WHEN condition THEN expression [ELSE expression] END ExpressionOrExpression(ProjectedCaseBody), /// This is a specialization for [`EvalMethod::WithExpression`] when the value and results are literals @@ -659,7 +664,7 @@ impl CaseExpr { && body.else_expr.as_ref().unwrap().as_any().is::() { EvalMethod::ScalarOrScalar - } else if body.when_then_expr.len() == 1 && body.else_expr.is_some() { + } else if body.when_then_expr.len() == 1 { EvalMethod::ExpressionOrExpression(body.project()?) } else { EvalMethod::NoExpression(body.project()?) @@ -961,32 +966,40 @@ impl CaseBody { let then_batch = filter_record_batch(batch, &when_filter)?; let then_value = self.when_then_expr[0].1.evaluate(&then_batch)?; - let else_selection = not(&when_value)?; - let else_filter = create_filter(&else_selection, optimize_filter); - let else_batch = filter_record_batch(batch, &else_filter)?; - - // keep `else_expr`'s data type and return type consistent - let e = self.else_expr.as_ref().unwrap(); - let return_type = self.data_type(&batch.schema())?; - let else_expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) - .unwrap_or_else(|_| Arc::clone(e)); - - let else_value = else_expr.evaluate(&else_batch)?; - - Ok(ColumnarValue::Array(match (then_value, else_value) { - (ColumnarValue::Array(t), ColumnarValue::Array(e)) => { - merge(&when_value, &t, &e) - } - (ColumnarValue::Scalar(t), ColumnarValue::Array(e)) => { - merge(&when_value, &t.to_scalar()?, &e) - } - (ColumnarValue::Array(t), ColumnarValue::Scalar(e)) => { - merge(&when_value, &t, &e.to_scalar()?) + match &self.else_expr { + None => { + let then_array = then_value.to_array(when_value.true_count())?; + scatter(&when_value, then_array.as_ref()).map(ColumnarValue::Array) } - (ColumnarValue::Scalar(t), ColumnarValue::Scalar(e)) => { - merge(&when_value, &t.to_scalar()?, &e.to_scalar()?) + Some(else_expr) => { + let else_selection = not(&when_value)?; + let else_filter = create_filter(&else_selection, optimize_filter); + let else_batch = filter_record_batch(batch, &else_filter)?; + + // keep `else_expr`'s data type and return type consistent + let return_type = self.data_type(&batch.schema())?; + let else_expr = + try_cast(Arc::clone(else_expr), &batch.schema(), return_type.clone()) + .unwrap_or_else(|_| Arc::clone(else_expr)); + + let else_value = else_expr.evaluate(&else_batch)?; + + Ok(ColumnarValue::Array(match (then_value, else_value) { + (ColumnarValue::Array(t), ColumnarValue::Array(e)) => { + merge(&when_value, &t, &e) + } + (ColumnarValue::Scalar(t), ColumnarValue::Array(e)) => { + merge(&when_value, &t.to_scalar()?, &e) + } + (ColumnarValue::Array(t), ColumnarValue::Scalar(e)) => { + merge(&when_value, &t, &e.to_scalar()?) + } + (ColumnarValue::Scalar(t), ColumnarValue::Scalar(e)) => { + merge(&when_value, &t.to_scalar()?, &e.to_scalar()?) + } + }?)) } - }?)) + } } } @@ -1137,7 +1150,15 @@ impl CaseExpr { self.body.when_then_expr[0].1.evaluate(batch) } else if true_count == 0 { // All input rows are false/null, just call the 'else' expression - self.body.else_expr.as_ref().unwrap().evaluate(batch) + match &self.body.else_expr { + Some(else_expr) => else_expr.evaluate(batch), + None => { + let return_type = self.data_type(&batch.schema())?; + Ok(ColumnarValue::Scalar(ScalarValue::try_new_null( + &return_type, + )?)) + } + } } else if projected.projection.len() < batch.num_columns() { // The case expressions do not use all the columns of the input batch. // Project first to reduce time spent filtering. diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index bd5c63a69979f..6fced231f3e6f 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -26,6 +26,7 @@ use arrow::compute::{CastOptions, can_cast_types}; use arrow::datatypes::{DataType, DataType::*, FieldRef, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; +use datafusion_common::nested_struct::validate_struct_compatibility; use datafusion_common::{Result, not_impl_err}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; @@ -41,6 +42,22 @@ const DEFAULT_SAFE_CAST_OPTIONS: CastOptions<'static> = CastOptions { format_options: DEFAULT_FORMAT_OPTIONS, }; +/// Check if struct-to-struct casting is allowed by validating field compatibility. +/// +/// This function applies the same validation rules as execution time to ensure +/// planning-time validation matches runtime validation, enabling fail-fast behavior +/// instead of deferring errors to execution. +fn can_cast_struct_types(source: &DataType, target: &DataType) -> bool { + match (source, target) { + (Struct(source_fields), Struct(target_fields)) => { + // Apply the same struct compatibility rules as at execution time. + // This ensures planning-time validation matches execution-time validation. + validate_struct_compatibility(source_fields, target_fields).is_ok() + } + _ => false, + } +} + /// CAST expression casts an expression to a specific data type and returns a runtime error on invalid cast #[derive(Debug, Clone, Eq)] pub struct CastExpr { @@ -129,7 +146,7 @@ impl CastExpr { impl fmt::Display for CastExpr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "CAST({} AS {:?})", self.expr, self.cast_type) + write!(f, "CAST({} AS {})", self.expr, self.cast_type) } } @@ -237,6 +254,12 @@ pub fn cast_with_options( Ok(Arc::clone(&expr)) } else if can_cast_types(&expr_type, &cast_type) { Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options))) + } else if can_cast_struct_types(&expr_type, &cast_type) { + // Allow struct-to-struct casts that pass name-based compatibility validation. + // This validation is applied at planning time (now) to fail fast, rather than + // deferring errors to execution time. The name-based casting logic will be + // executed at runtime via ColumnarValue::cast_to. + Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options))) } else { not_impl_err!("Unsupported CAST from {expr_type} to {cast_type}") } @@ -289,10 +312,7 @@ mod tests { cast_with_options(col("a", &schema)?, &schema, $TYPE, $CAST_OPTIONS)?; // verify that its display is correct - assert_eq!( - format!("CAST(a@0 AS {:?})", $TYPE), - format!("{}", expression) - ); + assert_eq!(format!("CAST(a@0 AS {})", $TYPE), format!("{}", expression)); // verify that the expression's type is correct assert_eq!(expression.data_type(&schema)?, $TYPE); @@ -341,10 +361,7 @@ mod tests { cast_with_options(col("a", &schema)?, &schema, $TYPE, $CAST_OPTIONS)?; // verify that its display is correct - assert_eq!( - format!("CAST(a@0 AS {:?})", $TYPE), - format!("{}", expression) - ); + assert_eq!(format!("CAST(a@0 AS {})", $TYPE), format!("{}", expression)); // verify that the expression's type is correct assert_eq!(expression.data_type(&schema)?, $TYPE); diff --git a/datafusion/physical-expr/src/expressions/cast_column.rs b/datafusion/physical-expr/src/expressions/cast_column.rs index 3dc0293da83d4..d80b6f4a588a4 100644 --- a/datafusion/physical-expr/src/expressions/cast_column.rs +++ b/datafusion/physical-expr/src/expressions/cast_column.rs @@ -114,7 +114,7 @@ impl Display for CastColumnExpr { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "CAST_COLUMN({} AS {:?})", + "CAST_COLUMN({} AS {})", self.expr, self.target_field.data_type() ) diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index 8c7e8c319fff4..cf844790a002e 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -30,6 +30,7 @@ use arrow::{ use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{Result, internal_err, plan_err}; use datafusion_expr::ColumnarValue; +use datafusion_expr_common::placement::ExpressionPlacement; /// Represents the column at a given index in a RecordBatch /// @@ -146,6 +147,10 @@ impl PhysicalExpr for Column { fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.name) } + + fn placement(&self) -> ExpressionPlacement { + ExpressionPlacement::Column + } } impl Column { diff --git a/datafusion/physical-expr/src/expressions/dynamic_filters.rs b/datafusion/physical-expr/src/expressions/dynamic_filters.rs index 643745ac0f07e..d285f8b377eca 100644 --- a/datafusion/physical-expr/src/expressions/dynamic_filters.rs +++ b/datafusion/physical-expr/src/expressions/dynamic_filters.rs @@ -51,6 +51,10 @@ impl FilterState { /// Any `ExecutionPlan` that uses this expression and holds a reference to it internally should probably also /// implement `ExecutionPlan::reset_state` to remain compatible with recursive queries and other situations where /// the same `ExecutionPlan` is reused with different data. +/// +/// For more background, please also see the [Dynamic Filters: Passing Information Between Operators During Execution for 25x Faster Queries blog] +/// +/// [Dynamic Filters: Passing Information Between Operators During Execution for 25x Faster Queries blog]: https://datafusion.apache.org/blog/2025/09/10/dynamic-filters #[derive(Debug)] pub struct DynamicFilterPhysicalExpr { /// The original children of this PhysicalExpr, if any. @@ -272,6 +276,10 @@ impl DynamicFilterPhysicalExpr { /// /// This method will return when [`Self::update`] is called and the generation increases. /// It does not guarantee that the filter is complete. + /// + /// Producers (e.g.) HashJoinExec may never update the expression or mark it as completed if there are no consumers. + /// If you call this method on a dynamic filter created by such a producer and there are no consumers registered this method would wait indefinitely. + /// This should not happen under normal operation and would indicate a programming error either in your producer or in DataFusion if the producer is a built in node. pub async fn wait_update(&self) { let mut rx = self.state_watch.subscribe(); // Get the current generation @@ -283,17 +291,16 @@ impl DynamicFilterPhysicalExpr { /// Wait asynchronously until this dynamic filter is marked as complete. /// - /// This method returns immediately if the filter is already complete or if the filter - /// is not being used by any consumers. + /// This method returns immediately if the filter is already complete. /// Otherwise, it waits until [`Self::mark_complete`] is called. /// /// Unlike [`Self::wait_update`], this method guarantees that when it returns, /// the filter is fully complete with no more updates expected. - pub async fn wait_complete(self: &Arc) { - if !self.is_used() { - return; - } - + /// + /// Producers (e.g.) HashJoinExec may never update the expression or mark it as completed if there are no consumers. + /// If you call this method on a dynamic filter created by such a producer and there are no consumers registered this method would wait indefinitely. + /// This should not happen under normal operation and would indicate a programming error either in your producer or in DataFusion if the producer is a built in node. + pub async fn wait_complete(&self) { if self.inner.read().is_complete { return; } diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index 1f3fefc60b7ad..9105297c96d61 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -33,6 +33,7 @@ use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Expr; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; +use datafusion_expr_common::placement::ExpressionPlacement; use datafusion_expr_common::sort_properties::{ExprProperties, SortProperties}; /// Represents a literal value @@ -134,6 +135,10 @@ impl PhysicalExpr for Literal { fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(self, f) } + + fn placement(&self) -> ExpressionPlacement { + ExpressionPlacement::Literal + } } /// Create a literal expression diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index 0c9476bebaaf0..c727c8fa5f77e 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -37,7 +37,7 @@ use datafusion_expr::statistics::Distribution::{ }; use datafusion_expr::{ ColumnarValue, - type_coercion::{is_interval, is_null, is_signed_numeric, is_timestamp}, + type_coercion::{is_interval, is_signed_numeric, is_timestamp}, }; /// Negative expression @@ -190,7 +190,7 @@ pub fn negative( input_schema: &Schema, ) -> Result> { let data_type = arg.data_type(input_schema)?; - if is_null(&data_type) { + if data_type.is_null() { Ok(arg) } else if !is_signed_numeric(&data_type) && !is_interval(&data_type) diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index c9ace3239c645..c63550f430be7 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -72,7 +72,7 @@ impl TryCastExpr { impl fmt::Display for TryCastExpr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "TRY_CAST({} AS {:?})", self.expr, self.cast_type) + write!(f, "TRY_CAST({} AS {})", self.expr, self.cast_type) } } @@ -180,7 +180,7 @@ mod tests { // verify that its display is correct assert_eq!( - format!("TRY_CAST(a@0 AS {:?})", $TYPE), + format!("TRY_CAST(a@0 AS {})", $TYPE), format!("{}", expression) ); @@ -231,7 +231,7 @@ mod tests { // verify that its display is correct assert_eq!( - format!("TRY_CAST(a@0 AS {:?})", $TYPE), + format!("TRY_CAST(a@0 AS {})", $TYPE), format!("{}", expression) ); diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 988e14c28e17c..bedd348dab92f 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -24,8 +24,6 @@ // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -// https://github.com/apache/datafusion/issues/18881 -#![deny(clippy::allow_attributes)] // Backward compatibility pub mod aggregate; diff --git a/datafusion/physical-expr/src/projection.rs b/datafusion/physical-expr/src/projection.rs index 540fd620c92ce..dbbd289415277 100644 --- a/datafusion/physical-expr/src/projection.rs +++ b/datafusion/physical-expr/src/projection.rs @@ -29,7 +29,8 @@ use arrow::datatypes::{Field, Schema, SchemaRef}; use datafusion_common::stats::{ColumnStatistics, Precision}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{ - Result, ScalarValue, assert_or_internal_err, internal_datafusion_err, plan_err, + Result, ScalarValue, Statistics, assert_or_internal_err, internal_datafusion_err, + plan_err, }; use datafusion_physical_expr_common::metrics::ExecutionPlanMetricsSet; @@ -125,7 +126,8 @@ impl From for (Arc, String) { /// indices. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ProjectionExprs { - exprs: Vec, + /// [`Arc`] used for a cheap clone, which improves physical plan optimization performance. + exprs: Arc<[ProjectionExpr]>, } impl std::fmt::Display for ProjectionExprs { @@ -137,14 +139,16 @@ impl std::fmt::Display for ProjectionExprs { impl From> for ProjectionExprs { fn from(value: Vec) -> Self { - Self { exprs: value } + Self { + exprs: value.into(), + } } } impl From<&[ProjectionExpr]> for ProjectionExprs { fn from(value: &[ProjectionExpr]) -> Self { Self { - exprs: value.to_vec(), + exprs: value.iter().cloned().collect(), } } } @@ -152,7 +156,7 @@ impl From<&[ProjectionExpr]> for ProjectionExprs { impl FromIterator for ProjectionExprs { fn from_iter>(exprs: T) -> Self { Self { - exprs: exprs.into_iter().collect::>(), + exprs: exprs.into_iter().collect(), } } } @@ -164,12 +168,17 @@ impl AsRef<[ProjectionExpr]> for ProjectionExprs { } impl ProjectionExprs { - pub fn new(exprs: I) -> Self - where - I: IntoIterator, - { + /// Make a new [`ProjectionExprs`] from expressions iterator. + pub fn new(exprs: impl IntoIterator) -> Self { Self { - exprs: exprs.into_iter().collect::>(), + exprs: exprs.into_iter().collect(), + } + } + + /// Make a new [`ProjectionExprs`] from expressions. + pub fn from_expressions(exprs: impl Into>) -> Self { + Self { + exprs: exprs.into(), } } @@ -285,13 +294,14 @@ impl ProjectionExprs { { let exprs = self .exprs - .into_iter() + .iter() + .cloned() .map(|mut proj| { proj.expr = f(proj.expr)?; Ok(proj) }) - .collect::>>()?; - Ok(Self::new(exprs)) + .collect::>>()?; + Ok(Self::from_expressions(exprs)) } /// Apply another projection on top of this projection, returning the combined projection. @@ -361,17 +371,9 @@ impl ProjectionExprs { /// applied on top of this projection. pub fn try_merge(&self, other: &ProjectionExprs) -> Result { let mut new_exprs = Vec::with_capacity(other.exprs.len()); - for proj_expr in &other.exprs { - let new_expr = update_expr(&proj_expr.expr, &self.exprs, true)? - .ok_or_else(|| { - internal_datafusion_err!( - "Failed to combine projections: expression {} could not be applied on top of existing projections {}", - proj_expr.expr, - self.exprs.iter().map(|e| format!("{e}")).join(", ") - ) - })?; + for proj_expr in other.exprs.iter() { new_exprs.push(ProjectionExpr { - expr: new_expr, + expr: self.unproject_expr(&proj_expr.expr)?, alias: proj_expr.alias.clone(), }); } @@ -440,9 +442,16 @@ impl ProjectionExprs { } /// Project a schema according to this projection. - /// For example, for a projection `SELECT a AS x, b + 1 AS y`, where `a` is at index 0 and `b` is at index 1, - /// if the input schema is `[a: Int32, b: Int32, c: Int32]`, the output schema would be `[x: Int32, y: Int32]`. - /// Fields' metadata are preserved from the input schema. + /// + /// For example, given a projection: + /// * `SELECT a AS x, b + 1 AS y` + /// * where `a` is at index 0 + /// * `b` is at index 1 + /// + /// If the input schema is `[a: Int32, b: Int32, c: Int32]`, the output + /// schema would be `[x: Int32, y: Int32]`. + /// + /// Note that [`Field`] metadata are preserved from the input schema. pub fn project_schema(&self, input_schema: &Schema) -> Result { let fields: Result> = self .exprs @@ -471,6 +480,48 @@ impl ProjectionExprs { )) } + /// "unproject" an expression by applying this projection in reverse, + /// returning a new set of expressions that reference the original input + /// columns. + /// + /// For example, consider + /// * an expression `c1_c2 > 5`, and a schema `[c1, c2]` + /// * a projection `c1 + c2 as c1_c2` + /// + /// This method would rewrite the expression to `c1 + c2 > 5` + pub fn unproject_expr( + &self, + expr: &Arc, + ) -> Result> { + update_expr(expr, &self.exprs, true)?.ok_or_else(|| { + internal_datafusion_err!( + "Failed to unproject an expression {} with ProjectionExprs {}", + expr, + self.exprs.iter().map(|e| format!("{e}")).join(", ") + ) + }) + } + + /// "project" an expression using these projection's expressions + /// + /// For example, consider + /// * an expression `c1 + c2 > 5`, and a schema `[c1, c2]` + /// * a projection `c1 + c2 as c1_c2` + /// + /// * This method would rewrite the expression to `c1_c2 > 5` + pub fn project_expr( + &self, + expr: &Arc, + ) -> Result> { + update_expr(expr, &self.exprs, false)?.ok_or_else(|| { + internal_datafusion_err!( + "Failed to project an expression {} with ProjectionExprs {}", + expr, + self.exprs.iter().map(|e| format!("{e}")).join(", ") + ) + }) + } + /// Create a new [`Projector`] from this projection and an input schema. /// /// A [`Projector`] can be used to apply this projection to record batches. @@ -602,12 +653,12 @@ impl ProjectionExprs { /// ``` pub fn project_statistics( &self, - mut stats: datafusion_common::Statistics, + mut stats: Statistics, output_schema: &Schema, - ) -> Result { + ) -> Result { let mut column_statistics = vec![]; - for proj_expr in &self.exprs { + for proj_expr in self.exprs.iter() { let expr = &proj_expr.expr; let col_stats = if let Some(col) = expr.as_any().downcast_ref::() { std::mem::take(&mut stats.column_statistics[col.index()]) @@ -754,35 +805,92 @@ impl Projector { } } -impl IntoIterator for ProjectionExprs { - type Item = ProjectionExpr; - type IntoIter = std::vec::IntoIter; +/// Describes an immutable reference counted projection. +/// +/// This structure represents projecting a set of columns by index. +/// [`Arc`] is used to make it cheap to clone. +pub type ProjectionRef = Arc<[usize]>; - fn into_iter(self) -> Self::IntoIter { - self.exprs.into_iter() - } +/// Combine two projections. +/// +/// If `p1` is [`None`] then there are no changes. +/// Otherwise, if passed `p2` is not [`None`] then it is remapped +/// according to the `p1`. Otherwise, there are no changes. +/// +/// # Example +/// +/// If stored projection is [0, 2] and we call `apply_projection([0, 2, 3])`, +/// then the resulting projection will be [0, 3]. +/// +/// # Error +/// +/// Returns an internal error if `p1` contains index that is greater than `p2` len. +/// +pub fn combine_projections( + p1: Option<&ProjectionRef>, + p2: Option<&ProjectionRef>, +) -> Result> { + let Some(p1) = p1 else { + return Ok(None); + }; + let Some(p2) = p2 else { + return Ok(Some(Arc::clone(p1))); + }; + + Ok(Some( + p1.iter() + .map(|i| { + let idx = *i; + assert_or_internal_err!( + idx < p2.len(), + "unable to apply projection: index {} is greater than new projection len {}", + idx, + p2.len(), + ); + Ok(p2[*i]) + }) + .collect::>>()?, + )) } -/// The function operates in two modes: +/// The function projects / unprojects an expression with respect to set of +/// projection expressions. +/// +/// See also [`ProjectionExprs::unproject_expr`] and [`ProjectionExprs::project_expr`] +/// +/// 1) When `unproject` is `true`: +/// +/// Rewrites an expression with respect to the projection expressions, +/// effectively "unprojecting" it to reference the original input columns. +/// +/// For example, given +/// * the expressions `a@1 + b@2` and `c@0` +/// * and projection expressions `c@2, a@0, b@1` +/// +/// Then +/// * `a@1 + b@2` becomes `a@0 + b@1` +/// * `c@0` becomes `c@2` +/// +/// 2) When `unproject` is `false`: /// -/// 1) When `sync_with_child` is `true`: +/// Rewrites the expression to reference the projected expressions, +/// effectively "projecting" it. The resulting expression will reference the +/// indices as they appear in the projection. /// -/// The function updates the indices of `expr` if the expression resides -/// in the input plan. For instance, given the expressions `a@1 + b@2` -/// and `c@0` with the input schema `c@2, a@0, b@1`, the expressions are -/// updated to `a@0 + b@1` and `c@2`. +/// If the expression cannot be rewritten after the projection, it returns +/// `None`. /// -/// 2) When `sync_with_child` is `false`: +/// For example, given +/// * the expressions `c@0`, `a@1` and `b@2` +/// * the projection `a@1 as a, c@0 as c_new`, /// -/// The function determines how the expression would be updated if a projection -/// was placed before the plan associated with the expression. If the expression -/// cannot be rewritten after the projection, it returns `None`. For example, -/// given the expressions `c@0`, `a@1` and `b@2`, and the projection with -/// an output schema of `a, c_new`, then `c@0` becomes `c_new@1`, `a@1` becomes -/// `a@0`, but `b@2` results in `None` since the projection does not include `b`. +/// Then +/// * `c@0` becomes `c_new@1` +/// * `a@1` becomes `a@0` +/// * `b@2` results in `None` since the projection does not include `b`. /// /// # Errors -/// This function returns an error if `sync_with_child` is `true` and if any expression references +/// This function returns an error if `unproject` is `true` and if any expression references /// an index that is out of bounds for `projected_exprs`. /// For example: /// @@ -793,7 +901,7 @@ impl IntoIterator for ProjectionExprs { pub fn update_expr( expr: &Arc, projected_exprs: &[ProjectionExpr], - sync_with_child: bool, + unproject: bool, ) -> Result>> { #[derive(Debug, PartialEq)] enum RewriteState { @@ -817,7 +925,7 @@ pub fn update_expr( let Some(column) = expr.as_any().downcast_ref::() else { return Ok(Transformed::no(expr)); }; - if sync_with_child { + if unproject { state = RewriteState::RewrittenValid; // Update the index of `column`: let projected_expr = projected_exprs.get(column.index()).ok_or_else(|| { diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index aa090743ad441..dab4153fa6828 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -45,8 +45,8 @@ use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::functions::fields_with_udf; use datafusion_expr::{ - ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, Volatility, - expr_vec_fmt, + ColumnarValue, ExpressionPlacement, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, + Volatility, expr_vec_fmt, }; /// Physical expression of a scalar function @@ -362,6 +362,12 @@ impl PhysicalExpr for ScalarFunctionExpr { fn is_volatile_node(&self) -> bool { self.fun.signature().volatility == Volatility::Volatile } + + fn placement(&self) -> ExpressionPlacement { + let arg_placements: Vec<_> = + self.args.iter().map(|arg| arg.placement()).collect(); + self.fun.placement(&arg_placements) + } } #[cfg(test)] diff --git a/datafusion/physical-expr/src/simplifier/const_evaluator.rs b/datafusion/physical-expr/src/simplifier/const_evaluator.rs index 65111b2911654..1e62e47ce2066 100644 --- a/datafusion/physical-expr/src/simplifier/const_evaluator.rs +++ b/datafusion/physical-expr/src/simplifier/const_evaluator.rs @@ -25,7 +25,6 @@ use arrow::record_batch::RecordBatch; use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr_common::columnar_value::ColumnarValue; -use datafusion_physical_expr_common::physical_expr::is_volatile; use crate::PhysicalExpr; use crate::expressions::{Column, Literal}; @@ -41,17 +40,22 @@ use crate::expressions::{Column, Literal}; /// - `(1 + 2) * 3` -> `9` (with bottom-up traversal) /// - `'hello' || ' world'` -> `'hello world'` pub fn simplify_const_expr( - expr: &Arc, + expr: Arc, ) -> Result>> { - if is_volatile(expr) || has_column_references(expr) { - return Ok(Transformed::no(Arc::clone(expr))); - } + simplify_const_expr_with_dummy(expr, &create_dummy_batch()?) +} - // Create a 1-row dummy batch for evaluation - let batch = create_dummy_batch()?; +pub(crate) fn simplify_const_expr_with_dummy( + expr: Arc, + batch: &RecordBatch, +) -> Result>> { + // If expr is already a const literal or can't be evaluated into one. + if expr.as_any().is::() || (!can_evaluate_as_constant(&expr)) { + return Ok(Transformed::no(expr)); + } // Evaluate the expression - match expr.evaluate(&batch) { + match expr.evaluate(batch) { Ok(ColumnarValue::Scalar(scalar)) => { Ok(Transformed::yes(Arc::new(Literal::new(scalar)))) } @@ -62,17 +66,33 @@ pub fn simplify_const_expr( } Ok(_) => { // Unexpected result - keep original expression - Ok(Transformed::no(Arc::clone(expr))) + Ok(Transformed::no(expr)) } Err(_) => { // On error, keep original expression // The expression might succeed at runtime due to short-circuit evaluation // or other runtime conditions - Ok(Transformed::no(Arc::clone(expr))) + Ok(Transformed::no(expr)) } } } +fn can_evaluate_as_constant(expr: &Arc) -> bool { + let mut can_evaluate = true; + + expr.apply(|e| { + if e.as_any().is::() || e.is_volatile_node() { + can_evaluate = false; + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + }) + .expect("apply should not fail"); + + can_evaluate +} + /// Create a 1-row dummy RecordBatch for evaluating constant expressions. /// /// The batch is never actually accessed for data - it's just needed because @@ -80,7 +100,7 @@ pub fn simplify_const_expr( /// that only contain literals, the batch content is irrelevant. /// /// This is the same approach used in the logical expression `ConstEvaluator`. -fn create_dummy_batch() -> Result { +pub(crate) fn create_dummy_batch() -> Result { // RecordBatch requires at least one column let dummy_schema = Arc::new(Schema::new(vec![Field::new("_", DataType::Null, true)])); let col = new_null_array(&DataType::Null, 1); diff --git a/datafusion/physical-expr/src/simplifier/mod.rs b/datafusion/physical-expr/src/simplifier/mod.rs index 97395f4fe8a27..45ead82a0a93d 100644 --- a/datafusion/physical-expr/src/simplifier/mod.rs +++ b/datafusion/physical-expr/src/simplifier/mod.rs @@ -21,7 +21,14 @@ use arrow::datatypes::Schema; use datafusion_common::{Result, tree_node::TreeNode}; use std::sync::Arc; -use crate::{PhysicalExpr, simplifier::not::simplify_not_expr}; +use crate::{ + PhysicalExpr, + simplifier::{ + const_evaluator::{create_dummy_batch, simplify_const_expr_with_dummy}, + not::simplify_not_expr, + unwrap_cast::unwrap_cast_in_comparison, + }, +}; pub mod const_evaluator; pub mod not; @@ -50,21 +57,23 @@ impl<'a> PhysicalExprSimplifier<'a> { let mut count = 0; let schema = self.schema; + let batch = create_dummy_batch()?; + while count < MAX_LOOP_COUNT { count += 1; let result = current_expr.transform(|node| { - #[cfg(test)] + #[cfg(debug_assertions)] let original_type = node.data_type(schema).unwrap(); // Apply NOT expression simplification first, then unwrap cast optimization, // then constant expression evaluation - let rewritten = simplify_not_expr(&node, schema)? + let rewritten = simplify_not_expr(node, schema)? + .transform_data(|node| unwrap_cast_in_comparison(node, schema))? .transform_data(|node| { - unwrap_cast::unwrap_cast_in_comparison(node, schema) - })? - .transform_data(|node| const_evaluator::simplify_const_expr(&node))?; + simplify_const_expr_with_dummy(node, &batch) + })?; - #[cfg(test)] + #[cfg(debug_assertions)] assert_eq!( rewritten.data.data_type(schema).unwrap(), original_type, diff --git a/datafusion/physical-expr/src/simplifier/not.rs b/datafusion/physical-expr/src/simplifier/not.rs index 9b65d5cba95a5..ea5467d0a4b42 100644 --- a/datafusion/physical-expr/src/simplifier/not.rs +++ b/datafusion/physical-expr/src/simplifier/not.rs @@ -44,13 +44,13 @@ use crate::expressions::{BinaryExpr, InListExpr, Literal, NotExpr, in_list, lit} /// TreeNodeRewriter, multiple passes will automatically be applied until no more /// transformations are possible. pub fn simplify_not_expr( - expr: &Arc, + expr: Arc, schema: &Schema, ) -> Result>> { // Check if this is a NOT expression let not_expr = match expr.as_any().downcast_ref::() { Some(not_expr) => not_expr, - None => return Ok(Transformed::no(Arc::clone(expr))), + None => return Ok(Transformed::no(expr)), }; let inner_expr = not_expr.arg(); @@ -120,5 +120,5 @@ pub fn simplify_not_expr( } // If no simplification possible, return the original expression - Ok(Transformed::no(Arc::clone(expr))) + Ok(Transformed::no(expr)) } diff --git a/datafusion/physical-expr/src/simplifier/unwrap_cast.rs b/datafusion/physical-expr/src/simplifier/unwrap_cast.rs index ae6da9c5e0dc5..0de517cd36c87 100644 --- a/datafusion/physical-expr/src/simplifier/unwrap_cast.rs +++ b/datafusion/physical-expr/src/simplifier/unwrap_cast.rs @@ -34,10 +34,7 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Schema}; -use datafusion_common::{ - Result, ScalarValue, - tree_node::{Transformed, TreeNode}, -}; +use datafusion_common::{Result, ScalarValue, tree_node::Transformed}; use datafusion_expr::Operator; use datafusion_expr_common::casts::try_cast_literal_to_type; @@ -49,14 +46,12 @@ pub(crate) fn unwrap_cast_in_comparison( expr: Arc, schema: &Schema, ) -> Result>> { - expr.transform_down(|e| { - if let Some(binary) = e.as_any().downcast_ref::() - && let Some(unwrapped) = try_unwrap_cast_binary(binary, schema)? - { - return Ok(Transformed::yes(unwrapped)); - } - Ok(Transformed::no(e)) - }) + if let Some(binary) = expr.as_any().downcast_ref::() + && let Some(unwrapped) = try_unwrap_cast_binary(binary, schema)? + { + return Ok(Transformed::yes(unwrapped)); + } + Ok(Transformed::no(expr)) } /// Try to unwrap casts in binary expressions @@ -144,7 +139,7 @@ mod tests { use super::*; use crate::expressions::{col, lit}; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::ScalarValue; + use datafusion_common::{ScalarValue, tree_node::TreeNode}; use datafusion_expr::Operator; /// Check if an expression is a cast expression @@ -484,8 +479,10 @@ mod tests { let and_expr = Arc::new(BinaryExpr::new(compare1, Operator::And, compare2)); - // Apply unwrap cast optimization - let result = unwrap_cast_in_comparison(and_expr, &schema).unwrap(); + // Apply unwrap cast optimization recursively + let result = (and_expr as Arc) + .transform_down(|node| unwrap_cast_in_comparison(node, &schema)) + .unwrap(); // Should be transformed assert!(result.transformed); @@ -602,8 +599,10 @@ mod tests { // Create AND expression let and_expr = Arc::new(BinaryExpr::new(c1_binary, Operator::And, c2_binary)); - // Apply unwrap cast optimization - let result = unwrap_cast_in_comparison(and_expr, &schema).unwrap(); + // Apply unwrap cast optimization recursively + let result = (and_expr as Arc) + .transform_down(|node| unwrap_cast_in_comparison(node, &schema)) + .unwrap(); // Should be transformed assert!(result.transformed); diff --git a/datafusion/physical-optimizer/src/aggregate_statistics.rs b/datafusion/physical-optimizer/src/aggregate_statistics.rs index cf3c15509c29a..5caee8b047d83 100644 --- a/datafusion/physical-optimizer/src/aggregate_statistics.rs +++ b/datafusion/physical-optimizer/src/aggregate_statistics.rs @@ -20,7 +20,7 @@ use datafusion_common::Result; use datafusion_common::config::ConfigOptions; use datafusion_common::scalar::ScalarValue; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_physical_plan::aggregates::AggregateExec; +use datafusion_physical_plan::aggregates::{AggregateExec, AggregateInputMode}; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; use datafusion_physical_plan::udaf::{AggregateFunctionExpr, StatisticsArgs}; @@ -116,13 +116,13 @@ impl PhysicalOptimizerRule for AggregateStatistics { /// the `ExecutionPlan.children()` method that returns an owned reference. fn take_optimizable(node: &dyn ExecutionPlan) -> Option> { if let Some(final_agg_exec) = node.as_any().downcast_ref::() - && !final_agg_exec.mode().is_first_stage() + && final_agg_exec.mode().input_mode() == AggregateInputMode::Partial && final_agg_exec.group_expr().is_empty() { let mut child = Arc::clone(final_agg_exec.input()); loop { if let Some(partial_agg_exec) = child.as_any().downcast_ref::() - && partial_agg_exec.mode().is_first_stage() + && partial_agg_exec.mode().input_mode() == AggregateInputMode::Raw && partial_agg_exec.group_expr().is_empty() && partial_agg_exec.filter_expr().iter().all(|e| e.is_none()) { diff --git a/datafusion/physical-optimizer/src/combine_partial_final_agg.rs b/datafusion/physical-optimizer/src/combine_partial_final_agg.rs index 782e0754b7d27..6d8e7995c18c2 100644 --- a/datafusion/physical-optimizer/src/combine_partial_final_agg.rs +++ b/datafusion/physical-optimizer/src/combine_partial_final_agg.rs @@ -98,7 +98,9 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate { Arc::clone(input_agg_exec.input()), input_agg_exec.input_schema(), ) - .map(|combined_agg| combined_agg.with_limit(agg_exec.limit())) + .map(|combined_agg| { + combined_agg.with_limit_options(agg_exec.limit_options()) + }) .ok() .map(Arc::new) } else { diff --git a/datafusion/physical-optimizer/src/enforce_distribution.rs b/datafusion/physical-optimizer/src/enforce_distribution.rs index 6120e1f3b5826..790669b5c9dbf 100644 --- a/datafusion/physical-optimizer/src/enforce_distribution.rs +++ b/datafusion/physical-optimizer/src/enforce_distribution.rs @@ -36,7 +36,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::error::Result; use datafusion_common::stats::Precision; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_expr::logical_plan::JoinType; +use datafusion_expr::logical_plan::{Aggregate, JoinType}; use datafusion_physical_expr::expressions::{Column, NoOp}; use datafusion_physical_expr::utils::map_columns_before_projection; use datafusion_physical_expr::{ @@ -49,7 +49,7 @@ use datafusion_physical_plan::aggregates::{ use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::execution_plan::EmissionType; use datafusion_physical_plan::joins::{ - CrossJoinExec, HashJoinExec, PartitionMode, SortMergeJoinExec, + CrossJoinExec, HashJoinExec, HashJoinExecBuilder, PartitionMode, SortMergeJoinExec, }; use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; use datafusion_physical_plan::repartition::RepartitionExec; @@ -295,6 +295,7 @@ pub fn adjust_input_keys_ordering( projection, mode, null_equality, + null_aware, .. }) = plan.as_any().downcast_ref::() { @@ -304,17 +305,19 @@ pub fn adjust_input_keys_ordering( Vec<(PhysicalExprRef, PhysicalExprRef)>, Vec, )| { - HashJoinExec::try_new( + HashJoinExecBuilder::new( Arc::clone(left), Arc::clone(right), new_conditions.0, - filter.clone(), - join_type, - // TODO: although projection is not used in the join here, because projection pushdown is after enforce_distribution. Maybe we need to handle it later. Same as filter. - projection.clone(), - PartitionMode::Partitioned, - *null_equality, + *join_type, ) + .with_filter(filter.clone()) + // TODO: although projection is not used in the join here, because projection pushdown is after enforce_distribution. Maybe we need to handle it later. Same as filter. + .with_projection_ref(projection.clone()) + .with_partition_mode(PartitionMode::Partitioned) + .with_null_equality(*null_equality) + .with_null_aware(*null_aware) + .build() .map(|e| Arc::new(e) as _) }; return reorder_partitioned_join_keys( @@ -618,6 +621,7 @@ pub fn reorder_join_keys_to_inputs( projection, mode, null_equality, + null_aware, .. }) = plan_any.downcast_ref::() { @@ -635,16 +639,20 @@ pub fn reorder_join_keys_to_inputs( right_keys, } = join_keys; let new_join_on = new_join_conditions(&left_keys, &right_keys); - return Ok(Arc::new(HashJoinExec::try_new( - Arc::clone(left), - Arc::clone(right), - new_join_on, - filter.clone(), - join_type, - projection.clone(), - PartitionMode::Partitioned, - *null_equality, - )?)); + return Ok(Arc::new( + HashJoinExecBuilder::new( + Arc::clone(left), + Arc::clone(right), + new_join_on, + *join_type, + ) + .with_filter(filter.clone()) + .with_projection_ref(projection.clone()) + .with_partition_mode(PartitionMode::Partitioned) + .with_null_equality(*null_equality) + .with_null_aware(*null_aware) + .build()?, + )); } } } else if let Some(SortMergeJoinExec { @@ -1297,10 +1305,25 @@ pub fn ensure_distribution( // Allow subset satisfaction when: // 1. Current partition count >= threshold // 2. Not a partitioned join since must use exact hash matching for joins + // 3. Not a grouping set aggregate (requires exact hash including __grouping_id) let current_partitions = child.plan.output_partitioning().partition_count(); + + // Check if the hash partitioning requirement includes __grouping_id column. + // Grouping set aggregates (ROLLUP, CUBE, GROUPING SETS) require exact hash + // partitioning on all group columns including __grouping_id to ensure partial + // aggregates from different partitions are correctly combined. + let requires_grouping_id = matches!(&requirement, Distribution::HashPartitioned(exprs) + if exprs.iter().any(|expr| { + expr.as_any() + .downcast_ref::() + .is_some_and(|col| col.name() == Aggregate::INTERNAL_GROUPING_ID) + }) + ); + let allow_subset_satisfy_partitioning = current_partitions >= subset_satisfaction_threshold - && !is_partitioned_join; + && !is_partitioned_join + && !requires_grouping_id; // When `repartition_file_scans` is set, attempt to increase // parallelism at the source. diff --git a/datafusion/physical-optimizer/src/enforce_sorting/mod.rs b/datafusion/physical-optimizer/src/enforce_sorting/mod.rs index a5fafb9e87e1d..247ebb2785dd3 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/mod.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/mod.rs @@ -581,11 +581,17 @@ fn analyze_immediate_sort_removal( // Remove the sort: node.children = node.children.swap_remove(0).children; if let Some(fetch) = sort_exec.fetch() { + let required_ordering = sort_exec.properties().output_ordering().cloned(); // If the sort has a fetch, we need to add a limit: if properties.output_partitioning().partition_count() == 1 { - Arc::new(GlobalLimitExec::new(Arc::clone(sort_input), 0, Some(fetch))) + let mut global_limit = + GlobalLimitExec::new(Arc::clone(sort_input), 0, Some(fetch)); + global_limit.set_required_ordering(required_ordering); + Arc::new(global_limit) } else { - Arc::new(LocalLimitExec::new(Arc::clone(sort_input), fetch)) + let mut local_limit = LocalLimitExec::new(Arc::clone(sort_input), fetch); + local_limit.set_required_ordering(required_ordering); + Arc::new(local_limit) } } else { Arc::clone(sort_input) diff --git a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs index 698fdea8e766e..2d9bfe217f40e 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs @@ -35,6 +35,7 @@ use datafusion_physical_expr_common::sort_expr::{ LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortExpr, PhysicalSortRequirement, }; +use datafusion_physical_plan::aggregates::AggregateExec; use datafusion_physical_plan::execution_plan::CardinalityEffect; use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::joins::utils::{ @@ -353,6 +354,8 @@ fn pushdown_requirement_to_children( Ok(None) } } + } else if let Some(aggregate_exec) = plan.as_any().downcast_ref::() { + handle_aggregate_pushdown(aggregate_exec, parent_required) } else if maintains_input_order.is_empty() || !maintains_input_order.iter().any(|o| *o) || plan.as_any().is::() @@ -388,6 +391,77 @@ fn pushdown_requirement_to_children( // TODO: Add support for Projection push down } +/// Try to push sorting through [`AggregateExec`] +/// +/// `AggregateExec` only preserves the input order of its group by columns +/// (not aggregates in general, which are formed from arbitrary expressions over +/// input) +/// +/// Thus function rewrites the parent required ordering in terms of the +/// aggregate input if possible. This rewritten requirement represents the +/// ordering of the `AggregateExec`'s **input** that would also satisfy the +/// **parent** ordering. +/// +/// If no such mapping is possible (e.g. because the sort references aggregate +/// columns), returns None. +fn handle_aggregate_pushdown( + aggregate_exec: &AggregateExec, + parent_required: OrderingRequirements, +) -> Result>>> { + if !aggregate_exec + .maintains_input_order() + .into_iter() + .any(|o| o) + { + return Ok(None); + } + + let group_expr = aggregate_exec.group_expr(); + // GROUPING SETS introduce additional output columns and NULL substitutions; + // skip pushdown until we can map those cases safely. + if group_expr.has_grouping_set() { + return Ok(None); + } + + let group_input_exprs = group_expr.input_exprs(); + let parent_requirement = parent_required.into_single(); + let mut child_requirement = Vec::with_capacity(parent_requirement.len()); + + for req in parent_requirement { + // Sort above AggregateExec should reference its output columns. Map each + // output group-by column to its original input expression. + let Some(column) = req.expr.as_any().downcast_ref::() else { + return Ok(None); + }; + if column.index() >= group_input_exprs.len() { + // AggregateExec does not produce output that is sorted on aggregate + // columns so those can not be pushed through. + return Ok(None); + } + child_requirement.push(PhysicalSortRequirement::new( + Arc::clone(&group_input_exprs[column.index()]), + req.options, + )); + } + + let Some(child_requirement) = LexRequirement::new(child_requirement) else { + return Ok(None); + }; + + // Keep sort above aggregate unless input ordering already satisfies the + // mapped requirement. + if aggregate_exec + .input() + .equivalence_properties() + .ordering_satisfy_requirement(child_requirement.iter().cloned())? + { + let child_requirements = OrderingRequirements::new(child_requirement); + Ok(Some(vec![Some(child_requirements)])) + } else { + Ok(None) + } +} + /// Return true if pushing the sort requirements through a node would violate /// the input sorting requirements for the plan fn pushdown_would_violate_requirements( @@ -723,7 +797,7 @@ fn handle_hash_join( .collect(); let column_indices = build_join_column_index(plan); - let projected_indices: Vec<_> = if let Some(projection) = &plan.projection { + let projected_indices: Vec<_> = if let Some(projection) = plan.projection.as_ref() { projection.iter().map(|&i| &column_indices[i]).collect() } else { column_indices.iter().collect() diff --git a/datafusion/physical-optimizer/src/ensure_coop.rs b/datafusion/physical-optimizer/src/ensure_coop.rs index dfa97fc840333..5d00d00bce21d 100644 --- a/datafusion/physical-optimizer/src/ensure_coop.rs +++ b/datafusion/physical-optimizer/src/ensure_coop.rs @@ -27,7 +27,7 @@ use crate::PhysicalOptimizerRule; use datafusion_common::Result; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_physical_plan::ExecutionPlan; use datafusion_physical_plan::coop::CooperativeExec; use datafusion_physical_plan::execution_plan::{EvaluationType, SchedulingType}; @@ -67,23 +67,57 @@ impl PhysicalOptimizerRule for EnsureCooperative { plan: Arc, _config: &ConfigOptions, ) -> Result> { - plan.transform_up(|plan| { - let is_leaf = plan.children().is_empty(); - let is_exchange = plan.properties().evaluation_type == EvaluationType::Eager; - if (is_leaf || is_exchange) - && plan.properties().scheduling_type != SchedulingType::Cooperative - { - // Wrap non-cooperative leaves or eager evaluation roots in a cooperative exec to - // ensure the plans they participate in are properly cooperative. - Ok(Transformed::new( - Arc::new(CooperativeExec::new(Arc::clone(&plan))), - true, - TreeNodeRecursion::Continue, - )) - } else { + use std::cell::RefCell; + + let ancestry_stack = RefCell::new(Vec::<(SchedulingType, EvaluationType)>::new()); + + plan.transform_down_up( + // Down phase: Push parent properties into the stack + |plan| { + let props = plan.properties(); + ancestry_stack + .borrow_mut() + .push((props.scheduling_type, props.evaluation_type)); Ok(Transformed::no(plan)) - } - }) + }, + // Up phase: Wrap nodes with CooperativeExec if needed + |plan| { + ancestry_stack.borrow_mut().pop(); + + let props = plan.properties(); + let is_cooperative = props.scheduling_type == SchedulingType::Cooperative; + let is_leaf = plan.children().is_empty(); + let is_exchange = props.evaluation_type == EvaluationType::Eager; + + let mut is_under_cooperative_context = false; + for (scheduling_type, evaluation_type) in + ancestry_stack.borrow().iter().rev() + { + // If nearest ancestor is cooperative, we are under a cooperative context + if *scheduling_type == SchedulingType::Cooperative { + is_under_cooperative_context = true; + break; + // If nearest ancestor is eager, the cooperative context will be reset + } else if *evaluation_type == EvaluationType::Eager { + is_under_cooperative_context = false; + break; + } + } + + // Wrap if: + // 1. Node is a leaf or exchange point + // 2. Node is not already cooperative + // 3. Not under any Cooperative context + if (is_leaf || is_exchange) + && !is_cooperative + && !is_under_cooperative_context + { + return Ok(Transformed::yes(Arc::new(CooperativeExec::new(plan)))); + } + + Ok(Transformed::no(plan)) + }, + ) .map(|t| t.data) } @@ -115,4 +149,264 @@ mod tests { DataSourceExec: partitions=1, partition_sizes=[1] "); } + + #[tokio::test] + async fn test_optimizer_is_idempotent() { + // Comprehensive idempotency test: verify f(f(...f(x))) = f(x) + // This test covers: + // 1. Multiple runs on unwrapped plan + // 2. Multiple runs on already-wrapped plan + // 3. No accumulation of CooperativeExec nodes + + let config = ConfigOptions::new(); + let rule = EnsureCooperative::new(); + + // Test 1: Start with unwrapped plan, run multiple times + let unwrapped_plan = scan_partitioned(1); + let mut current = unwrapped_plan; + let mut stable_result = String::new(); + + for run in 1..=5 { + current = rule.optimize(current, &config).unwrap(); + let display = displayable(current.as_ref()).indent(true).to_string(); + + if run == 1 { + stable_result = display.clone(); + assert_eq!(display.matches("CooperativeExec").count(), 1); + } else { + assert_eq!( + display, stable_result, + "Run {run} should match run 1 (idempotent)" + ); + assert_eq!( + display.matches("CooperativeExec").count(), + 1, + "Should always have exactly 1 CooperativeExec, not accumulate" + ); + } + } + + // Test 2: Start with already-wrapped plan, verify no double wrapping + let pre_wrapped = Arc::new(CooperativeExec::new(scan_partitioned(1))); + let result = rule.optimize(pre_wrapped, &config).unwrap(); + let display = displayable(result.as_ref()).indent(true).to_string(); + + assert_eq!( + display.matches("CooperativeExec").count(), + 1, + "Should not double-wrap already cooperative plans" + ); + assert_eq!( + display, stable_result, + "Pre-wrapped plan should produce same result as unwrapped after optimization" + ); + } + + #[tokio::test] + async fn test_selective_wrapping() { + // Test that wrapping is selective: only leaf/eager nodes, not intermediate nodes + // Also verify depth tracking prevents double wrapping in subtrees + use datafusion_physical_expr::expressions::lit; + use datafusion_physical_plan::filter::FilterExec; + + let config = ConfigOptions::new(); + let rule = EnsureCooperative::new(); + + // Case 1: Filter -> Scan (middle node should not be wrapped) + let scan = scan_partitioned(1); + let filter = Arc::new(FilterExec::try_new(lit(true), scan).unwrap()); + let optimized = rule.optimize(filter, &config).unwrap(); + let display = displayable(optimized.as_ref()).indent(true).to_string(); + + assert_eq!(display.matches("CooperativeExec").count(), 1); + assert!(display.contains("FilterExec")); + + // Case 2: Filter -> CoopExec -> Scan (depth tracking prevents double wrap) + let scan2 = scan_partitioned(1); + let wrapped_scan = Arc::new(CooperativeExec::new(scan2)); + let filter2 = Arc::new(FilterExec::try_new(lit(true), wrapped_scan).unwrap()); + let optimized2 = rule.optimize(filter2, &config).unwrap(); + let display2 = displayable(optimized2.as_ref()).indent(true).to_string(); + + assert_eq!(display2.matches("CooperativeExec").count(), 1); + } + + #[tokio::test] + async fn test_multiple_leaf_nodes() { + // When there are multiple leaf nodes, each should be wrapped separately + use datafusion_physical_plan::union::UnionExec; + + let scan1 = scan_partitioned(1); + let scan2 = scan_partitioned(1); + let union = UnionExec::try_new(vec![scan1, scan2]).unwrap(); + + let config = ConfigOptions::new(); + let optimized = EnsureCooperative::new() + .optimize(union as Arc, &config) + .unwrap(); + + let display = displayable(optimized.as_ref()).indent(true).to_string(); + + // Each leaf should have its own CooperativeExec + assert_eq!( + display.matches("CooperativeExec").count(), + 2, + "Each leaf node should be wrapped separately" + ); + assert_eq!( + display.matches("DataSourceExec").count(), + 2, + "Both data sources should be present" + ); + } + + #[tokio::test] + async fn test_eager_evaluation_resets_cooperative_context() { + // Test that cooperative context is reset when encountering an eager evaluation boundary. + use arrow::datatypes::Schema; + use datafusion_common::{Result, internal_err}; + use datafusion_execution::TaskContext; + use datafusion_physical_expr::EquivalenceProperties; + use datafusion_physical_plan::{ + DisplayAs, DisplayFormatType, Partitioning, PlanProperties, + SendableRecordBatchStream, + execution_plan::{Boundedness, EmissionType}, + }; + use std::any::Any; + use std::fmt::Formatter; + + #[derive(Debug)] + struct DummyExec { + name: String, + input: Arc, + scheduling_type: SchedulingType, + evaluation_type: EvaluationType, + properties: PlanProperties, + } + + impl DummyExec { + fn new( + name: &str, + input: Arc, + scheduling_type: SchedulingType, + evaluation_type: EvaluationType, + ) -> Self { + let properties = PlanProperties::new( + EquivalenceProperties::new(Arc::new(Schema::empty())), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + ) + .with_scheduling_type(scheduling_type) + .with_evaluation_type(evaluation_type); + + Self { + name: name.to_string(), + input, + scheduling_type, + evaluation_type, + properties, + } + } + } + + impl DisplayAs for DummyExec { + fn fmt_as( + &self, + _: DisplayFormatType, + f: &mut Formatter, + ) -> std::fmt::Result { + write!(f, "{}", self.name) + } + } + + impl ExecutionPlan for DummyExec { + fn name(&self) -> &str { + &self.name + } + fn as_any(&self) -> &dyn Any { + self + } + fn properties(&self) -> &PlanProperties { + &self.properties + } + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(DummyExec::new( + &self.name, + Arc::clone(&children[0]), + self.scheduling_type, + self.evaluation_type, + ))) + } + fn execute( + &self, + _: usize, + _: Arc, + ) -> Result { + internal_err!("DummyExec does not support execution") + } + } + + // Build a plan similar to the original test: + // scan -> exch1(NonCoop,Eager) -> CoopExec -> filter -> exch2(Coop,Eager) -> filter + let scan = scan_partitioned(1); + let exch1 = Arc::new(DummyExec::new( + "exch1", + scan, + SchedulingType::NonCooperative, + EvaluationType::Eager, + )); + let coop = Arc::new(CooperativeExec::new(exch1)); + let filter1 = Arc::new(DummyExec::new( + "filter1", + coop, + SchedulingType::NonCooperative, + EvaluationType::Lazy, + )); + let exch2 = Arc::new(DummyExec::new( + "exch2", + filter1, + SchedulingType::Cooperative, + EvaluationType::Eager, + )); + let filter2 = Arc::new(DummyExec::new( + "filter2", + exch2, + SchedulingType::NonCooperative, + EvaluationType::Lazy, + )); + + let config = ConfigOptions::new(); + let optimized = EnsureCooperative::new().optimize(filter2, &config).unwrap(); + + let display = displayable(optimized.as_ref()).indent(true).to_string(); + + // Expected wrapping: + // - Scan (leaf) gets wrapped + // - exch1 (eager+noncoop) keeps its manual CooperativeExec wrapper + // - filter1 is protected by exch2's cooperative context, no extra wrap + // - exch2 (already Cooperative) does NOT get wrapped + // - filter2 (not leaf or eager) does NOT get wrapped + assert_eq!( + display.matches("CooperativeExec").count(), + 2, + "Should have 2 CooperativeExec: one wrapping scan, one wrapping exch1" + ); + + assert_snapshot!(display, @r" + filter2 + exch2 + filter1 + CooperativeExec + exch1 + CooperativeExec + DataSourceExec: partitions=1, partition_sizes=[1] + "); + } } diff --git a/datafusion/physical-optimizer/src/join_selection.rs b/datafusion/physical-optimizer/src/join_selection.rs index f837c79a4e391..02ef378d704a0 100644 --- a/datafusion/physical-optimizer/src/join_selection.rs +++ b/datafusion/physical-optimizer/src/join_selection.rs @@ -34,7 +34,7 @@ use datafusion_physical_expr::expressions::Column; use datafusion_physical_plan::execution_plan::EmissionType; use datafusion_physical_plan::joins::utils::ColumnIndex; use datafusion_physical_plan::joins::{ - CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, + CrossJoinExec, HashJoinExec, HashJoinExecBuilder, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode, SymmetricHashJoinExec, }; use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; @@ -184,35 +184,28 @@ pub(crate) fn try_collect_left( match (left_can_collect, right_can_collect) { (true, true) => { + // Don't swap null-aware anti joins as they have specific side requirements if hash_join.join_type().supports_swap() + && !hash_join.null_aware && should_swap_join_order(&**left, &**right)? { Ok(Some(hash_join.swap_inputs(PartitionMode::CollectLeft)?)) } else { - Ok(Some(Arc::new(HashJoinExec::try_new( - Arc::clone(left), - Arc::clone(right), - hash_join.on().to_vec(), - hash_join.filter().cloned(), - hash_join.join_type(), - hash_join.projection.clone(), - PartitionMode::CollectLeft, - hash_join.null_equality(), - )?))) + Ok(Some(Arc::new( + HashJoinExecBuilder::from(hash_join) + .with_partition_mode(PartitionMode::CollectLeft) + .build()?, + ))) } } - (true, false) => Ok(Some(Arc::new(HashJoinExec::try_new( - Arc::clone(left), - Arc::clone(right), - hash_join.on().to_vec(), - hash_join.filter().cloned(), - hash_join.join_type(), - hash_join.projection.clone(), - PartitionMode::CollectLeft, - hash_join.null_equality(), - )?))), + (true, false) => Ok(Some(Arc::new( + HashJoinExecBuilder::from(hash_join) + .with_partition_mode(PartitionMode::CollectLeft) + .build()?, + ))), (false, true) => { - if hash_join.join_type().supports_swap() { + // Don't swap null-aware anti joins as they have specific side requirements + if hash_join.join_type().supports_swap() && !hash_join.null_aware { hash_join.swap_inputs(PartitionMode::CollectLeft).map(Some) } else { Ok(None) @@ -232,20 +225,28 @@ pub(crate) fn partitioned_hash_join( ) -> Result> { let left = hash_join.left(); let right = hash_join.right(); - if hash_join.join_type().supports_swap() && should_swap_join_order(&**left, &**right)? + // Don't swap null-aware anti joins as they have specific side requirements + if hash_join.join_type().supports_swap() + && !hash_join.null_aware + && should_swap_join_order(&**left, &**right)? { hash_join.swap_inputs(PartitionMode::Partitioned) } else { - Ok(Arc::new(HashJoinExec::try_new( - Arc::clone(left), - Arc::clone(right), - hash_join.on().to_vec(), - hash_join.filter().cloned(), - hash_join.join_type(), - hash_join.projection.clone(), - PartitionMode::Partitioned, - hash_join.null_equality(), - )?)) + // Null-aware anti joins must use CollectLeft mode because they track probe-side state + // (probe_side_non_empty, probe_side_has_null) per-partition, but need global knowledge + // for correct null handling. With partitioning, a partition might not see probe rows + // even if the probe side is globally non-empty, leading to incorrect NULL row handling. + let partition_mode = if hash_join.null_aware { + PartitionMode::CollectLeft + } else { + PartitionMode::Partitioned + }; + + Ok(Arc::new( + HashJoinExecBuilder::from(hash_join) + .with_partition_mode(partition_mode) + .build()?, + )) } } @@ -277,7 +278,9 @@ fn statistical_join_selection_subrule( PartitionMode::Partitioned => { let left = hash_join.left(); let right = hash_join.right(); + // Don't swap null-aware anti joins as they have specific side requirements if hash_join.join_type().supports_swap() + && !hash_join.null_aware && should_swap_join_order(&**left, &**right)? { hash_join @@ -484,6 +487,7 @@ pub fn hash_join_swap_subrule( if let Some(hash_join) = input.as_any().downcast_ref::() && hash_join.left.boundedness().is_unbounded() && !hash_join.right.boundedness().is_unbounded() + && !hash_join.null_aware // Don't swap null-aware anti joins && matches!( *hash_join.join_type(), JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti diff --git a/datafusion/physical-optimizer/src/lib.rs b/datafusion/physical-optimizer/src/lib.rs index e98772291cbeb..3a0d79ae2d234 100644 --- a/datafusion/physical-optimizer/src/lib.rs +++ b/datafusion/physical-optimizer/src/lib.rs @@ -24,8 +24,6 @@ // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -// https://github.com/apache/datafusion/issues/18881 -#![deny(clippy::allow_attributes)] pub mod aggregate_statistics; pub mod combine_partial_final_agg; diff --git a/datafusion/physical-optimizer/src/limit_pushdown.rs b/datafusion/physical-optimizer/src/limit_pushdown.rs index 4cb3abe30bae2..e7bede494da99 100644 --- a/datafusion/physical-optimizer/src/limit_pushdown.rs +++ b/datafusion/physical-optimizer/src/limit_pushdown.rs @@ -50,6 +50,7 @@ pub struct GlobalRequirements { fetch: Option, skip: usize, satisfied: bool, + preserve_order: bool, } impl LimitPushdown { @@ -69,6 +70,7 @@ impl PhysicalOptimizerRule for LimitPushdown { fetch: None, skip: 0, satisfied: false, + preserve_order: false, }; pushdown_limits(plan, global_state) } @@ -111,6 +113,13 @@ impl LimitExec { Self::Local(_) => 0, } } + + fn preserve_order(&self) -> bool { + match self { + Self::Global(global) => global.required_ordering().is_some(), + Self::Local(local) => local.required_ordering().is_some(), + } + } } impl From for Arc { @@ -145,6 +154,8 @@ pub fn pushdown_limit_helper( ); global_state.skip = skip; global_state.fetch = fetch; + global_state.preserve_order = limit_exec.preserve_order(); + global_state.satisfied = false; // Now the global state has the most recent information, we can remove // the `LimitExec` plan. We will decide later if we should add it again @@ -162,7 +173,7 @@ pub fn pushdown_limit_helper( // If we have a non-limit operator with fetch capability, update global // state as necessary: if pushdown_plan.fetch().is_some() { - if global_state.fetch.is_none() { + if global_state.skip == 0 { global_state.satisfied = true; } (global_state.skip, global_state.fetch) = combine_limit( @@ -241,17 +252,28 @@ pub fn pushdown_limit_helper( let maybe_fetchable = pushdown_plan.with_fetch(skip_and_fetch); if global_state.satisfied { if let Some(plan_with_fetch) = maybe_fetchable { - Ok((Transformed::yes(plan_with_fetch), global_state)) + let plan_with_preserve_order = plan_with_fetch + .with_preserve_order(global_state.preserve_order) + .unwrap_or(plan_with_fetch); + Ok((Transformed::yes(plan_with_preserve_order), global_state)) } else { Ok((Transformed::no(pushdown_plan), global_state)) } } else { global_state.satisfied = true; pushdown_plan = if let Some(plan_with_fetch) = maybe_fetchable { + let plan_with_preserve_order = plan_with_fetch + .with_preserve_order(global_state.preserve_order) + .unwrap_or(plan_with_fetch); + if global_skip > 0 { - add_global_limit(plan_with_fetch, global_skip, Some(global_fetch)) + add_global_limit( + plan_with_preserve_order, + global_skip, + Some(global_fetch), + ) } else { - plan_with_fetch + plan_with_preserve_order } } else { add_limit(pushdown_plan, global_skip, global_fetch) diff --git a/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs b/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs index 671d247cf36a5..fe9636f67619b 100644 --- a/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs +++ b/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs @@ -20,7 +20,7 @@ use std::sync::Arc; -use datafusion_physical_plan::aggregates::AggregateExec; +use datafusion_physical_plan::aggregates::{AggregateExec, LimitOptions}; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; @@ -63,7 +63,7 @@ impl LimitedDistinctAggregation { aggr.input_schema(), ) .expect("Unable to copy Aggregate!") - .with_limit(Some(limit)); + .with_limit_options(Some(LimitOptions::new(limit))); Some(Arc::new(new_aggr)) } diff --git a/datafusion/physical-optimizer/src/output_requirements.rs b/datafusion/physical-optimizer/src/output_requirements.rs index 0dc6a25fbc0b7..afc0ee1a336dd 100644 --- a/datafusion/physical-optimizer/src/output_requirements.rs +++ b/datafusion/physical-optimizer/src/output_requirements.rs @@ -244,10 +244,6 @@ impl ExecutionPlan for OutputRequirementExec { unreachable!(); } - fn statistics(&self) -> Result { - self.input.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { self.input.partition_statistics(partition) } diff --git a/datafusion/physical-optimizer/src/projection_pushdown.rs b/datafusion/physical-optimizer/src/projection_pushdown.rs index 281d61aecf538..44d0926a8b250 100644 --- a/datafusion/physical-optimizer/src/projection_pushdown.rs +++ b/datafusion/physical-optimizer/src/projection_pushdown.rs @@ -32,7 +32,7 @@ use datafusion_common::tree_node::{ }; use datafusion_common::{JoinSide, JoinType, Result}; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::physical_expr::{PhysicalExpr, is_volatile}; use datafusion_physical_plan::ExecutionPlan; use datafusion_physical_plan::joins::NestedLoopJoinExec; use datafusion_physical_plan::joins::utils::{ColumnIndex, JoinFilter}; @@ -135,7 +135,7 @@ fn try_push_down_join_filter( ); let new_lhs_length = lhs_rewrite.data.0.schema().fields.len(); - let projections = match projections { + let projections = match projections.as_ref() { None => match join.join_type() { JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { // Build projections that ignore the newly projected columns. @@ -349,8 +349,7 @@ impl<'a> JoinFilterRewriter<'a> { // Recurse if there is a dependency to both sides or if the entire expression is volatile. let depends_on_other_side = self.depends_on_join_side(&expr, self.join_side.negate())?; - let is_volatile = is_volatile_expression_tree(expr.as_ref()); - if depends_on_other_side || is_volatile { + if depends_on_other_side || is_volatile(&expr) { return expr.map_children(|expr| self.rewrite(expr)); } @@ -431,18 +430,6 @@ impl<'a> JoinFilterRewriter<'a> { } } -fn is_volatile_expression_tree(expr: &dyn PhysicalExpr) -> bool { - if expr.is_volatile_node() { - return true; - } - - expr.children() - .iter() - .map(|expr| is_volatile_expression_tree(expr.as_ref())) - .reduce(|lhs, rhs| lhs || rhs) - .unwrap_or(false) -} - #[cfg(test)] mod test { use super::*; diff --git a/datafusion/physical-optimizer/src/topk_aggregation.rs b/datafusion/physical-optimizer/src/topk_aggregation.rs index 7b2983ee71996..cec6bd70a2089 100644 --- a/datafusion/physical-optimizer/src/topk_aggregation.rs +++ b/datafusion/physical-optimizer/src/topk_aggregation.rs @@ -25,6 +25,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::aggregates::LimitOptions; use datafusion_physical_plan::aggregates::{AggregateExec, topk_types_supported}; use datafusion_physical_plan::execution_plan::CardinalityEffect; use datafusion_physical_plan::projection::ProjectionExec; @@ -47,28 +48,47 @@ impl TopKAggregation { order_desc: bool, limit: usize, ) -> Option> { - // ensure the sort direction matches aggregate function - let (field, desc) = aggr.get_minmax_desc()?; - if desc != order_desc { - return None; - } - let group_key = aggr.group_expr().expr().iter().exactly_one().ok()?; - let kt = group_key.0.data_type(&aggr.input().schema()).ok()?; - let vt = field.data_type(); - if !topk_types_supported(&kt, vt) { + // Current only support single group key + let (group_key, group_key_alias) = + aggr.group_expr().expr().iter().exactly_one().ok()?; + let kt = group_key.data_type(&aggr.input().schema()).ok()?; + let vt = if let Some((field, _)) = aggr.get_minmax_desc() { + field.data_type().clone() + } else { + kt.clone() + }; + if !topk_types_supported(&kt, &vt) { return None; } if aggr.filter_expr().iter().any(|e| e.is_some()) { return None; } - // ensure the sort is on the same field as the aggregate output - if order_by != field.name() { + // Check if this is ordering by an aggregate function (MIN/MAX) + if let Some((field, desc)) = aggr.get_minmax_desc() { + // ensure the sort direction matches aggregate function + if desc != order_desc { + return None; + } + // ensure the sort is on the same field as the aggregate output + if order_by != field.name() { + return None; + } + } else if aggr.aggr_expr().is_empty() { + // This is a GROUP BY without aggregates, check if ordering is on the group key itself + if order_by != group_key_alias { + return None; + } + } else { + // Has aggregates but not MIN/MAX, or doesn't DISTINCT return None; } // We found what we want: clone, copy the limit down, and return modified node - let new_aggr = aggr.with_new_limit(Some(limit)); + let new_aggr = AggregateExec::with_new_limit_options( + aggr, + Some(LimitOptions::new_with_order(limit, order_desc)), + ); Some(Arc::new(new_aggr)) } diff --git a/datafusion/physical-optimizer/src/update_aggr_exprs.rs b/datafusion/physical-optimizer/src/update_aggr_exprs.rs index c0aab4080da77..67127c2a238f9 100644 --- a/datafusion/physical-optimizer/src/update_aggr_exprs.rs +++ b/datafusion/physical-optimizer/src/update_aggr_exprs.rs @@ -25,7 +25,9 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{Result, plan_datafusion_err}; use datafusion_physical_expr::aggregate::AggregateFunctionExpr; use datafusion_physical_expr::{EquivalenceProperties, PhysicalSortRequirement}; -use datafusion_physical_plan::aggregates::{AggregateExec, concat_slices}; +use datafusion_physical_plan::aggregates::{ + AggregateExec, AggregateInputMode, concat_slices, +}; use datafusion_physical_plan::windows::get_ordered_partition_by_indices; use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; @@ -81,7 +83,7 @@ impl PhysicalOptimizerRule for OptimizeAggregateOrder { // ordering fields may be pruned out by first stage aggregates. // Hence, necessary information for proper merge is added during // the first stage to the state field, which the final stage uses. - if !aggr_exec.mode().is_first_stage() { + if aggr_exec.mode().input_mode() == AggregateInputMode::Partial { return Ok(Transformed::no(plan)); } let input = aggr_exec.input(); diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs index c46cde8786eb4..2b8a2cfa68897 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs @@ -128,7 +128,9 @@ where let hash = key.hash(state); let insert = self.map.entry( hash, - |&(g, _)| unsafe { self.values.get_unchecked(g).is_eq(key) }, + |&(g, h)| unsafe { + hash == h && self.values.get_unchecked(g).is_eq(key) + }, |&(_, h)| h, ); diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 4dd9482ac4322..27eee0025aa60 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -89,10 +89,54 @@ pub fn topk_types_supported(key_type: &DataType, value_type: &DataType) -> bool const AGGREGATION_HASH_SEED: ahash::RandomState = ahash::RandomState::with_seeds('A' as u64, 'G' as u64, 'G' as u64, 'R' as u64); +/// Whether an aggregate stage consumes raw input data or intermediate +/// accumulator state from a previous aggregation stage. +/// +/// See the [table on `AggregateMode`](AggregateMode#variants-and-their-inputoutput-modes) +/// for how this relates to aggregate modes. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum AggregateInputMode { + /// The stage consumes raw, unaggregated input data and calls + /// [`Accumulator::update_batch`]. + Raw, + /// The stage consumes intermediate accumulator state from a previous + /// aggregation stage and calls [`Accumulator::merge_batch`]. + Partial, +} + +/// Whether an aggregate stage produces intermediate accumulator state +/// or final output values. +/// +/// See the [table on `AggregateMode`](AggregateMode#variants-and-their-inputoutput-modes) +/// for how this relates to aggregate modes. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum AggregateOutputMode { + /// The stage produces intermediate accumulator state, serialized via + /// [`Accumulator::state`]. + Partial, + /// The stage produces final output values via + /// [`Accumulator::evaluate`]. + Final, +} + /// Aggregation modes /// /// See [`Accumulator::state`] for background information on multi-phase /// aggregation and how these modes are used. +/// +/// # Variants and their input/output modes +/// +/// Each variant can be characterized by its [`AggregateInputMode`] and +/// [`AggregateOutputMode`]: +/// +/// ```text +/// | Input: Raw data | Input: Partial state +/// Output: Final values | Single, SinglePartitioned | Final, FinalPartitioned +/// Output: Partial state | Partial | PartialReduce +/// ``` +/// +/// Use [`AggregateMode::input_mode`] and [`AggregateMode::output_mode`] +/// to query these properties. #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum AggregateMode { /// One of multiple layers of aggregation, any input partitioning @@ -144,18 +188,56 @@ pub enum AggregateMode { /// This mode requires that the input has more than one partition, and is /// partitioned by group key (like FinalPartitioned). SinglePartitioned, + /// Combine multiple partial aggregations to produce a new partial + /// aggregation. + /// + /// Input is intermediate accumulator state (like Final), but output is + /// also intermediate accumulator state (like Partial). This enables + /// tree-reduce aggregation strategies where partial results from + /// multiple workers are combined in multiple stages before a final + /// evaluation. + /// + /// ```text + /// Final + /// / \ + /// PartialReduce PartialReduce + /// / \ / \ + /// Partial Partial Partial Partial + /// ``` + PartialReduce, } impl AggregateMode { - /// Checks whether this aggregation step describes a "first stage" calculation. - /// In other words, its input is not another aggregation result and the - /// `merge_batch` method will not be called for these modes. - pub fn is_first_stage(&self) -> bool { + /// Returns the [`AggregateInputMode`] for this mode: whether this + /// stage consumes raw input data or intermediate accumulator state. + /// + /// See the [table above](AggregateMode#variants-and-their-inputoutput-modes) + /// for details. + pub fn input_mode(&self) -> AggregateInputMode { match self { AggregateMode::Partial | AggregateMode::Single - | AggregateMode::SinglePartitioned => true, - AggregateMode::Final | AggregateMode::FinalPartitioned => false, + | AggregateMode::SinglePartitioned => AggregateInputMode::Raw, + AggregateMode::Final + | AggregateMode::FinalPartitioned + | AggregateMode::PartialReduce => AggregateInputMode::Partial, + } + } + + /// Returns the [`AggregateOutputMode`] for this mode: whether this + /// stage produces intermediate accumulator state or final output values. + /// + /// See the [table above](AggregateMode#variants-and-their-inputoutput-modes) + /// for details. + pub fn output_mode(&self) -> AggregateOutputMode { + match self { + AggregateMode::Final + | AggregateMode::FinalPartitioned + | AggregateMode::Single + | AggregateMode::SinglePartitioned => AggregateOutputMode::Final, + AggregateMode::Partial | AggregateMode::PartialReduce => { + AggregateOutputMode::Partial + } } } } @@ -502,19 +584,58 @@ enum DynamicFilterAggregateType { Max, } +/// Configuration for limit-based optimizations in aggregation +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct LimitOptions { + /// The maximum number of rows to return + pub limit: usize, + /// Optional ordering direction (true = descending, false = ascending) + /// This is used for TopK aggregation to maintain a priority queue with the correct ordering + pub descending: Option, +} + +impl LimitOptions { + /// Create a new LimitOptions with a limit and no specific ordering + pub fn new(limit: usize) -> Self { + Self { + limit, + descending: None, + } + } + + /// Create a new LimitOptions with a limit and ordering direction + pub fn new_with_order(limit: usize, descending: bool) -> Self { + Self { + limit, + descending: Some(descending), + } + } + + pub fn limit(&self) -> usize { + self.limit + } + + pub fn descending(&self) -> Option { + self.descending + } +} + /// Hash aggregate execution plan #[derive(Debug, Clone)] pub struct AggregateExec { /// Aggregation mode (full, partial) mode: AggregateMode, /// Group by expressions - group_by: PhysicalGroupBy, + /// [`Arc`] used for a cheap clone, which improves physical plan optimization performance. + group_by: Arc, /// Aggregate expressions - aggr_expr: Vec>, + /// The same reason to [`Arc`] it as for [`Self::group_by`]. + aggr_expr: Arc<[Arc]>, /// FILTER (WHERE clause) expression for each aggregate expression - filter_expr: Vec>>, - /// Set if the output of this aggregation is truncated by a upstream sort/limit clause - limit: Option, + /// The same reason to [`Arc`] it as for [`Self::group_by`]. + filter_expr: Arc<[Option>]>, + /// Configuration for limit-based optimizations + limit_options: Option, /// Input plan, could be a partial aggregate or the input to the aggregate pub input: Arc, /// Schema after the aggregate is applied @@ -546,19 +667,19 @@ impl AggregateExec { /// Rewrites aggregate exec with new aggregate expressions. pub fn with_new_aggr_exprs( &self, - aggr_expr: Vec>, + aggr_expr: impl Into]>>, ) -> Self { Self { - aggr_expr, + aggr_expr: aggr_expr.into(), // clone the rest of the fields required_input_ordering: self.required_input_ordering.clone(), metrics: ExecutionPlanMetricsSet::new(), input_order_mode: self.input_order_mode.clone(), cache: self.cache.clone(), mode: self.mode, - group_by: self.group_by.clone(), - filter_expr: self.filter_expr.clone(), - limit: self.limit, + group_by: Arc::clone(&self.group_by), + filter_expr: Arc::clone(&self.filter_expr), + limit_options: self.limit_options, input: Arc::clone(&self.input), schema: Arc::clone(&self.schema), input_schema: Arc::clone(&self.input_schema), @@ -567,18 +688,18 @@ impl AggregateExec { } /// Clone this exec, overriding only the limit hint. - pub fn with_new_limit(&self, limit: Option) -> Self { + pub fn with_new_limit_options(&self, limit_options: Option) -> Self { Self { - limit, + limit_options, // clone the rest of the fields required_input_ordering: self.required_input_ordering.clone(), metrics: ExecutionPlanMetricsSet::new(), input_order_mode: self.input_order_mode.clone(), cache: self.cache.clone(), mode: self.mode, - group_by: self.group_by.clone(), - aggr_expr: self.aggr_expr.clone(), - filter_expr: self.filter_expr.clone(), + group_by: Arc::clone(&self.group_by), + aggr_expr: Arc::clone(&self.aggr_expr), + filter_expr: Arc::clone(&self.filter_expr), input: Arc::clone(&self.input), schema: Arc::clone(&self.schema), input_schema: Arc::clone(&self.input_schema), @@ -593,12 +714,13 @@ impl AggregateExec { /// Create a new hash aggregate execution plan pub fn try_new( mode: AggregateMode, - group_by: PhysicalGroupBy, + group_by: impl Into>, aggr_expr: Vec>, filter_expr: Vec>>, input: Arc, input_schema: SchemaRef, ) -> Result { + let group_by = group_by.into(); let schema = create_schema(&input.schema(), &group_by, &aggr_expr, mode)?; let schema = Arc::new(schema); @@ -623,13 +745,16 @@ impl AggregateExec { /// the schema in such cases. fn try_new_with_schema( mode: AggregateMode, - group_by: PhysicalGroupBy, + group_by: impl Into>, mut aggr_expr: Vec>, - filter_expr: Vec>>, + filter_expr: impl Into>]>>, input: Arc, input_schema: SchemaRef, schema: SchemaRef, ) -> Result { + let group_by = group_by.into(); + let filter_expr = filter_expr.into(); + // Make sure arguments are consistent in size assert_eq_or_internal_err!( aggr_expr.len(), @@ -696,20 +821,20 @@ impl AggregateExec { &group_expr_mapping, &mode, &input_order_mode, - aggr_expr.as_slice(), + aggr_expr.as_ref(), )?; let mut exec = AggregateExec { mode, group_by, - aggr_expr, + aggr_expr: aggr_expr.into(), filter_expr, input, schema, input_schema, metrics: ExecutionPlanMetricsSet::new(), required_input_ordering, - limit: None, + limit_options: None, input_order_mode, cache, dynamic_filter: None, @@ -725,11 +850,17 @@ impl AggregateExec { &self.mode } - /// Set the `limit` of this AggExec - pub fn with_limit(mut self, limit: Option) -> Self { - self.limit = limit; + /// Set the limit options for this AggExec + pub fn with_limit_options(mut self, limit_options: Option) -> Self { + self.limit_options = limit_options; self } + + /// Get the limit options (if set) + pub fn limit_options(&self) -> Option { + self.limit_options + } + /// Grouping expressions pub fn group_expr(&self) -> &PhysicalGroupBy { &self.group_by @@ -760,11 +891,6 @@ impl AggregateExec { Arc::clone(&self.input_schema) } - /// number of rows soft limit of the AggregateExec - pub fn limit(&self) -> Option { - self.limit - } - fn execute_typed( &self, partition: usize, @@ -777,11 +903,11 @@ impl AggregateExec { } // grouping by an expression that has a sort/limit upstream - if let Some(limit) = self.limit + if let Some(config) = self.limit_options && !self.is_unordered_unfiltered_group_by_distinct() { return Ok(StreamType::GroupedPriorityQueue( - GroupedTopKAggregateStream::new(self, context, partition, limit)?, + GroupedTopKAggregateStream::new(self, context, partition, config.limit)?, )); } @@ -802,6 +928,13 @@ impl AggregateExec { /// This method qualifies the use of the LimitedDistinctAggregation rewrite rule /// on an AggregateExec. pub fn is_unordered_unfiltered_group_by_distinct(&self) -> bool { + if self + .limit_options() + .and_then(|config| config.descending) + .is_some() + { + return false; + } // ensure there is a group by if self.group_expr().is_empty() && !self.group_expr().has_grouping_set() { return false; @@ -873,14 +1006,15 @@ impl AggregateExec { // Get output partitioning: let input_partitioning = input.output_partitioning().clone(); - let output_partitioning = if mode.is_first_stage() { - // First stage aggregation will not change the output partitioning, - // but needs to respect aliases (e.g. mapping in the GROUP BY - // expression). - let input_eq_properties = input.equivalence_properties(); - input_partitioning.project(group_expr_mapping, input_eq_properties) - } else { - input_partitioning.clone() + let output_partitioning = match mode.input_mode() { + AggregateInputMode::Raw => { + // First stage aggregation will not change the output partitioning, + // but needs to respect aliases (e.g. mapping in the GROUP BY + // expression). + let input_eq_properties = input.equivalence_properties(); + input_partitioning.project(group_expr_mapping, input_eq_properties) + } + AggregateInputMode::Partial => input_partitioning.clone(), }; // TODO: Emission type and boundedness information can be enhanced here @@ -1013,7 +1147,7 @@ impl AggregateExec { } else if fun_name.eq_ignore_ascii_case("max") { DynamicFilterAggregateType::Max } else { - continue; + return; }; // 2. arg should be only 1 column reference @@ -1119,8 +1253,8 @@ impl DisplayAs for AggregateExec { .map(|agg| agg.name().to_string()) .collect(); write!(f, ", aggr=[{}]", a.join(", "))?; - if let Some(limit) = self.limit { - write!(f, ", lim=[{limit}]")?; + if let Some(config) = self.limit_options { + write!(f, ", lim=[{}]", config.limit)?; } if self.input_order_mode != InputOrderMode::Linear { @@ -1179,6 +1313,9 @@ impl DisplayAs for AggregateExec { if !a.is_empty() { writeln!(f, "aggr={}", a.join(", "))?; } + if let Some(config) = self.limit_options { + writeln!(f, "limit={}", config.limit)?; + } } } Ok(()) @@ -1201,7 +1338,7 @@ impl ExecutionPlan for AggregateExec { fn required_input_distribution(&self) -> Vec { match &self.mode { - AggregateMode::Partial => { + AggregateMode::Partial | AggregateMode::PartialReduce => { vec![Distribution::UnspecifiedDistribution] } AggregateMode::FinalPartitioned | AggregateMode::SinglePartitioned => { @@ -1240,14 +1377,14 @@ impl ExecutionPlan for AggregateExec { ) -> Result> { let mut me = AggregateExec::try_new_with_schema( self.mode, - self.group_by.clone(), - self.aggr_expr.clone(), - self.filter_expr.clone(), + Arc::clone(&self.group_by), + self.aggr_expr.to_vec(), + Arc::clone(&self.filter_expr), Arc::clone(&children[0]), Arc::clone(&self.input_schema), Arc::clone(&self.schema), )?; - me.limit = self.limit; + me.limit_options = self.limit_options; me.dynamic_filter = self.dynamic_filter.clone(); Ok(Arc::new(me)) @@ -1266,10 +1403,6 @@ impl ExecutionPlan for AggregateExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { let child_statistics = self.input().partition_statistics(partition)?; self.statistics_inner(&child_statistics) @@ -1430,20 +1563,17 @@ fn create_schema( let mut fields = Vec::with_capacity(group_by.num_output_exprs() + aggr_expr.len()); fields.extend(group_by.output_fields(input_schema)?); - match mode { - AggregateMode::Partial => { - // in partial mode, the fields of the accumulator's state + match mode.output_mode() { + AggregateOutputMode::Final => { + // in final mode, the field with the final result of the accumulator for expr in aggr_expr { - fields.extend(expr.state_fields()?.iter().cloned()); + fields.push(expr.field()) } } - AggregateMode::Final - | AggregateMode::FinalPartitioned - | AggregateMode::Single - | AggregateMode::SinglePartitioned => { - // in final mode, the field with the final result of the accumulator + AggregateOutputMode::Partial => { + // in partial mode, the fields of the accumulator's state for expr in aggr_expr { - fields.push(expr.field()) + fields.extend(expr.state_fields()?.iter().cloned()); } } } @@ -1483,7 +1613,7 @@ fn get_aggregate_expr_req( // If the aggregation is performing a "second stage" calculation, // then ignore the ordering requirement. Ordering requirement applies // only to the aggregation input data. - if !agg_mode.is_first_stage() { + if agg_mode.input_mode() == AggregateInputMode::Partial { return None; } @@ -1649,10 +1779,8 @@ pub fn aggregate_expressions( mode: &AggregateMode, col_idx_base: usize, ) -> Result>>> { - match mode { - AggregateMode::Partial - | AggregateMode::Single - | AggregateMode::SinglePartitioned => Ok(aggr_expr + match mode.input_mode() { + AggregateInputMode::Raw => Ok(aggr_expr .iter() .map(|agg| { let mut result = agg.expressions(); @@ -1663,8 +1791,8 @@ pub fn aggregate_expressions( result }) .collect()), - // In this mode, we build the merge expressions of the aggregation. - AggregateMode::Final | AggregateMode::FinalPartitioned => { + AggregateInputMode::Partial => { + // In merge mode, we build the merge expressions of the aggregation. let mut col_idx_base = col_idx_base; aggr_expr .iter() @@ -1712,8 +1840,15 @@ pub fn finalize_aggregation( accumulators: &mut [AccumulatorItem], mode: &AggregateMode, ) -> Result> { - match mode { - AggregateMode::Partial => { + match mode.output_mode() { + AggregateOutputMode::Final => { + // Merge the state to the final value + accumulators + .iter_mut() + .map(|accumulator| accumulator.evaluate().and_then(|v| v.to_array())) + .collect() + } + AggregateOutputMode::Partial => { // Build the vector of states accumulators .iter_mut() @@ -1727,16 +1862,6 @@ pub fn finalize_aggregation( .flatten_ok() .collect() } - AggregateMode::Final - | AggregateMode::FinalPartitioned - | AggregateMode::Single - | AggregateMode::SinglePartitioned => { - // Merge the state to the final value - accumulators - .iter_mut() - .map(|accumulator| accumulator.evaluate().and_then(|v| v.to_array())) - .collect() - } } } @@ -2358,10 +2483,6 @@ mod tests { Ok(Box::pin(stream)) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { if partition.is_some() { return Ok(Statistics::new_unknown(self.schema().as_ref())); @@ -3698,4 +3819,135 @@ mod tests { } Ok(()) } + + /// Tests that PartialReduce mode: + /// 1. Accepts state as input (like Final) + /// 2. Produces state as output (like Partial) + /// 3. Can be followed by a Final stage to get the correct result + /// + /// This simulates a tree-reduce pattern: + /// Partial -> PartialReduce -> Final + #[tokio::test] + async fn test_partial_reduce_mode() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::UInt32, false), + Field::new("b", DataType::Float64, false), + ])); + + // Produce two partitions of input data + let batch1 = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(UInt32Array::from(vec![1, 2, 3])), + Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0])), + ], + )?; + let batch2 = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(UInt32Array::from(vec![1, 2, 3])), + Arc::new(Float64Array::from(vec![40.0, 50.0, 60.0])), + ], + )?; + + let groups = + PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]); + let aggregates: Vec> = vec![Arc::new( + AggregateExprBuilder::new(sum_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("SUM(b)") + .build()?, + )]; + + // Step 1: Partial aggregation on partition 1 + let input1 = + TestMemoryExec::try_new_exec(&[vec![batch1]], Arc::clone(&schema), None)?; + let partial1 = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + groups.clone(), + aggregates.clone(), + vec![None], + input1, + Arc::clone(&schema), + )?); + + // Step 2: Partial aggregation on partition 2 + let input2 = + TestMemoryExec::try_new_exec(&[vec![batch2]], Arc::clone(&schema), None)?; + let partial2 = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + groups.clone(), + aggregates.clone(), + vec![None], + input2, + Arc::clone(&schema), + )?); + + // Collect partial results + let task_ctx = Arc::new(TaskContext::default()); + let partial_result1 = + crate::collect(Arc::clone(&partial1) as _, Arc::clone(&task_ctx)).await?; + let partial_result2 = + crate::collect(Arc::clone(&partial2) as _, Arc::clone(&task_ctx)).await?; + + // The partial results have state schema (group cols + accumulator state) + let partial_schema = partial1.schema(); + + // Step 3: PartialReduce — combine partial results, still producing state + let combined_input = TestMemoryExec::try_new_exec( + &[partial_result1, partial_result2], + Arc::clone(&partial_schema), + None, + )?; + // Coalesce into a single partition for the PartialReduce + let coalesced = Arc::new(CoalescePartitionsExec::new(combined_input)); + + let partial_reduce = Arc::new(AggregateExec::try_new( + AggregateMode::PartialReduce, + groups.clone(), + aggregates.clone(), + vec![None], + coalesced, + Arc::clone(&partial_schema), + )?); + + // Verify PartialReduce output schema matches Partial output schema + // (both produce state, not final values) + assert_eq!(partial_reduce.schema(), partial_schema); + + // Collect PartialReduce results + let reduce_result = + crate::collect(Arc::clone(&partial_reduce) as _, Arc::clone(&task_ctx)) + .await?; + + // Step 4: Final aggregation on the PartialReduce output + let final_input = TestMemoryExec::try_new_exec( + &[reduce_result], + Arc::clone(&partial_schema), + None, + )?; + let final_agg = Arc::new(AggregateExec::try_new( + AggregateMode::Final, + groups.clone(), + aggregates.clone(), + vec![None], + final_input, + Arc::clone(&partial_schema), + )?); + + let result = crate::collect(final_agg, Arc::clone(&task_ctx)).await?; + + // Expected: group 1 -> 10+40=50, group 2 -> 20+50=70, group 3 -> 30+60=90 + assert_snapshot!(batches_to_sort_string(&result), @r" + +---+--------+ + | a | SUM(b) | + +---+--------+ + | 1 | 50.0 | + | 2 | 70.0 | + | 3 | 90.0 | + +---+--------+ + "); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/aggregates/no_grouping.rs b/datafusion/physical-plan/src/aggregates/no_grouping.rs index a55d70ca6fb27..a7dd7c9a66cb1 100644 --- a/datafusion/physical-plan/src/aggregates/no_grouping.rs +++ b/datafusion/physical-plan/src/aggregates/no_grouping.rs @@ -18,8 +18,9 @@ //! Aggregate without grouping columns use crate::aggregates::{ - AccumulatorItem, AggrDynFilter, AggregateMode, DynamicFilterAggregateType, - aggregate_expressions, create_accumulators, finalize_aggregation, + AccumulatorItem, AggrDynFilter, AggregateInputMode, AggregateMode, + DynamicFilterAggregateType, aggregate_expressions, create_accumulators, + finalize_aggregation, }; use crate::metrics::{BaselineMetrics, RecordOutput}; use crate::{RecordBatchStream, SendableRecordBatchStream}; @@ -61,7 +62,7 @@ struct AggregateStreamInner { mode: AggregateMode, input: SendableRecordBatchStream, aggregate_expressions: Vec>>, - filter_expressions: Vec>>, + filter_expressions: Arc<[Option>]>, // ==== Runtime States/Buffers ==== accumulators: Vec, @@ -160,6 +161,8 @@ impl AggregateStreamInner { return Ok(()); }; + let mut bounds_changed = false; + for acc_info in &filter_state.supported_accumulators_info { let acc = self.accumulators @@ -175,20 +178,27 @@ impl AggregateStreamInner { let current_bound = acc.evaluate()?; { let mut bound = acc_info.shared_bound.lock(); - match acc_info.aggr_type { + let new_bound = match acc_info.aggr_type { DynamicFilterAggregateType::Max => { - *bound = scalar_max(&bound, ¤t_bound)?; + scalar_max(&bound, ¤t_bound)? } DynamicFilterAggregateType::Min => { - *bound = scalar_min(&bound, ¤t_bound)?; + scalar_min(&bound, ¤t_bound)? } + }; + if new_bound != *bound { + *bound = new_bound; + bounds_changed = true; } } } - // Step 2: Sync the dynamic filter physical expression with reader - let predicate = self.build_dynamic_filter_from_accumulator_bounds()?; - filter_state.filter.update(predicate)?; + // Step 2: Sync the dynamic filter physical expression with reader, + // but only if any bound actually changed. + if bounds_changed { + let predicate = self.build_dynamic_filter_from_accumulator_bounds()?; + filter_state.filter.update(predicate)?; + } Ok(()) } @@ -276,19 +286,15 @@ impl AggregateStream { partition: usize, ) -> Result { let agg_schema = Arc::clone(&agg.schema); - let agg_filter_expr = agg.filter_expr.clone(); + let agg_filter_expr = Arc::clone(&agg.filter_expr); let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition); let input = agg.input.execute(partition, Arc::clone(context))?; let aggregate_expressions = aggregate_expressions(&agg.aggr_expr, &agg.mode, 0)?; - let filter_expressions = match agg.mode { - AggregateMode::Partial - | AggregateMode::Single - | AggregateMode::SinglePartitioned => agg_filter_expr, - AggregateMode::Final | AggregateMode::FinalPartitioned => { - vec![None; agg.aggr_expr.len()] - } + let filter_expressions = match agg.mode.input_mode() { + AggregateInputMode::Raw => agg_filter_expr, + AggregateInputMode::Partial => vec![None; agg.aggr_expr.len()].into(), }; let accumulators = create_accumulators(&agg.aggr_expr)?; @@ -455,13 +461,9 @@ fn aggregate_batch( // 1.4 let size_pre = accum.size(); - let res = match mode { - AggregateMode::Partial - | AggregateMode::Single - | AggregateMode::SinglePartitioned => accum.update_batch(&values), - AggregateMode::Final | AggregateMode::FinalPartitioned => { - accum.merge_batch(&values) - } + let res = match mode.input_mode() { + AggregateInputMode::Raw => accum.update_batch(&values), + AggregateInputMode::Partial => accum.merge_batch(&values), }; let size_post = accum.size(); allocated += size_post.saturating_sub(size_pre); diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 1ae7202711112..de857370ce285 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -26,8 +26,8 @@ use super::order::GroupOrdering; use crate::aggregates::group_values::{GroupByMetrics, GroupValues, new_group_values}; use crate::aggregates::order::GroupOrderingFull; use crate::aggregates::{ - AggregateMode, PhysicalGroupBy, create_schema, evaluate_group_by, evaluate_many, - evaluate_optional, + AggregateInputMode, AggregateMode, AggregateOutputMode, PhysicalGroupBy, + create_schema, evaluate_group_by, evaluate_many, evaluate_optional, }; use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput}; use crate::sorts::sort::sort_batch; @@ -377,10 +377,10 @@ pub(crate) struct GroupedHashAggregateStream { /// /// For example, for an aggregate like `SUM(x) FILTER (WHERE x >= 100)`, /// the filter expression is `x > 100`. - filter_expressions: Vec>>, + filter_expressions: Arc<[Option>]>, /// GROUP BY expressions - group_by: PhysicalGroupBy, + group_by: Arc, /// max rows in output RecordBatches batch_size: usize, @@ -465,8 +465,8 @@ impl GroupedHashAggregateStream { ) -> Result { debug!("Creating GroupedHashAggregateStream"); let agg_schema = Arc::clone(&agg.schema); - let agg_group_by = agg.group_by.clone(); - let agg_filter_expr = agg.filter_expr.clone(); + let agg_group_by = Arc::clone(&agg.group_by); + let agg_filter_expr = Arc::clone(&agg.filter_expr); let batch_size = context.session_config().batch_size(); let input = agg.input.execute(partition, Arc::clone(context))?; @@ -475,7 +475,7 @@ impl GroupedHashAggregateStream { let timer = baseline_metrics.elapsed_compute().timer(); - let aggregate_exprs = agg.aggr_expr.clone(); + let aggregate_exprs = Arc::clone(&agg.aggr_expr); // arguments for each aggregate, one vec of expressions per // aggregate @@ -491,13 +491,9 @@ impl GroupedHashAggregateStream { agg_group_by.num_group_exprs(), )?; - let filter_expressions = match agg.mode { - AggregateMode::Partial - | AggregateMode::Single - | AggregateMode::SinglePartitioned => agg_filter_expr, - AggregateMode::Final | AggregateMode::FinalPartitioned => { - vec![None; agg.aggr_expr.len()] - } + let filter_expressions = match agg.mode.input_mode() { + AggregateInputMode::Raw => agg_filter_expr, + AggregateInputMode::Partial => vec![None; agg.aggr_expr.len()].into(), }; // Instantiate the accumulators @@ -679,7 +675,7 @@ impl GroupedHashAggregateStream { group_ordering, input_done: false, spill_state, - group_values_soft_limit: agg.limit, + group_values_soft_limit: agg.limit_options().map(|config| config.limit()), skip_aggregation_probe, reduction_factor, }) @@ -982,29 +978,24 @@ impl GroupedHashAggregateStream { // Call the appropriate method on each aggregator with // the entire input row and the relevant group indexes - match self.mode { - AggregateMode::Partial - | AggregateMode::Single - | AggregateMode::SinglePartitioned - if !self.spill_state.is_stream_merging => - { - acc.update_batch( - values, - group_indices, - opt_filter, - total_num_groups, - )?; - } - _ => { - assert_or_internal_err!( - opt_filter.is_none(), - "aggregate filter should be applied in partial stage, there should be no filter in final stage" - ); - - // if aggregation is over intermediate states, - // use merge - acc.merge_batch(values, group_indices, None, total_num_groups)?; - } + if self.mode.input_mode() == AggregateInputMode::Raw + && !self.spill_state.is_stream_merging + { + acc.update_batch( + values, + group_indices, + opt_filter, + total_num_groups, + )?; + } else { + assert_or_internal_err!( + opt_filter.is_none(), + "aggregate filter should be applied in partial stage, there should be no filter in final stage" + ); + + // if aggregation is over intermediate states, + // use merge + acc.merge_batch(values, group_indices, None, total_num_groups)?; } self.group_by_metrics .aggregation_time @@ -1092,17 +1083,12 @@ impl GroupedHashAggregateStream { // Next output each aggregate value for acc in self.accumulators.iter_mut() { - match self.mode { - AggregateMode::Partial => output.extend(acc.state(emit_to)?), - _ if spilling => { - // If spilling, output partial state because the spilled data will be - // merged and re-evaluated later. - output.extend(acc.state(emit_to)?) - } - AggregateMode::Final - | AggregateMode::FinalPartitioned - | AggregateMode::Single - | AggregateMode::SinglePartitioned => output.push(acc.evaluate(emit_to)?), + if self.mode.output_mode() == AggregateOutputMode::Final && !spilling { + output.push(acc.evaluate(emit_to)?) + } else { + // Output partial state: either because we're in a non-final mode, + // or because we're spilling and will merge/re-evaluate later. + output.extend(acc.state(emit_to)?) } } drop(timer); diff --git a/datafusion/physical-plan/src/aggregates/topk_stream.rs b/datafusion/physical-plan/src/aggregates/topk_stream.rs index a43b5cff12989..4aa566ccfcd0a 100644 --- a/datafusion/physical-plan/src/aggregates/topk_stream.rs +++ b/datafusion/physical-plan/src/aggregates/topk_stream.rs @@ -19,6 +19,7 @@ use crate::aggregates::group_values::GroupByMetrics; use crate::aggregates::topk::priority_map::PriorityMap; +#[cfg(debug_assertions)] use crate::aggregates::topk_types_supported; use crate::aggregates::{ AggregateExec, PhysicalGroupBy, aggregate_expressions, evaluate_group_by, @@ -33,6 +34,7 @@ use datafusion_common::Result; use datafusion_common::internal_datafusion_err; use datafusion_execution::TaskContext; use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::metrics::RecordOutput; use futures::stream::{Stream, StreamExt}; use log::{Level, trace}; use std::pin::Pin; @@ -48,7 +50,7 @@ pub struct GroupedTopKAggregateStream { baseline_metrics: BaselineMetrics, group_by_metrics: GroupByMetrics, aggregate_arguments: Vec>>, - group_by: PhysicalGroupBy, + group_by: Arc, priority_map: PriorityMap, } @@ -60,19 +62,33 @@ impl GroupedTopKAggregateStream { limit: usize, ) -> Result { let agg_schema = Arc::clone(&aggr.schema); - let group_by = aggr.group_by.clone(); + let group_by = Arc::clone(&aggr.group_by); let input = aggr.input.execute(partition, Arc::clone(context))?; let baseline_metrics = BaselineMetrics::new(&aggr.metrics, partition); let group_by_metrics = GroupByMetrics::new(&aggr.metrics, partition); let aggregate_arguments = aggregate_expressions(&aggr.aggr_expr, &aggr.mode, group_by.expr.len())?; - let (val_field, desc) = aggr - .get_minmax_desc() - .ok_or_else(|| internal_datafusion_err!("Min/max required"))?; let (expr, _) = &aggr.group_expr().expr()[0]; let kt = expr.data_type(&aggr.input().schema())?; - let vt = val_field.data_type().clone(); + + // Check if this is a MIN/MAX aggregate or a DISTINCT-like operation + let (vt, desc) = if let Some((val_field, desc)) = aggr.get_minmax_desc() { + // MIN/MAX case: use the aggregate output type + (val_field.data_type().clone(), desc) + } else { + // DISTINCT case: use the group key type and get ordering from limit_order_descending + // The ordering direction is set by the optimizer when it pushes down the limit + let desc = aggr + .limit_options() + .and_then(|config| config.descending) + .ok_or_else(|| { + internal_datafusion_err!( + "Ordering direction required for DISTINCT with limit" + ) + })?; + (kt.clone(), desc) + }; // Type validation is performed by the optimizer and can_use_topk() check. // This debug assertion documents the contract without runtime overhead in release builds. @@ -168,18 +184,21 @@ impl Stream for GroupedTopKAggregateStream { "Exactly 1 group value required" ); let group_by_values = Arc::clone(&group_by_values[0][0]); - let input_values = { - let _timer = (!self.aggregate_arguments.is_empty()).then(|| { - self.group_by_metrics.aggregate_arguments_time.timer() - }); - evaluate_many( + let input_values = if self.aggregate_arguments.is_empty() { + // DISTINCT case: use group key as both key and value + Arc::clone(&group_by_values) + } else { + // MIN/MAX case: evaluate aggregate expressions + let _timer = + self.group_by_metrics.aggregate_arguments_time.timer(); + let input_values = evaluate_many( &self.aggregate_arguments, batches.first().unwrap(), - )? + )?; + assert_eq!(input_values.len(), 1, "Exactly 1 input required"); + assert_eq!(input_values[0].len(), 1, "Exactly 1 input required"); + Arc::clone(&input_values[0][0]) }; - assert_eq!(input_values.len(), 1, "Exactly 1 input required"); - assert_eq!(input_values[0].len(), 1, "Exactly 1 input required"); - let input_values = Arc::clone(&input_values[0][0]); // iterate over each column of group_by values (*self).intern(&group_by_values, &input_values)?; @@ -192,9 +211,15 @@ impl Stream for GroupedTopKAggregateStream { } let batch = { let _timer = emitting_time.timer(); - let cols = self.priority_map.emit()?; + let mut cols = self.priority_map.emit()?; + // For DISTINCT case (no aggregate expressions), only use the group key column + // since the schema only has one field and key/value are the same + if self.aggregate_arguments.is_empty() { + cols.truncate(1); + } RecordBatch::try_new(Arc::clone(&self.schema), cols)? }; + let batch = batch.record_output(&self.baseline_metrics); trace!( "partition {} emit batch with {} rows", self.partition, diff --git a/datafusion/physical-plan/src/buffer.rs b/datafusion/physical-plan/src/buffer.rs new file mode 100644 index 0000000000000..3b80f9924e311 --- /dev/null +++ b/datafusion/physical-plan/src/buffer.rs @@ -0,0 +1,629 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`BufferExec`] decouples production and consumption on messages by buffering the input in the +//! background up to a certain capacity. + +use crate::execution_plan::{CardinalityEffect, SchedulingType}; +use crate::filter_pushdown::{ + ChildPushdownResult, FilterDescription, FilterPushdownPhase, + FilterPushdownPropagation, +}; +use crate::projection::ProjectionExec; +use crate::stream::RecordBatchStreamAdapter; +use crate::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SortOrderPushdownResult, +}; +use arrow::array::RecordBatch; +use datafusion_common::config::ConfigOptions; +use datafusion_common::{Result, Statistics, internal_err, plan_err}; +use datafusion_common_runtime::SpawnedTask; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr_common::metrics::{ + ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, +}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; +use futures::{Stream, StreamExt, TryStreamExt}; +use pin_project_lite::pin_project; +use std::any::Any; +use std::fmt; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::task::{Context, Poll}; +use tokio::sync::mpsc::UnboundedReceiver; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; + +/// WARNING: EXPERIMENTAL +/// +/// Decouples production and consumption of record batches with an internal queue per partition, +/// eagerly filling up the capacity of the queues even before any message is requested. +/// +/// ```text +/// ┌───────────────────────────┐ +/// │ BufferExec │ +/// │ │ +/// │┌────── Partition 0 ──────┐│ +/// ││ ┌────┐ ┌────┐││ ┌────┐ +/// ──background poll────────▶│ │ │ ├┼┼───────▶ │ +/// ││ └────┘ └────┘││ └────┘ +/// │└─────────────────────────┘│ +/// │┌────── Partition 1 ──────┐│ +/// ││ ┌────┐ ┌────┐ ┌────┐││ ┌────┐ +/// ──background poll─▶│ │ │ │ │ ├┼┼───────▶ │ +/// ││ └────┘ └────┘ └────┘││ └────┘ +/// │└─────────────────────────┘│ +/// │ │ +/// │ ... │ +/// │ │ +/// │┌────── Partition N ──────┐│ +/// ││ ┌────┐││ ┌────┐ +/// ──background poll───────────────▶│ ├┼┼───────▶ │ +/// ││ └────┘││ └────┘ +/// │└─────────────────────────┘│ +/// └───────────────────────────┘ +/// ``` +/// +/// The capacity is provided in bytes, and for each buffered record batch it will take into account +/// the size reported by [RecordBatch::get_array_memory_size]. +/// +/// If a single record batch exceeds the maximum capacity set in the `capacity` argument, it's still +/// allowed to pass in order to not deadlock the buffer. +/// +/// This is useful for operators that conditionally start polling one of their children only after +/// other child has finished, allowing to perform some early work and accumulating batches in +/// memory so that they can be served immediately when requested. +#[derive(Debug, Clone)] +pub struct BufferExec { + input: Arc, + properties: PlanProperties, + capacity: usize, + metrics: ExecutionPlanMetricsSet, +} + +impl BufferExec { + /// Builds a new [BufferExec] with the provided capacity in bytes. + pub fn new(input: Arc, capacity: usize) -> Self { + let properties = input + .properties() + .clone() + .with_scheduling_type(SchedulingType::Cooperative); + + Self { + input, + properties, + capacity, + metrics: ExecutionPlanMetricsSet::new(), + } + } + + /// Returns the input [ExecutionPlan] of this [BufferExec]. + pub fn input(&self) -> &Arc { + &self.input + } + + /// Returns the per-partition capacity in bytes for this [BufferExec]. + pub fn capacity(&self) -> usize { + self.capacity + } +} + +impl DisplayAs for BufferExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "BufferExec: capacity={}", self.capacity) + } + DisplayFormatType::TreeRender => { + writeln!(f, "target_batch_size={}", self.capacity) + } + } + } +} + +impl ExecutionPlan for BufferExec { + fn name(&self) -> &str { + "BufferExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn maintains_input_order(&self) -> Vec { + vec![true] + } + + fn benefits_from_input_partitioning(&self) -> Vec { + vec![false] + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> Result> { + if children.len() != 1 { + return plan_err!("BufferExec can only have one child"); + } + Ok(Arc::new(Self::new(children.swap_remove(0), self.capacity))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let mem_reservation = MemoryConsumer::new(format!("BufferExec[{partition}]")) + .register(context.memory_pool()); + let in_stream = self.input.execute(partition, context)?; + + // Set up the metrics for the stream. + let curr_mem_in = Arc::new(AtomicUsize::new(0)); + let curr_mem_out = Arc::clone(&curr_mem_in); + let mut max_mem_in = 0; + let max_mem = MetricBuilder::new(&self.metrics).gauge("max_mem_used", partition); + + let curr_queued_in = Arc::new(AtomicUsize::new(0)); + let curr_queued_out = Arc::clone(&curr_queued_in); + let mut max_queued_in = 0; + let max_queued = MetricBuilder::new(&self.metrics).gauge("max_queued", partition); + + // Capture metrics when an element is queued on the stream. + let in_stream = in_stream.inspect_ok(move |v| { + let size = v.get_array_memory_size(); + let curr_size = curr_mem_in.fetch_add(size, Ordering::Relaxed) + size; + if curr_size > max_mem_in { + max_mem_in = curr_size; + max_mem.set(max_mem_in); + } + + let curr_queued = curr_queued_in.fetch_add(1, Ordering::Relaxed) + 1; + if curr_queued > max_queued_in { + max_queued_in = curr_queued; + max_queued.set(max_queued_in); + } + }); + // Buffer the input. + let out_stream = + MemoryBufferedStream::new(in_stream, self.capacity, mem_reservation); + // Update in the metrics that when an element gets out, some memory gets freed. + let out_stream = out_stream.inspect_ok(move |v| { + curr_mem_out.fetch_sub(v.get_array_memory_size(), Ordering::Relaxed); + curr_queued_out.fetch_sub(1, Ordering::Relaxed); + }); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + out_stream, + ))) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn partition_statistics(&self, partition: Option) -> Result { + self.input.partition_statistics(partition) + } + + fn supports_limit_pushdown(&self) -> bool { + self.input.supports_limit_pushdown() + } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::Equal + } + + fn try_swapping_with_projection( + &self, + projection: &ProjectionExec, + ) -> Result>> { + match self.input.try_swapping_with_projection(projection)? { + Some(new_input) => Ok(Some( + Arc::new(self.clone()).with_new_children(vec![new_input])?, + )), + None => Ok(None), + } + } + + fn gather_filters_for_pushdown( + &self, + _phase: FilterPushdownPhase, + parent_filters: Vec>, + _config: &ConfigOptions, + ) -> Result { + FilterDescription::from_children(parent_filters, &self.children()) + } + + fn handle_child_pushdown_result( + &self, + _phase: FilterPushdownPhase, + child_pushdown_result: ChildPushdownResult, + _config: &ConfigOptions, + ) -> Result>> { + Ok(FilterPushdownPropagation::if_all(child_pushdown_result)) + } + + fn try_pushdown_sort( + &self, + order: &[PhysicalSortExpr], + ) -> Result>> { + // CoalesceBatchesExec is transparent for sort ordering - it preserves order + // Delegate to the child and wrap with a new CoalesceBatchesExec + self.input.try_pushdown_sort(order)?.try_map(|new_input| { + Ok(Arc::new(Self::new(new_input, self.capacity)) as Arc) + }) + } +} + +/// Represents anything that occupies a capacity in a [MemoryBufferedStream]. +pub trait SizedMessage { + fn size(&self) -> usize; +} + +impl SizedMessage for RecordBatch { + fn size(&self) -> usize { + self.get_array_memory_size() + } +} + +pin_project! { +/// Decouples production and consumption of messages in a stream with an internal queue, eagerly +/// filling it up to the specified maximum capacity even before any message is requested. +/// +/// Allows each message to have a different size, which is taken into account for determining if +/// the queue is full or not. +pub struct MemoryBufferedStream { + task: SpawnedTask<()>, + batch_rx: UnboundedReceiver>, + memory_reservation: Arc, +}} + +impl MemoryBufferedStream { + /// Builds a new [MemoryBufferedStream] with the provided capacity and event handler. + /// + /// This immediately spawns a Tokio task that will start consumption of the input stream. + pub fn new( + mut input: impl Stream> + Unpin + Send + 'static, + capacity: usize, + memory_reservation: MemoryReservation, + ) -> Self { + let semaphore = Arc::new(Semaphore::new(capacity)); + let (batch_tx, batch_rx) = tokio::sync::mpsc::unbounded_channel(); + + let memory_reservation = Arc::new(memory_reservation); + let memory_reservation_clone = Arc::clone(&memory_reservation); + let task = SpawnedTask::spawn(async move { + loop { + // Select on both the input stream and the channel being closed. + // By down this, we abort polling the input as soon as the consumer channel is + // closed. Otherwise, we would need to wait for a full new message to be available + // in order to consider aborting the stream + let item_or_err = tokio::select! { + biased; + _ = batch_tx.closed() => break, + item_or_err = input.next() => { + let Some(item_or_err) = item_or_err else { + break; // stream finished + }; + item_or_err + } + }; + + let item = match item_or_err { + Ok(batch) => batch, + Err(err) => { + let _ = batch_tx.send(Err(err)); // If there's an error it means the channel was closed, which is fine. + break; + } + }; + + let size = item.size(); + if let Err(err) = memory_reservation.try_grow(size) { + let _ = batch_tx.send(Err(err)); // If there's an error it means the channel was closed, which is fine. + break; + } + + // We need to cap the minimum between amount of permits and the actual size of the + // message. If at any point we try to acquire more permits than the capacity of the + // semaphore, the stream will deadlock. + let capped_size = size.min(capacity) as u32; + + let semaphore = Arc::clone(&semaphore); + let Ok(permit) = semaphore.acquire_many_owned(capped_size).await else { + let _ = batch_tx.send(internal_err!("Closed semaphore in MemoryBufferedStream. This is a bug in DataFusion, please report it!")); + break; + }; + + if batch_tx.send(Ok((item, permit))).is_err() { + break; // stream was closed + }; + } + }); + + Self { + task, + batch_rx, + memory_reservation: memory_reservation_clone, + } + } + + /// Returns the number of queued messages. + pub fn messages_queued(&self) -> usize { + self.batch_rx.len() + } +} + +impl Stream for MemoryBufferedStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let self_project = self.project(); + match self_project.batch_rx.poll_recv(cx) { + Poll::Ready(Some(Ok((item, _semaphore_permit)))) => { + self_project.memory_reservation.shrink(item.size()); + Poll::Ready(Some(Ok(item))) + } + Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } + + fn size_hint(&self) -> (usize, Option) { + if self.batch_rx.is_closed() { + let len = self.batch_rx.len(); + (len, Some(len)) + } else { + (self.batch_rx.len(), None) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_common::{DataFusionError, assert_contains}; + use datafusion_execution::memory_pool::{ + GreedyMemoryPool, MemoryPool, UnboundedMemoryPool, + }; + use std::error::Error; + use std::fmt::Debug; + use std::sync::Arc; + use std::time::Duration; + use tokio::time::timeout; + + #[tokio::test] + async fn buffers_only_some_messages() -> Result<(), Box> { + let input = futures::stream::iter([1, 2, 3, 4]).map(Ok); + let (_, res) = memory_pool_and_reservation(); + + let buffered = MemoryBufferedStream::new(input, 4, res); + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 2); + Ok(()) + } + + #[tokio::test] + async fn yields_all_messages() -> Result<(), Box> { + let input = futures::stream::iter([1, 2, 3, 4]).map(Ok); + let (_, res) = memory_pool_and_reservation(); + + let mut buffered = MemoryBufferedStream::new(input, 10, res); + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 4); + + pull_ok_msg(&mut buffered).await?; + pull_ok_msg(&mut buffered).await?; + pull_ok_msg(&mut buffered).await?; + pull_ok_msg(&mut buffered).await?; + finished(&mut buffered).await?; + Ok(()) + } + + #[tokio::test] + async fn yields_first_msg_even_if_big() -> Result<(), Box> { + let input = futures::stream::iter([25, 1, 2, 3]).map(Ok); + let (_, res) = memory_pool_and_reservation(); + + let mut buffered = MemoryBufferedStream::new(input, 10, res); + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 1); + pull_ok_msg(&mut buffered).await?; + Ok(()) + } + + #[tokio::test] + async fn memory_pool_kills_stream() -> Result<(), Box> { + let input = futures::stream::iter([1, 2, 3, 4]).map(Ok); + let (_, res) = bounded_memory_pool_and_reservation(7); + + let mut buffered = MemoryBufferedStream::new(input, 10, res); + wait_for_buffering().await; + + pull_ok_msg(&mut buffered).await?; + pull_ok_msg(&mut buffered).await?; + pull_ok_msg(&mut buffered).await?; + let msg = pull_err_msg(&mut buffered).await?; + + assert_contains!(msg.to_string(), "Failed to allocate additional 4.0 B"); + Ok(()) + } + + #[tokio::test] + async fn memory_pool_does_not_kill_stream() -> Result<(), Box> { + let input = futures::stream::iter([1, 2, 3, 4]).map(Ok); + let (_, res) = bounded_memory_pool_and_reservation(7); + + let mut buffered = MemoryBufferedStream::new(input, 3, res); + wait_for_buffering().await; + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + finished(&mut buffered).await?; + Ok(()) + } + + #[tokio::test] + async fn messages_pass_even_if_all_exceed_limit() -> Result<(), Box> { + let input = futures::stream::iter([3, 3, 3, 3]).map(Ok); + let (_, res) = memory_pool_and_reservation(); + + let mut buffered = MemoryBufferedStream::new(input, 2, res); + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 1); + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 1); + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 1); + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 1); + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + finished(&mut buffered).await?; + Ok(()) + } + + #[tokio::test] + async fn errors_get_propagated() -> Result<(), Box> { + let input = futures::stream::iter([1, 2, 3, 4]).map(|v| { + if v == 3 { + return internal_err!("Error on 3"); + } + Ok(v) + }); + let (_, res) = memory_pool_and_reservation(); + + let mut buffered = MemoryBufferedStream::new(input, 10, res); + wait_for_buffering().await; + + pull_ok_msg(&mut buffered).await?; + pull_ok_msg(&mut buffered).await?; + pull_err_msg(&mut buffered).await?; + + Ok(()) + } + + #[tokio::test] + async fn memory_gets_released_if_stream_drops() -> Result<(), Box> { + let input = futures::stream::iter([1, 2, 3, 4]).map(Ok); + let (pool, res) = memory_pool_and_reservation(); + + let mut buffered = MemoryBufferedStream::new(input, 10, res); + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 4); + assert_eq!(pool.reserved(), 10); + + pull_ok_msg(&mut buffered).await?; + assert_eq!(buffered.messages_queued(), 3); + assert_eq!(pool.reserved(), 9); + + pull_ok_msg(&mut buffered).await?; + assert_eq!(buffered.messages_queued(), 2); + assert_eq!(pool.reserved(), 7); + + drop(buffered); + assert_eq!(pool.reserved(), 0); + Ok(()) + } + + fn memory_pool_and_reservation() -> (Arc, MemoryReservation) { + let pool = Arc::new(UnboundedMemoryPool::default()) as _; + let reservation = MemoryConsumer::new("test").register(&pool); + (pool, reservation) + } + + fn bounded_memory_pool_and_reservation( + size: usize, + ) -> (Arc, MemoryReservation) { + let pool = Arc::new(GreedyMemoryPool::new(size)) as _; + let reservation = MemoryConsumer::new("test").register(&pool); + (pool, reservation) + } + + async fn wait_for_buffering() { + // We do not have control over the spawned task, so the best we can do is to yield some + // cycles to the tokio runtime and let the task make progress on its own. + tokio::time::sleep(Duration::from_millis(1)).await; + } + + async fn pull_ok_msg( + buffered: &mut MemoryBufferedStream, + ) -> Result> { + Ok(timeout(Duration::from_millis(1), buffered.next()) + .await? + .unwrap_or_else(|| internal_err!("Stream should not have finished"))?) + } + + async fn pull_err_msg( + buffered: &mut MemoryBufferedStream, + ) -> Result> { + Ok(timeout(Duration::from_millis(1), buffered.next()) + .await? + .map(|v| match v { + Ok(v) => internal_err!( + "Stream should not have failed, but succeeded with {v:?}" + ), + Err(err) => Ok(err), + }) + .unwrap_or_else(|| internal_err!("Stream should not have finished"))?) + } + + async fn finished( + buffered: &mut MemoryBufferedStream, + ) -> Result<(), Box> { + match timeout(Duration::from_millis(1), buffered.next()) + .await? + .is_none() + { + true => Ok(()), + false => internal_err!("Stream should have finished")?, + } + } + + impl SizedMessage for usize { + fn size(&self) -> usize { + *self + } + } +} diff --git a/datafusion/physical-plan/src/coalesce/mod.rs b/datafusion/physical-plan/src/coalesce/mod.rs index b3947170d9e41..ea1a87d091481 100644 --- a/datafusion/physical-plan/src/coalesce/mod.rs +++ b/datafusion/physical-plan/src/coalesce/mod.rs @@ -134,6 +134,10 @@ impl LimitedBatchCoalescer { Ok(()) } + pub(crate) fn is_finished(&self) -> bool { + self.finished + } + /// Return the next completed batch, if any pub fn next_completed_batch(&mut self) -> Option { self.inner.next_completed_batch() diff --git a/datafusion/physical-plan/src/coalesce_batches.rs b/datafusion/physical-plan/src/coalesce_batches.rs index dfcd3cb0bcae7..1356eca78329d 100644 --- a/datafusion/physical-plan/src/coalesce_batches.rs +++ b/datafusion/physical-plan/src/coalesce_batches.rs @@ -206,10 +206,6 @@ impl ExecutionPlan for CoalesceBatchesExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { self.input .partition_statistics(partition)? diff --git a/datafusion/physical-plan/src/coalesce_partitions.rs b/datafusion/physical-plan/src/coalesce_partitions.rs index d83f90eb3d8c1..d1fc58837b0fa 100644 --- a/datafusion/physical-plan/src/coalesce_partitions.rs +++ b/datafusion/physical-plan/src/coalesce_partitions.rs @@ -224,10 +224,6 @@ impl ExecutionPlan for CoalescePartitionsExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, _partition: Option) -> Result { self.input .partition_statistics(None)? @@ -278,6 +274,19 @@ impl ExecutionPlan for CoalescePartitionsExec { })) } + fn with_preserve_order( + &self, + preserve_order: bool, + ) -> Option> { + self.input + .with_preserve_order(preserve_order) + .and_then(|new_input| { + Arc::new(self.clone()) + .with_new_children(vec![new_input]) + .ok() + }) + } + fn gather_filters_for_pushdown( &self, _phase: FilterPushdownPhase, diff --git a/datafusion/physical-plan/src/column_rewriter.rs b/datafusion/physical-plan/src/column_rewriter.rs new file mode 100644 index 0000000000000..7cd8656304554 --- /dev/null +++ b/datafusion/physical-plan/src/column_rewriter.rs @@ -0,0 +1,383 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use datafusion_common::{ + DataFusionError, HashMap, + tree_node::{Transformed, TreeNodeRecursion, TreeNodeRewriter}, +}; +use datafusion_physical_expr::{PhysicalExpr, expressions::Column}; + +/// Rewrite column references in a physical expr according to a mapping. +/// +/// This rewriter traverses the expression tree and replaces [`Column`] nodes +/// with the corresponding expression found in the `column_map`. +/// +/// If a column is found in the map, it is replaced by the mapped expression. +/// If a column is NOT found in the map, a `DataFusionError::Internal` is +/// returned. +pub struct PhysicalColumnRewriter<'a> { + /// Mapping from original column to new column. + pub column_map: &'a HashMap>, +} + +impl<'a> PhysicalColumnRewriter<'a> { + /// Create a new PhysicalColumnRewriter with the given column mapping. + pub fn new(column_map: &'a HashMap>) -> Self { + Self { column_map } + } +} + +impl<'a> TreeNodeRewriter for PhysicalColumnRewriter<'a> { + type Node = Arc; + + fn f_down( + &mut self, + node: Self::Node, + ) -> datafusion_common::Result> { + if let Some(column) = node.as_any().downcast_ref::() { + if let Some(new_column) = self.column_map.get(column) { + // jump to prevent rewriting the new sub-expression again + return Ok(Transformed::new( + Arc::clone(new_column), + true, + TreeNodeRecursion::Jump, + )); + } else { + // Column not found in mapping + return Err(DataFusionError::Internal(format!( + "Column {column:?} not found in column mapping {:?}", + self.column_map + ))); + } + } + Ok(Transformed::no(node)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::{DataFusionError, Result, tree_node::TreeNode}; + use datafusion_physical_expr::{ + PhysicalExpr, + expressions::{Column, binary, col, lit}, + }; + use std::sync::Arc; + + /// Helper function to create a test schema + fn create_test_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + Field::new("new_col", DataType::Int32, true), + Field::new("inner_col", DataType::Int32, true), + Field::new("another_col", DataType::Int32, true), + ])) + } + + /// Helper function to create a complex nested expression with multiple columns + /// Create: (col_a + col_b) * (col_c - col_d) + col_e + fn create_complex_expression(schema: &Schema) -> Arc { + let col_a = col("a", schema).unwrap(); + let col_b = col("b", schema).unwrap(); + let col_c = col("c", schema).unwrap(); + let col_d = col("d", schema).unwrap(); + let col_e = col("e", schema).unwrap(); + + let add_expr = + binary(col_a, datafusion_expr::Operator::Plus, col_b, schema).unwrap(); + let sub_expr = + binary(col_c, datafusion_expr::Operator::Minus, col_d, schema).unwrap(); + let mul_expr = binary( + add_expr, + datafusion_expr::Operator::Multiply, + sub_expr, + schema, + ) + .unwrap(); + binary(mul_expr, datafusion_expr::Operator::Plus, col_e, schema).unwrap() + } + + /// Helper function to create a deeply nested expression + /// Create: col_a + (col_b + (col_c + (col_d + col_e))) + fn create_deeply_nested_expression(schema: &Schema) -> Arc { + let col_a = col("a", schema).unwrap(); + let col_b = col("b", schema).unwrap(); + let col_c = col("c", schema).unwrap(); + let col_d = col("d", schema).unwrap(); + let col_e = col("e", schema).unwrap(); + + let inner1 = + binary(col_d, datafusion_expr::Operator::Plus, col_e, schema).unwrap(); + let inner2 = + binary(col_c, datafusion_expr::Operator::Plus, inner1, schema).unwrap(); + let inner3 = + binary(col_b, datafusion_expr::Operator::Plus, inner2, schema).unwrap(); + binary(col_a, datafusion_expr::Operator::Plus, inner3, schema).unwrap() + } + + #[test] + fn test_simple_column_replacement_with_jump() -> Result<()> { + let schema = create_test_schema(); + + // Test that Jump prevents re-processing of replaced columns + let mut column_map = HashMap::new(); + column_map.insert(Column::new_with_schema("a", &schema).unwrap(), lit(42i32)); + column_map.insert( + Column::new_with_schema("b", &schema).unwrap(), + lit("replaced_b"), + ); + column_map.insert( + Column::new_with_schema("c", &schema).unwrap(), + col("c", &schema).unwrap(), + ); + column_map.insert( + Column::new_with_schema("d", &schema).unwrap(), + col("d", &schema).unwrap(), + ); + column_map.insert( + Column::new_with_schema("e", &schema).unwrap(), + col("e", &schema).unwrap(), + ); + + let mut rewriter = PhysicalColumnRewriter::new(&column_map); + let expr = create_complex_expression(&schema); + + let result = expr.rewrite(&mut rewriter)?; + + // Verify the transformation occurred + assert!(result.transformed); + + assert_eq!( + format!("{}", result.data), + "(42 + replaced_b) * (c@2 - d@3) + e@4" + ); + + Ok(()) + } + + #[test] + fn test_nested_column_replacement_with_jump() -> Result<()> { + let schema = create_test_schema(); + // Test Jump behavior with deeply nested expressions + let mut column_map = HashMap::new(); + // Replace col_c with a complex expression containing new columns + let replacement_expr = binary( + lit(100i32), + datafusion_expr::Operator::Plus, + col("new_col", &schema).unwrap(), + &schema, + ) + .unwrap(); + column_map.insert( + Column::new_with_schema("c", &schema).unwrap(), + replacement_expr, + ); + column_map.insert( + Column::new_with_schema("a", &schema).unwrap(), + col("a", &schema).unwrap(), + ); + column_map.insert( + Column::new_with_schema("b", &schema).unwrap(), + col("b", &schema).unwrap(), + ); + column_map.insert( + Column::new_with_schema("d", &schema).unwrap(), + col("d", &schema).unwrap(), + ); + column_map.insert( + Column::new_with_schema("e", &schema).unwrap(), + col("e", &schema).unwrap(), + ); + + let mut rewriter = PhysicalColumnRewriter::new(&column_map); + let expr = create_deeply_nested_expression(&schema); + + let result = expr.rewrite(&mut rewriter)?; + + // Verify transformation occurred + assert!(result.transformed); + + assert_eq!( + format!("{}", result.data), + "a@0 + b@1 + 100 + new_col@5 + d@3 + e@4" + ); + + Ok(()) + } + + #[test] + fn test_circular_reference_prevention() -> Result<()> { + let schema = create_test_schema(); + // Test that Jump prevents infinite recursion with circular references + let mut column_map = HashMap::new(); + + // Create a circular reference: col_a -> col_b -> col_a (but Jump should prevent the second visit) + column_map.insert( + Column::new_with_schema("a", &schema).unwrap(), + col("b", &schema).unwrap(), + ); + column_map.insert( + Column::new_with_schema("b", &schema).unwrap(), + col("a", &schema).unwrap(), + ); + + let mut rewriter = PhysicalColumnRewriter::new(&column_map); + + // Start with an expression containing col_a + let expr = binary( + col("a", &schema).unwrap(), + datafusion_expr::Operator::Plus, + col("b", &schema).unwrap(), + &schema, + ) + .unwrap(); + + let result = expr.rewrite(&mut rewriter)?; + + // Verify transformation occurred + assert!(result.transformed); + + assert_eq!(format!("{}", result.data), "b@1 + a@0"); + + Ok(()) + } + + #[test] + fn test_multiple_replacements_in_same_expression() -> Result<()> { + let schema = create_test_schema(); + // Test multiple column replacements in the same complex expression + let mut column_map = HashMap::new(); + + // Replace multiple columns with literals + column_map.insert(Column::new_with_schema("a", &schema).unwrap(), lit(10i32)); + column_map.insert(Column::new_with_schema("c", &schema).unwrap(), lit(20i32)); + column_map.insert(Column::new_with_schema("e", &schema).unwrap(), lit(30i32)); + column_map.insert( + Column::new_with_schema("b", &schema).unwrap(), + col("b", &schema).unwrap(), + ); + column_map.insert( + Column::new_with_schema("d", &schema).unwrap(), + col("d", &schema).unwrap(), + ); + + let mut rewriter = PhysicalColumnRewriter::new(&column_map); + let expr = create_complex_expression(&schema); // (col_a + col_b) * (col_c - col_d) + col_e + + let result = expr.rewrite(&mut rewriter)?; + + // Verify transformation occurred + assert!(result.transformed); + assert_eq!(format!("{}", result.data), "(10 + b@1) * (20 - d@3) + 30"); + + Ok(()) + } + + #[test] + fn test_jump_with_complex_replacement_expression() -> Result<()> { + let schema = create_test_schema(); + // Test Jump behavior when replacing with very complex expressions + let mut column_map = HashMap::new(); + + // Replace col_a with a complex nested expression + let inner_expr = binary( + lit(5i32), + datafusion_expr::Operator::Multiply, + col("a", &schema).unwrap(), + &schema, + ) + .unwrap(); + let middle_expr = binary( + inner_expr, + datafusion_expr::Operator::Plus, + lit(3i32), + &schema, + ) + .unwrap(); + let complex_replacement = binary( + middle_expr, + datafusion_expr::Operator::Minus, + col("another_col", &schema).unwrap(), + &schema, + ) + .unwrap(); + + column_map.insert( + Column::new_with_schema("a", &schema).unwrap(), + complex_replacement, + ); + column_map.insert( + Column::new_with_schema("b", &schema).unwrap(), + col("b", &schema).unwrap(), + ); + + let mut rewriter = PhysicalColumnRewriter::new(&column_map); + + // Create expression: col_a + col_b + let expr = binary( + col("a", &schema).unwrap(), + datafusion_expr::Operator::Plus, + col("b", &schema).unwrap(), + &schema, + ) + .unwrap(); + + let result = expr.rewrite(&mut rewriter)?; + + assert_eq!( + format!("{}", result.data), + "5 * a@0 + 3 - another_col@7 + b@1" + ); + + // Verify transformation occurred + assert!(result.transformed); + + Ok(()) + } + + #[test] + fn test_unmapped_columns_detection() -> Result<()> { + let schema = create_test_schema(); + let mut column_map = HashMap::new(); + + // Only map col_a, leave col_b unmapped + column_map.insert(Column::new_with_schema("a", &schema).unwrap(), lit(42i32)); + + let mut rewriter = PhysicalColumnRewriter::new(&column_map); + + // Create expression: col_a + col_b + let expr = binary( + col("a", &schema).unwrap(), + datafusion_expr::Operator::Plus, + col("b", &schema).unwrap(), + &schema, + ) + .unwrap(); + + let err = expr.rewrite(&mut rewriter).unwrap_err(); + assert!(matches!(err, DataFusionError::Internal(_))); + + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index 32dc60b56ad48..590f6f09e8b9e 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -181,7 +181,7 @@ pub fn compute_record_batch_statistics( /// Checks if the given projection is valid for the given schema. pub fn can_project( schema: &arrow::datatypes::SchemaRef, - projection: Option<&Vec>, + projection: Option<&[usize]>, ) -> Result<()> { match projection { Some(columns) => { diff --git a/datafusion/physical-plan/src/coop.rs b/datafusion/physical-plan/src/coop.rs index a1fad86777408..ce54a451ac4d1 100644 --- a/datafusion/physical-plan/src/coop.rs +++ b/datafusion/physical-plan/src/coop.rs @@ -22,10 +22,15 @@ //! A single call to `poll_next` on a top-level [`Stream`] may potentially perform a lot of work //! before it returns a `Poll::Pending`. Think for instance of calculating an aggregation over a //! large dataset. +//! //! If a `Stream` runs for a long period of time without yielding back to the Tokio executor, //! it can starve other tasks waiting on that executor to execute them. //! Additionally, this prevents the query execution from being cancelled. //! +//! For more background, please also see the [Using Rust async for Query Execution and Cancelling Long-Running Queries blog] +//! +//! [Using Rust async for Query Execution and Cancelling Long-Running Queries blog]: https://datafusion.apache.org/blog/2025/06/30/cancellation +//! //! To ensure that `Stream` implementations yield regularly, operators can insert explicit yield //! points using the utilities in this module. For most operators this is **not** necessary. The //! `Stream`s of the built-in DataFusion operators that generate (rather than manipulate) diff --git a/datafusion/physical-plan/src/display.rs b/datafusion/physical-plan/src/display.rs index 52c37a106b39e..19698cd4ea78c 100644 --- a/datafusion/physical-plan/src/display.rs +++ b/datafusion/physical-plan/src/display.rs @@ -1176,10 +1176,6 @@ mod tests { todo!() } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { if partition.is_some() { return Ok(Statistics::new_unknown(self.schema().as_ref())); diff --git a/datafusion/physical-plan/src/empty.rs b/datafusion/physical-plan/src/empty.rs index fcfbcfa3e8277..64808bbc25167 100644 --- a/datafusion/physical-plan/src/empty.rs +++ b/datafusion/physical-plan/src/empty.rs @@ -21,7 +21,7 @@ use std::any::Any; use std::sync::Arc; use crate::memory::MemoryStream; -use crate::{DisplayAs, PlanProperties, SendableRecordBatchStream, Statistics, common}; +use crate::{DisplayAs, PlanProperties, SendableRecordBatchStream, Statistics}; use crate::{ DisplayFormatType, ExecutionPlan, Partitioning, execution_plan::{Boundedness, EmissionType}, @@ -29,7 +29,8 @@ use crate::{ use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use datafusion_common::{Result, assert_or_internal_err}; +use datafusion_common::stats::Precision; +use datafusion_common::{ColumnStatistics, Result, ScalarValue, assert_or_internal_err}; use datafusion_execution::TaskContext; use datafusion_physical_expr::EquivalenceProperties; @@ -155,10 +156,6 @@ impl ExecutionPlan for EmptyExec { )?)) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { if let Some(partition) = partition { assert_or_internal_err!( @@ -169,20 +166,31 @@ impl ExecutionPlan for EmptyExec { ); } - let batch = self - .data() - .expect("Create empty RecordBatch should not fail"); - Ok(common::compute_record_batch_statistics( - &[batch], - &self.schema, - None, - )) + // Build explicit stats: exact zero rows and bytes, with explicit known column stats + let mut stats = Statistics::default() + .with_num_rows(Precision::Exact(0)) + .with_total_byte_size(Precision::Exact(0)); + + // Add explicit column stats for each field in schema + for _ in self.schema.fields() { + stats = stats.add_column_statistics(ColumnStatistics { + null_count: Precision::Exact(0), + distinct_count: Precision::Exact(0), + min_value: Precision::::Absent, + max_value: Precision::::Absent, + sum_value: Precision::::Absent, + byte_size: Precision::Exact(0), + }); + } + + Ok(stats) } } #[cfg(test)] mod tests { use super::*; + use crate::common; use crate::test; use crate::with_new_children_if_necessary; diff --git a/datafusion/physical-plan/src/execution_plan.rs b/datafusion/physical-plan/src/execution_plan.rs index 06da0b8933c18..2ce1e79601c52 100644 --- a/datafusion/physical-plan/src/execution_plan.rs +++ b/datafusion/physical-plan/src/execution_plan.rs @@ -26,6 +26,7 @@ use crate::sort_pushdown::SortOrderPushdownResult; pub use crate::stream::EmptyRecordBatchStream; pub use datafusion_common::hash_utils; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; pub use datafusion_common::utils::project_schema; pub use datafusion_common::{ColumnStatistics, Statistics, internal_err}; pub use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; @@ -471,17 +472,6 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { None } - /// Returns statistics for this `ExecutionPlan` node. If statistics are not - /// available, should return [`Statistics::new_unknown`] (the default), not - /// an error. - /// - /// For TableScan executors, which supports filter pushdown, special attention - /// needs to be paid to whether the stats returned by this method are exact or not - #[deprecated(since = "48.0.0", note = "Use `partition_statistics` method instead")] - fn statistics(&self) -> Result { - Ok(Statistics::new_unknown(&self.schema())) - } - /// Returns statistics for a specific partition of this `ExecutionPlan` node. /// If statistics are not available, should return [`Statistics::new_unknown`] /// (the default), not an error. @@ -576,6 +566,7 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { } /// Handle the result of a child pushdown. + /// /// This method is called as we recurse back up the plan tree after pushing /// filters down to child nodes via [`ExecutionPlan::gather_filters_for_pushdown`]. /// It allows the current node to process the results of filter pushdown from @@ -708,6 +699,19 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { ) -> Result>> { Ok(SortOrderPushdownResult::Unsupported) } + + /// Returns a variant of this `ExecutionPlan` that is aware of order-sensitivity. + /// + /// This is used to signal to data sources that the output ordering must be + /// preserved, even if it might be more efficient to ignore it (e.g. by + /// skipping some row groups in Parquet). + /// + fn with_preserve_order( + &self, + _preserve_order: bool, + ) -> Option> { + None + } } /// [`ExecutionPlan`] Invariant Level @@ -1384,6 +1388,30 @@ pub fn check_not_null_constraints( Ok(batch) } +/// Make plan ready to be re-executed returning its clone with state reset for all nodes. +/// +/// Some plans will change their internal states after execution, making them unable to be executed again. +/// This function uses [`ExecutionPlan::reset_state`] to reset any internal state within the plan. +/// +/// An example is `CrossJoinExec`, which loads the left table into memory and stores it in the plan. +/// However, if the data of the left table is derived from the work table, it will become outdated +/// as the work table changes. When the next iteration executes this plan again, we must clear the left table. +/// +/// # Limitations +/// +/// While this function enables plan reuse, it does not allow the same plan to be executed if it (OR): +/// +/// * uses dynamic filters, +/// * represents a recursive query. +/// +pub fn reset_plan_states(plan: Arc) -> Result> { + plan.transform_up(|plan| { + let new_plan = Arc::clone(&plan).reset_state()?; + Ok(Transformed::yes(new_plan)) + }) + .data() +} + /// Utility function yielding a string representation of the given [`ExecutionPlan`]. pub fn get_plan_string(plan: &Arc) -> Vec { let formatted = displayable(plan.as_ref()).indent(true).to_string(); @@ -1469,10 +1497,6 @@ mod tests { unimplemented!() } - fn statistics(&self) -> Result { - unimplemented!() - } - fn partition_statistics(&self, _partition: Option) -> Result { unimplemented!() } @@ -1536,10 +1560,6 @@ mod tests { unimplemented!() } - fn statistics(&self) -> Result { - unimplemented!() - } - fn partition_statistics(&self, _partition: Option) -> Result { unimplemented!() } diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 674fe6692adf5..2af0731fb7a63 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -20,19 +20,19 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll, ready}; +use datafusion_physical_expr::projection::{ProjectionRef, combine_projections}; use itertools::Itertools; use super::{ ColumnStatistics, DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, }; -use crate::coalesce::LimitedBatchCoalescer; -use crate::coalesce::PushBatchStatus::LimitReached; +use crate::coalesce::{LimitedBatchCoalescer, PushBatchStatus}; use crate::common::can_project; use crate::execution_plan::CardinalityEffect; use crate::filter_pushdown::{ ChildFilterDescription, ChildPushdownResult, FilterDescription, FilterPushdownPhase, - FilterPushdownPropagation, PushedDown, PushedDownPredicate, + FilterPushdownPropagation, PushedDown, }; use crate::metrics::{MetricBuilder, MetricType}; use crate::projection::{ @@ -58,7 +58,7 @@ use datafusion_expr::Operator; use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::expressions::{BinaryExpr, Column, lit}; use datafusion_physical_expr::intervals::utils::check_support; -use datafusion_physical_expr::utils::collect_columns; +use datafusion_physical_expr::utils::{collect_columns, reassign_expr_columns}; use datafusion_physical_expr::{ AcrossPartitions, AnalysisContext, ConstExpr, ExprBoundaries, PhysicalExpr, analyze, conjunction, split_conjunction, @@ -86,46 +86,166 @@ pub struct FilterExec { /// Properties equivalence properties, partitioning, etc. cache: PlanProperties, /// The projection indices of the columns in the output schema of join - projection: Option>, + projection: Option, /// Target batch size for output batches batch_size: usize, /// Number of rows to fetch fetch: Option, } +/// Builder for [`FilterExec`] to set optional parameters +pub struct FilterExecBuilder { + predicate: Arc, + input: Arc, + projection: Option, + default_selectivity: u8, + batch_size: usize, + fetch: Option, +} + +impl FilterExecBuilder { + /// Create a new builder with required parameters (predicate and input) + pub fn new(predicate: Arc, input: Arc) -> Self { + Self { + predicate, + input, + projection: None, + default_selectivity: FILTER_EXEC_DEFAULT_SELECTIVITY, + batch_size: FILTER_EXEC_DEFAULT_BATCH_SIZE, + fetch: None, + } + } + + /// Set the input execution plan + pub fn with_input(mut self, input: Arc) -> Self { + self.input = input; + self + } + + /// Set the predicate expression + pub fn with_predicate(mut self, predicate: Arc) -> Self { + self.predicate = predicate; + self + } + + /// Set the projection, composing with any existing projection. + /// + /// If a projection is already set, the new projection indices are mapped + /// through the existing projection. For example, if the current projection + /// is `[0, 2, 3]` and `apply_projection(Some(vec![0, 2]))` is called, the + /// resulting projection will be `[0, 3]` (indices 0 and 2 of `[0, 2, 3]`). + /// + /// If no projection is currently set, the new projection is used directly. + /// If `None` is passed, the projection is cleared. + pub fn apply_projection(self, projection: Option>) -> Result { + let projection = projection.map(Into::into); + self.apply_projection_by_ref(projection.as_ref()) + } + + /// The same as [`Self::apply_projection`] but takes projection shared reference. + pub fn apply_projection_by_ref( + mut self, + projection: Option<&ProjectionRef>, + ) -> Result { + // Check if the projection is valid against current output schema + can_project(&self.input.schema(), projection.map(AsRef::as_ref))?; + self.projection = combine_projections(projection, self.projection.as_ref())?; + Ok(self) + } + + /// Set the default selectivity + pub fn with_default_selectivity(mut self, default_selectivity: u8) -> Self { + self.default_selectivity = default_selectivity; + self + } + + /// Set the batch size + pub fn with_batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = batch_size; + self + } + + /// Set the fetch limit + pub fn with_fetch(mut self, fetch: Option) -> Self { + self.fetch = fetch; + self + } + + /// Build the FilterExec, computing properties once with all configured parameters + pub fn build(self) -> Result { + // Validate predicate type + match self.predicate.data_type(self.input.schema().as_ref())? { + DataType::Boolean => {} + other => { + return plan_err!( + "Filter predicate must return BOOLEAN values, got {other:?}" + ); + } + } + + // Validate selectivity + if self.default_selectivity > 100 { + return plan_err!( + "Default filter selectivity value needs to be less than or equal to 100" + ); + } + + // Validate projection if provided + can_project(&self.input.schema(), self.projection.as_deref())?; + + // Compute properties once with all parameters + let cache = FilterExec::compute_properties( + &self.input, + &self.predicate, + self.default_selectivity, + self.projection.as_deref(), + )?; + + Ok(FilterExec { + predicate: self.predicate, + input: self.input, + metrics: ExecutionPlanMetricsSet::new(), + default_selectivity: self.default_selectivity, + cache, + projection: self.projection, + batch_size: self.batch_size, + fetch: self.fetch, + }) + } +} + +impl From<&FilterExec> for FilterExecBuilder { + fn from(exec: &FilterExec) -> Self { + Self { + predicate: Arc::clone(&exec.predicate), + input: Arc::clone(&exec.input), + projection: exec.projection.clone(), + default_selectivity: exec.default_selectivity, + batch_size: exec.batch_size, + fetch: exec.fetch, + // We could cache / copy over PlanProperties + // here but that would require invalidating them in FilterExecBuilder::apply_projection, etc. + // and currently every call to this method ends up invalidating them anyway. + // If useful this can be added in the future as a non-breaking change. + } + } +} + impl FilterExec { - /// Create a FilterExec on an input - #[expect(clippy::needless_pass_by_value)] + /// Create a FilterExec on an input using the builder pattern pub fn try_new( predicate: Arc, input: Arc, ) -> Result { - match predicate.data_type(input.schema().as_ref())? { - DataType::Boolean => { - let default_selectivity = FILTER_EXEC_DEFAULT_SELECTIVITY; - let cache = Self::compute_properties( - &input, - &predicate, - default_selectivity, - None, - )?; - Ok(Self { - predicate, - input: Arc::clone(&input), - metrics: ExecutionPlanMetricsSet::new(), - default_selectivity, - cache, - projection: None, - batch_size: FILTER_EXEC_DEFAULT_BATCH_SIZE, - fetch: None, - }) - } - other => { - plan_err!("Filter predicate must return BOOLEAN values, got {other:?}") - } - } + FilterExecBuilder::new(predicate, input).build() } + /// Get a batch size + pub fn batch_size(&self) -> usize { + self.batch_size + } + + /// Set the default selectivity pub fn with_default_selectivity( mut self, default_selectivity: u8, @@ -140,36 +260,19 @@ impl FilterExec { } /// Return new instance of [FilterExec] with the given projection. + /// + /// # Deprecated + /// Use [`FilterExecBuilder::apply_projection`] instead + #[deprecated( + since = "52.0.0", + note = "Use FilterExecBuilder::apply_projection instead" + )] pub fn with_projection(&self, projection: Option>) -> Result { - // Check if the projection is valid - can_project(&self.schema(), projection.as_ref())?; - - let projection = match projection { - Some(projection) => match &self.projection { - Some(p) => Some(projection.iter().map(|i| p[*i]).collect()), - None => Some(projection), - }, - None => None, - }; - - let cache = Self::compute_properties( - &self.input, - &self.predicate, - self.default_selectivity, - projection.as_ref(), - )?; - Ok(Self { - predicate: Arc::clone(&self.predicate), - input: Arc::clone(&self.input), - metrics: self.metrics.clone(), - default_selectivity: self.default_selectivity, - cache, - projection, - batch_size: self.batch_size, - fetch: self.fetch, - }) + let builder = FilterExecBuilder::from(self); + builder.apply_projection(projection)?.build() } + /// Set the batch size pub fn with_batch_size(&self, batch_size: usize) -> Result { Ok(Self { predicate: Arc::clone(&self.predicate), @@ -199,8 +302,8 @@ impl FilterExec { } /// Projection - pub fn projection(&self) -> Option<&Vec> { - self.projection.as_ref() + pub fn projection(&self) -> &Option { + &self.projection } /// Calculates `Statistics` for `FilterExec`, by applying selectivity (either default, or estimated) to input statistics. @@ -277,7 +380,7 @@ impl FilterExec { input: &Arc, predicate: &Arc, default_selectivity: u8, - projection: Option<&Vec>, + projection: Option<&[usize]>, ) -> Result { // Combine the equal predicates with the input equivalence properties // to construct the equivalence properties: @@ -316,7 +419,7 @@ impl FilterExec { if let Some(projection) = projection { let schema = eq_properties.schema(); let projection_mapping = ProjectionMapping::from_indices(projection, schema)?; - let out_schema = project_schema(schema, Some(projection))?; + let out_schema = project_schema(schema, Some(&projection))?; output_partitioning = output_partitioning.project(&projection_mapping, &eq_properties); eq_properties = eq_properties.project(&projection_mapping, out_schema); @@ -400,13 +503,11 @@ impl ExecutionPlan for FilterExec { self: Arc, mut children: Vec>, ) -> Result> { - FilterExec::try_new(Arc::clone(&self.predicate), children.swap_remove(0)) - .and_then(|e| { - let selectivity = e.default_selectivity(); - e.with_default_selectivity(selectivity) - }) - .and_then(|e| e.with_projection(self.projection().cloned())) - .map(|e| e.with_fetch(self.fetch).unwrap()) + let new_input = children.swap_remove(0); + FilterExecBuilder::from(&*self) + .with_input(new_input) + .build() + .map(|e| Arc::new(e) as _) } fn execute( @@ -441,15 +542,10 @@ impl ExecutionPlan for FilterExec { /// The output statistics of a filtering operation can be estimated if the /// predicate's selectivity value can be determined for the incoming data. - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { let input_stats = self.input.partition_statistics(partition)?; - let schema = self.schema(); let stats = Self::statistics_helper( - &schema, + &self.input.schema(), input_stats, self.predicate(), self.default_selectivity, @@ -473,15 +569,11 @@ impl ExecutionPlan for FilterExec { if let Some(new_predicate) = update_expr(self.predicate(), projection.expr(), false)? { - return FilterExec::try_new( - new_predicate, - make_with_child(projection, self.input())?, - ) - .and_then(|e| { - let selectivity = self.default_selectivity(); - e.with_default_selectivity(selectivity) - }) - .map(|e| Some(Arc::new(e) as _)); + return FilterExecBuilder::from(self) + .with_input(make_with_child(projection, self.input())?) + .with_predicate(new_predicate) + .build() + .map(|e| Some(Arc::new(e) as _)); } } try_embed_projection(projection, self) @@ -494,16 +586,9 @@ impl ExecutionPlan for FilterExec { _config: &ConfigOptions, ) -> Result { if !matches!(phase, FilterPushdownPhase::Pre) { - // For non-pre phase, filters pass through unchanged - let filter_supports = parent_filters - .into_iter() - .map(PushedDownPredicate::supported) - .collect(); - - return Ok(FilterDescription::new().with_child(ChildFilterDescription { - parent_filters: filter_supports, - self_filters: vec![], - })); + let child = + ChildFilterDescription::from_child(&parent_filters, self.input())?; + return Ok(FilterDescription::new().with_child(child)); } let child = ChildFilterDescription::from_child(&parent_filters, self.input())? @@ -527,10 +612,26 @@ impl ExecutionPlan for FilterExec { return Ok(FilterPushdownPropagation::if_all(child_pushdown_result)); } // We absorb any parent filters that were not handled by our children - let unsupported_parent_filters = - child_pushdown_result.parent_filters.iter().filter_map(|f| { - matches!(f.all(), PushedDown::No).then_some(Arc::clone(&f.filter)) - }); + let mut unsupported_parent_filters: Vec> = + child_pushdown_result + .parent_filters + .iter() + .filter_map(|f| { + matches!(f.all(), PushedDown::No).then_some(Arc::clone(&f.filter)) + }) + .collect(); + + // If this FilterExec has a projection, the unsupported parent filters + // are in the output schema (after projection) coordinates. We need to + // remap them to the input schema coordinates before combining with self filters. + if self.projection.is_some() { + let input_schema = self.input().schema(); + unsupported_parent_filters = unsupported_parent_filters + .into_iter() + .map(|expr| reassign_expr_columns(expr, &input_schema)) + .collect::>>()?; + } + let unsupported_self_filters = child_pushdown_result .self_filters .first() @@ -552,7 +653,7 @@ impl ExecutionPlan for FilterExec { let new_predicate = conjunction(unhandled_filters); let updated_node = if new_predicate.eq(&lit(true)) { // FilterExec is no longer needed, but we may need to leave a projection in place - match self.projection() { + match self.projection().as_ref() { Some(projection_indices) => { let filter_child_schema = filter_input.schema(); let proj_exprs = projection_indices @@ -578,7 +679,7 @@ impl ExecutionPlan for FilterExec { // The new predicate is the same as our current predicate None } else { - // Create a new FilterExec with the new predicate + // Create a new FilterExec with the new predicate, preserving the projection let new = FilterExec { predicate: Arc::clone(&new_predicate), input: Arc::clone(&filter_input), @@ -588,9 +689,9 @@ impl ExecutionPlan for FilterExec { &filter_input, &new_predicate, self.default_selectivity, - self.projection.as_ref(), + self.projection.as_deref(), )?, - projection: None, + projection: self.projection.clone(), batch_size: self.batch_size, fetch: self.fetch, }; @@ -615,11 +716,26 @@ impl ExecutionPlan for FilterExec { fetch, })) } + + fn with_preserve_order( + &self, + preserve_order: bool, + ) -> Option> { + self.input + .with_preserve_order(preserve_order) + .and_then(|new_input| { + Arc::new(self.clone()) + .with_new_children(vec![new_input]) + .ok() + }) + } } impl EmbeddedProjection for FilterExec { fn with_projection(&self, projection: Option>) -> Result { - self.with_projection(projection) + FilterExecBuilder::from(self) + .apply_projection(projection)? + .build() } } @@ -685,7 +801,7 @@ struct FilterExecStream { /// Runtime metrics recording metrics: FilterExecMetrics, /// The projection indices of the columns in the input schema - projection: Option>, + projection: Option, /// Batch coalescer to combine small batches batch_coalescer: LimitedBatchCoalescer, } @@ -711,23 +827,6 @@ impl FilterExecMetrics { } } -impl FilterExecStream { - fn flush_remaining_batches( - &mut self, - ) -> Poll>> { - // Flush any remaining buffered batch - match self.batch_coalescer.finish() { - Ok(()) => { - Poll::Ready(self.batch_coalescer.next_completed_batch().map(|batch| { - self.metrics.selectivity.add_part(batch.num_rows()); - Ok(batch) - })) - } - Err(e) => Poll::Ready(Some(Err(e))), - } - } -} - pub fn batch_filter( batch: &RecordBatch, predicate: &Arc, @@ -767,18 +866,34 @@ impl Stream for FilterExecStream { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - let poll; let elapsed_compute = self.metrics.baseline_metrics.elapsed_compute().clone(); loop { + // If there is a completed batch ready, return it + if let Some(batch) = self.batch_coalescer.next_completed_batch() { + self.metrics.selectivity.add_part(batch.num_rows()); + let poll = Poll::Ready(Some(Ok(batch))); + return self.metrics.baseline_metrics.record_poll(poll); + } + + if self.batch_coalescer.is_finished() { + // If input is done and no batches are ready, return None to signal end of stream. + return Poll::Ready(None); + } + + // Attempt to pull the next batch from the input stream. match ready!(self.input.poll_next_unpin(cx)) { + None => { + self.batch_coalescer.finish()?; + // continue draining the coalescer + } Some(Ok(batch)) => { let timer = elapsed_compute.timer(); let status = self.predicate.as_ref() .evaluate(&batch) .and_then(|v| v.into_array(batch.num_rows())) .and_then(|array| { - Ok(match self.projection { - Some(ref projection) => { + Ok(match self.projection.as_ref() { + Some(projection) => { let projected_batch = batch.project(projection)?; (array, projected_batch) }, @@ -802,37 +917,22 @@ impl Stream for FilterExecStream { })?; timer.done(); - if let LimitReached = status { - poll = self.flush_remaining_batches(); - break; - } - - if let Some(batch) = self.batch_coalescer.next_completed_batch() { - self.metrics.selectivity.add_part(batch.num_rows()); - poll = Poll::Ready(Some(Ok(batch))); - break; - } - continue; - } - None => { - // Flush any remaining buffered batch - match self.batch_coalescer.finish() { - Ok(()) => { - poll = self.flush_remaining_batches(); + match status { + PushBatchStatus::Continue => { + // Keep pushing more batches } - Err(e) => { - poll = Poll::Ready(Some(Err(e))); + PushBatchStatus::LimitReached => { + // limit was reached, so stop early + self.batch_coalescer.finish()?; + // continue draining the coalescer } } - break; - } - value => { - poll = Poll::Ready(value); - break; } + + // Error case + other => return Poll::Ready(other), } } - self.metrics.baseline_metrics.record_poll(poll) } fn size_hint(&self) -> (usize, Option) { @@ -1557,13 +1657,14 @@ mod tests { #[test] fn test_equivalence_properties_union_type() -> Result<()> { let union_type = DataType::Union( - UnionFields::new( + UnionFields::try_new( vec![0, 1], vec![ Field::new("f1", DataType::Int32, true), Field::new("f2", DataType::Utf8, true), ], - ), + ) + .unwrap(), UnionMode::Sparse, ); @@ -1586,4 +1687,370 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_builder_with_projection() -> Result<()> { + // Create a schema with multiple columns + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ])); + + let input = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + // Create a filter predicate: a > 10 + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )); + + // Create filter with projection [0, 2] (columns a and c) using builder + let projection = Some(vec![0, 2]); + let filter = FilterExecBuilder::new(predicate, input) + .apply_projection(projection.clone()) + .unwrap() + .build()?; + + // Verify projection is set correctly + assert_eq!(filter.projection(), &Some([0, 2].into())); + + // Verify schema contains only projected columns + let output_schema = filter.schema(); + assert_eq!(output_schema.fields().len(), 2); + assert_eq!(output_schema.field(0).name(), "a"); + assert_eq!(output_schema.field(1).name(), "c"); + + Ok(()) + } + + #[tokio::test] + async fn test_builder_without_projection() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + let input = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )); + + // Create filter without projection using builder + let filter = FilterExecBuilder::new(predicate, input).build()?; + + // Verify no projection is set + assert!(filter.projection().is_none()); + + // Verify schema contains all columns + let output_schema = filter.schema(); + assert_eq!(output_schema.fields().len(), 2); + + Ok(()) + } + + #[tokio::test] + async fn test_builder_invalid_projection() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + let input = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )); + + // Try to create filter with invalid projection (index out of bounds) using builder + let result = + FilterExecBuilder::new(predicate, input).apply_projection(Some(vec![0, 5])); // 5 is out of bounds + + // Should return an error + assert!(result.is_err()); + + Ok(()) + } + + #[tokio::test] + async fn test_builder_vs_with_projection() -> Result<()> { + // This test verifies that the builder with projection produces the same result + // as try_new().with_projection(), but more efficiently (one compute_properties call) + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + Field::new("d", DataType::Int32, false), + ]); + + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(1000), + total_byte_size: Precision::Inexact(4000), + column_statistics: vec![ + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), + ..Default::default() + }, + ColumnStatistics { + ..Default::default() + }, + ColumnStatistics { + ..Default::default() + }, + ColumnStatistics { + ..Default::default() + }, + ], + }, + schema, + )); + let input: Arc = input; + + let predicate: Arc = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Lt, + Arc::new(Literal::new(ScalarValue::Int32(Some(50)))), + )); + + let projection = Some(vec![0, 2]); + + // Method 1: Builder with projection (one call to compute_properties) + let filter1 = FilterExecBuilder::new(Arc::clone(&predicate), Arc::clone(&input)) + .apply_projection(projection.clone()) + .unwrap() + .build()?; + + // Method 2: Also using builder for comparison (deprecated try_new().with_projection() removed) + let filter2 = FilterExecBuilder::new(predicate, input) + .apply_projection(projection) + .unwrap() + .build()?; + + // Both methods should produce equivalent results + assert_eq!(filter1.schema(), filter2.schema()); + assert_eq!(filter1.projection(), filter2.projection()); + + // Verify statistics are the same + let stats1 = filter1.partition_statistics(None)?; + let stats2 = filter2.partition_statistics(None)?; + assert_eq!(stats1.num_rows, stats2.num_rows); + assert_eq!(stats1.total_byte_size, stats2.total_byte_size); + + Ok(()) + } + + #[tokio::test] + async fn test_builder_statistics_with_projection() -> Result<()> { + // Test that statistics are correctly computed when using builder with projection + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ]); + + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(1000), + total_byte_size: Precision::Inexact(12000), + column_statistics: vec![ + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), + ..Default::default() + }, + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(10))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(200))), + ..Default::default() + }, + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(5))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(50))), + ..Default::default() + }, + ], + }, + schema, + )); + + // Filter: a < 50, Project: [0, 2] + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Lt, + Arc::new(Literal::new(ScalarValue::Int32(Some(50)))), + )); + + let filter = FilterExecBuilder::new(predicate, input) + .apply_projection(Some(vec![0, 2])) + .unwrap() + .build()?; + + let statistics = filter.partition_statistics(None)?; + + // Verify statistics reflect both filtering and projection + assert!(matches!(statistics.num_rows, Precision::Inexact(_))); + + // Schema should only have 2 columns after projection + assert_eq!(filter.schema().fields().len(), 2); + + Ok(()) + } + + #[test] + fn test_builder_predicate_validation() -> Result<()> { + // Test that builder validates predicate type correctly + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + let input = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + // Create a predicate that doesn't return boolean (returns Int32) + let invalid_predicate = Arc::new(Column::new("a", 0)); + + // Should fail because predicate doesn't return boolean + let result = FilterExecBuilder::new(invalid_predicate, input) + .apply_projection(Some(vec![0])) + .unwrap() + .build(); + + assert!(result.is_err()); + + Ok(()) + } + + #[tokio::test] + async fn test_builder_projection_composition() -> Result<()> { + // Test that calling apply_projection multiple times composes projections + // If initial projection is [0, 2, 3] and we call apply_projection([0, 2]), + // the result should be [0, 3] (indices 0 and 2 of [0, 2, 3]) + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + Field::new("d", DataType::Int32, false), + ])); + + let input = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + // Create a filter predicate: a > 10 + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )); + + // First projection: [0, 2, 3] -> select columns a, c, d + // Second projection: [0, 2] -> select indices 0 and 2 of [0, 2, 3] -> [0, 3] + // Final result: columns a and d + let filter = FilterExecBuilder::new(predicate, input) + .apply_projection(Some(vec![0, 2, 3]))? + .apply_projection(Some(vec![0, 2]))? + .build()?; + + // Verify composed projection is [0, 3] + assert_eq!(filter.projection(), &Some([0, 3].into())); + + // Verify schema contains only columns a and d + let output_schema = filter.schema(); + assert_eq!(output_schema.fields().len(), 2); + assert_eq!(output_schema.field(0).name(), "a"); + assert_eq!(output_schema.field(1).name(), "d"); + + Ok(()) + } + + #[tokio::test] + async fn test_builder_projection_composition_none_clears() -> Result<()> { + // Test that passing None clears the projection + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + let input = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )); + + // Set a projection then clear it with None + let filter = FilterExecBuilder::new(predicate, input) + .apply_projection(Some(vec![0]))? + .apply_projection(None)? + .build()?; + + // Projection should be cleared + assert_eq!(filter.projection(), &None); + + // Schema should have all columns + let output_schema = filter.schema(); + assert_eq!(output_schema.fields().len(), 2); + + Ok(()) + } + + #[test] + fn test_filter_with_projection_remaps_post_phase_parent_filters() -> Result<()> { + // Test that FilterExec with a projection must remap parent dynamic + // filter column indices from its output schema to the input schema + // before passing them to the child. + let input_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ])); + let input = Arc::new(EmptyExec::new(Arc::clone(&input_schema))); + + // FilterExec: a > 0, projection=[c@2] + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(0)))), + )); + let filter = FilterExecBuilder::new(predicate, input) + .apply_projection(Some(vec![2]))? + .build()?; + + // Output schema should be [c:Float64] + let output_schema = filter.schema(); + assert_eq!(output_schema.fields().len(), 1); + assert_eq!(output_schema.field(0).name(), "c"); + + // Simulate a parent dynamic filter referencing output column c@0 + let parent_filter: Arc = Arc::new(Column::new("c", 0)); + + let config = ConfigOptions::new(); + let desc = filter.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![parent_filter], + &config, + )?; + + // The filter pushed to the child must reference c@2 (input schema), + // not c@0 (output schema). + let parent_filters = desc.parent_filters(); + assert_eq!(parent_filters.len(), 1); // one child + assert_eq!(parent_filters[0].len(), 1); // one filter + let remapped = &parent_filters[0][0].predicate; + let display = format!("{remapped}"); + assert_eq!( + display, "c@2", + "Post-phase parent filter column index must be remapped \ + from output schema (c@0) to input schema (c@2)" + ); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/filter_pushdown.rs b/datafusion/physical-plan/src/filter_pushdown.rs index 1274e954eaeb3..7e82b9e8239e0 100644 --- a/datafusion/physical-plan/src/filter_pushdown.rs +++ b/datafusion/physical-plan/src/filter_pushdown.rs @@ -37,10 +37,13 @@ use std::collections::HashSet; use std::sync::Arc; -use datafusion_common::Result; -use datafusion_physical_expr::utils::{collect_columns, reassign_expr_columns}; +use arrow_schema::SchemaRef; +use datafusion_common::{ + Result, + tree_node::{Transformed, TreeNode}, +}; +use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use itertools::Itertools; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum FilterPushdownPhase { @@ -217,13 +220,13 @@ pub struct ChildPushdownResult { /// Returned from [`ExecutionPlan::handle_child_pushdown_result`] to communicate /// to the optimizer: /// -/// 1. What to do with any parent filters that were could not be pushed down into the children. +/// 1. What to do with any parent filters that could not be pushed down into the children. /// 2. If the node needs to be replaced in the execution plan with a new node or not. /// /// [`ExecutionPlan::handle_child_pushdown_result`]: crate::ExecutionPlan::handle_child_pushdown_result #[derive(Debug, Clone)] pub struct FilterPushdownPropagation { - /// What filters were pushed into the parent node. + /// Which parent filters were pushed down into this node's children. pub filters: Vec, /// The updated node, if it was updated during pushdown pub updated_node: Option, @@ -306,6 +309,83 @@ pub struct ChildFilterDescription { pub(crate) self_filters: Vec>, } +/// Validates and remaps filter column references to a target schema in one step. +/// +/// When pushing filters from a parent to a child node, we need to: +/// 1. Verify that all columns referenced by the filter exist in the target +/// 2. Remap column indices to match the target schema +/// +/// `allowed_indices` controls which column indices (in the parent schema) are +/// considered valid. For single-input nodes this defaults to +/// `0..child_schema.len()` (all columns are reachable). For join nodes it is +/// restricted to the subset of output columns that map to the target child, +/// which is critical when different sides have same-named columns. +pub(crate) struct FilterRemapper { + /// The target schema to remap column indices into. + child_schema: SchemaRef, + /// Only columns at these indices (in the *parent* schema) are considered + /// valid. For non-join nodes this defaults to `0..child_schema.len()`. + allowed_indices: HashSet, +} + +impl FilterRemapper { + /// Create a remapper that accepts any column whose index falls within + /// `0..child_schema.len()` and whose name exists in the target schema. + pub(crate) fn new(child_schema: SchemaRef) -> Self { + let allowed_indices = (0..child_schema.fields().len()).collect(); + Self { + child_schema, + allowed_indices, + } + } + + /// Create a remapper that only accepts columns at the given indices. + /// This is used by join nodes to restrict pushdown to one side of the + /// join when both sides have same-named columns. + fn with_allowed_indices( + child_schema: SchemaRef, + allowed_indices: HashSet, + ) -> Self { + Self { + child_schema, + allowed_indices, + } + } + + /// Try to remap a filter's column references to the target schema. + /// + /// Validates and remaps in a single tree traversal: for each column, + /// checks that its index is in the allowed set and that + /// its name exists in the target schema, then remaps the index. + /// Returns `Some(remapped)` if all columns are valid, or `None` if any + /// column fails validation. + pub(crate) fn try_remap( + &self, + filter: &Arc, + ) -> Result>> { + let mut all_valid = true; + let transformed = Arc::clone(filter).transform_down(|expr| { + if let Some(col) = expr.as_any().downcast_ref::() { + if self.allowed_indices.contains(&col.index()) + && let Ok(new_index) = self.child_schema.index_of(col.name()) + { + Ok(Transformed::yes(Arc::new(Column::new( + col.name(), + new_index, + )))) + } else { + all_valid = false; + Ok(Transformed::complete(expr)) + } + } else { + Ok(Transformed::no(expr)) + } + })?; + + Ok(all_valid.then_some(transformed.data)) + } +} + impl ChildFilterDescription { /// Build a child filter description by analyzing which parent filters can be pushed to a specific child. /// @@ -318,36 +398,41 @@ impl ChildFilterDescription { parent_filters: &[Arc], child: &Arc, ) -> Result { - let child_schema = child.schema(); + let remapper = FilterRemapper::new(child.schema()); + Self::remap_filters(parent_filters, &remapper) + } - // Get column names from child schema for quick lookup - let child_column_names: HashSet<&str> = child_schema - .fields() - .iter() - .map(|f| f.name().as_str()) - .collect(); + /// Like [`Self::from_child`], but restricts which parent-level columns are + /// considered reachable through this child. + /// + /// `allowed_indices` is the set of column indices (in the *parent* + /// schema) that map to this child's side of a join. A filter is only + /// eligible for pushdown when **every** column index it references + /// appears in `allowed_indices`. + /// + /// This prevents incorrect pushdown when different join sides have + /// columns with the same name: matching on index ensures a filter + /// referencing the right side's `k@2` is not pushed to the left side + /// which also has a column named `k` but at a different index. + pub fn from_child_with_allowed_indices( + parent_filters: &[Arc], + allowed_indices: HashSet, + child: &Arc, + ) -> Result { + let remapper = + FilterRemapper::with_allowed_indices(child.schema(), allowed_indices); + Self::remap_filters(parent_filters, &remapper) + } - // Analyze each parent filter + fn remap_filters( + parent_filters: &[Arc], + remapper: &FilterRemapper, + ) -> Result { let mut child_parent_filters = Vec::with_capacity(parent_filters.len()); - for filter in parent_filters { - // Check which columns the filter references - let referenced_columns = collect_columns(filter); - - // Check if all referenced columns exist in the child schema - let all_columns_exist = referenced_columns - .iter() - .all(|col| child_column_names.contains(col.name())); - - if all_columns_exist { - // All columns exist in child - we can push down - // Need to reassign column indices to match child schema - let reassigned_filter = - reassign_expr_columns(Arc::clone(filter), &child_schema)?; - child_parent_filters - .push(PushedDownPredicate::supported(reassigned_filter)); + if let Some(remapped) = remapper.try_remap(filter)? { + child_parent_filters.push(PushedDownPredicate::supported(remapped)); } else { - // Some columns don't exist in child - cannot push down child_parent_filters .push(PushedDownPredicate::unsupported(Arc::clone(filter))); } @@ -359,6 +444,17 @@ impl ChildFilterDescription { }) } + /// Mark all parent filters as unsupported for this child. + pub fn all_unsupported(parent_filters: &[Arc]) -> Self { + Self { + parent_filters: parent_filters + .iter() + .map(|f| PushedDownPredicate::unsupported(Arc::clone(f))) + .collect(), + self_filters: vec![], + } + } + /// Add a self filter (from the current node) to be pushed down to this child. pub fn with_self_filter(mut self, filter: Arc) -> Self { self.self_filters.push(filter); @@ -434,15 +530,9 @@ impl FilterDescription { children: &[&Arc], ) -> Self { let mut desc = Self::new(); - let child_filters = parent_filters - .iter() - .map(|f| PushedDownPredicate::unsupported(Arc::clone(f))) - .collect_vec(); for _ in 0..children.len() { - desc = desc.with_child(ChildFilterDescription { - parent_filters: child_filters.clone(), - self_filters: vec![], - }); + desc = + desc.with_child(ChildFilterDescription::all_unsupported(parent_filters)); } desc } diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index 4f32b6176ec39..d5b540885efae 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -206,7 +206,7 @@ async fn load_left_input( let (batches, _metrics, reservation) = stream .try_fold( (Vec::new(), metrics, reservation), - |(mut batches, metrics, mut reservation), batch| async { + |(mut batches, metrics, reservation), batch| async { let batch_size = batch.get_array_memory_size(); // Reserve memory for incoming batch reservation.try_grow(batch_size)?; @@ -356,10 +356,6 @@ impl ExecutionPlan for CrossJoinExec { } } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { // Get the all partitions statistics of the left let left_stats = self.left.partition_statistics(None)?; diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index beca48a5b7d50..f39208bcb78d0 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -15,16 +15,17 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashSet; use std::fmt; use std::mem::size_of; -use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::{Arc, OnceLock}; use std::{any::Any, vec}; use crate::ExecutionPlanProperties; use crate::execution_plan::{EmissionType, boundedness_from_children}; use crate::filter_pushdown::{ - ChildPushdownResult, FilterDescription, FilterPushdownPhase, + ChildFilterDescription, ChildPushdownResult, FilterDescription, FilterPushdownPhase, FilterPushdownPropagation, }; use crate::joins::Map; @@ -80,7 +81,8 @@ use datafusion_functions_aggregate_common::min_max::{MaxAccumulator, MinAccumula use datafusion_physical_expr::equivalence::{ ProjectionMapping, join_equivalence_properties, }; -use datafusion_physical_expr::expressions::{DynamicFilterPhysicalExpr, lit}; +use datafusion_physical_expr::expressions::{Column, DynamicFilterPhysicalExpr, lit}; +use datafusion_physical_expr::projection::{ProjectionRef, combine_projections}; use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; use ahash::RandomState; @@ -206,6 +208,11 @@ pub(super) struct JoinLeftData { /// Membership testing strategy for filter pushdown /// Contains either InList values for small build sides or hash table reference for large build sides pub(super) membership: PushdownStrategy, + /// Shared atomic flag indicating if any probe partition saw data (for null-aware anti joins) + /// This is shared across all probe partitions to provide global knowledge + pub(super) probe_side_non_empty: AtomicBool, + /// Shared atomic flag indicating if any probe partition saw NULL in join keys (for null-aware anti joins) + pub(super) probe_side_has_null: AtomicBool, } impl JoinLeftData { @@ -241,6 +248,187 @@ impl JoinLeftData { } } +/// Helps to build [`HashJoinExec`]. +pub struct HashJoinExecBuilder { + left: Arc, + right: Arc, + on: Vec<(PhysicalExprRef, PhysicalExprRef)>, + join_type: JoinType, + filter: Option, + projection: Option, + partition_mode: PartitionMode, + null_equality: NullEquality, + null_aware: bool, + /// Maximum number of rows to return + /// + /// If the operator produces `< fetch` rows, it returns all available rows. + /// If it produces `>= fetch` rows, it returns exactly `fetch` rows and stops early. + fetch: Option, +} + +impl HashJoinExecBuilder { + /// Make a new [`HashJoinExecBuilder`]. + pub fn new( + left: Arc, + right: Arc, + on: Vec<(PhysicalExprRef, PhysicalExprRef)>, + join_type: JoinType, + ) -> Self { + Self { + left, + right, + on, + filter: None, + projection: None, + partition_mode: PartitionMode::Auto, + join_type, + null_equality: NullEquality::NullEqualsNothing, + null_aware: false, + fetch: None, + } + } + + /// Set projection from the vector. + pub fn with_projection(self, projection: Option>) -> Self { + self.with_projection_ref(projection.map(Into::into)) + } + + /// Set projection from the shared reference. + pub fn with_projection_ref(mut self, projection: Option) -> Self { + self.projection = projection; + self + } + + /// Set optional filter. + pub fn with_filter(mut self, filter: Option) -> Self { + self.filter = filter; + self + } + + /// Set partition mode. + pub fn with_partition_mode(mut self, mode: PartitionMode) -> Self { + self.partition_mode = mode; + self + } + + /// Set null equality property. + pub fn with_null_equality(mut self, null_equality: NullEquality) -> Self { + self.null_equality = null_equality; + self + } + + /// Set null aware property. + pub fn with_null_aware(mut self, null_aware: bool) -> Self { + self.null_aware = null_aware; + self + } + + /// Set fetch limit. + pub fn with_fetch(mut self, fetch: Option) -> Self { + self.fetch = fetch; + self + } + + /// Build resulting execution plan. + pub fn build(self) -> Result { + let Self { + left, + right, + on, + join_type, + filter, + projection, + partition_mode, + null_equality, + null_aware, + fetch, + } = self; + + let left_schema = left.schema(); + let right_schema = right.schema(); + if on.is_empty() { + return plan_err!("On constraints in HashJoinExec should be non-empty"); + } + + check_join_is_valid(&left_schema, &right_schema, &on)?; + + // Validate null_aware flag + if null_aware { + if !matches!(join_type, JoinType::LeftAnti) { + return plan_err!( + "null_aware can only be true for LeftAnti joins, got {join_type}" + ); + } + if on.len() != 1 { + return plan_err!( + "null_aware anti join only supports single column join key, got {} columns", + on.len() + ); + } + } + + let (join_schema, column_indices) = + build_join_schema(&left_schema, &right_schema, &join_type); + + let random_state = HASH_JOIN_SEED; + + let join_schema = Arc::new(join_schema); + + // check if the projection is valid + can_project(&join_schema, projection.as_deref())?; + + let cache = HashJoinExec::compute_properties( + &left, + &right, + &join_schema, + join_type, + &on, + partition_mode, + projection.as_deref(), + )?; + + // Initialize both dynamic filter and bounds accumulator to None + // They will be set later if dynamic filtering is enabled + + Ok(HashJoinExec { + left, + right, + on, + filter, + join_type, + join_schema, + left_fut: Default::default(), + random_state, + mode: partition_mode, + metrics: ExecutionPlanMetricsSet::new(), + projection, + column_indices, + null_equality, + null_aware, + cache, + dynamic_filter: None, + fetch, + }) + } +} + +impl From<&HashJoinExec> for HashJoinExecBuilder { + fn from(exec: &HashJoinExec) -> Self { + Self { + left: Arc::clone(exec.left()), + right: Arc::clone(exec.right()), + on: exec.on.clone(), + join_type: exec.join_type, + filter: exec.filter.clone(), + projection: exec.projection.clone(), + partition_mode: exec.mode, + null_equality: exec.null_equality, + null_aware: exec.null_aware, + fetch: exec.fetch, + } + } +} + #[expect(rustdoc::private_intra_doc_links)] /// Join execution plan: Evaluates equijoin predicates in parallel on multiple /// partitions using a hash table and an optional filter list to apply post @@ -461,17 +649,21 @@ pub struct HashJoinExec { /// Execution metrics metrics: ExecutionPlanMetricsSet, /// The projection indices of the columns in the output schema of join - pub projection: Option>, + pub projection: Option, /// Information of index and left / right placement of columns column_indices: Vec, /// The equality null-handling behavior of the join algorithm. pub null_equality: NullEquality, + /// Flag to indicate if this is a null-aware anti join + pub null_aware: bool, /// Cache holding plan properties like equivalences, output partitioning etc. cache: PlanProperties, /// Dynamic filter for pushing down to the probe side /// Set when dynamic filter pushdown is detected in handle_child_pushdown_result. /// HashJoinExec also needs to keep a shared bounds accumulator for coordinating updates. dynamic_filter: Option, + /// Maximum number of rows to return + fetch: Option, } #[derive(Clone)] @@ -526,55 +718,15 @@ impl HashJoinExec { projection: Option>, partition_mode: PartitionMode, null_equality: NullEquality, + null_aware: bool, ) -> Result { - let left_schema = left.schema(); - let right_schema = right.schema(); - if on.is_empty() { - return plan_err!("On constraints in HashJoinExec should be non-empty"); - } - - check_join_is_valid(&left_schema, &right_schema, &on)?; - - let (join_schema, column_indices) = - build_join_schema(&left_schema, &right_schema, join_type); - - let random_state = HASH_JOIN_SEED; - - let join_schema = Arc::new(join_schema); - - // check if the projection is valid - can_project(&join_schema, projection.as_ref())?; - - let cache = Self::compute_properties( - &left, - &right, - &join_schema, - *join_type, - &on, - partition_mode, - projection.as_ref(), - )?; - - // Initialize both dynamic filter and bounds accumulator to None - // They will be set later if dynamic filtering is enabled - - Ok(HashJoinExec { - left, - right, - on, - filter, - join_type: *join_type, - join_schema, - left_fut: Default::default(), - random_state, - mode: partition_mode, - metrics: ExecutionPlanMetricsSet::new(), - projection, - column_indices, - null_equality, - cache, - dynamic_filter: None, - }) + HashJoinExecBuilder::new(left, right, on, *join_type) + .with_filter(filter) + .with_projection(projection) + .with_partition_mode(partition_mode) + .with_null_equality(null_equality) + .with_null_aware(null_aware) + .build() } fn create_dynamic_filter(on: &JoinOn) -> Arc { @@ -585,6 +737,28 @@ impl HashJoinExec { Arc::new(DynamicFilterPhysicalExpr::new(right_keys, lit(true))) } + fn allow_join_dynamic_filter_pushdown(&self, config: &ConfigOptions) -> bool { + if self.join_type != JoinType::Inner + || !config.optimizer.enable_join_dynamic_filter_pushdown + { + return false; + } + + // `preserve_file_partitions` can report Hash partitioning for Hive-style + // file groups, but those partitions are not actually hash-distributed. + // Partitioned dynamic filters rely on hash routing, so disable them in + // this mode to avoid incorrect results. Follow-up work: enable dynamic + // filtering for preserve_file_partitioned scans (issue #20195). + // https://github.com/apache/datafusion/issues/20195 + if config.optimizer.preserve_file_partitions > 0 + && self.mode == PartitionMode::Partitioned + { + return false; + } + + true + } + /// left (build) side which gets hashed pub fn left(&self) -> &Arc { &self.left @@ -663,25 +837,14 @@ impl HashJoinExec { /// Return new instance of [HashJoinExec] with the given projection. pub fn with_projection(&self, projection: Option>) -> Result { + let projection = projection.map(Into::into); // check if the projection is valid - can_project(&self.schema(), projection.as_ref())?; - let projection = match projection { - Some(projection) => match &self.projection { - Some(p) => Some(projection.iter().map(|i| p[*i]).collect()), - None => Some(projection), - }, - None => None, - }; - Self::try_new( - Arc::clone(&self.left), - Arc::clone(&self.right), - self.on.clone(), - self.filter.clone(), - &self.join_type, - projection, - self.mode, - self.null_equality, - ) + can_project(&self.schema(), projection.as_deref())?; + let projection = + combine_projections(projection.as_ref(), self.projection.as_ref())?; + HashJoinExecBuilder::from(self) + .with_projection_ref(projection) + .build() } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -692,7 +855,7 @@ impl HashJoinExec { join_type: JoinType, on: JoinOnRef, mode: PartitionMode, - projection: Option<&Vec>, + projection: Option<&[usize]>, ) -> Result { // Calculate equivalence properties: let mut eq_properties = join_equivalence_properties( @@ -744,7 +907,7 @@ impl HashJoinExec { if let Some(projection) = projection { // construct a map from the input expressions to the output expression of the Projection let projection_mapping = ProjectionMapping::from_indices(projection, schema)?; - let out_schema = project_schema(schema, Some(projection))?; + let out_schema = project_schema(schema, Some(&projection))?; output_partitioning = output_partitioning.project(&projection_mapping, &eq_properties); eq_properties = eq_properties.project(&projection_mapping, out_schema); @@ -787,24 +950,27 @@ impl HashJoinExec { ) -> Result> { let left = self.left(); let right = self.right(); - let new_join = HashJoinExec::try_new( + let new_join = HashJoinExecBuilder::new( Arc::clone(right), Arc::clone(left), self.on() .iter() .map(|(l, r)| (Arc::clone(r), Arc::clone(l))) .collect(), - self.filter().map(JoinFilter::swap), - &self.join_type().swap(), - swap_join_projection( - left.schema().fields().len(), - right.schema().fields().len(), - self.projection.as_ref(), - self.join_type(), - ), - partition_mode, - self.null_equality(), - )?; + self.join_type().swap(), + ) + .with_filter(self.filter().map(JoinFilter::swap)) + .with_projection(swap_join_projection( + left.schema().fields().len(), + right.schema().fields().len(), + self.projection.as_deref(), + self.join_type(), + )) + .with_partition_mode(partition_mode) + .with_null_equality(self.null_equality()) + .with_null_aware(self.null_aware) + .with_fetch(self.fetch) + .build()?; // In case of anti / semi joins or if there is embedded projection in HashJoinExec, output column order is preserved, no need to add projection again if matches!( self.join_type(), @@ -855,6 +1021,9 @@ impl DisplayAs for HashJoinExec { } else { "" }; + let display_fetch = self + .fetch + .map_or_else(String::new, |f| format!(", fetch={f}")); let on = self .on .iter() @@ -863,13 +1032,14 @@ impl DisplayAs for HashJoinExec { .join(", "); write!( f, - "HashJoinExec: mode={:?}, join_type={:?}, on=[{}]{}{}{}", + "HashJoinExec: mode={:?}, join_type={:?}, on=[{}]{}{}{}{}", self.mode, self.join_type, on, display_filter, display_projections, display_null_equality, + display_fetch, ) } DisplayFormatType::TreeRender => { @@ -896,6 +1066,10 @@ impl DisplayAs for HashJoinExec { writeln!(f, "filter={filter}")?; } + if let Some(fetch) = self.fetch { + writeln!(f, "fetch={fetch}")?; + } + Ok(()) } } @@ -986,6 +1160,7 @@ impl ExecutionPlan for HashJoinExec { projection: self.projection.clone(), column_indices: self.column_indices.clone(), null_equality: self.null_equality, + null_aware: self.null_aware, cache: Self::compute_properties( &children[0], &children[1], @@ -993,10 +1168,11 @@ impl ExecutionPlan for HashJoinExec { self.join_type, &self.on, self.mode, - self.projection.as_ref(), + self.projection.as_deref(), )?, // Keep the dynamic filter, bounds accumulator will be reset dynamic_filter: self.dynamic_filter.clone(), + fetch: self.fetch, })) } @@ -1016,9 +1192,11 @@ impl ExecutionPlan for HashJoinExec { projection: self.projection.clone(), column_indices: self.column_indices.clone(), null_equality: self.null_equality, + null_aware: self.null_aware, cache: self.cache.clone(), // Reset dynamic filter and bounds accumulator to initial state dynamic_filter: None, + fetch: self.fetch, })) } @@ -1053,11 +1231,8 @@ impl ExecutionPlan for HashJoinExec { // - A dynamic filter exists // - At least one consumer is holding a reference to it, this avoids expensive filter // computation when disabled or when no consumer will use it. - let enable_dynamic_filter_pushdown = context - .session_config() - .options() - .optimizer - .enable_join_dynamic_filter_pushdown + let enable_dynamic_filter_pushdown = self + .allow_join_dynamic_filter_pushdown(context.session_config().options()) && self .dynamic_filter .as_ref() @@ -1153,7 +1328,7 @@ impl ExecutionPlan for HashJoinExec { let right_stream = self.right.execute(partition, context)?; // update column indices to reflect the projection - let column_indices_after_projection = match &self.projection { + let column_indices_after_projection = match self.projection.as_ref() { Some(projection) => projection .iter() .map(|i| self.column_indices[*i].clone()) @@ -1185,6 +1360,8 @@ impl ExecutionPlan for HashJoinExec { self.right.output_ordering().is_some(), build_accumulator, self.mode, + self.null_aware, + self.fetch, ))) } @@ -1192,10 +1369,6 @@ impl ExecutionPlan for HashJoinExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { if partition.is_some() { return Ok(Statistics::new_unknown(&self.schema())); @@ -1211,7 +1384,9 @@ impl ExecutionPlan for HashJoinExec { &self.join_schema, )?; // Project statistics if there is a projection - Ok(stats.project(self.projection.as_ref())) + let stats = stats.project(self.projection.as_ref()); + // Apply fetch limit to statistics + stats.with_fetch(self.fetch, 0, 1) } /// Tries to push `projection` down through `hash_join`. If possible, performs the @@ -1240,17 +1415,22 @@ impl ExecutionPlan for HashJoinExec { &schema, self.filter(), )? { - Ok(Some(Arc::new(HashJoinExec::try_new( - Arc::new(projected_left_child), - Arc::new(projected_right_child), - join_on, - join_filter, - self.join_type(), + Ok(Some(Arc::new( + HashJoinExecBuilder::new( + Arc::new(projected_left_child), + Arc::new(projected_right_child), + join_on, + *self.join_type(), + ) + .with_filter(join_filter) // Returned early if projection is not None - None, - *self.partition_mode(), - self.null_equality, - )?))) + .with_projection(None) + .with_partition_mode(*self.partition_mode()) + .with_null_equality(self.null_equality) + .with_null_aware(self.null_aware) + .with_fetch(self.fetch) + .build()?, + ))) } else { try_embed_projection(projection, self) } @@ -1262,30 +1442,111 @@ impl ExecutionPlan for HashJoinExec { parent_filters: Vec>, config: &ConfigOptions, ) -> Result { - // Other types of joins can support *some* filters, but restrictions are complex and error prone. - // For now we don't support them. - // See the logical optimizer rules for more details: datafusion/optimizer/src/push_down_filter.rs - // See https://github.com/apache/datafusion/issues/16973 for tracking. - if self.join_type != JoinType::Inner { - return Ok(FilterDescription::all_unsupported( - &parent_filters, - &self.children(), - )); + // This is the physical-plan equivalent of `push_down_all_join` in + // `datafusion/optimizer/src/push_down_filter.rs`. That function uses `lr_is_preserved` + // to decide which parent predicates can be pushed past a logical join to its children, + // then checks column references to route each predicate to the correct side. + // + // We apply the same two-level logic here: + // 1. `lr_is_preserved` gates whether a side is eligible at all. + // 2. For each filter, we check that all column references belong to the + // target child (using `column_indices` to map output column positions + // to join sides). This is critical for correctness: name-based matching + // alone (as done by `ChildFilterDescription::from_child`) can incorrectly + // push filters when different join sides have columns with the same name + // (e.g. nested mark joins both producing "mark" columns). + let (left_preserved, right_preserved) = lr_is_preserved(self.join_type); + + // Build the set of allowed column indices for each side + let column_indices: Vec = match self.projection.as_ref() { + Some(projection) => projection + .iter() + .map(|i| self.column_indices[*i].clone()) + .collect(), + None => self.column_indices.clone(), + }; + + let (mut left_allowed, mut right_allowed) = (HashSet::new(), HashSet::new()); + column_indices + .iter() + .enumerate() + .for_each(|(output_idx, ci)| { + match ci.side { + JoinSide::Left => left_allowed.insert(output_idx), + JoinSide::Right => right_allowed.insert(output_idx), + // Mark columns - don't allow pushdown to either side + JoinSide::None => false, + }; + }); + + // For semi/anti joins, the non-preserved side's columns are not in the + // output, but filters on join key columns can still be pushed there. + // We find output columns that are join keys on the preserved side and + // add their output indices to the non-preserved side's allowed set. + // The name-based remap in FilterRemapper will then match them to the + // corresponding column in the non-preserved child's schema. + match self.join_type { + JoinType::LeftSemi | JoinType::LeftAnti => { + let left_key_indices: HashSet = self + .on + .iter() + .filter_map(|(left_key, _)| { + left_key + .as_any() + .downcast_ref::() + .map(|c| c.index()) + }) + .collect(); + for (output_idx, ci) in column_indices.iter().enumerate() { + if ci.side == JoinSide::Left && left_key_indices.contains(&ci.index) { + right_allowed.insert(output_idx); + } + } + } + JoinType::RightSemi | JoinType::RightAnti => { + let right_key_indices: HashSet = self + .on + .iter() + .filter_map(|(_, right_key)| { + right_key + .as_any() + .downcast_ref::() + .map(|c| c.index()) + }) + .collect(); + for (output_idx, ci) in column_indices.iter().enumerate() { + if ci.side == JoinSide::Right && right_key_indices.contains(&ci.index) + { + left_allowed.insert(output_idx); + } + } + } + _ => {} } - // Get basic filter descriptions for both children - let left_child = crate::filter_pushdown::ChildFilterDescription::from_child( - &parent_filters, - self.left(), - )?; - let mut right_child = crate::filter_pushdown::ChildFilterDescription::from_child( - &parent_filters, - self.right(), - )?; + let left_child = if left_preserved { + ChildFilterDescription::from_child_with_allowed_indices( + &parent_filters, + left_allowed, + self.left(), + )? + } else { + ChildFilterDescription::all_unsupported(&parent_filters) + }; + + let mut right_child = if right_preserved { + ChildFilterDescription::from_child_with_allowed_indices( + &parent_filters, + right_allowed, + self.right(), + )? + } else { + ChildFilterDescription::all_unsupported(&parent_filters) + }; // Add dynamic filters in Post phase if enabled if matches!(phase, FilterPushdownPhase::Post) - && config.optimizer.enable_join_dynamic_filter_pushdown + && self.allow_join_dynamic_filter_pushdown(config) { // Add actual dynamic filter to right side (probe side) let dynamic_filter = Self::create_dynamic_filter(&self.on); @@ -1303,19 +1564,6 @@ impl ExecutionPlan for HashJoinExec { child_pushdown_result: ChildPushdownResult, _config: &ConfigOptions, ) -> Result>> { - // Note: this check shouldn't be necessary because we already marked all parent filters as unsupported for - // non-inner joins in `gather_filters_for_pushdown`. - // However it's a cheap check and serves to inform future devs touching this function that they need to be really - // careful pushing down filters through non-inner joins. - if self.join_type != JoinType::Inner { - // Other types of joins can support *some* filters, but restrictions are complex and error prone. - // For now we don't support them. - // See the logical optimizer rules for more details: datafusion/optimizer/src/push_down_filter.rs - return Ok(FilterPushdownPropagation::all_unsupported( - child_pushdown_result, - )); - } - let mut result = FilterPushdownPropagation::if_any(child_pushdown_result.clone()); assert_eq!(child_pushdown_result.self_filters.len(), 2); // Should always be 2, we have 2 children let right_child_self_filters = &child_pushdown_result.self_filters[1]; // We only push down filters to the right child @@ -1342,17 +1590,59 @@ impl ExecutionPlan for HashJoinExec { projection: self.projection.clone(), column_indices: self.column_indices.clone(), null_equality: self.null_equality, + null_aware: self.null_aware, cache: self.cache.clone(), dynamic_filter: Some(HashJoinExecDynamicFilter { filter: dynamic_filter, build_accumulator: OnceLock::new(), }), + fetch: self.fetch, }); result = result.with_updated_node(new_node as Arc); } } Ok(result) } + + fn supports_limit_pushdown(&self) -> bool { + // Hash join execution plan does not support pushing limit down through to children + // because the children don't know about the join condition and can't + // determine how many rows to produce + false + } + + fn fetch(&self) -> Option { + self.fetch + } + + fn with_fetch(&self, limit: Option) -> Option> { + HashJoinExecBuilder::from(self) + .with_fetch(limit) + .build() + .ok() + .map(|exec| Arc::new(exec) as _) + } +} + +/// Determines which sides of a join are "preserved" for filter pushdown. +/// +/// A preserved side means filters on that side's columns can be safely pushed +/// below the join. This mirrors the logic in the logical optimizer's +/// `lr_is_preserved` in `datafusion/optimizer/src/push_down_filter.rs`. +fn lr_is_preserved(join_type: JoinType) -> (bool, bool) { + match join_type { + JoinType::Inner => (true, true), + JoinType::Left => (true, false), + JoinType::Right => (false, true), + JoinType::Full => (false, false), + // Filters in semi/anti joins are either on the preserved side, or on join keys, + // as all output columns come from the preserved side. Join key filters can be + // safely pushed down into the other side. + JoinType::LeftSemi | JoinType::LeftAnti => (true, true), + JoinType::RightSemi | JoinType::RightAnti => (true, true), + JoinType::LeftMark => (true, false), + JoinType::RightMark => (false, true), + } } /// Accumulator for collecting min/max bounds from build-side data during hash join. @@ -1706,6 +1996,8 @@ async fn collect_left_input( _reservation: reservation, bounds, membership, + probe_side_non_empty: AtomicBool::new(false), + probe_side_has_null: AtomicBool::new(false), }; Ok(data) @@ -1829,6 +2121,26 @@ mod tests { TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() } + /// Build a table with two columns supporting nullable values + fn build_table_two_cols( + a: (&str, &Vec>), + b: (&str, &Vec>), + ) -> Arc { + let schema = Arc::new(Schema::new(vec![ + Field::new(a.0, DataType::Int32, true), + Field::new(b.0, DataType::Int32, true), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(a.1.clone())), + Arc::new(Int32Array::from(b.1.clone())), + ], + ) + .unwrap(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() + } + fn join( left: Arc, right: Arc, @@ -1845,6 +2157,7 @@ mod tests { None, PartitionMode::CollectLeft, null_equality, + false, ) } @@ -1865,6 +2178,7 @@ mod tests { None, PartitionMode::CollectLeft, null_equality, + false, ) } @@ -1963,6 +2277,7 @@ mod tests { None, partition_mode, null_equality, + false, )?; let columns = columns(&join.schema()); @@ -4846,6 +5161,7 @@ mod tests { None, PartitionMode::Partitioned, NullEquality::NullEqualsNothing, + false, )?; let stream = join.execute(1, task_ctx)?; @@ -5021,11 +5337,6 @@ mod tests { let dynamic_filter = HashJoinExec::create_dynamic_filter(&on); let dynamic_filter_clone = Arc::clone(&dynamic_filter); - // Simulate a consumer by creating a transformed copy (what happens during filter pushdown) - let _consumer = Arc::clone(&dynamic_filter) - .with_new_children(vec![]) - .unwrap(); - // Create HashJoinExec with the dynamic filter let mut join = HashJoinExec::try_new( left, @@ -5036,6 +5347,7 @@ mod tests { None, PartitionMode::CollectLeft, NullEquality::NullEqualsNothing, + false, )?; join.dynamic_filter = Some(HashJoinExecDynamicFilter { filter: dynamic_filter, @@ -5074,11 +5386,6 @@ mod tests { let dynamic_filter = HashJoinExec::create_dynamic_filter(&on); let dynamic_filter_clone = Arc::clone(&dynamic_filter); - // Simulate a consumer by creating a transformed copy (what happens during filter pushdown) - let _consumer = Arc::clone(&dynamic_filter) - .with_new_children(vec![]) - .unwrap(); - // Create HashJoinExec with the dynamic filter let mut join = HashJoinExec::try_new( left, @@ -5089,6 +5396,7 @@ mod tests { None, PartitionMode::CollectLeft, NullEquality::NullEqualsNothing, + false, )?; join.dynamic_filter = Some(HashJoinExecDynamicFilter { filter: dynamic_filter, @@ -5320,4 +5628,249 @@ mod tests { Ok(()) } + + /// Test null-aware anti join when probe side (right) contains NULL + /// Expected: no rows should be output (NULL in subquery means all results are unknown) + #[apply(hash_join_exec_configs)] + #[tokio::test] + async fn test_null_aware_anti_join_probe_null(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, false); + + // Build left table (rows to potentially output) + let left = build_table_two_cols( + ("c1", &vec![Some(1), Some(2), Some(3), Some(4)]), + ("dummy", &vec![Some(10), Some(20), Some(30), Some(40)]), + ); + + // Build right table (subquery with NULL) + let right = build_table_two_cols( + ("c2", &vec![Some(1), Some(2), Some(3), None]), + ("dummy", &vec![Some(100), Some(200), Some(300), Some(400)]), + ); + + let on = vec![( + Arc::new(Column::new_with_schema("c1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("c2", &right.schema())?) as _, + )]; + + // Create null-aware anti join + let join = HashJoinExec::try_new( + left, + right, + on, + None, + &JoinType::LeftAnti, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + true, // null_aware = true + )?; + + let stream = join.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + + // Expected: empty result (probe side has NULL, so no rows should be output) + allow_duplicates! { + assert_snapshot!(batches_to_sort_string(&batches), @r" + ++ + ++ + "); + } + Ok(()) + } + + /// Test null-aware anti join when build side (left) contains NULL keys + /// Expected: rows with NULL keys should not be output + #[apply(hash_join_exec_configs)] + #[tokio::test] + async fn test_null_aware_anti_join_build_null(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, false); + + // Build left table with NULL key (this row should not be output) + let left = build_table_two_cols( + ("c1", &vec![Some(1), Some(4), None]), + ("dummy", &vec![Some(10), Some(40), Some(0)]), + ); + + // Build right table (no NULL, so probe-side check passes) + let right = build_table_two_cols( + ("c2", &vec![Some(1), Some(2), Some(3)]), + ("dummy", &vec![Some(100), Some(200), Some(300)]), + ); + + let on = vec![( + Arc::new(Column::new_with_schema("c1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("c2", &right.schema())?) as _, + )]; + + // Create null-aware anti join + let join = HashJoinExec::try_new( + left, + right, + on, + None, + &JoinType::LeftAnti, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + true, // null_aware = true + )?; + + let stream = join.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + + // Expected: only c1=4 (not c1=1 which matches, not c1=NULL) + allow_duplicates! { + assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+-------+ + | c1 | dummy | + +----+-------+ + | 4 | 40 | + +----+-------+ + "); + } + Ok(()) + } + + /// Test null-aware anti join with no NULLs (should work like regular anti join) + #[apply(hash_join_exec_configs)] + #[tokio::test] + async fn test_null_aware_anti_join_no_nulls(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, false); + + // Build left table (no NULLs) + let left = build_table_two_cols( + ("c1", &vec![Some(1), Some(2), Some(4), Some(5)]), + ("dummy", &vec![Some(10), Some(20), Some(40), Some(50)]), + ); + + // Build right table (no NULLs) + let right = build_table_two_cols( + ("c2", &vec![Some(1), Some(2), Some(3)]), + ("dummy", &vec![Some(100), Some(200), Some(300)]), + ); + + let on = vec![( + Arc::new(Column::new_with_schema("c1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("c2", &right.schema())?) as _, + )]; + + // Create null-aware anti join + let join = HashJoinExec::try_new( + left, + right, + on, + None, + &JoinType::LeftAnti, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + true, // null_aware = true + )?; + + let stream = join.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + + // Expected: c1=4 and c1=5 (they don't match anything in right) + allow_duplicates! { + assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+-------+ + | c1 | dummy | + +----+-------+ + | 4 | 40 | + | 5 | 50 | + +----+-------+ + "); + } + Ok(()) + } + + /// Test that null_aware validation rejects non-LeftAnti join types + #[tokio::test] + async fn test_null_aware_validation_wrong_join_type() { + let left = + build_table_two_cols(("c1", &vec![Some(1)]), ("dummy", &vec![Some(10)])); + let right = + build_table_two_cols(("c2", &vec![Some(1)]), ("dummy", &vec![Some(100)])); + + let on = vec![( + Arc::new(Column::new_with_schema("c1", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("c2", &right.schema()).unwrap()) as _, + )]; + + // Try to create null-aware Inner join (should fail) + let result = HashJoinExec::try_new( + left, + right, + on, + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + true, // null_aware = true (invalid for Inner join) + ); + + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("null_aware can only be true for LeftAnti joins") + ); + } + + /// Test that null_aware validation rejects multi-column joins + #[tokio::test] + async fn test_null_aware_validation_multi_column() { + let left = build_table(("a", &vec![1]), ("b", &vec![2]), ("c", &vec![3])); + let right = build_table(("x", &vec![1]), ("y", &vec![2]), ("z", &vec![3])); + + // Try multi-column join + let on = vec![ + ( + Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("x", &right.schema()).unwrap()) as _, + ), + ( + Arc::new(Column::new_with_schema("b", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("y", &right.schema()).unwrap()) as _, + ), + ]; + + // Try to create null-aware anti join with 2 columns (should fail) + let result = HashJoinExec::try_new( + left, + right, + on, + None, + &JoinType::LeftAnti, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + true, // null_aware = true (invalid for multi-column) + ); + + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("null_aware anti join only supports single column join key") + ); + } + + #[test] + fn test_lr_is_preserved() { + assert_eq!(lr_is_preserved(JoinType::Inner), (true, true)); + assert_eq!(lr_is_preserved(JoinType::Left), (true, false)); + assert_eq!(lr_is_preserved(JoinType::Right), (false, true)); + assert_eq!(lr_is_preserved(JoinType::Full), (false, false)); + assert_eq!(lr_is_preserved(JoinType::LeftSemi), (true, true)); + assert_eq!(lr_is_preserved(JoinType::LeftAnti), (true, true)); + assert_eq!(lr_is_preserved(JoinType::LeftMark), (true, false)); + assert_eq!(lr_is_preserved(JoinType::RightSemi), (true, true)); + assert_eq!(lr_is_preserved(JoinType::RightAnti), (true, true)); + assert_eq!(lr_is_preserved(JoinType::RightMark), (false, true)); + } } diff --git a/datafusion/physical-plan/src/joins/hash_join/mod.rs b/datafusion/physical-plan/src/joins/hash_join/mod.rs index 8592e1d968535..b915802ea4015 100644 --- a/datafusion/physical-plan/src/joins/hash_join/mod.rs +++ b/datafusion/physical-plan/src/joins/hash_join/mod.rs @@ -17,7 +17,7 @@ //! [`HashJoinExec`] Partitioned Hash Join Operator -pub use exec::HashJoinExec; +pub use exec::{HashJoinExec, HashJoinExecBuilder}; pub use partitioned_hash_eval::{HashExpr, HashTableLookupExpr, SeededRandomState}; mod exec; diff --git a/datafusion/physical-plan/src/joins/hash_join/stream.rs b/datafusion/physical-plan/src/joins/hash_join/stream.rs index a08ab2eedab3b..8af26c1b8a055 100644 --- a/datafusion/physical-plan/src/joins/hash_join/stream.rs +++ b/datafusion/physical-plan/src/joins/hash_join/stream.rs @@ -21,8 +21,10 @@ //! [`super::HashJoinExec`]. See comments in [`HashJoinStream`] for more details. use std::sync::Arc; +use std::sync::atomic::Ordering; use std::task::Poll; +use crate::coalesce::{LimitedBatchCoalescer, PushBatchStatus}; use crate::joins::Map; use crate::joins::MapOffset; use crate::joins::PartitionMode; @@ -45,7 +47,6 @@ use crate::{ }; use arrow::array::{Array, ArrayRef, UInt32Array, UInt64Array}; -use arrow::compute::BatchCoalescer; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::{ @@ -220,10 +221,11 @@ pub(super) struct HashJoinStream { build_waiter: Option>, /// Partitioning mode to use mode: PartitionMode, - /// Output buffer for coalescing small batches into larger ones. - /// Uses `BatchCoalescer` from arrow to efficiently combine batches. - /// When batches are already close to target size, they bypass coalescing. - output_buffer: Box, + /// Output buffer for coalescing small batches into larger ones with optional fetch limit. + /// Uses `LimitedBatchCoalescer` to efficiently combine batches and absorb limit with 'fetch' + output_buffer: LimitedBatchCoalescer, + /// Whether this is a null-aware anti join + null_aware: bool, } impl RecordBatchStream for HashJoinStream { @@ -371,14 +373,12 @@ impl HashJoinStream { right_side_ordered: bool, build_accumulator: Option>, mode: PartitionMode, + null_aware: bool, + fetch: Option, ) -> Self { - // Create output buffer with coalescing. - // Use biggest_coalesce_batch_size to bypass coalescing for batches - // that are already close to target size (within 50%). - let output_buffer = Box::new( - BatchCoalescer::new(Arc::clone(&schema), batch_size) - .with_biggest_coalesce_batch_size(Some(batch_size / 2)), - ); + // Create output buffer with coalescing and optional fetch limit. + let output_buffer = + LimitedBatchCoalescer::new(Arc::clone(&schema), batch_size, fetch); Self { partition, @@ -402,6 +402,7 @@ impl HashJoinStream { build_waiter: None, mode, output_buffer, + null_aware, } } @@ -420,6 +421,11 @@ impl HashJoinStream { .record_poll(Poll::Ready(Some(Ok(batch)))); } + // Check if the coalescer has finished (limit reached and flushed) + if self.output_buffer.is_finished() { + return Poll::Ready(None); + } + return match self.state { HashJoinStreamState::WaitBuildSide => { handle_state!(ready!(self.collect_build_side(cx))) @@ -438,7 +444,7 @@ impl HashJoinStream { } HashJoinStreamState::Completed if !self.output_buffer.is_empty() => { // Flush any remaining buffered data - self.output_buffer.finish_buffered_batch()?; + self.output_buffer.finish()?; // Continue loop to emit the flushed batch continue; } @@ -484,6 +490,10 @@ impl HashJoinStream { )?; build_timer.done(); + // Note: For null-aware anti join, we need to check the probe side (right) for NULLs, + // not the build side (left). The probe-side NULL check happens during process_probe_batch. + // The probe_side_has_null flag will be set there if any probe batch contains NULL. + // Handle dynamic filter build-side information accumulation // // Dynamic filter coordination between partitions: @@ -595,6 +605,44 @@ impl HashJoinStream { let timer = self.join_metrics.join_time.timer(); + // Null-aware anti join semantics: + // For LeftAnti: output LEFT (build) rows where LEFT.key NOT IN RIGHT.key + // 1. If RIGHT (probe) contains NULL in any batch, no LEFT rows should be output + // 2. LEFT rows with NULL keys should not be output (handled in final stage) + if self.null_aware { + // Mark that we've seen a probe batch with actual rows (probe side is non-empty) + // Only set this if batch has rows - empty batches don't count + // Use shared atomic state so all partitions can see this global information + if state.batch.num_rows() > 0 { + build_side + .left_data + .probe_side_non_empty + .store(true, Ordering::Relaxed); + } + + // Check if probe side (RIGHT) contains NULL + // Since null_aware validation ensures single column join, we only check the first column + let probe_key_column = &state.values[0]; + if probe_key_column.null_count() > 0 { + // Found NULL in probe side - set shared flag to prevent any output + build_side + .left_data + .probe_side_has_null + .store(true, Ordering::Relaxed); + } + + // If probe side has NULL (detected in this or any other partition), return empty result + if build_side + .left_data + .probe_side_has_null + .load(Ordering::Relaxed) + { + timer.done(); + self.state = HashJoinStreamState::FetchProbeBatch; + return Ok(StatefulStreamResult::Continue); + } + } + // if the left side is empty, we can skip the (potentially expensive) join operation let is_empty = build_side.left_data.map().is_empty(); @@ -735,10 +783,17 @@ impl HashJoinStream { join_side, )?; - self.output_buffer.push_batch(batch)?; + let push_status = self.output_buffer.push_batch(batch)?; timer.done(); + // If limit reached, finish and move to Completed state + if push_status == PushBatchStatus::LimitReached { + self.output_buffer.finish()?; + self.state = HashJoinStreamState::Completed; + return Ok(StatefulStreamResult::Continue); + } + if next_offset.is_none() { self.state = HashJoinStreamState::FetchProbeBatch; } else { @@ -766,18 +821,66 @@ impl HashJoinStream { } let build_side = self.build_side.try_as_ready()?; + + // For null-aware anti join, if probe side had NULL, no rows should be output + // Check shared atomic state to get global knowledge across all partitions + if self.null_aware + && build_side + .left_data + .probe_side_has_null + .load(Ordering::Relaxed) + { + timer.done(); + self.state = HashJoinStreamState::Completed; + return Ok(StatefulStreamResult::Continue); + } if !build_side.left_data.report_probe_completed() { self.state = HashJoinStreamState::Completed; return Ok(StatefulStreamResult::Continue); } // use the global left bitmap to produce the left indices and right indices - let (left_side, right_side) = get_final_indices_from_shared_bitmap( + let (mut left_side, mut right_side) = get_final_indices_from_shared_bitmap( build_side.left_data.visited_indices_bitmap(), self.join_type, true, ); + // For null-aware anti join, filter out LEFT rows with NULL in join keys + // BUT only if the probe side (RIGHT) was non-empty. If probe side is empty, + // NULL NOT IN (empty) = TRUE, so NULL rows should be returned. + // Use shared atomic state to get global knowledge across all partitions + if self.null_aware + && self.join_type == JoinType::LeftAnti + && build_side + .left_data + .probe_side_non_empty + .load(Ordering::Relaxed) + { + // Since null_aware validation ensures single column join, we only check the first column + let build_key_column = &build_side.left_data.values()[0]; + + // Filter out indices where the key is NULL + let filtered_indices: Vec = left_side + .iter() + .filter_map(|idx| { + let idx_usize = idx.unwrap() as usize; + if build_key_column.is_null(idx_usize) { + None // Skip rows with NULL keys + } else { + Some(idx.unwrap()) + } + }) + .collect(); + + left_side = UInt64Array::from(filtered_indices); + + // Update right_side to match the new length + let mut builder = arrow::array::UInt32Builder::with_capacity(left_side.len()); + builder.append_nulls(left_side.len()); + right_side = builder.finish(); + } + self.join_metrics.input_batches.add(1); self.join_metrics.input_rows.add(left_side.len()); @@ -797,7 +900,12 @@ impl HashJoinStream { &self.column_indices, JoinSide::Left, )?; - self.output_buffer.push_batch(batch)?; + let push_status = self.output_buffer.push_batch(batch)?; + + // If limit reached, finish the coalescer + if push_status == PushBatchStatus::LimitReached { + self.output_buffer.finish()?; + } } Ok(StatefulStreamResult::Continue) diff --git a/datafusion/physical-plan/src/joins/mod.rs b/datafusion/physical-plan/src/joins/mod.rs index 848d0472fe885..2cdfa1e6ac020 100644 --- a/datafusion/physical-plan/src/joins/mod.rs +++ b/datafusion/physical-plan/src/joins/mod.rs @@ -20,8 +20,10 @@ use arrow::array::BooleanBufferBuilder; pub use cross_join::CrossJoinExec; use datafusion_physical_expr::PhysicalExprRef; -pub use hash_join::{HashExpr, HashJoinExec, HashTableLookupExpr, SeededRandomState}; -pub use nested_loop_join::NestedLoopJoinExec; +pub use hash_join::{ + HashExpr, HashJoinExec, HashJoinExecBuilder, HashTableLookupExpr, SeededRandomState, +}; +pub use nested_loop_join::{NestedLoopJoinExec, NestedLoopJoinExecBuilder}; use parking_lot::Mutex; // Note: SortMergeJoin is not used in plans yet pub use piecewise_merge_join::PiecewiseMergeJoinExec; diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 44637321a7e35..5b2cebb360439 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -71,6 +71,7 @@ use datafusion_physical_expr::equivalence::{ ProjectionMapping, join_equivalence_properties, }; +use datafusion_physical_expr::projection::{ProjectionRef, combine_projections}; use futures::{Stream, StreamExt, TryStreamExt}; use log::debug; use parking_lot::Mutex; @@ -192,7 +193,7 @@ pub struct NestedLoopJoinExec { /// Information of index and left / right placement of columns column_indices: Vec, /// Projection to apply to the output of the join - projection: Option>, + projection: Option, /// Execution metrics metrics: ExecutionPlanMetricsSet, @@ -200,34 +201,76 @@ pub struct NestedLoopJoinExec { cache: PlanProperties, } -impl NestedLoopJoinExec { - /// Try to create a new [`NestedLoopJoinExec`] - pub fn try_new( +/// Helps to build [`NestedLoopJoinExec`]. +pub struct NestedLoopJoinExecBuilder { + left: Arc, + right: Arc, + join_type: JoinType, + filter: Option, + projection: Option, +} + +impl NestedLoopJoinExecBuilder { + /// Make a new [`NestedLoopJoinExecBuilder`]. + pub fn new( left: Arc, right: Arc, - filter: Option, - join_type: &JoinType, - projection: Option>, - ) -> Result { + join_type: JoinType, + ) -> Self { + Self { + left, + right, + join_type, + filter: None, + projection: None, + } + } + + /// Set projection from the vector. + pub fn with_projection(self, projection: Option>) -> Self { + self.with_projection_ref(projection.map(Into::into)) + } + + /// Set projection from the shared reference. + pub fn with_projection_ref(mut self, projection: Option) -> Self { + self.projection = projection; + self + } + + /// Set optional filter. + pub fn with_filter(mut self, filter: Option) -> Self { + self.filter = filter; + self + } + + /// Build resulting execution plan. + pub fn build(self) -> Result { + let Self { + left, + right, + join_type, + filter, + projection, + } = self; + let left_schema = left.schema(); let right_schema = right.schema(); check_join_is_valid(&left_schema, &right_schema, &[])?; let (join_schema, column_indices) = - build_join_schema(&left_schema, &right_schema, join_type); + build_join_schema(&left_schema, &right_schema, &join_type); let join_schema = Arc::new(join_schema); - let cache = Self::compute_properties( + let cache = NestedLoopJoinExec::compute_properties( &left, &right, &join_schema, - *join_type, - projection.as_ref(), + join_type, + projection.as_deref(), )?; - Ok(NestedLoopJoinExec { left, right, filter, - join_type: *join_type, + join_type, join_schema, build_side_data: Default::default(), column_indices, @@ -236,6 +279,34 @@ impl NestedLoopJoinExec { cache, }) } +} + +impl From<&NestedLoopJoinExec> for NestedLoopJoinExecBuilder { + fn from(exec: &NestedLoopJoinExec) -> Self { + Self { + left: Arc::clone(exec.left()), + right: Arc::clone(exec.right()), + join_type: exec.join_type, + filter: exec.filter.clone(), + projection: exec.projection.clone(), + } + } +} + +impl NestedLoopJoinExec { + /// Try to create a new [`NestedLoopJoinExec`] + pub fn try_new( + left: Arc, + right: Arc, + filter: Option, + join_type: &JoinType, + projection: Option>, + ) -> Result { + NestedLoopJoinExecBuilder::new(left, right, *join_type) + .with_projection(projection) + .with_filter(filter) + .build() + } /// left side pub fn left(&self) -> &Arc { @@ -257,8 +328,8 @@ impl NestedLoopJoinExec { &self.join_type } - pub fn projection(&self) -> Option<&Vec> { - self.projection.as_ref() + pub fn projection(&self) -> &Option { + &self.projection } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -267,7 +338,7 @@ impl NestedLoopJoinExec { right: &Arc, schema: &SchemaRef, join_type: JoinType, - projection: Option<&Vec>, + projection: Option<&[usize]>, ) -> Result { // Calculate equivalence properties: let mut eq_properties = join_equivalence_properties( @@ -310,7 +381,7 @@ impl NestedLoopJoinExec { if let Some(projection) = projection { // construct a map from the input expressions to the output expression of the Projection let projection_mapping = ProjectionMapping::from_indices(projection, schema)?; - let out_schema = project_schema(schema, Some(projection))?; + let out_schema = project_schema(schema, Some(&projection))?; output_partitioning = output_partitioning.project(&projection_mapping, &eq_properties); eq_properties = eq_properties.project(&projection_mapping, out_schema); @@ -334,22 +405,14 @@ impl NestedLoopJoinExec { } pub fn with_projection(&self, projection: Option>) -> Result { + let projection = projection.map(Into::into); // check if the projection is valid - can_project(&self.schema(), projection.as_ref())?; - let projection = match projection { - Some(projection) => match &self.projection { - Some(p) => Some(projection.iter().map(|i| p[*i]).collect()), - None => Some(projection), - }, - None => None, - }; - Self::try_new( - Arc::clone(&self.left), - Arc::clone(&self.right), - self.filter.clone(), - &self.join_type, - projection, - ) + can_project(&self.schema(), projection.as_deref())?; + let projection = + combine_projections(projection.as_ref(), self.projection.as_ref())?; + NestedLoopJoinExecBuilder::from(self) + .with_projection_ref(projection) + .build() } /// Returns a new `ExecutionPlan` that runs NestedLoopsJoins with the left @@ -371,7 +434,7 @@ impl NestedLoopJoinExec { swap_join_projection( left.schema().fields().len(), right.schema().fields().len(), - self.projection.as_ref(), + self.projection.as_deref(), self.join_type(), ), )?; @@ -476,13 +539,16 @@ impl ExecutionPlan for NestedLoopJoinExec { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(NestedLoopJoinExec::try_new( - Arc::clone(&children[0]), - Arc::clone(&children[1]), - self.filter.clone(), - &self.join_type, - self.projection.clone(), - )?)) + Ok(Arc::new( + NestedLoopJoinExecBuilder::new( + Arc::clone(&children[0]), + Arc::clone(&children[1]), + self.join_type, + ) + .with_filter(self.filter.clone()) + .with_projection_ref(self.projection.clone()) + .build()?, + )) } fn execute( @@ -521,7 +587,7 @@ impl ExecutionPlan for NestedLoopJoinExec { let probe_side_data = self.right.execute(partition, context)?; // update column indices to reflect the projection - let column_indices_after_projection = match &self.projection { + let column_indices_after_projection = match self.projection.as_ref() { Some(projection) => projection .iter() .map(|i| self.column_indices[*i].clone()) @@ -545,10 +611,6 @@ impl ExecutionPlan for NestedLoopJoinExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { // NestedLoopJoinExec is designed for joins without equijoin keys in the // ON clause (e.g., `t1 JOIN t2 ON (t1.v1 + t2.v1) % 2 = 0`). Any join @@ -682,10 +744,10 @@ async fn collect_left_input( let schema = stream.schema(); // Load all batches and count the rows - let (batches, metrics, mut reservation) = stream + let (batches, metrics, reservation) = stream .try_fold( (Vec::new(), join_metrics, reservation), - |(mut batches, metrics, mut reservation), batch| async { + |(mut batches, metrics, reservation), batch| async { let batch_size = batch.get_array_memory_size(); // Reserve memory for incoming batch reservation.try_grow(batch_size)?; diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs index 508be2e3984f4..d7ece845e943c 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs @@ -620,7 +620,7 @@ async fn build_buffered_data( // Combine batches and record number of rows let initial = (Vec::new(), 0, metrics, reservation); - let (batches, num_rows, metrics, mut reservation) = buffered + let (batches, num_rows, metrics, reservation) = buffered .try_fold(initial, |mut acc, batch| async { let batch_size = get_record_batch_memory_size(&batch); acc.3.try_grow(batch_size)?; diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs b/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs index ae7a5fa764bcc..8778e4154e60e 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs @@ -519,10 +519,6 @@ impl ExecutionPlan for SortMergeJoinExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { // SortMergeJoinExec uses symmetric hash partitioning where both left and right // inputs are hash-partitioned on the join keys. This means partition `i` of the diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 1f6bc703a0300..4fdc5fc64dc67 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -52,7 +52,7 @@ use crate::projection::{ }; use crate::{ DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, - PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, + PlanProperties, RecordBatchStream, SendableRecordBatchStream, joins::StreamJoinPartitionMode, metrics::{ExecutionPlanMetricsSet, MetricsSet}, }; @@ -470,11 +470,6 @@ impl ExecutionPlan for SymmetricHashJoinExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - // TODO stats: it is not possible in general to know the output size of joins - Ok(Statistics::new_unknown(&self.schema())) - } - fn execute( &self, partition: usize, diff --git a/datafusion/physical-plan/src/joins/test_utils.rs b/datafusion/physical-plan/src/joins/test_utils.rs index 27284bf546bc1..0455fb2a1eb6e 100644 --- a/datafusion/physical-plan/src/joins/test_utils.rs +++ b/datafusion/physical-plan/src/joins/test_utils.rs @@ -152,6 +152,7 @@ pub async fn partitioned_hash_join_with_filter( None, PartitionMode::Partitioned, null_equality, + false, // null_aware )?); let mut batches = vec![]; diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index a9243fe04e28d..83fd418d73d72 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -977,6 +977,17 @@ pub(crate) fn apply_join_filter_to_indices( )) } +/// Creates a [RecordBatch] with zero columns but the given row count. +/// Used when a join has an empty projection (e.g. `SELECT count(1) ...`). +fn new_empty_schema_batch(schema: &Schema, row_count: usize) -> Result { + let options = RecordBatchOptions::new().with_row_count(Some(row_count)); + Ok(RecordBatch::try_new_with_options( + Arc::new(schema.clone()), + vec![], + &options, + )?) +} + /// Returns a new [RecordBatch] by combining the `left` and `right` according to `indices`. /// The resulting batch has [Schema] `schema`. pub(crate) fn build_batch_from_indices( @@ -989,15 +1000,7 @@ pub(crate) fn build_batch_from_indices( build_side: JoinSide, ) -> Result { if schema.fields().is_empty() { - let options = RecordBatchOptions::new() - .with_match_field_names(true) - .with_row_count(Some(build_indices.len())); - - return Ok(RecordBatch::try_new_with_options( - Arc::new(schema.clone()), - vec![], - &options, - )?); + return new_empty_schema_batch(schema, build_indices.len()); } // build the columns of the new [RecordBatch]: @@ -1057,6 +1060,9 @@ pub(crate) fn build_batch_empty_build_side( // the remaining joins will return data for the right columns and null for the left ones JoinType::Right | JoinType::Full | JoinType::RightAnti | JoinType::RightMark => { let num_rows = probe_batch.num_rows(); + if schema.fields().is_empty() { + return new_empty_schema_batch(schema, num_rows); + } let mut columns: Vec> = Vec::with_capacity(schema.fields().len()); @@ -1674,7 +1680,7 @@ fn swap_reverting_projection( pub fn swap_join_projection( left_schema_len: usize, right_schema_len: usize, - projection: Option<&Vec>, + projection: Option<&[usize]>, join_type: &JoinType, ) -> Option> { match join_type { @@ -1685,7 +1691,7 @@ pub fn swap_join_projection( | JoinType::RightAnti | JoinType::RightSemi | JoinType::LeftMark - | JoinType::RightMark => projection.cloned(), + | JoinType::RightMark => projection.map(|p| p.to_vec()), _ => projection.map(|p| { p.iter() .map(|i| { @@ -2889,4 +2895,35 @@ mod tests { Ok(()) } + + #[test] + fn test_build_batch_empty_build_side_empty_schema() -> Result<()> { + // When the output schema has no fields (empty projection pushed into + // the join), build_batch_empty_build_side should return a RecordBatch + // with the correct row count but no columns. + let empty_schema = Schema::empty(); + + let build_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])), + vec![Arc::new(arrow::array::Int32Array::from(vec![1, 2, 3]))], + )?; + + let probe_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("b", DataType::Int32, true)])), + vec![Arc::new(arrow::array::Int32Array::from(vec![4, 5, 6, 7]))], + )?; + + let result = build_batch_empty_build_side( + &empty_schema, + &build_batch, + &probe_batch, + &[], // no column indices with empty projection + JoinType::Right, + )?; + + assert_eq!(result.num_rows(), 4); + assert_eq!(result.num_columns(), 0); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index ec8e154caec91..6467d7a2e389d 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -24,8 +24,6 @@ // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -// https://github.com/apache/datafusion/issues/18881 -#![deny(clippy::allow_attributes)] //! Traits for physical query plan, supporting parallel execution for partitioned relations. //! @@ -65,9 +63,11 @@ mod visitor; pub mod aggregates; pub mod analyze; pub mod async_func; +pub mod buffer; pub mod coalesce; pub mod coalesce_batches; pub mod coalesce_partitions; +pub mod column_rewriter; pub mod common; pub mod coop; pub mod display; diff --git a/datafusion/physical-plan/src/limit.rs b/datafusion/physical-plan/src/limit.rs index 05d6882821477..9ce63a1c586a6 100644 --- a/datafusion/physical-plan/src/limit.rs +++ b/datafusion/physical-plan/src/limit.rs @@ -35,6 +35,7 @@ use arrow::record_batch::RecordBatch; use datafusion_common::{Result, assert_eq_or_internal_err, internal_err}; use datafusion_execution::TaskContext; +use datafusion_physical_expr::LexOrdering; use futures::stream::{Stream, StreamExt}; use log::trace; @@ -51,6 +52,9 @@ pub struct GlobalLimitExec { /// Execution metrics metrics: ExecutionPlanMetricsSet, cache: PlanProperties, + /// Does the limit have to preserve the order of its input, and if so what is it? + /// Some optimizations may reorder the input if no particular sort is required + required_ordering: Option, } impl GlobalLimitExec { @@ -63,6 +67,7 @@ impl GlobalLimitExec { fetch, metrics: ExecutionPlanMetricsSet::new(), cache, + required_ordering: None, } } @@ -91,6 +96,16 @@ impl GlobalLimitExec { Boundedness::Bounded, ) } + + /// Get the required ordering from limit + pub fn required_ordering(&self) -> &Option { + &self.required_ordering + } + + /// Set the required ordering for limit + pub fn set_required_ordering(&mut self, required_ordering: Option) { + self.required_ordering = required_ordering; + } } impl DisplayAs for GlobalLimitExec { @@ -194,10 +209,6 @@ impl ExecutionPlan for GlobalLimitExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { self.input .partition_statistics(partition)? @@ -223,6 +234,9 @@ pub struct LocalLimitExec { /// Execution metrics metrics: ExecutionPlanMetricsSet, cache: PlanProperties, + /// If the child plan is a sort node, after the sort node is removed during + /// physical optimization, we should add the required ordering to the limit node + required_ordering: Option, } impl LocalLimitExec { @@ -234,6 +248,7 @@ impl LocalLimitExec { fetch, metrics: ExecutionPlanMetricsSet::new(), cache, + required_ordering: None, } } @@ -257,6 +272,16 @@ impl LocalLimitExec { Boundedness::Bounded, ) } + + /// Get the required ordering from limit + pub fn required_ordering(&self) -> &Option { + &self.required_ordering + } + + /// Set the required ordering for limit + pub fn set_required_ordering(&mut self, required_ordering: Option) { + self.required_ordering = required_ordering; + } } impl DisplayAs for LocalLimitExec { @@ -340,10 +365,6 @@ impl ExecutionPlan for LocalLimitExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { self.input .partition_statistics(partition)? diff --git a/datafusion/physical-plan/src/memory.rs b/datafusion/physical-plan/src/memory.rs index 4a406ca648d57..a58abe20a23ee 100644 --- a/datafusion/physical-plan/src/memory.rs +++ b/datafusion/physical-plan/src/memory.rs @@ -27,7 +27,7 @@ use crate::execution_plan::{Boundedness, EmissionType, SchedulingType}; use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use crate::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, - RecordBatchStream, SendableRecordBatchStream, Statistics, + RecordBatchStream, SendableRecordBatchStream, }; use arrow::array::RecordBatch; @@ -352,10 +352,6 @@ impl ExecutionPlan for LazyMemoryExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - Ok(Statistics::new_unknown(&self.schema)) - } - fn reset_state(self: Arc) -> Result> { let generators = self .generators() diff --git a/datafusion/physical-plan/src/placeholder_row.rs b/datafusion/physical-plan/src/placeholder_row.rs index 4d00b73cff39c..c91085965b07c 100644 --- a/datafusion/physical-plan/src/placeholder_row.rs +++ b/datafusion/physical-plan/src/placeholder_row.rs @@ -169,10 +169,6 @@ impl ExecutionPlan for PlaceholderRowExec { Ok(Box::pin(cooperative(ms))) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { let batches = self .data() diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index e8608f17a1b20..55b4129223c24 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -20,16 +20,17 @@ //! of a projection on table `t1` where the expressions `a`, `b`, and `a+b` are the //! projection expressions. `SELECT` without `FROM` will only evaluate expressions. -use super::expressions::{Column, Literal}; +use super::expressions::Column; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::{ DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream, SendableRecordBatchStream, SortOrderPushdownResult, Statistics, }; +use crate::column_rewriter::PhysicalColumnRewriter; use crate::execution_plan::CardinalityEffect; use crate::filter_pushdown::{ - ChildPushdownResult, FilterDescription, FilterPushdownPhase, - FilterPushdownPropagation, + ChildFilterDescription, ChildPushdownResult, FilterDescription, FilterPushdownPhase, + FilterPushdownPropagation, FilterRemapper, PushedDownPredicate, }; use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn, JoinOnRef}; use crate::{DisplayFormatType, ExecutionPlan, PhysicalExpr}; @@ -45,8 +46,9 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; -use datafusion_common::{JoinSide, Result, internal_err}; +use datafusion_common::{DataFusionError, JoinSide, Result, internal_err}; use datafusion_execution::TaskContext; +use datafusion_expr::ExpressionPlacement; use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::projection::Projector; use datafusion_physical_expr::utils::collect_columns; @@ -136,13 +138,19 @@ impl ProjectionExec { E: Into, { let input_schema = input.schema(); - // convert argument to Vec - let expr_vec = expr.into_iter().map(Into::into).collect::>(); - let projection = ProjectionExprs::new(expr_vec); + let expr_arc = expr.into_iter().map(Into::into).collect::>(); + let projection = ProjectionExprs::from_expressions(expr_arc); let projector = projection.make_projector(&input_schema)?; + Self::try_from_projector(projector, input) + } + fn try_from_projector( + projector: Projector, + input: Arc, + ) -> Result { // Construct a map from the input expressions to the output expression of the Projection - let projection_mapping = projection.projection_mapping(&input_schema)?; + let projection_mapping = + projector.projection().projection_mapping(&input.schema())?; let cache = Self::compute_properties( &input, &projection_mapping, @@ -192,6 +200,29 @@ impl ProjectionExec { input.boundedness(), )) } + + /// Collect reverse alias mapping from projection expressions. + /// The result hash map is a map from aliased Column in parent to original expr. + fn collect_reverse_alias( + &self, + ) -> Result>> { + let mut alias_map = datafusion_common::HashMap::new(); + for projection in self.projection_expr().iter() { + let (aliased_index, _output_field) = self + .projector + .output_schema() + .column_with_name(&projection.alias) + .ok_or_else(|| { + DataFusionError::Internal(format!( + "Expr {} with alias {} not found in output schema", + projection.expr, projection.alias + )) + })?; + let aliased_col = Column::new(&projection.alias, aliased_index); + alias_map.insert(aliased_col, Arc::clone(&projection.expr)); + } + Ok(alias_map) + } } impl DisplayAs for ProjectionExec { @@ -261,10 +292,13 @@ impl ExecutionPlan for ProjectionExec { .as_ref() .iter() .all(|proj_expr| { - proj_expr.expr.as_any().is::() - || proj_expr.expr.as_any().is::() + !matches!( + proj_expr.expr.placement(), + ExpressionPlacement::KeepInPlace + ) }); - // If expressions are all either column_expr or Literal, then all computations in this projection are reorder or rename, + // If expressions are all either column_expr or Literal (or other cheap expressions), + // then all computations in this projection are reorder or rename, // and projection would not benefit from the repartition, benefits_from_input_partitioning will return false. vec![!all_simple_exprs] } @@ -277,8 +311,8 @@ impl ExecutionPlan for ProjectionExec { self: Arc, mut children: Vec>, ) -> Result> { - ProjectionExec::try_new( - self.projector.projection().clone(), + ProjectionExec::try_from_projector( + self.projector.clone(), children.swap_remove(0), ) .map(|p| Arc::new(p) as _) @@ -308,10 +342,6 @@ impl ExecutionPlan for ProjectionExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { let input_stats = self.input.partition_statistics(partition)?; let output_schema = self.schema(); @@ -347,10 +377,28 @@ impl ExecutionPlan for ProjectionExec { parent_filters: Vec>, _config: &ConfigOptions, ) -> Result { - // TODO: In future, we can try to handle inverting aliases here. - // For the time being, we pass through untransformed filters, so filters on aliases are not handled. - // https://github.com/apache/datafusion/issues/17246 - FilterDescription::from_children(parent_filters, &self.children()) + // expand alias column to original expr in parent filters + let invert_alias_map = self.collect_reverse_alias()?; + let output_schema = self.schema(); + let remapper = FilterRemapper::new(output_schema); + let mut child_parent_filters = Vec::with_capacity(parent_filters.len()); + + for filter in parent_filters { + // Check that column exists in child, then reassign column indices to match child schema + if let Some(reassigned) = remapper.try_remap(&filter)? { + // rewrite filter expression using invert alias map + let mut rewriter = PhysicalColumnRewriter::new(&invert_alias_map); + let rewritten = reassigned.rewrite(&mut rewriter)?.data; + child_parent_filters.push(PushedDownPredicate::supported(rewritten)); + } else { + child_parent_filters.push(PushedDownPredicate::unsupported(filter)); + } + } + + Ok(FilterDescription::new().with_child(ChildFilterDescription { + parent_filters: child_parent_filters, + self_filters: vec![], + })) } fn handle_child_pushdown_result( @@ -427,6 +475,19 @@ impl ExecutionPlan for ProjectionExec { } } } + + fn with_preserve_order( + &self, + preserve_order: bool, + ) -> Option> { + self.input + .with_preserve_order(preserve_order) + .and_then(|new_input| { + Arc::new(self.clone()) + .with_new_children(vec![new_input]) + .ok() + }) + } } impl ProjectionStream { @@ -485,6 +546,15 @@ impl RecordBatchStream for ProjectionStream { } } +/// Trait for execution plans that can embed a projection, avoiding a separate +/// [`ProjectionExec`] wrapper. +/// +/// # Empty projections +/// +/// `Some(vec![])` is a valid projection that produces zero output columns while +/// preserving the correct row count. Implementors must ensure that runtime batch +/// construction still returns batches with the right number of rows even when no +/// columns are selected (e.g. for `SELECT count(1) … JOIN …`). pub trait EmbeddedProjection: ExecutionPlan + Sized { fn with_projection(&self, projection: Option>) -> Result; } @@ -495,6 +565,15 @@ pub fn try_embed_projection( projection: &ProjectionExec, execution_plan: &Exec, ) -> Result>> { + // If the projection has no expressions at all (e.g., ProjectionExec: expr=[]), + // embed an empty projection into the execution plan so it outputs zero columns. + // This avoids allocating throwaway null arrays for build-side columns + // when no output columns are actually needed (e.g., count(1) over a right join). + if projection.expr().is_empty() { + let new_execution_plan = Arc::new(execution_plan.with_projection(Some(vec![]))?); + return Ok(Some(new_execution_plan)); + } + // Collect all column indices from the given projection expressions. let projection_index = collect_column_indices(projection.expr()); @@ -945,11 +1024,15 @@ fn try_unifying_projections( .unwrap(); }); // Merging these projections is not beneficial, e.g - // If an expression is not trivial and it is referred more than 1, unifies projections will be + // If an expression is not trivial (KeepInPlace) and it is referred more than 1, unifies projections will be // beneficial as caching mechanism for non-trivial computations. // See discussion in: https://github.com/apache/datafusion/issues/8296 if column_ref_map.iter().any(|(column, count)| { - *count > 1 && !is_expr_trivial(&Arc::clone(&child.expr()[column.index()].expr)) + *count > 1 + && !child.expr()[column.index()] + .expr + .placement() + .should_push_to_leaves() }) { return Ok(None); } @@ -1059,13 +1142,6 @@ fn new_columns_for_join_on( (new_columns.len() == hash_join_on.len()).then_some(new_columns) } -/// Checks if the given expression is trivial. -/// An expression is considered trivial if it is either a `Column` or a `Literal`. -fn is_expr_trivial(expr: &Arc) -> bool { - expr.as_any().downcast_ref::().is_some() - || expr.as_any().downcast_ref::().is_some() -} - #[cfg(test)] mod tests { use super::*; @@ -1073,6 +1149,7 @@ mod tests { use crate::common::collect; + use crate::filter_pushdown::PushedDown; use crate::test; use crate::test::exec::StatisticsExec; @@ -1081,7 +1158,9 @@ mod tests { use datafusion_common::stats::{ColumnStatistics, Precision, Statistics}; use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal, col}; + use datafusion_physical_expr::expressions::{ + BinaryExpr, Column, DynamicFilterPhysicalExpr, Literal, binary, col, lit, + }; #[test] fn test_collect_column_indices() -> Result<()> { @@ -1270,4 +1349,431 @@ mod tests { ); assert!(stats.total_byte_size.is_exact().unwrap_or(false)); } + + #[test] + fn test_filter_pushdown_with_alias() -> Result<()> { + let input_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics::new_unknown(&input_schema), + input_schema.clone(), + )); + + // project "a" as "b" + let projection = ProjectionExec::try_new( + vec![ProjectionExpr { + expr: Arc::new(Column::new("a", 0)), + alias: "b".to_string(), + }], + input, + )?; + + // filter "b > 5" + let filter = Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )) as Arc; + + let description = projection.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![filter], + &ConfigOptions::default(), + )?; + + // Should be converted to "a > 5" + // "a" is index 0 in input + let expected_filter = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )) as Arc; + + assert_eq!(description.self_filters(), vec![vec![]]); + let pushed_filters = &description.parent_filters()[0]; + assert_eq!( + format!("{}", pushed_filters[0].predicate), + format!("{}", expected_filter) + ); + // Verify the predicate was actually pushed down + assert!(matches!(pushed_filters[0].discriminant, PushedDown::Yes)); + + Ok(()) + } + + #[test] + fn test_filter_pushdown_with_multiple_aliases() -> Result<()> { + let input_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let input = Arc::new(StatisticsExec::new( + Statistics { + column_statistics: vec![Default::default(); input_schema.fields().len()], + ..Default::default() + }, + input_schema.clone(), + )); + + // project "a" as "x", "b" as "y" + let projection = ProjectionExec::try_new( + vec![ + ProjectionExpr { + expr: Arc::new(Column::new("a", 0)), + alias: "x".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("b", 1)), + alias: "y".to_string(), + }, + ], + input, + )?; + + // filter "x > 5" + let filter1 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("x", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )) as Arc; + + // filter "y < 10" + let filter2 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("y", 1)), + Operator::Lt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )) as Arc; + + let description = projection.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![filter1, filter2], + &ConfigOptions::default(), + )?; + + // Should be converted to "a > 5" and "b < 10" + let expected_filter1 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )) as Arc; + + let expected_filter2 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Lt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )) as Arc; + + let pushed_filters = &description.parent_filters()[0]; + assert_eq!(pushed_filters.len(), 2); + // Note: The order of filters is preserved + assert_eq!( + format!("{}", pushed_filters[0].predicate), + format!("{}", expected_filter1) + ); + assert_eq!( + format!("{}", pushed_filters[1].predicate), + format!("{}", expected_filter2) + ); + // Verify the predicates were actually pushed down + assert!(matches!(pushed_filters[0].discriminant, PushedDown::Yes)); + assert!(matches!(pushed_filters[1].discriminant, PushedDown::Yes)); + + Ok(()) + } + + #[test] + fn test_filter_pushdown_with_swapped_aliases() -> Result<()> { + let input_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let input = Arc::new(StatisticsExec::new( + Statistics { + column_statistics: vec![Default::default(); input_schema.fields().len()], + ..Default::default() + }, + input_schema.clone(), + )); + + // project "a" as "b", "b" as "a" + let projection = ProjectionExec::try_new( + vec![ + ProjectionExpr { + expr: Arc::new(Column::new("a", 0)), + alias: "b".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("b", 1)), + alias: "a".to_string(), + }, + ], + input, + )?; + + // filter "b > 5" (output column 0, which is "a" in input) + let filter1 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )) as Arc; + + // filter "a < 10" (output column 1, which is "b" in input) + let filter2 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 1)), + Operator::Lt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )) as Arc; + + let description = projection.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![filter1, filter2], + &ConfigOptions::default(), + )?; + + let pushed_filters = &description.parent_filters()[0]; + assert_eq!(pushed_filters.len(), 2); + + // "b" (output index 0) -> "a" (input index 0) + let expected_filter1 = "a@0 > 5"; + // "a" (output index 1) -> "b" (input index 1) + let expected_filter2 = "b@1 < 10"; + + assert_eq!(format!("{}", pushed_filters[0].predicate), expected_filter1); + assert_eq!(format!("{}", pushed_filters[1].predicate), expected_filter2); + // Verify the predicates were actually pushed down + assert!(matches!(pushed_filters[0].discriminant, PushedDown::Yes)); + assert!(matches!(pushed_filters[1].discriminant, PushedDown::Yes)); + + Ok(()) + } + + #[test] + fn test_filter_pushdown_with_mixed_columns() -> Result<()> { + let input_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let input = Arc::new(StatisticsExec::new( + Statistics { + column_statistics: vec![Default::default(); input_schema.fields().len()], + ..Default::default() + }, + input_schema.clone(), + )); + + // project "a" as "x", "b" as "b" (pass through) + let projection = ProjectionExec::try_new( + vec![ + ProjectionExpr { + expr: Arc::new(Column::new("a", 0)), + alias: "x".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("b", 1)), + alias: "b".to_string(), + }, + ], + input, + )?; + + // filter "x > 5" + let filter1 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("x", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )) as Arc; + + // filter "b < 10" (using output index 1 which corresponds to 'b') + let filter2 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Lt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )) as Arc; + + let description = projection.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![filter1, filter2], + &ConfigOptions::default(), + )?; + + let pushed_filters = &description.parent_filters()[0]; + assert_eq!(pushed_filters.len(), 2); + // "x" -> "a" (index 0) + let expected_filter1 = "a@0 > 5"; + // "b" -> "b" (index 1) + let expected_filter2 = "b@1 < 10"; + + assert_eq!(format!("{}", pushed_filters[0].predicate), expected_filter1); + assert_eq!(format!("{}", pushed_filters[1].predicate), expected_filter2); + // Verify the predicates were actually pushed down + assert!(matches!(pushed_filters[0].discriminant, PushedDown::Yes)); + assert!(matches!(pushed_filters[1].discriminant, PushedDown::Yes)); + + Ok(()) + } + + #[test] + fn test_filter_pushdown_with_complex_expression() -> Result<()> { + let input_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics { + column_statistics: vec![Default::default(); input_schema.fields().len()], + ..Default::default() + }, + input_schema.clone(), + )); + + // project "a + 1" as "z" + let projection = ProjectionExec::try_new( + vec![ProjectionExpr { + expr: Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + )), + alias: "z".to_string(), + }], + input, + )?; + + // filter "z > 10" + let filter = Arc::new(BinaryExpr::new( + Arc::new(Column::new("z", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )) as Arc; + + let description = projection.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![filter], + &ConfigOptions::default(), + )?; + + // expand to `a + 1 > 10` + let pushed_filters = &description.parent_filters()[0]; + assert!(matches!(pushed_filters[0].discriminant, PushedDown::Yes)); + assert_eq!(format!("{}", pushed_filters[0].predicate), "a@0 + 1 > 10"); + + Ok(()) + } + + #[test] + fn test_filter_pushdown_with_unknown_column() -> Result<()> { + let input_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics { + column_statistics: vec![Default::default(); input_schema.fields().len()], + ..Default::default() + }, + input_schema.clone(), + )); + + // project "a" as "a" + let projection = ProjectionExec::try_new( + vec![ProjectionExpr { + expr: Arc::new(Column::new("a", 0)), + alias: "a".to_string(), + }], + input, + )?; + + // filter "unknown_col > 5" - using a column name that doesn't exist in projection output + // Column constructor: name, index. Index 1 doesn't exist. + let filter = Arc::new(BinaryExpr::new( + Arc::new(Column::new("unknown_col", 1)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )) as Arc; + + let description = projection.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![filter], + &ConfigOptions::default(), + )?; + + let pushed_filters = &description.parent_filters()[0]; + assert!(matches!(pushed_filters[0].discriminant, PushedDown::No)); + // The column shouldn't be found in the alias map, so it remains unchanged with its index + assert_eq!( + format!("{}", pushed_filters[0].predicate), + "unknown_col@1 > 5" + ); + + Ok(()) + } + + /// Basic test for `DynamicFilterPhysicalExpr` can correctly update its child expression + /// i.e. starting with lit(true) and after update it becomes `a > 5` + /// with projection [b - 1 as a], the pushed down filter should be `b - 1 > 5` + #[test] + fn test_basic_dyn_filter_projection_pushdown_update_child() -> Result<()> { + let input_schema = + Arc::new(Schema::new(vec![Field::new("b", DataType::Int32, false)])); + + let input = Arc::new(StatisticsExec::new( + Statistics { + column_statistics: vec![Default::default(); input_schema.fields().len()], + ..Default::default() + }, + input_schema.as_ref().clone(), + )); + + // project "b" - 1 as "a" + let projection = ProjectionExec::try_new( + vec![ProjectionExpr { + expr: binary( + Arc::new(Column::new("b", 0)), + Operator::Minus, + lit(1), + &input_schema, + ) + .unwrap(), + alias: "a".to_string(), + }], + input, + )?; + + // simulate projection's parent create a dynamic filter on "a" + let projected_schema = projection.schema(); + let col_a = col("a", &projected_schema)?; + let dynamic_filter = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::clone(&col_a)], + lit(true), + )); + // Initial state should be lit(true) + let current = dynamic_filter.current()?; + assert_eq!(format!("{current}"), "true"); + + let dyn_phy_expr: Arc = Arc::clone(&dynamic_filter) as _; + + let description = projection.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![dyn_phy_expr], + &ConfigOptions::default(), + )?; + + let pushed_filters = &description.parent_filters()[0][0]; + + // Check currently pushed_filters is lit(true) + assert_eq!( + format!("{}", pushed_filters.predicate), + "DynamicFilter [ empty ]" + ); + + // Update to a > 5 (after projection, b is now called a) + let new_expr = + Arc::new(BinaryExpr::new(Arc::clone(&col_a), Operator::Gt, lit(5i32))); + dynamic_filter.update(new_expr)?; + + // Now it should be a > 5 + let current = dynamic_filter.current()?; + assert_eq!(format!("{current}"), "a@0 > 5"); + + // Check currently pushed_filters is b - 1 > 5 (because b - 1 is projected as a) + assert_eq!( + format!("{}", pushed_filters.predicate), + "DynamicFilter [ b@0 - 1 > 5 ]" + ); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 683dbb4e49765..f2cba13717acc 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -24,13 +24,13 @@ use std::task::{Context, Poll}; use super::work_table::{ReservedBatches, WorkTable}; use crate::aggregates::group_values::{GroupValues, new_group_values}; use crate::aggregates::order::GroupOrdering; -use crate::execution_plan::{Boundedness, EmissionType}; +use crate::execution_plan::{Boundedness, EmissionType, reset_plan_states}; use crate::metrics::{ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput, }; use crate::{ DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream, - SendableRecordBatchStream, Statistics, + SendableRecordBatchStream, }; use arrow::array::{BooleanArray, BooleanBuilder}; use arrow::compute::filter_record_batch; @@ -208,10 +208,6 @@ impl ExecutionPlan for RecursiveQueryExec { fn metrics(&self) -> Option { Some(self.metrics.clone_inner()) } - - fn statistics(&self) -> Result { - Ok(Statistics::new_unknown(&self.schema())) - } } impl DisplayAs for RecursiveQueryExec { @@ -387,20 +383,6 @@ fn assign_work_table( .data() } -/// Some plans will change their internal states after execution, making them unable to be executed again. -/// This function uses [`ExecutionPlan::reset_state`] to reset any internal state within the plan. -/// -/// An example is `CrossJoinExec`, which loads the left table into memory and stores it in the plan. -/// However, if the data of the left table is derived from the work table, it will become outdated -/// as the work table changes. When the next iteration executes this plan again, we must clear the left table. -fn reset_plan_states(plan: Arc) -> Result> { - plan.transform_up(|plan| { - let new_plan = Arc::clone(&plan).reset_state()?; - Ok(Transformed::yes(new_plan)) - }) - .data() -} - impl Stream for RecursiveQueryStream { type Item = Result; diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index d50404c8fc1e8..2b0c0ea31689b 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -48,7 +48,8 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::stats::Precision; use datafusion_common::utils::transpose; use datafusion_common::{ - ColumnStatistics, DataFusionError, HashMap, assert_or_internal_err, internal_err, + ColumnStatistics, DataFusionError, HashMap, assert_or_internal_err, + internal_datafusion_err, internal_err, }; use datafusion_common::{Result, not_impl_err}; use datafusion_common_runtime::SpawnedTask; @@ -421,6 +422,7 @@ enum BatchPartitionerState { exprs: Vec>, num_partitions: usize, hash_buffer: Vec, + indices: Vec>, }, RoundRobin { num_partitions: usize, @@ -453,6 +455,7 @@ impl BatchPartitioner { exprs, num_partitions, hash_buffer: vec![], + indices: vec![vec![]; num_partitions], }, timer, } @@ -562,6 +565,7 @@ impl BatchPartitioner { exprs, num_partitions: partitions, hash_buffer, + indices, } => { // Tracking time required for distributing indexes across output partitions let timer = self.timer.timer(); @@ -578,9 +582,7 @@ impl BatchPartitioner { hash_buffer, )?; - let mut indices: Vec<_> = (0..*partitions) - .map(|_| Vec::with_capacity(batch.num_rows())) - .collect(); + indices.iter_mut().for_each(|v| v.clear()); for (index, hash) in hash_buffer.iter().enumerate() { indices[(*hash % *partitions as u64) as usize].push(index as u32); @@ -591,22 +593,23 @@ impl BatchPartitioner { // Borrowing partitioner timer to prevent moving `self` to closure let partitioner_timer = &self.timer; - let it = indices - .into_iter() - .enumerate() - .filter_map(|(partition, indices)| { - let indices: PrimitiveArray = indices.into(); - (!indices.is_empty()).then_some((partition, indices)) - }) - .map(move |(partition, indices)| { + + let mut partitioned_batches = vec![]; + for (partition, p_indices) in indices.iter_mut().enumerate() { + if !p_indices.is_empty() { + let taken_indices = std::mem::take(p_indices); + let indices_array: PrimitiveArray = + taken_indices.into(); + // Tracking time required for repartitioned batches construction let _timer = partitioner_timer.timer(); // Produce batches based on indices - let columns = take_arrays(batch.columns(), &indices, None)?; + let columns = + take_arrays(batch.columns(), &indices_array, None)?; let mut options = RecordBatchOptions::new(); - options = options.with_row_count(Some(indices.len())); + options = options.with_row_count(Some(indices_array.len())); let batch = RecordBatch::try_new_with_options( batch.schema(), columns, @@ -614,10 +617,22 @@ impl BatchPartitioner { ) .unwrap(); - Ok((partition, batch)) - }); + partitioned_batches.push(Ok((partition, batch))); + + // Return the taken vec + let (_, buffer, _) = indices_array.into_parts(); + let mut vec = + buffer.into_inner().into_vec::().map_err(|e| { + internal_datafusion_err!( + "Could not convert buffer to vec: {e:?}" + ) + })?; + vec.clear(); + *p_indices = vec; + } + } - Box::new(it) + Box::new(partitioned_batches.into_iter()) } }; @@ -731,6 +746,10 @@ impl BatchPartitioner { /// system Paper](https://dl.acm.org/doi/pdf/10.1145/93605.98720) /// which uses the term "Exchange" for the concept of repartitioning /// data across threads. +/// +/// For more background, please also see the [Optimizing Repartitions in DataFusion] blog. +/// +/// [Optimizing Repartitions in DataFusion]: https://datafusion.apache.org/blog/2025/12/15/avoid-consecutive-repartitions #[derive(Debug, Clone)] pub struct RepartitionExec { /// Input execution plan @@ -1051,10 +1070,6 @@ impl ExecutionPlan for RepartitionExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.input.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { if let Some(partition) = partition { let partition_count = self.partitioning().partition_count(); diff --git a/datafusion/physical-plan/src/sorts/partial_sort.rs b/datafusion/physical-plan/src/sorts/partial_sort.rs index 73ba889c9e40b..08bc73c92d4b3 100644 --- a/datafusion/physical-plan/src/sorts/partial_sort.rs +++ b/datafusion/physical-plan/src/sorts/partial_sort.rs @@ -329,10 +329,6 @@ impl ExecutionPlan for PartialSortExec { Some(self.metrics_set.clone_inner()) } - fn statistics(&self) -> Result { - self.input.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { self.input.partition_statistics(partition) } diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 3e8fdf1f3ed7e..55e1f460e1901 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -709,7 +709,7 @@ impl ExternalSorter { &self, batch: RecordBatch, metrics: &BaselineMetrics, - mut reservation: MemoryReservation, + reservation: MemoryReservation, ) -> Result { assert_eq!( get_reserved_bytes_for_record_batch(&batch)?, @@ -736,7 +736,7 @@ impl ExternalSorter { .then({ move |batches| async move { match batches { - Ok((schema, sorted_batches, mut reservation)) => { + Ok((schema, sorted_batches, reservation)) => { // Calculate the total size of sorted batches let total_sorted_size: usize = sorted_batches .iter() @@ -819,7 +819,8 @@ impl ExternalSorter { match e { DataFusionError::ResourcesExhausted(_) => e.context( "Not enough memory to continue external sort. \ - Consider increasing the memory limit, or decreasing sort_spill_reservation_bytes" + Consider increasing the memory limit config: 'datafusion.runtime.memory_limit', \ + or decreasing the config: 'datafusion.execution.sort_spill_reservation_bytes'." ), // This is not an OOM error, so just return it as is. _ => e, @@ -1352,10 +1353,6 @@ impl ExecutionPlan for SortExec { Some(self.metrics_set.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { if !self.preserve_partitioning() { return self @@ -1736,6 +1733,21 @@ mod tests { "Assertion failed: expected a ResourcesExhausted error, but got: {err:?}" ); + // Verify external sorter error message when resource is exhausted + let config_vector = vec![ + "datafusion.runtime.memory_limit", + "datafusion.execution.sort_spill_reservation_bytes", + ]; + let error_message = err.message().to_string(); + for config in config_vector.into_iter() { + assert!( + error_message.as_str().contains(config), + "Config: '{}' should be contained in error message: {}.", + config, + error_message.as_str() + ); + } + Ok(()) } @@ -1756,7 +1768,7 @@ mod tests { // The input has 200 partitions, each partition has a batch containing 100 rows. // Each row has a single Utf8 column, the Utf8 string values are roughly 42 bytes. - // The total size of the input is roughly 8.4 KB. + // The total size of the input is roughly 820 KB. let input = test::scan_partitioned_utf8(200); let schema = input.schema(); diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index 0ddea90a98bf3..6c1bb4883d1ad 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -245,6 +245,19 @@ impl ExecutionPlan for SortPreservingMergeExec { })) } + fn with_preserve_order( + &self, + preserve_order: bool, + ) -> Option> { + self.input + .with_preserve_order(preserve_order) + .and_then(|new_input| { + Arc::new(self.clone()) + .with_new_children(vec![new_input]) + .ok() + }) + } + fn required_input_distribution(&self) -> Vec { vec![Distribution::UnspecifiedDistribution] } @@ -359,10 +372,6 @@ impl ExecutionPlan for SortPreservingMergeExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.input.partition_statistics(None) - } - fn partition_statistics(&self, _partition: Option) -> Result { self.input.partition_statistics(None) } diff --git a/datafusion/physical-plan/src/sorts/stream.rs b/datafusion/physical-plan/src/sorts/stream.rs index a510f44e4f4df..779511a865b6a 100644 --- a/datafusion/physical-plan/src/sorts/stream.rs +++ b/datafusion/physical-plan/src/sorts/stream.rs @@ -180,7 +180,7 @@ impl RowCursorStream { self.rows.save(stream_idx, &rows); // track the memory in the newly created Rows. - let mut rows_reservation = self.reservation.new_empty(); + let rows_reservation = self.reservation.new_empty(); rows_reservation.try_grow(rows.size())?; Ok(RowValues::new(rows, rows_reservation)) } @@ -246,7 +246,7 @@ impl FieldCursorStream { let array = value.into_array(batch.num_rows())?; let size_in_mem = array.get_buffer_memory_size(); let array = array.as_any().downcast_ref::().expect("field values"); - let mut array_reservation = self.reservation.new_empty(); + let array_reservation = self.reservation.new_empty(); array_reservation.try_grow(size_in_mem)?; Ok(ArrayValues::new( self.sort.options, diff --git a/datafusion/physical-plan/src/spill/in_progress_spill_file.rs b/datafusion/physical-plan/src/spill/in_progress_spill_file.rs index d2acf4993b857..2666ab8822ed9 100644 --- a/datafusion/physical-plan/src/spill/in_progress_spill_file.rs +++ b/datafusion/physical-plan/src/spill/in_progress_spill_file.rs @@ -63,7 +63,7 @@ impl InProgressSpillFile { } if self.writer.is_none() { let schema = batch.schema(); - if let Some(ref in_progress_file) = self.in_progress_file { + if let Some(in_progress_file) = &mut self.in_progress_file { self.writer = Some(IPCStreamWriter::new( in_progress_file.path(), schema.as_ref(), @@ -72,18 +72,38 @@ impl InProgressSpillFile { // Update metrics self.spill_writer.metrics.spill_file_count.add(1); + + // Update initial size (schema/header) + in_progress_file.update_disk_usage()?; + let initial_size = in_progress_file.current_disk_usage(); + self.spill_writer + .metrics + .spilled_bytes + .add(initial_size as usize); } } if let Some(writer) = &mut self.writer { let (spilled_rows, _) = writer.write(batch)?; if let Some(in_progress_file) = &mut self.in_progress_file { + let pre_size = in_progress_file.current_disk_usage(); in_progress_file.update_disk_usage()?; + let post_size = in_progress_file.current_disk_usage(); + + self.spill_writer.metrics.spilled_rows.add(spilled_rows); + self.spill_writer + .metrics + .spilled_bytes + .add((post_size - pre_size) as usize); } else { unreachable!() // Already checked inside current function } + } + Ok(()) + } - // Update metrics - self.spill_writer.metrics.spilled_rows.add(spilled_rows); + pub fn flush(&mut self) -> Result<()> { + if let Some(writer) = &mut self.writer { + writer.flush()?; } Ok(()) } @@ -106,9 +126,13 @@ impl InProgressSpillFile { // Since spill files are append-only, add the file size to spilled_bytes if let Some(in_progress_file) = &mut self.in_progress_file { // Since writer.finish() writes continuation marker and message length at the end + let pre_size = in_progress_file.current_disk_usage(); in_progress_file.update_disk_usage()?; - let size = in_progress_file.current_disk_usage(); - self.spill_writer.metrics.spilled_bytes.add(size as usize); + let post_size = in_progress_file.current_disk_usage(); + self.spill_writer + .metrics + .spilled_bytes + .add((post_size - pre_size) as usize); } Ok(self.in_progress_file.take()) diff --git a/datafusion/physical-plan/src/spill/mod.rs b/datafusion/physical-plan/src/spill/mod.rs index 78dea99ac820c..4c93c03b342eb 100644 --- a/datafusion/physical-plan/src/spill/mod.rs +++ b/datafusion/physical-plan/src/spill/mod.rs @@ -49,7 +49,7 @@ use datafusion_common_runtime::SpawnedTask; use datafusion_execution::RecordBatchStream; use datafusion_execution::disk_manager::RefCountedTempFile; use futures::{FutureExt as _, Stream}; -use log::warn; +use log::debug; /// Stream that reads spill files from disk where each batch is read in a spawned blocking task /// It will read one batch at a time and will not do any buffering, to buffer data use [`crate::common::spawn_buffered`] @@ -154,7 +154,7 @@ impl SpillReaderStream { > max_record_batch_memory + SPILL_BATCH_MEMORY_MARGIN { - warn!( + debug!( "Record batch memory usage ({actual_size} bytes) exceeds the expected limit ({max_record_batch_memory} bytes) \n\ by more than the allowed tolerance ({SPILL_BATCH_MEMORY_MARGIN} bytes).\n\ This likely indicates a bug in memory accounting during spilling.\n\ @@ -310,6 +310,11 @@ impl IPCStreamWriter { Ok((delta_num_rows, delta_num_bytes)) } + pub fn flush(&mut self) -> Result<()> { + self.writer.flush()?; + Ok(()) + } + /// Finish the writer pub fn finish(&mut self) -> Result<()> { self.writer.finish().map_err(Into::into) @@ -685,13 +690,13 @@ mod tests { Arc::new(StringArray::from(vec!["d", "e", "f"])), ], )?; - // After appending each batch, spilled_rows should increase, while spill_file_count and - // spilled_bytes remain the same (spilled_bytes is updated only after finish() is called) + // After appending each batch, spilled_rows and spilled_bytes should increase incrementally, + // while spill_file_count remains 1 (since we're writing to the same file) in_progress_file.append_batch(&batch1)?; - verify_metrics(&in_progress_file, 1, 0, 3)?; + verify_metrics(&in_progress_file, 1, 440, 3)?; in_progress_file.append_batch(&batch2)?; - verify_metrics(&in_progress_file, 1, 0, 6)?; + verify_metrics(&in_progress_file, 1, 704, 6)?; let completed_file = in_progress_file.finish()?; assert!(completed_file.is_some()); @@ -799,4 +804,70 @@ mod tests { assert_eq!(alignment, 8); Ok(()) } + #[tokio::test] + async fn test_real_time_spill_metrics() -> Result<()> { + let env = Arc::new(RuntimeEnv::default()); + let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + ])); + + let spill_manager = Arc::new(SpillManager::new( + Arc::clone(&env), + metrics.clone(), + Arc::clone(&schema), + )); + let mut in_progress_file = spill_manager.create_in_progress_file("Test")?; + + let batch1 = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["a", "b", "c"])), + ], + )?; + + // Before any batch, metrics should be 0 + assert_eq!(metrics.spilled_bytes.value(), 0); + assert_eq!(metrics.spill_file_count.value(), 0); + + // Append first batch + in_progress_file.append_batch(&batch1)?; + + // Metrics should be updated immediately (at least schema and first batch) + let bytes_after_batch1 = metrics.spilled_bytes.value(); + assert_eq!(bytes_after_batch1, 440); + assert_eq!(metrics.spill_file_count.value(), 1); + + // Check global progress + let progress = env.spilling_progress(); + assert_eq!(progress.current_bytes, bytes_after_batch1 as u64); + assert_eq!(progress.active_files_count, 1); + + // Append another batch + in_progress_file.append_batch(&batch1)?; + let bytes_after_batch2 = metrics.spilled_bytes.value(); + assert!(bytes_after_batch2 > bytes_after_batch1); + + // Check global progress again + let progress = env.spilling_progress(); + assert_eq!(progress.current_bytes, bytes_after_batch2 as u64); + + // Finish the file + let spilled_file = in_progress_file.finish()?; + let final_bytes = metrics.spilled_bytes.value(); + assert!(final_bytes > bytes_after_batch2); + + // Even after finish, file is still "active" until dropped + let progress = env.spilling_progress(); + assert!(progress.current_bytes > 0); + assert_eq!(progress.active_files_count, 1); + + drop(spilled_file); + assert_eq!(env.spilling_progress().active_files_count, 0); + assert_eq!(env.spilling_progress().current_bytes, 0); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/spill/spill_manager.rs b/datafusion/physical-plan/src/spill/spill_manager.rs index 89b0276206774..6d931112ad888 100644 --- a/datafusion/physical-plan/src/spill/spill_manager.rs +++ b/datafusion/physical-plan/src/spill/spill_manager.rs @@ -188,6 +188,19 @@ impl SpillManager { Ok(spawn_buffered(stream, self.batch_read_buffer_capacity)) } + + /// Same as `read_spill_as_stream`, but without buffering. + pub fn read_spill_as_stream_unbuffered( + &self, + spill_file_path: RefCountedTempFile, + max_record_batch_memory: Option, + ) -> Result { + Ok(Box::pin(cooperative(SpillReaderStream::new( + Arc::clone(&self.schema), + spill_file_path, + max_record_batch_memory, + )))) + } } pub(crate) trait GetSlicedSize { diff --git a/datafusion/physical-plan/src/spill/spill_pool.rs b/datafusion/physical-plan/src/spill/spill_pool.rs index e3b547b5731f3..1b9d82eaf4506 100644 --- a/datafusion/physical-plan/src/spill/spill_pool.rs +++ b/datafusion/physical-plan/src/spill/spill_pool.rs @@ -194,6 +194,8 @@ impl SpillPoolWriter { // Append the batch if let Some(ref mut writer) = file_shared.writer { writer.append_batch(batch)?; + // make sure we flush the writer for readers + writer.flush()?; file_shared.batches_written += 1; file_shared.estimated_size += batch_size; } @@ -535,7 +537,11 @@ impl Stream for SpillFile { // Step 2: Lazy-create reader stream if needed if self.reader.is_none() && should_read { if let Some(file) = file { - match self.spill_manager.read_spill_as_stream(file, None) { + // we want this unbuffered because files are actively being written to + match self + .spill_manager + .read_spill_as_stream_unbuffered(file, None) + { Ok(stream) => { self.reader = Some(SpillFileReader { stream, @@ -879,8 +885,8 @@ mod tests { ); assert_eq!( metrics.spilled_bytes.value(), - 0, - "Spilled bytes should be 0 before file finalization" + 320, + "Spilled bytes should reflect data written (header + 1 batch)" ); assert_eq!( metrics.spilled_rows.value(), @@ -1300,11 +1306,11 @@ mod tests { writer.push_batch(&batch)?; } - // Check metrics before drop - spilled_bytes should be 0 since file isn't finalized yet + // Check metrics before drop - spilled_bytes already reflects written data let spilled_bytes_before = metrics.spilled_bytes.value(); assert_eq!( - spilled_bytes_before, 0, - "Spilled bytes should be 0 before writer is dropped" + spilled_bytes_before, 1088, + "Spilled bytes should reflect data written (header + 5 batches)" ); // Explicitly drop the writer - this should finalize the current file diff --git a/datafusion/physical-plan/src/stream.rs b/datafusion/physical-plan/src/stream.rs index 80c2233d05db6..4b7e707fccedd 100644 --- a/datafusion/physical-plan/src/stream.rs +++ b/datafusion/physical-plan/src/stream.rs @@ -1005,7 +1005,7 @@ mod test { .build_arc() .unwrap(); - let mut reservation = MemoryConsumer::new("test").register(&runtime.memory_pool); + let reservation = MemoryConsumer::new("test").register(&runtime.memory_pool); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); @@ -1071,7 +1071,7 @@ mod test { .build_arc() .unwrap(); - let mut reservation = MemoryConsumer::new("test").register(&runtime.memory_pool); + let reservation = MemoryConsumer::new("test").register(&runtime.memory_pool); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); diff --git a/datafusion/physical-plan/src/test.rs b/datafusion/physical-plan/src/test.rs index c94b5a4131397..a967d035bd387 100644 --- a/datafusion/physical-plan/src/test.rs +++ b/datafusion/physical-plan/src/test.rs @@ -146,7 +146,7 @@ impl ExecutionPlan for TestMemoryExec { self: Arc, _: Vec>, ) -> Result> { - unimplemented!() + Ok(self) } fn repartitioned( @@ -169,10 +169,6 @@ impl ExecutionPlan for TestMemoryExec { unimplemented!() } - fn statistics(&self) -> Result { - self.statistics_inner() - } - fn partition_statistics(&self, partition: Option) -> Result { if partition.is_some() { Ok(Statistics::new_unknown(&self.schema)) diff --git a/datafusion/physical-plan/src/test/exec.rs b/datafusion/physical-plan/src/test/exec.rs index 4507cccba05a9..ebed84477a568 100644 --- a/datafusion/physical-plan/src/test/exec.rs +++ b/datafusion/physical-plan/src/test/exec.rs @@ -254,10 +254,6 @@ impl ExecutionPlan for MockExec { } // Panics if one of the batches is an error - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { if partition.is_some() { return Ok(Statistics::new_unknown(&self.schema)); @@ -410,10 +406,6 @@ impl ExecutionPlan for BarrierExec { Ok(builder.build()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { if partition.is_some() { return Ok(Statistics::new_unknown(&self.schema)); @@ -600,10 +592,6 @@ impl ExecutionPlan for StatisticsExec { unimplemented!("This plan only serves for testing statistics") } - fn statistics(&self) -> Result { - Ok(self.stats.clone()) - } - fn partition_statistics(&self, partition: Option) -> Result { Ok(if partition.is_some() { Statistics::new_unknown(&self.schema) diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index ebac497f4fbc3..4b93e6a188d57 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -131,6 +131,9 @@ pub struct TopK { pub(crate) finished: bool, } +/// For more background, please also see the [Dynamic Filters: Passing Information Between Operators During Execution for 25x Faster Queries blog] +/// +/// [Dynamic Filters: Passing Information Between Operators During Execution for 25x Faster Queries blog]: https://datafusion.apache.org/blog/2025/09/10/dynamic-filters #[derive(Debug, Clone)] pub struct TopKDynamicFilters { /// The current *global* threshold for the dynamic filter. diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index d27c81b968490..8174160dc9332 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -36,7 +36,11 @@ use crate::execution_plan::{ InvariantLevel, boundedness_from_children, check_default_invariants, emission_type_from_children, }; -use crate::filter_pushdown::{FilterDescription, FilterPushdownPhase}; +use crate::filter::FilterExec; +use crate::filter_pushdown::{ + ChildPushdownResult, FilterDescription, FilterPushdownPhase, + FilterPushdownPropagation, PushedDown, +}; use crate::metrics::BaselineMetrics; use crate::projection::{ProjectionExec, make_with_child}; use crate::stream::ObservedStream; @@ -49,7 +53,9 @@ use datafusion_common::{ Result, assert_or_internal_err, exec_err, internal_datafusion_err, }; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr, calculate_union}; +use datafusion_physical_expr::{ + EquivalenceProperties, PhysicalExpr, calculate_union, conjunction, +}; use futures::Stream; use itertools::Itertools; @@ -304,10 +310,6 @@ impl ExecutionPlan for UnionExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { if let Some(partition_idx) = partition { // For a specific partition, find which input it belongs to @@ -370,6 +372,83 @@ impl ExecutionPlan for UnionExec { ) -> Result { FilterDescription::from_children(parent_filters, &self.children()) } + + fn handle_child_pushdown_result( + &self, + phase: FilterPushdownPhase, + child_pushdown_result: ChildPushdownResult, + _config: &ConfigOptions, + ) -> Result>> { + // Pre phase: handle heterogeneous pushdown by wrapping individual + // children with FilterExec and reporting all filters as handled. + // Post phase: use default behavior to let the filter creator decide how to handle + // filters that weren't fully pushed down. + if !matches!(phase, FilterPushdownPhase::Pre) { + return Ok(FilterPushdownPropagation::if_all(child_pushdown_result)); + } + + // UnionExec needs specialized filter pushdown handling when children have + // heterogeneous pushdown support. Without this, when some children support + // pushdown and others don't, the default behavior would leave FilterExec + // above UnionExec, re-applying filters to outputs of all children—including + // those that already applied the filters via pushdown. This specialized + // implementation adds FilterExec only to children that don't support + // pushdown, avoiding redundant filtering and improving performance. + // + // Example: Given Child1 (no pushdown support) and Child2 (has pushdown support) + // Default behavior: This implementation: + // FilterExec UnionExec + // UnionExec FilterExec + // Child1 Child1 + // Child2(filter) Child2(filter) + + // Collect unsupported filters for each child + let mut unsupported_filters_per_child = vec![Vec::new(); self.inputs.len()]; + for parent_filter_result in child_pushdown_result.parent_filters.iter() { + for (child_idx, &child_result) in + parent_filter_result.child_results.iter().enumerate() + { + if matches!(child_result, PushedDown::No) { + unsupported_filters_per_child[child_idx] + .push(Arc::clone(&parent_filter_result.filter)); + } + } + } + + // Wrap children that have unsupported filters with FilterExec + let mut new_children = self.inputs.clone(); + for (child_idx, unsupported_filters) in + unsupported_filters_per_child.iter().enumerate() + { + if !unsupported_filters.is_empty() { + let combined_filter = conjunction(unsupported_filters.clone()); + new_children[child_idx] = Arc::new(FilterExec::try_new( + combined_filter, + Arc::clone(&self.inputs[child_idx]), + )?); + } + } + + // Check if any children were modified + let children_modified = new_children + .iter() + .zip(self.inputs.iter()) + .any(|(new, old)| !Arc::ptr_eq(new, old)); + + let all_filters_pushed = + vec![PushedDown::Yes; child_pushdown_result.parent_filters.len()]; + let propagation = if children_modified { + let updated_node = UnionExec::try_new(new_children)?; + FilterPushdownPropagation::with_parent_pushdown_result(all_filters_pushed) + .with_updated_node(updated_node) + } else { + FilterPushdownPropagation::with_parent_pushdown_result(all_filters_pushed) + }; + + // Report all parent filters as supported since we've ensured they're applied + // on all children (either pushed down or via FilterExec) + Ok(propagation) + } } /// Combines multiple input streams by interleaving them. @@ -545,10 +624,6 @@ impl ExecutionPlan for InterleaveExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { let stats = self .inputs @@ -593,8 +668,20 @@ fn union_schema(inputs: &[Arc]) -> Result { } let first_schema = inputs[0].schema(); + let first_field_count = first_schema.fields().len(); + + // validate that all inputs have the same number of fields + for (idx, input) in inputs.iter().enumerate().skip(1) { + let field_count = input.schema().fields().len(); + if field_count != first_field_count { + return exec_err!( + "UnionExec/InterleaveExec requires all inputs to have the same number of fields. \ + Input 0 has {first_field_count} fields, but input {idx} has {field_count} fields" + ); + } + } - let fields = (0..first_schema.fields().len()) + let fields = (0..first_field_count) .map(|i| { // We take the name from the left side of the union to match how names are coerced during logical planning, // which also uses the left side names. @@ -763,6 +850,18 @@ mod tests { Ok(schema) } + fn create_test_schema2() -> Result { + let a = Field::new("a", DataType::Int32, true); + let b = Field::new("b", DataType::Int32, true); + let c = Field::new("c", DataType::Int32, true); + let d = Field::new("d", DataType::Int32, true); + let e = Field::new("e", DataType::Int32, true); + let f = Field::new("f", DataType::Int32, true); + let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f])); + + Ok(schema) + } + #[tokio::test] async fn test_union_partitions() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); @@ -1052,4 +1151,23 @@ mod tests { Ok(()) } + + #[test] + fn test_union_schema_mismatch() { + // Test that UnionExec properly rejects inputs with different field counts + let schema = create_test_schema().unwrap(); + let schema2 = create_test_schema2().unwrap(); + let memory_exec1 = + Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None).unwrap()); + let memory_exec2 = + Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema2), None).unwrap()); + + let result = UnionExec::try_new(vec![memory_exec1, memory_exec2]); + assert!(result.is_err()); + assert!( + result.unwrap_err().to_string().contains( + "UnionExec/InterleaveExec requires all inputs to have the same number of fields" + ) + ); + } } diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 987a400ec369e..20d54303a94b4 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -368,10 +368,6 @@ impl ExecutionPlan for BoundedWindowAggExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { let input_stat = self.input.partition_statistics(partition)?; self.statistics_helper(input_stat) diff --git a/datafusion/physical-plan/src/windows/window_agg_exec.rs b/datafusion/physical-plan/src/windows/window_agg_exec.rs index aa99f4f49885a..0c73cf23523d5 100644 --- a/datafusion/physical-plan/src/windows/window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/window_agg_exec.rs @@ -272,10 +272,6 @@ impl ExecutionPlan for WindowAggExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - fn partition_statistics(&self, partition: Option) -> Result { let input_stat = self.input.partition_statistics(partition)?; let win_cols = self.window_expr.len(); diff --git a/datafusion/physical-plan/src/work_table.rs b/datafusion/physical-plan/src/work_table.rs index f1b9e3e88d123..08390f87a2033 100644 --- a/datafusion/physical-plan/src/work_table.rs +++ b/datafusion/physical-plan/src/work_table.rs @@ -231,10 +231,6 @@ impl ExecutionPlan for WorkTableExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - Ok(Statistics::new_unknown(&self.schema())) - } - fn partition_statistics(&self, _partition: Option) -> Result { Ok(Statistics::new_unknown(&self.schema())) } @@ -283,7 +279,7 @@ mod tests { assert!(work_table.take().is_err()); let pool = Arc::new(UnboundedMemoryPool::default()) as _; - let mut reservation = MemoryConsumer::new("test_work_table").register(&pool); + let reservation = MemoryConsumer::new("test_work_table").register(&pool); // Update batch to work_table let array: ArrayRef = Arc::new((0..5).collect::()); diff --git a/datafusion/proto-common/gen/Cargo.toml b/datafusion/proto-common/gen/Cargo.toml index 2d2557811d0df..f0e60819d42a8 100644 --- a/datafusion/proto-common/gen/Cargo.toml +++ b/datafusion/proto-common/gen/Cargo.toml @@ -37,5 +37,5 @@ workspace = true [dependencies] # Pin these dependencies so that the generated output is deterministic -pbjson-build = "=0.8.0" -prost-build = "=0.14.1" +pbjson-build = "=0.9.0" +prost-build = "=0.14.3" diff --git a/datafusion/proto-common/proto/datafusion_common.proto b/datafusion/proto-common/proto/datafusion_common.proto index 08bb25bd715b9..62c6bbe85612a 100644 --- a/datafusion/proto-common/proto/datafusion_common.proto +++ b/datafusion/proto-common/proto/datafusion_common.proto @@ -183,6 +183,11 @@ message Map { bool keys_sorted = 2; } +message RunEndEncoded { + Field run_ends_field = 1; + Field values_field = 2; +} + enum UnionMode{ sparse = 0; dense = 1; @@ -236,6 +241,12 @@ message ScalarDictionaryValue { ScalarValue value = 2; } +message ScalarRunEndEncodedValue { + Field run_ends_field = 1; + Field values_field = 2; + ScalarValue value = 3; +} + message IntervalDayTimeValue { int32 days = 1; int32 milliseconds = 2; @@ -321,6 +332,8 @@ message ScalarValue{ IntervalMonthDayNanoValue interval_month_day_nano = 31; ScalarFixedSizeBinary fixed_size_binary_value = 34; UnionValue union_value = 42; + + ScalarRunEndEncodedValue run_end_encoded_value = 45; } } @@ -389,6 +402,7 @@ message ArrowType{ Union UNION = 29; Dictionary DICTIONARY = 30; Map MAP = 33; + RunEndEncoded RUN_END_ENCODED = 42; } } @@ -469,6 +483,7 @@ message JsonOptions { CompressionTypeVariant compression = 1; // Compression type optional uint64 schema_infer_max_rec = 2; // Optional max records for schema inference optional uint32 compression_level = 3; // Optional compression level + optional bool newline_delimited = 4; // Whether to read as newline-delimited JSON (default true). When false, expects JSON array format [{},...] } message TableParquetOptions { diff --git a/datafusion/proto-common/src/from_proto/mod.rs b/datafusion/proto-common/src/from_proto/mod.rs index e8e71c3884586..ca8a269958d73 100644 --- a/datafusion/proto-common/src/from_proto/mod.rs +++ b/datafusion/proto-common/src/from_proto/mod.rs @@ -28,7 +28,12 @@ use arrow::datatypes::{ DataType, Field, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit, UnionFields, UnionMode, i256, }; -use arrow::ipc::{reader::read_record_batch, root_as_message}; +use arrow::ipc::{ + convert::fb_to_schema, + reader::{read_dictionary, read_record_batch}, + root_as_message, + writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions}, +}; use datafusion_common::{ Column, ColumnStatistics, Constraint, Constraints, DFSchema, DFSchemaRef, @@ -304,13 +309,16 @@ impl TryFrom<&protobuf::arrow_type::ArrowTypeEnum> for DataType { }; let union_fields = parse_proto_fields_to_fields(&union.union_types)?; - // Default to index based type ids if not provided - let type_ids: Vec<_> = match union.type_ids.is_empty() { - true => (0..union_fields.len() as i8).collect(), - false => union.type_ids.iter().map(|i| *i as i8).collect(), + // Default to index based type ids if not explicitly provided + let union_fields = if union.type_ids.is_empty() { + UnionFields::from_fields(union_fields) + } else { + let type_ids = union.type_ids.iter().map(|i| *i as i8); + UnionFields::try_new(type_ids, union_fields).map_err(|e| { + DataFusionError::from(e).context("Deserializing Union DataType") + })? }; - - DataType::Union(UnionFields::new(type_ids, union_fields), union_mode) + DataType::Union(union_fields, union_mode) } arrow_type::ArrowTypeEnum::Dictionary(dict) => { let key_datatype = dict.as_ref().key.as_deref().required("key")?; @@ -323,6 +331,19 @@ impl TryFrom<&protobuf::arrow_type::ArrowTypeEnum> for DataType { let keys_sorted = map.keys_sorted; DataType::Map(Arc::new(field), keys_sorted) } + arrow_type::ArrowTypeEnum::RunEndEncoded(run_end_encoded) => { + let run_ends_field: Field = run_end_encoded + .as_ref() + .run_ends_field + .as_deref() + .required("run_ends_field")?; + let value_field: Field = run_end_encoded + .as_ref() + .values_field + .as_deref() + .required("values_field")?; + DataType::RunEndEncoded(run_ends_field.into(), value_field.into()) + } }) } } @@ -381,7 +402,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Value::Float32Value(v) => Self::Float32(Some(*v)), Value::Float64Value(v) => Self::Float64(Some(*v)), Value::Date32Value(v) => Self::Date32(Some(*v)), - // ScalarValue::List is serialized using arrow IPC format + // Nested ScalarValue types are serialized using arrow IPC format Value::ListValue(v) | Value::FixedSizeListValue(v) | Value::LargeListValue(v) @@ -398,55 +419,83 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { schema_ref.try_into()? } else { return Err(Error::General( - "Invalid schema while deserializing ScalarValue::List" + "Invalid schema while deserializing nested ScalarValue" .to_string(), )); }; + // IPC dictionary batch IDs are assigned when encoding the schema, but our protobuf + // `Schema` doesn't preserve those IDs. Reconstruct them deterministically by + // round-tripping the schema through IPC. + let schema: Schema = { + let ipc_gen = IpcDataGenerator {}; + let write_options = IpcWriteOptions::default(); + let mut dict_tracker = DictionaryTracker::new(false); + let encoded_schema = ipc_gen.schema_to_bytes_with_dictionary_tracker( + &schema, + &mut dict_tracker, + &write_options, + ); + let message = + root_as_message(encoded_schema.ipc_message.as_slice()).map_err( + |e| { + Error::General(format!( + "Error IPC schema message while deserializing nested ScalarValue: {e}" + )) + }, + )?; + let ipc_schema = message.header_as_schema().ok_or_else(|| { + Error::General( + "Unexpected message type deserializing nested ScalarValue schema" + .to_string(), + ) + })?; + fb_to_schema(ipc_schema) + }; + let message = root_as_message(ipc_message.as_slice()).map_err(|e| { Error::General(format!( - "Error IPC message while deserializing ScalarValue::List: {e}" + "Error IPC message while deserializing nested ScalarValue: {e}" )) })?; let buffer = Buffer::from(arrow_data.as_slice()); let ipc_batch = message.header_as_record_batch().ok_or_else(|| { Error::General( - "Unexpected message type deserializing ScalarValue::List" + "Unexpected message type deserializing nested ScalarValue" .to_string(), ) })?; - let dict_by_id: HashMap = dictionaries.iter().map(|protobuf::scalar_nested_value::Dictionary { ipc_message, arrow_data }| { + let mut dict_by_id: HashMap = HashMap::new(); + for protobuf::scalar_nested_value::Dictionary { + ipc_message, + arrow_data, + } in dictionaries + { let message = root_as_message(ipc_message.as_slice()).map_err(|e| { Error::General(format!( - "Error IPC message while deserializing ScalarValue::List dictionary message: {e}" + "Error IPC message while deserializing nested ScalarValue dictionary message: {e}" )) })?; let buffer = Buffer::from(arrow_data.as_slice()); let dict_batch = message.header_as_dictionary_batch().ok_or_else(|| { Error::General( - "Unexpected message type deserializing ScalarValue::List dictionary message" + "Unexpected message type deserializing nested ScalarValue dictionary message" .to_string(), ) })?; - - let id = dict_batch.id(); - - let record_batch = read_record_batch( + read_dictionary( &buffer, - dict_batch.data().unwrap(), - Arc::new(schema.clone()), - &Default::default(), - None, + dict_batch, + &schema, + &mut dict_by_id, &message.version(), - )?; - - let values: ArrayRef = Arc::clone(record_batch.column(0)); - - Ok((id, values)) - }).collect::>>()?; + ) + .map_err(|e| arrow_datafusion_err!(e)) + .map_err(|e| e.context("Decoding nested ScalarValue dictionary"))?; + } let record_batch = read_record_batch( &buffer, @@ -457,7 +506,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { &message.version(), ) .map_err(|e| arrow_datafusion_err!(e)) - .map_err(|e| e.context("Decoding ScalarValue::List Value"))?; + .map_err(|e| e.context("Decoding nested ScalarValue value"))?; let arr = record_batch.column(0); match value { Value::ListValue(_) => { @@ -575,6 +624,32 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Self::Dictionary(Box::new(index_type), Box::new(value)) } + Value::RunEndEncodedValue(v) => { + let run_ends_field: Field = v + .run_ends_field + .as_ref() + .ok_or_else(|| Error::required("run_ends_field"))? + .try_into()?; + + let values_field: Field = v + .values_field + .as_ref() + .ok_or_else(|| Error::required("values_field"))? + .try_into()?; + + let value: Self = v + .value + .as_ref() + .ok_or_else(|| Error::required("value"))? + .as_ref() + .try_into()?; + + Self::RunEndEncoded( + run_ends_field.into(), + values_field.into(), + Box::new(value), + ) + } Value::BinaryValue(v) => Self::Binary(Some(v.clone())), Value::BinaryViewValue(v) => Self::BinaryView(Some(v.clone())), Value::LargeBinaryValue(v) => Self::LargeBinary(Some(v.clone())), @@ -602,7 +677,9 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { .collect::>>(); let fields = fields.ok_or_else(|| Error::required("UnionField"))?; let fields = parse_proto_fields_to_fields(&fields)?; - let fields = UnionFields::new(ids, fields); + let union_fields = UnionFields::try_new(ids, fields).map_err(|e| { + DataFusionError::from(e).context("Deserializing Union ScalarValue") + })?; let v_id = val.value_id as i8; let val = match &val.value { None => None, @@ -614,7 +691,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Some((v_id, Box::new(val))) } }; - Self::Union(val, fields, mode) + Self::Union(val, union_fields, mode) } Value::FixedSizeBinaryValue(v) => { Self::FixedSizeBinary(v.length, Some(v.clone().values)) @@ -1100,6 +1177,7 @@ impl TryFrom<&protobuf::JsonOptions> for JsonOptions { compression: compression.into(), compression_level: proto_opts.compression_level, schema_infer_max_rec: proto_opts.schema_infer_max_rec.map(|h| h as usize), + newline_delimited: proto_opts.newline_delimited.unwrap_or(true), }) } } diff --git a/datafusion/proto-common/src/generated/mod.rs b/datafusion/proto-common/src/generated/mod.rs index 08cd75b622db3..9c2ca9385aa5e 100644 --- a/datafusion/proto-common/src/generated/mod.rs +++ b/datafusion/proto-common/src/generated/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +// This code is generated so we don't want to fix any lint violations manually #[allow(clippy::allow_attributes)] #[allow(clippy::all)] #[rustfmt::skip] diff --git a/datafusion/proto-common/src/generated/pbjson.rs b/datafusion/proto-common/src/generated/pbjson.rs index d38cf86825d46..b00e7546bba20 100644 --- a/datafusion/proto-common/src/generated/pbjson.rs +++ b/datafusion/proto-common/src/generated/pbjson.rs @@ -29,7 +29,7 @@ impl<'de> serde::Deserialize<'de> for ArrowFormat { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -100,7 +100,7 @@ impl<'de> serde::Deserialize<'de> for ArrowOptions { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -276,6 +276,9 @@ impl serde::Serialize for ArrowType { arrow_type::ArrowTypeEnum::Map(v) => { struct_ser.serialize_field("MAP", v)?; } + arrow_type::ArrowTypeEnum::RunEndEncoded(v) => { + struct_ser.serialize_field("RUNENDENCODED", v)?; + } } } struct_ser.end() @@ -333,6 +336,8 @@ impl<'de> serde::Deserialize<'de> for ArrowType { "UNION", "DICTIONARY", "MAP", + "RUN_END_ENCODED", + "RUNENDENCODED", ]; #[allow(clippy::enum_variant_names)] @@ -375,6 +380,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType { Union, Dictionary, Map, + RunEndEncoded, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -383,7 +389,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -434,6 +440,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType { "UNION" => Ok(GeneratedField::Union), "DICTIONARY" => Ok(GeneratedField::Dictionary), "MAP" => Ok(GeneratedField::Map), + "RUNENDENCODED" | "RUN_END_ENCODED" => Ok(GeneratedField::RunEndEncoded), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -715,6 +722,13 @@ impl<'de> serde::Deserialize<'de> for ArrowType { return Err(serde::de::Error::duplicate_field("MAP")); } arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Map) +; + } + GeneratedField::RunEndEncoded => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("RUNENDENCODED")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::RunEndEncoded) ; } } @@ -758,7 +772,7 @@ impl<'de> serde::Deserialize<'de> for AvroFormat { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -829,7 +843,7 @@ impl<'de> serde::Deserialize<'de> for AvroOptions { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -916,7 +930,7 @@ impl<'de> serde::Deserialize<'de> for Column { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1016,7 +1030,7 @@ impl<'de> serde::Deserialize<'de> for ColumnRelation { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1153,7 +1167,7 @@ impl<'de> serde::Deserialize<'de> for ColumnStats { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1282,7 +1296,7 @@ impl<'de> serde::Deserialize<'de> for CompressionTypeVariant { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = CompressionTypeVariant; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1379,7 +1393,7 @@ impl<'de> serde::Deserialize<'de> for Constraint { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1479,7 +1493,7 @@ impl<'de> serde::Deserialize<'de> for Constraints { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1570,7 +1584,7 @@ impl<'de> serde::Deserialize<'de> for CsvFormat { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1840,7 +1854,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -2204,7 +2218,7 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -2407,7 +2421,7 @@ impl<'de> serde::Deserialize<'de> for Decimal128 { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -2530,7 +2544,7 @@ impl<'de> serde::Deserialize<'de> for Decimal128Type { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -2656,7 +2670,7 @@ impl<'de> serde::Deserialize<'de> for Decimal256 { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -2779,7 +2793,7 @@ impl<'de> serde::Deserialize<'de> for Decimal256Type { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -2905,7 +2919,7 @@ impl<'de> serde::Deserialize<'de> for Decimal32 { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -3028,7 +3042,7 @@ impl<'de> serde::Deserialize<'de> for Decimal32Type { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -3154,7 +3168,7 @@ impl<'de> serde::Deserialize<'de> for Decimal64 { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -3277,7 +3291,7 @@ impl<'de> serde::Deserialize<'de> for Decimal64Type { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -3389,7 +3403,7 @@ impl<'de> serde::Deserialize<'de> for DfField { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -3497,7 +3511,7 @@ impl<'de> serde::Deserialize<'de> for DfSchema { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -3607,7 +3621,7 @@ impl<'de> serde::Deserialize<'de> for Dictionary { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -3699,7 +3713,7 @@ impl<'de> serde::Deserialize<'de> for EmptyMessage { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -3811,7 +3825,7 @@ impl<'de> serde::Deserialize<'de> for Field { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -3950,7 +3964,7 @@ impl<'de> serde::Deserialize<'de> for FixedSizeList { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4060,7 +4074,7 @@ impl<'de> serde::Deserialize<'de> for IntervalDayTimeValue { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4182,7 +4196,7 @@ impl<'de> serde::Deserialize<'de> for IntervalMonthDayNanoValue { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4286,7 +4300,7 @@ impl<'de> serde::Deserialize<'de> for IntervalUnit { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = IntervalUnit; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4358,7 +4372,7 @@ impl<'de> serde::Deserialize<'de> for JoinConstraint { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = JoinConstraint; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4431,7 +4445,7 @@ impl<'de> serde::Deserialize<'de> for JoinSide { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = JoinSide; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4519,7 +4533,7 @@ impl<'de> serde::Deserialize<'de> for JoinType { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = JoinType; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4589,6 +4603,9 @@ impl serde::Serialize for JsonOptions { if self.compression_level.is_some() { len += 1; } + if self.newline_delimited.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion_common.JsonOptions", len)?; if self.compression != 0 { let v = CompressionTypeVariant::try_from(self.compression) @@ -4603,6 +4620,9 @@ impl serde::Serialize for JsonOptions { if let Some(v) = self.compression_level.as_ref() { struct_ser.serialize_field("compressionLevel", v)?; } + if let Some(v) = self.newline_delimited.as_ref() { + struct_ser.serialize_field("newlineDelimited", v)?; + } struct_ser.end() } } @@ -4618,6 +4638,8 @@ impl<'de> serde::Deserialize<'de> for JsonOptions { "schemaInferMaxRec", "compression_level", "compressionLevel", + "newline_delimited", + "newlineDelimited", ]; #[allow(clippy::enum_variant_names)] @@ -4625,6 +4647,7 @@ impl<'de> serde::Deserialize<'de> for JsonOptions { Compression, SchemaInferMaxRec, CompressionLevel, + NewlineDelimited, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -4633,7 +4656,7 @@ impl<'de> serde::Deserialize<'de> for JsonOptions { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4649,6 +4672,7 @@ impl<'de> serde::Deserialize<'de> for JsonOptions { "compression" => Ok(GeneratedField::Compression), "schemaInferMaxRec" | "schema_infer_max_rec" => Ok(GeneratedField::SchemaInferMaxRec), "compressionLevel" | "compression_level" => Ok(GeneratedField::CompressionLevel), + "newlineDelimited" | "newline_delimited" => Ok(GeneratedField::NewlineDelimited), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -4671,6 +4695,7 @@ impl<'de> serde::Deserialize<'de> for JsonOptions { let mut compression__ = None; let mut schema_infer_max_rec__ = None; let mut compression_level__ = None; + let mut newline_delimited__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Compression => { @@ -4695,12 +4720,19 @@ impl<'de> serde::Deserialize<'de> for JsonOptions { map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| x.0) ; } + GeneratedField::NewlineDelimited => { + if newline_delimited__.is_some() { + return Err(serde::de::Error::duplicate_field("newlineDelimited")); + } + newline_delimited__ = map_.next_value()?; + } } } Ok(JsonOptions { compression: compression__.unwrap_or_default(), schema_infer_max_rec: schema_infer_max_rec__, compression_level: compression_level__, + newline_delimited: newline_delimited__, }) } } @@ -4748,7 +4780,7 @@ impl<'de> serde::Deserialize<'de> for JsonWriterOptions { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4840,7 +4872,7 @@ impl<'de> serde::Deserialize<'de> for List { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4941,7 +4973,7 @@ impl<'de> serde::Deserialize<'de> for Map { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -5041,7 +5073,7 @@ impl<'de> serde::Deserialize<'de> for NdJsonFormat { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -5119,7 +5151,7 @@ impl<'de> serde::Deserialize<'de> for NullEquality { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = NullEquality; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -5286,7 +5318,7 @@ impl<'de> serde::Deserialize<'de> for ParquetColumnOptions { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -5440,7 +5472,7 @@ impl<'de> serde::Deserialize<'de> for ParquetColumnSpecificOptions { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -5540,7 +5572,7 @@ impl<'de> serde::Deserialize<'de> for ParquetFormat { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -5976,7 +6008,7 @@ impl<'de> serde::Deserialize<'de> for ParquetOptions { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -6371,7 +6403,7 @@ impl<'de> serde::Deserialize<'de> for Precision { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -6460,7 +6492,7 @@ impl<'de> serde::Deserialize<'de> for PrecisionInfo { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = PrecisionInfo; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -6545,7 +6577,7 @@ impl<'de> serde::Deserialize<'de> for PrimaryKeyConstraint { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -6600,6 +6632,116 @@ impl<'de> serde::Deserialize<'de> for PrimaryKeyConstraint { deserializer.deserialize_struct("datafusion_common.PrimaryKeyConstraint", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for RunEndEncoded { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.run_ends_field.is_some() { + len += 1; + } + if self.values_field.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.RunEndEncoded", len)?; + if let Some(v) = self.run_ends_field.as_ref() { + struct_ser.serialize_field("runEndsField", v)?; + } + if let Some(v) = self.values_field.as_ref() { + struct_ser.serialize_field("valuesField", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for RunEndEncoded { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "run_ends_field", + "runEndsField", + "values_field", + "valuesField", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + RunEndsField, + ValuesField, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "runEndsField" | "run_ends_field" => Ok(GeneratedField::RunEndsField), + "valuesField" | "values_field" => Ok(GeneratedField::ValuesField), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = RunEndEncoded; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.RunEndEncoded") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut run_ends_field__ = None; + let mut values_field__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::RunEndsField => { + if run_ends_field__.is_some() { + return Err(serde::de::Error::duplicate_field("runEndsField")); + } + run_ends_field__ = map_.next_value()?; + } + GeneratedField::ValuesField => { + if values_field__.is_some() { + return Err(serde::de::Error::duplicate_field("valuesField")); + } + values_field__ = map_.next_value()?; + } + } + } + Ok(RunEndEncoded { + run_ends_field: run_ends_field__, + values_field: values_field__, + }) + } + } + deserializer.deserialize_struct("datafusion_common.RunEndEncoded", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for ScalarDictionaryValue { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -6648,7 +6790,7 @@ impl<'de> serde::Deserialize<'de> for ScalarDictionaryValue { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -6758,7 +6900,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFixedSizeBinary { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -6892,7 +7034,7 @@ impl<'de> serde::Deserialize<'de> for ScalarNestedValue { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -7028,7 +7170,7 @@ impl<'de> serde::Deserialize<'de> for scalar_nested_value::Dictionary { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -7093,6 +7235,133 @@ impl<'de> serde::Deserialize<'de> for scalar_nested_value::Dictionary { deserializer.deserialize_struct("datafusion_common.ScalarNestedValue.Dictionary", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for ScalarRunEndEncodedValue { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.run_ends_field.is_some() { + len += 1; + } + if self.values_field.is_some() { + len += 1; + } + if self.value.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.ScalarRunEndEncodedValue", len)?; + if let Some(v) = self.run_ends_field.as_ref() { + struct_ser.serialize_field("runEndsField", v)?; + } + if let Some(v) = self.values_field.as_ref() { + struct_ser.serialize_field("valuesField", v)?; + } + if let Some(v) = self.value.as_ref() { + struct_ser.serialize_field("value", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ScalarRunEndEncodedValue { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "run_ends_field", + "runEndsField", + "values_field", + "valuesField", + "value", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + RunEndsField, + ValuesField, + Value, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "runEndsField" | "run_ends_field" => Ok(GeneratedField::RunEndsField), + "valuesField" | "values_field" => Ok(GeneratedField::ValuesField), + "value" => Ok(GeneratedField::Value), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ScalarRunEndEncodedValue; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.ScalarRunEndEncodedValue") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut run_ends_field__ = None; + let mut values_field__ = None; + let mut value__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::RunEndsField => { + if run_ends_field__.is_some() { + return Err(serde::de::Error::duplicate_field("runEndsField")); + } + run_ends_field__ = map_.next_value()?; + } + GeneratedField::ValuesField => { + if values_field__.is_some() { + return Err(serde::de::Error::duplicate_field("valuesField")); + } + values_field__ = map_.next_value()?; + } + GeneratedField::Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("value")); + } + value__ = map_.next_value()?; + } + } + } + Ok(ScalarRunEndEncodedValue { + run_ends_field: run_ends_field__, + values_field: values_field__, + value: value__, + }) + } + } + deserializer.deserialize_struct("datafusion_common.ScalarRunEndEncodedValue", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for ScalarTime32Value { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -7143,7 +7412,7 @@ impl<'de> serde::Deserialize<'de> for ScalarTime32Value { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -7256,7 +7525,7 @@ impl<'de> serde::Deserialize<'de> for ScalarTime64Value { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -7393,7 +7662,7 @@ impl<'de> serde::Deserialize<'de> for ScalarTimestampValue { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -7635,6 +7904,9 @@ impl serde::Serialize for ScalarValue { scalar_value::Value::UnionValue(v) => { struct_ser.serialize_field("unionValue", v)?; } + scalar_value::Value::RunEndEncodedValue(v) => { + struct_ser.serialize_field("runEndEncodedValue", v)?; + } } } struct_ser.end() @@ -7731,6 +8003,8 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "fixedSizeBinaryValue", "union_value", "unionValue", + "run_end_encoded_value", + "runEndEncodedValue", ]; #[allow(clippy::enum_variant_names)] @@ -7777,6 +8051,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { IntervalMonthDayNano, FixedSizeBinaryValue, UnionValue, + RunEndEncodedValue, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7785,7 +8060,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -7840,6 +8115,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "intervalMonthDayNano" | "interval_month_day_nano" => Ok(GeneratedField::IntervalMonthDayNano), "fixedSizeBinaryValue" | "fixed_size_binary_value" => Ok(GeneratedField::FixedSizeBinaryValue), "unionValue" | "union_value" => Ok(GeneratedField::UnionValue), + "runEndEncodedValue" | "run_end_encoded_value" => Ok(GeneratedField::RunEndEncodedValue), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -8130,6 +8406,13 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { return Err(serde::de::Error::duplicate_field("unionValue")); } value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::UnionValue) +; + } + GeneratedField::RunEndEncodedValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("runEndEncodedValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::RunEndEncodedValue) ; } } @@ -8189,7 +8472,7 @@ impl<'de> serde::Deserialize<'de> for Schema { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8310,7 +8593,7 @@ impl<'de> serde::Deserialize<'de> for Statistics { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8420,7 +8703,7 @@ impl<'de> serde::Deserialize<'de> for Struct { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8529,7 +8812,7 @@ impl<'de> serde::Deserialize<'de> for TableParquetOptions { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8631,7 +8914,7 @@ impl<'de> serde::Deserialize<'de> for TimeUnit { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = TimeUnit; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8728,7 +9011,7 @@ impl<'de> serde::Deserialize<'de> for Timestamp { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8849,7 +9132,7 @@ impl<'de> serde::Deserialize<'de> for Union { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8970,7 +9253,7 @@ impl<'de> serde::Deserialize<'de> for UnionField { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -9059,7 +9342,7 @@ impl<'de> serde::Deserialize<'de> for UnionMode { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = UnionMode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -9170,7 +9453,7 @@ impl<'de> serde::Deserialize<'de> for UnionValue { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -9290,7 +9573,7 @@ impl<'de> serde::Deserialize<'de> for UniqueConstraint { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/datafusion/proto-common/src/generated/prost.rs b/datafusion/proto-common/src/generated/prost.rs index 16601dcf46977..a09826a29be52 100644 --- a/datafusion/proto-common/src/generated/prost.rs +++ b/datafusion/proto-common/src/generated/prost.rs @@ -176,6 +176,13 @@ pub struct Map { pub keys_sorted: bool, } #[derive(Clone, PartialEq, ::prost::Message)] +pub struct RunEndEncoded { + #[prost(message, optional, boxed, tag = "1")] + pub run_ends_field: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "2")] + pub values_field: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, PartialEq, ::prost::Message)] pub struct Union { #[prost(message, repeated, tag = "1")] pub union_types: ::prost::alloc::vec::Vec, @@ -264,6 +271,15 @@ pub struct ScalarDictionaryValue { #[prost(message, optional, boxed, tag = "2")] pub value: ::core::option::Option<::prost::alloc::boxed::Box>, } +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ScalarRunEndEncodedValue { + #[prost(message, optional, tag = "1")] + pub run_ends_field: ::core::option::Option, + #[prost(message, optional, tag = "2")] + pub values_field: ::core::option::Option, + #[prost(message, optional, boxed, tag = "3")] + pub value: ::core::option::Option<::prost::alloc::boxed::Box>, +} #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct IntervalDayTimeValue { #[prost(int32, tag = "1")] @@ -311,7 +327,7 @@ pub struct ScalarFixedSizeBinary { pub struct ScalarValue { #[prost( oneof = "scalar_value::Value", - tags = "33, 1, 2, 3, 23, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 41, 43, 44, 20, 39, 21, 24, 35, 36, 37, 38, 26, 27, 28, 29, 22, 30, 25, 31, 34, 42" + tags = "33, 1, 2, 3, 23, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 41, 43, 44, 20, 39, 21, 24, 35, 36, 37, 38, 26, 27, 28, 29, 22, 30, 25, 31, 34, 42, 45" )] pub value: ::core::option::Option, } @@ -406,6 +422,8 @@ pub mod scalar_value { FixedSizeBinaryValue(super::ScalarFixedSizeBinary), #[prost(message, tag = "42")] UnionValue(::prost::alloc::boxed::Box), + #[prost(message, tag = "45")] + RunEndEncodedValue(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] @@ -449,7 +467,7 @@ pub struct Decimal256 { pub struct ArrowType { #[prost( oneof = "arrow_type::ArrowTypeEnum", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 40, 41, 24, 36, 25, 26, 27, 28, 29, 30, 33" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 40, 41, 24, 36, 25, 26, 27, 28, 29, 30, 33, 42" )] pub arrow_type_enum: ::core::option::Option, } @@ -538,6 +556,8 @@ pub mod arrow_type { Dictionary(::prost::alloc::boxed::Box), #[prost(message, tag = "33")] Map(::prost::alloc::boxed::Box), + #[prost(message, tag = "42")] + RunEndEncoded(::prost::alloc::boxed::Box), } } /// Useful for representing an empty enum variant in rust @@ -665,6 +685,9 @@ pub struct JsonOptions { /// Optional compression level #[prost(uint32, optional, tag = "3")] pub compression_level: ::core::option::Option, + /// Whether to read as newline-delimited JSON (default true). When false, expects JSON array format \[{},...\] + #[prost(bool, optional, tag = "4")] + pub newline_delimited: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct TableParquetOptions { diff --git a/datafusion/proto-common/src/lib.rs b/datafusion/proto-common/src/lib.rs index b7e1c906d90f5..6f7fb7b89c0c4 100644 --- a/datafusion/proto-common/src/lib.rs +++ b/datafusion/proto-common/src/lib.rs @@ -24,7 +24,6 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] -#![deny(clippy::allow_attributes)] //! Serialize / Deserialize DataFusion Primitive Types to bytes //! diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index fee3656482005..79e3306a4df1b 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -180,7 +180,9 @@ impl TryFrom<&DataType> for protobuf::arrow_type::ArrowTypeEnum { UnionMode::Dense => protobuf::UnionMode::Dense, }; Self::Union(protobuf::Union { - union_types: convert_arc_fields_to_proto_fields(fields.iter().map(|(_, item)|item))?, + union_types: convert_arc_fields_to_proto_fields( + fields.iter().map(|(_, item)| item), + )?, union_mode: union_mode.into(), type_ids: fields.iter().map(|(x, _)| x as i32).collect(), }) @@ -191,37 +193,44 @@ impl TryFrom<&DataType> for protobuf::arrow_type::ArrowTypeEnum { value: Some(Box::new(value_type.as_ref().try_into()?)), })) } - DataType::Decimal32(precision, scale) => Self::Decimal32(protobuf::Decimal32Type { - precision: *precision as u32, - scale: *scale as i32, - }), - DataType::Decimal64(precision, scale) => Self::Decimal64(protobuf::Decimal64Type { - precision: *precision as u32, - scale: *scale as i32, - }), - DataType::Decimal128(precision, scale) => Self::Decimal128(protobuf::Decimal128Type { - precision: *precision as u32, - scale: *scale as i32, - }), - DataType::Decimal256(precision, scale) => Self::Decimal256(protobuf::Decimal256Type { - precision: *precision as u32, - scale: *scale as i32, - }), - DataType::Map(field, sorted) => { - Self::Map(Box::new( - protobuf::Map { - field_type: Some(Box::new(field.as_ref().try_into()?)), - keys_sorted: *sorted, - } - )) - } - DataType::RunEndEncoded(_, _) => { - return Err(Error::General( - "Proto serialization error: The RunEndEncoded data type is not yet supported".to_owned() - )) + DataType::Decimal32(precision, scale) => { + Self::Decimal32(protobuf::Decimal32Type { + precision: *precision as u32, + scale: *scale as i32, + }) + } + DataType::Decimal64(precision, scale) => { + Self::Decimal64(protobuf::Decimal64Type { + precision: *precision as u32, + scale: *scale as i32, + }) + } + DataType::Decimal128(precision, scale) => { + Self::Decimal128(protobuf::Decimal128Type { + precision: *precision as u32, + scale: *scale as i32, + }) + } + DataType::Decimal256(precision, scale) => { + Self::Decimal256(protobuf::Decimal256Type { + precision: *precision as u32, + scale: *scale as i32, + }) + } + DataType::Map(field, sorted) => Self::Map(Box::new(protobuf::Map { + field_type: Some(Box::new(field.as_ref().try_into()?)), + keys_sorted: *sorted, + })), + DataType::RunEndEncoded(run_ends_field, values_field) => { + Self::RunEndEncoded(Box::new(protobuf::RunEndEncoded { + run_ends_field: Some(Box::new(run_ends_field.as_ref().try_into()?)), + values_field: Some(Box::new(values_field.as_ref().try_into()?)), + })) } DataType::ListView(_) | DataType::LargeListView(_) => { - return Err(Error::General(format!("Proto serialization error: {val} not yet supported"))) + return Err(Error::General(format!( + "Proto serialization error: {val} not yet supported" + ))); } }; @@ -680,6 +689,18 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { ))), }) } + + ScalarValue::RunEndEncoded(run_ends_field, values_field, val) => { + Ok(protobuf::ScalarValue { + value: Some(Value::RunEndEncodedValue(Box::new( + protobuf::ScalarRunEndEncodedValue { + run_ends_field: Some(run_ends_field.as_ref().try_into()?), + values_field: Some(values_field.as_ref().try_into()?), + value: Some(Box::new(val.as_ref().try_into()?)), + }, + ))), + }) + } } } } @@ -990,6 +1011,7 @@ impl TryFrom<&JsonOptions> for protobuf::JsonOptions { compression: compression.into(), schema_infer_max_rec: opts.schema_infer_max_rec.map(|h| h as u64), compression_level: opts.compression_level, + newline_delimited: Some(opts.newline_delimited), }) } } @@ -1010,7 +1032,7 @@ fn create_proto_scalar protobuf::scalar_value::Value>( Ok(protobuf::ScalarValue { value: Some(value) }) } -// ScalarValue::List / FixedSizeList / LargeList / Struct / Map are serialized using +// Nested ScalarValue types (List / FixedSizeList / LargeList / Struct / Map) are serialized using // Arrow IPC messages as a single column RecordBatch fn encode_scalar_nested_value( arr: ArrayRef, @@ -1018,13 +1040,20 @@ fn encode_scalar_nested_value( ) -> Result { let batch = RecordBatch::try_from_iter(vec![("field_name", arr)]).map_err(|e| { Error::General(format!( - "Error creating temporary batch while encoding ScalarValue::List: {e}" + "Error creating temporary batch while encoding nested ScalarValue: {e}" )) })?; let ipc_gen = IpcDataGenerator {}; let mut dict_tracker = DictionaryTracker::new(false); let write_options = IpcWriteOptions::default(); + // The IPC writer requires pre-allocated dictionary IDs (normally assigned when + // serializing the schema). Populate `dict_tracker` by encoding the schema first. + ipc_gen.schema_to_bytes_with_dictionary_tracker( + batch.schema().as_ref(), + &mut dict_tracker, + &write_options, + ); let mut compression_context = CompressionContext::default(); let (encoded_dictionaries, encoded_message) = ipc_gen .encode( @@ -1034,7 +1063,7 @@ fn encode_scalar_nested_value( &mut compression_context, ) .map_err(|e| { - Error::General(format!("Error encoding ScalarValue::List as IPC: {e}")) + Error::General(format!("Error encoding nested ScalarValue as IPC: {e}")) })?; let schema: protobuf::Schema = batch.schema().try_into()?; diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index b00bd0dcc6bfd..3d17ed30d5726 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -28,9 +28,6 @@ license = { workspace = true } authors = { workspace = true } rust-version = { workspace = true } -# Exclude proto files so crates.io consumers don't need protoc -exclude = ["*.proto"] - [package.metadata.docs.rs] all-features = true @@ -69,6 +66,7 @@ datafusion-proto-common = { workspace = true } object_store = { workspace = true } pbjson = { workspace = true, optional = true } prost = { workspace = true } +rand = { workspace = true } serde = { version = "1.0", optional = true } serde_json = { workspace = true, optional = true } diff --git a/datafusion/proto/gen/Cargo.toml b/datafusion/proto/gen/Cargo.toml index d446ab0d89741..8b48dfe70e6c7 100644 --- a/datafusion/proto/gen/Cargo.toml +++ b/datafusion/proto/gen/Cargo.toml @@ -37,5 +37,5 @@ workspace = true [dependencies] # Pin these dependencies so that the generated output is deterministic -pbjson-build = "=0.8.0" -prost-build = "=0.14.1" +pbjson-build = "=0.9.0" +prost-build = "=0.14.3" diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index bd7dd3a6aff3c..7c0268867691e 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -278,6 +278,7 @@ message DmlNode{ INSERT_APPEND = 3; INSERT_OVERWRITE = 4; INSERT_REPLACE = 5; + TRUNCATE = 6; } Type dml_type = 1; LogicalPlanNode input = 2; @@ -749,6 +750,8 @@ message PhysicalPlanNode { SortMergeJoinExecNode sort_merge_join = 34; MemoryScanExecNode memory_scan = 35; AsyncFuncExecNode async_func = 36; + BufferExecNode buffer = 37; + ArrowScanExecNode arrow_scan = 38; } } @@ -758,6 +761,16 @@ message PartitionColumn { } +// Determines how file sink output paths are interpreted. +enum FileOutputMode { + // Infer output mode from the URL (extension/trailing `/` heuristic). + FILE_OUTPUT_MODE_AUTOMATIC = 0; + // Write to a single file at the exact output path. + FILE_OUTPUT_MODE_SINGLE_FILE = 1; + // Write to a directory with generated filenames. + FILE_OUTPUT_MODE_DIRECTORY = 2; +} + message FileSinkConfig { reserved 6; // writer_mode reserved 8; // was `overwrite` which has been superseded by `insert_op` @@ -770,6 +783,8 @@ message FileSinkConfig { bool keep_partition_by_columns = 9; InsertOp insert_op = 10; string file_extension = 11; + // Determines how the output path is interpreted. + FileOutputMode file_output_mode = 12; } enum InsertOp { @@ -837,6 +852,14 @@ message PhysicalExprNode { // Was date_time_interval_expr reserved 17; + // Unique identifier for this expression to do deduplication during deserialization. + // When serializing, this is set to a unique identifier for each combination of + // expression, process and serialization run. + // When deserializing, if this ID has been seen before, the cached Arc is returned + // instead of creating a new one, enabling reconstruction of referential integrity + // across serde roundtrips. + optional uint64 expr_id = 30; + oneof ExprType { // column references PhysicalColumn column = 1; @@ -1006,6 +1029,7 @@ message FilterExecNode { PhysicalExprNode expr = 2; uint32 default_filter_selectivity = 3; repeated uint32 projection = 9; + uint32 batch_size = 10; } message FileGroup { @@ -1083,6 +1107,10 @@ message AvroScanExecNode { FileScanExecConf base_conf = 1; } +message ArrowScanExecNode { + FileScanExecConf base_conf = 1; +} + message MemoryScanExecNode { repeated bytes partitions = 1; datafusion_common.Schema schema = 2; @@ -1111,6 +1139,7 @@ message HashJoinExecNode { datafusion_common.NullEquality null_equality = 7; JoinFilter filter = 8; repeated uint32 projection = 9; + bool null_aware = 10; } enum StreamPartitionMode { @@ -1190,6 +1219,7 @@ enum AggregateMode { FINAL_PARTITIONED = 2; SINGLE = 3; SINGLE_PARTITIONED = 4; + PARTIAL_REDUCE = 5; } message PartiallySortedInputOrderMode { @@ -1219,6 +1249,8 @@ message MaybePhysicalSortExprs { message AggLimit { // wrap into a message to make it optional uint64 limit = 1; + // Optional ordering direction for TopK aggregation (true = descending, false = ascending) + optional bool descending = 2; } message AggregateExecNode { @@ -1412,3 +1444,8 @@ message AsyncFuncExecNode { repeated PhysicalExprNode async_exprs = 2; repeated string async_expr_names = 3; } + +message BufferExecNode { + PhysicalPlanNode input = 1; + uint64 capacity = 2; +} \ No newline at end of file diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index d95bdd388699e..84b15ea9a8920 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -21,7 +21,8 @@ use crate::logical_plan::{ self, AsLogicalPlan, DefaultLogicalExtensionCodec, LogicalExtensionCodec, }; use crate::physical_plan::{ - AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, + DefaultPhysicalExtensionCodec, DefaultPhysicalProtoConverter, PhysicalExtensionCodec, + PhysicalProtoConverterExtension, }; use crate::protobuf; use datafusion_common::{Result, plan_datafusion_err}; @@ -276,16 +277,18 @@ pub fn logical_plan_from_json_with_extension_codec( /// Serialize a PhysicalPlan as bytes pub fn physical_plan_to_bytes(plan: Arc) -> Result { let extension_codec = DefaultPhysicalExtensionCodec {}; - physical_plan_to_bytes_with_extension_codec(plan, &extension_codec) + let proto_converter = DefaultPhysicalProtoConverter {}; + physical_plan_to_bytes_with_proto_converter(plan, &extension_codec, &proto_converter) } /// Serialize a PhysicalPlan as JSON #[cfg(feature = "json")] pub fn physical_plan_to_json(plan: Arc) -> Result { let extension_codec = DefaultPhysicalExtensionCodec {}; - let protobuf = - protobuf::PhysicalPlanNode::try_from_physical_plan(plan, &extension_codec) - .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}"))?; + let proto_converter = DefaultPhysicalProtoConverter {}; + let protobuf = proto_converter + .execution_plan_to_proto(&plan, &extension_codec) + .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}"))?; serde_json::to_string(&protobuf) .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}")) } @@ -295,8 +298,18 @@ pub fn physical_plan_to_bytes_with_extension_codec( plan: Arc, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result { - let protobuf = - protobuf::PhysicalPlanNode::try_from_physical_plan(plan, extension_codec)?; + let proto_converter = DefaultPhysicalProtoConverter {}; + physical_plan_to_bytes_with_proto_converter(plan, extension_codec, &proto_converter) +} + +/// Serialize a PhysicalPlan as bytes, using the provided extension codec +/// and protobuf converter. +pub fn physical_plan_to_bytes_with_proto_converter( + plan: Arc, + extension_codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, +) -> Result { + let protobuf = proto_converter.execution_plan_to_proto(&plan, extension_codec)?; let mut buffer = BytesMut::new(); protobuf .encode(&mut buffer) @@ -313,7 +326,8 @@ pub fn physical_plan_from_json( let back: protobuf::PhysicalPlanNode = serde_json::from_str(json) .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}"))?; let extension_codec = DefaultPhysicalExtensionCodec {}; - back.try_into_physical_plan(ctx, &extension_codec) + let proto_converter = DefaultPhysicalProtoConverter {}; + proto_converter.proto_to_execution_plan(ctx, &extension_codec, &back) } /// Deserialize a PhysicalPlan from bytes @@ -322,7 +336,13 @@ pub fn physical_plan_from_bytes( ctx: &TaskContext, ) -> Result> { let extension_codec = DefaultPhysicalExtensionCodec {}; - physical_plan_from_bytes_with_extension_codec(bytes, ctx, &extension_codec) + let proto_converter = DefaultPhysicalProtoConverter {}; + physical_plan_from_bytes_with_proto_converter( + bytes, + ctx, + &extension_codec, + &proto_converter, + ) } /// Deserialize a PhysicalPlan from bytes @@ -330,8 +350,24 @@ pub fn physical_plan_from_bytes_with_extension_codec( bytes: &[u8], ctx: &TaskContext, extension_codec: &dyn PhysicalExtensionCodec, +) -> Result> { + let proto_converter = DefaultPhysicalProtoConverter {}; + physical_plan_from_bytes_with_proto_converter( + bytes, + ctx, + extension_codec, + &proto_converter, + ) +} + +/// Deserialize a PhysicalPlan from bytes +pub fn physical_plan_from_bytes_with_proto_converter( + bytes: &[u8], + ctx: &TaskContext, + extension_codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let protobuf = protobuf::PhysicalPlanNode::decode(bytes) .map_err(|e| plan_datafusion_err!("Error decoding expr as protobuf: {e}"))?; - protobuf.try_into_physical_plan(ctx, extension_codec) + proto_converter.proto_to_execution_plan(ctx, extension_codec, &protobuf) } diff --git a/datafusion/proto/src/generated/datafusion_proto_common.rs b/datafusion/proto/src/generated/datafusion_proto_common.rs index 16601dcf46977..a09826a29be52 100644 --- a/datafusion/proto/src/generated/datafusion_proto_common.rs +++ b/datafusion/proto/src/generated/datafusion_proto_common.rs @@ -176,6 +176,13 @@ pub struct Map { pub keys_sorted: bool, } #[derive(Clone, PartialEq, ::prost::Message)] +pub struct RunEndEncoded { + #[prost(message, optional, boxed, tag = "1")] + pub run_ends_field: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "2")] + pub values_field: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, PartialEq, ::prost::Message)] pub struct Union { #[prost(message, repeated, tag = "1")] pub union_types: ::prost::alloc::vec::Vec, @@ -264,6 +271,15 @@ pub struct ScalarDictionaryValue { #[prost(message, optional, boxed, tag = "2")] pub value: ::core::option::Option<::prost::alloc::boxed::Box>, } +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ScalarRunEndEncodedValue { + #[prost(message, optional, tag = "1")] + pub run_ends_field: ::core::option::Option, + #[prost(message, optional, tag = "2")] + pub values_field: ::core::option::Option, + #[prost(message, optional, boxed, tag = "3")] + pub value: ::core::option::Option<::prost::alloc::boxed::Box>, +} #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct IntervalDayTimeValue { #[prost(int32, tag = "1")] @@ -311,7 +327,7 @@ pub struct ScalarFixedSizeBinary { pub struct ScalarValue { #[prost( oneof = "scalar_value::Value", - tags = "33, 1, 2, 3, 23, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 41, 43, 44, 20, 39, 21, 24, 35, 36, 37, 38, 26, 27, 28, 29, 22, 30, 25, 31, 34, 42" + tags = "33, 1, 2, 3, 23, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 41, 43, 44, 20, 39, 21, 24, 35, 36, 37, 38, 26, 27, 28, 29, 22, 30, 25, 31, 34, 42, 45" )] pub value: ::core::option::Option, } @@ -406,6 +422,8 @@ pub mod scalar_value { FixedSizeBinaryValue(super::ScalarFixedSizeBinary), #[prost(message, tag = "42")] UnionValue(::prost::alloc::boxed::Box), + #[prost(message, tag = "45")] + RunEndEncodedValue(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] @@ -449,7 +467,7 @@ pub struct Decimal256 { pub struct ArrowType { #[prost( oneof = "arrow_type::ArrowTypeEnum", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 40, 41, 24, 36, 25, 26, 27, 28, 29, 30, 33" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 40, 41, 24, 36, 25, 26, 27, 28, 29, 30, 33, 42" )] pub arrow_type_enum: ::core::option::Option, } @@ -538,6 +556,8 @@ pub mod arrow_type { Dictionary(::prost::alloc::boxed::Box), #[prost(message, tag = "33")] Map(::prost::alloc::boxed::Box), + #[prost(message, tag = "42")] + RunEndEncoded(::prost::alloc::boxed::Box), } } /// Useful for representing an empty enum variant in rust @@ -665,6 +685,9 @@ pub struct JsonOptions { /// Optional compression level #[prost(uint32, optional, tag = "3")] pub compression_level: ::core::option::Option, + /// Whether to read as newline-delimited JSON (default true). When false, expects JSON array format \[{},...\] + #[prost(bool, optional, tag = "4")] + pub newline_delimited: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct TableParquetOptions { diff --git a/datafusion/proto/src/generated/mod.rs b/datafusion/proto/src/generated/mod.rs index adf5125457c14..ca32b1500d57b 100644 --- a/datafusion/proto/src/generated/mod.rs +++ b/datafusion/proto/src/generated/mod.rs @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. -#![allow(clippy::allow_attributes)] - +// This code is generated so we don't want to fix any lint violations manually +#[allow(clippy::allow_attributes)] #[allow(clippy::all)] #[rustfmt::skip] pub mod datafusion { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index e269606d163a3..5b2b9133ce13a 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -9,12 +9,18 @@ impl serde::Serialize for AggLimit { if self.limit != 0 { len += 1; } + if self.descending.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.AggLimit", len)?; if self.limit != 0 { #[allow(clippy::needless_borrow)] #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("limit", ToString::to_string(&self.limit).as_str())?; } + if let Some(v) = self.descending.as_ref() { + struct_ser.serialize_field("descending", v)?; + } struct_ser.end() } } @@ -26,11 +32,13 @@ impl<'de> serde::Deserialize<'de> for AggLimit { { const FIELDS: &[&str] = &[ "limit", + "descending", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Limit, + Descending, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -39,7 +47,7 @@ impl<'de> serde::Deserialize<'de> for AggLimit { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -53,6 +61,7 @@ impl<'de> serde::Deserialize<'de> for AggLimit { { match value { "limit" => Ok(GeneratedField::Limit), + "descending" => Ok(GeneratedField::Descending), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -73,6 +82,7 @@ impl<'de> serde::Deserialize<'de> for AggLimit { V: serde::de::MapAccess<'de>, { let mut limit__ = None; + let mut descending__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Limit => { @@ -83,10 +93,17 @@ impl<'de> serde::Deserialize<'de> for AggLimit { Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } + GeneratedField::Descending => { + if descending__.is_some() { + return Err(serde::de::Error::duplicate_field("descending")); + } + descending__ = map_.next_value()?; + } } } Ok(AggLimit { limit: limit__.unwrap_or_default(), + descending: descending__, }) } } @@ -230,7 +247,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -393,6 +410,7 @@ impl serde::Serialize for AggregateMode { Self::FinalPartitioned => "FINAL_PARTITIONED", Self::Single => "SINGLE", Self::SinglePartitioned => "SINGLE_PARTITIONED", + Self::PartialReduce => "PARTIAL_REDUCE", }; serializer.serialize_str(variant) } @@ -409,11 +427,12 @@ impl<'de> serde::Deserialize<'de> for AggregateMode { "FINAL_PARTITIONED", "SINGLE", "SINGLE_PARTITIONED", + "PARTIAL_REDUCE", ]; struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = AggregateMode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -454,6 +473,7 @@ impl<'de> serde::Deserialize<'de> for AggregateMode { "FINAL_PARTITIONED" => Ok(AggregateMode::FinalPartitioned), "SINGLE" => Ok(AggregateMode::Single), "SINGLE_PARTITIONED" => Ok(AggregateMode::SinglePartitioned), + "PARTIAL_REDUCE" => Ok(AggregateMode::PartialReduce), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } @@ -518,7 +538,7 @@ impl<'de> serde::Deserialize<'de> for AggregateNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -683,7 +703,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -854,7 +874,7 @@ impl<'de> serde::Deserialize<'de> for AliasNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -999,7 +1019,7 @@ impl<'de> serde::Deserialize<'de> for AnalyzeExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1125,7 +1145,7 @@ impl<'de> serde::Deserialize<'de> for AnalyzeNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1226,7 +1246,7 @@ impl<'de> serde::Deserialize<'de> for AnalyzedLogicalPlanType { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1278,6 +1298,98 @@ impl<'de> serde::Deserialize<'de> for AnalyzedLogicalPlanType { deserializer.deserialize_struct("datafusion.AnalyzedLogicalPlanType", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for ArrowScanExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.base_conf.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ArrowScanExecNode", len)?; + if let Some(v) = self.base_conf.as_ref() { + struct_ser.serialize_field("baseConf", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ArrowScanExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "base_conf", + "baseConf", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + BaseConf, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "baseConf" | "base_conf" => Ok(GeneratedField::BaseConf), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ArrowScanExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ArrowScanExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut base_conf__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::BaseConf => { + if base_conf__.is_some() { + return Err(serde::de::Error::duplicate_field("baseConf")); + } + base_conf__ = map_.next_value()?; + } + } + } + Ok(ArrowScanExecNode { + base_conf: base_conf__, + }) + } + } + deserializer.deserialize_struct("datafusion.ArrowScanExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for AsyncFuncExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -1335,7 +1447,7 @@ impl<'de> serde::Deserialize<'de> for AsyncFuncExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1445,7 +1557,7 @@ impl<'de> serde::Deserialize<'de> for AvroScanExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1536,7 +1648,7 @@ impl<'de> serde::Deserialize<'de> for BareTableReference { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1651,7 +1763,7 @@ impl<'de> serde::Deserialize<'de> for BetweenNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1777,7 +1889,7 @@ impl<'de> serde::Deserialize<'de> for BinaryExprNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1838,6 +1950,118 @@ impl<'de> serde::Deserialize<'de> for BinaryExprNode { deserializer.deserialize_struct("datafusion.BinaryExprNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for BufferExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if self.capacity != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.BufferExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if self.capacity != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("capacity", ToString::to_string(&self.capacity).as_str())?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for BufferExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "capacity", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + Capacity, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "capacity" => Ok(GeneratedField::Capacity), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = BufferExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.BufferExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut capacity__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::Capacity => { + if capacity__.is_some() { + return Err(serde::de::Error::duplicate_field("capacity")); + } + capacity__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(BufferExecNode { + input: input__, + capacity: capacity__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.BufferExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for CaseNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -1895,7 +2119,7 @@ impl<'de> serde::Deserialize<'de> for CaseNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -2013,7 +2237,7 @@ impl<'de> serde::Deserialize<'de> for CastNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -2130,7 +2354,7 @@ impl<'de> serde::Deserialize<'de> for CoalesceBatchesExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -2251,7 +2475,7 @@ impl<'de> serde::Deserialize<'de> for CoalescePartitionsExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -2363,7 +2587,7 @@ impl<'de> serde::Deserialize<'de> for ColumnIndex { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -2474,7 +2698,7 @@ impl<'de> serde::Deserialize<'de> for ColumnUnnestListItem { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -2585,7 +2809,7 @@ impl<'de> serde::Deserialize<'de> for ColumnUnnestListRecursion { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -2687,7 +2911,7 @@ impl<'de> serde::Deserialize<'de> for ColumnUnnestListRecursions { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -2778,7 +3002,7 @@ impl<'de> serde::Deserialize<'de> for CooperativeExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -2898,7 +3122,7 @@ impl<'de> serde::Deserialize<'de> for CopyToNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -3036,7 +3260,7 @@ impl<'de> serde::Deserialize<'de> for CreateCatalogNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -3163,7 +3387,7 @@ impl<'de> serde::Deserialize<'de> for CreateCatalogSchemaNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -3382,7 +3606,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -3627,7 +3851,7 @@ impl<'de> serde::Deserialize<'de> for CreateViewNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -3762,7 +3986,7 @@ impl<'de> serde::Deserialize<'de> for CrossJoinExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -3870,7 +4094,7 @@ impl<'de> serde::Deserialize<'de> for CrossJoinNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4038,7 +4262,7 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4201,7 +4425,7 @@ impl<'de> serde::Deserialize<'de> for CsvSink { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4327,7 +4551,7 @@ impl<'de> serde::Deserialize<'de> for CsvSinkExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4453,7 +4677,7 @@ impl<'de> serde::Deserialize<'de> for CteWorkTableScanNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4553,7 +4777,7 @@ impl<'de> serde::Deserialize<'de> for CubeNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4680,7 +4904,7 @@ impl<'de> serde::Deserialize<'de> for CustomTableScanNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4796,7 +5020,7 @@ impl<'de> serde::Deserialize<'de> for DateUnit { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = DateUnit; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4880,7 +5104,7 @@ impl<'de> serde::Deserialize<'de> for DistinctNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -4998,7 +5222,7 @@ impl<'de> serde::Deserialize<'de> for DistinctOnNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -5144,7 +5368,7 @@ impl<'de> serde::Deserialize<'de> for DmlNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -5236,6 +5460,7 @@ impl serde::Serialize for dml_node::Type { Self::InsertAppend => "INSERT_APPEND", Self::InsertOverwrite => "INSERT_OVERWRITE", Self::InsertReplace => "INSERT_REPLACE", + Self::Truncate => "TRUNCATE", }; serializer.serialize_str(variant) } @@ -5253,11 +5478,12 @@ impl<'de> serde::Deserialize<'de> for dml_node::Type { "INSERT_APPEND", "INSERT_OVERWRITE", "INSERT_REPLACE", + "TRUNCATE", ]; struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = dml_node::Type; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -5299,6 +5525,7 @@ impl<'de> serde::Deserialize<'de> for dml_node::Type { "INSERT_APPEND" => Ok(dml_node::Type::InsertAppend), "INSERT_OVERWRITE" => Ok(dml_node::Type::InsertOverwrite), "INSERT_REPLACE" => Ok(dml_node::Type::InsertReplace), + "TRUNCATE" => Ok(dml_node::Type::Truncate), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } @@ -5362,7 +5589,7 @@ impl<'de> serde::Deserialize<'de> for DropViewNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -5471,7 +5698,7 @@ impl<'de> serde::Deserialize<'de> for EmptyExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -5563,7 +5790,7 @@ impl<'de> serde::Deserialize<'de> for EmptyRelationNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -5671,7 +5898,7 @@ impl<'de> serde::Deserialize<'de> for ExplainExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -5788,7 +6015,7 @@ impl<'de> serde::Deserialize<'de> for ExplainNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -5888,7 +6115,7 @@ impl<'de> serde::Deserialize<'de> for FileGroup { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -5940,6 +6167,80 @@ impl<'de> serde::Deserialize<'de> for FileGroup { deserializer.deserialize_struct("datafusion.FileGroup", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for FileOutputMode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::Automatic => "FILE_OUTPUT_MODE_AUTOMATIC", + Self::SingleFile => "FILE_OUTPUT_MODE_SINGLE_FILE", + Self::Directory => "FILE_OUTPUT_MODE_DIRECTORY", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for FileOutputMode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "FILE_OUTPUT_MODE_AUTOMATIC", + "FILE_OUTPUT_MODE_SINGLE_FILE", + "FILE_OUTPUT_MODE_DIRECTORY", + ]; + + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = FileOutputMode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "FILE_OUTPUT_MODE_AUTOMATIC" => Ok(FileOutputMode::Automatic), + "FILE_OUTPUT_MODE_SINGLE_FILE" => Ok(FileOutputMode::SingleFile), + "FILE_OUTPUT_MODE_DIRECTORY" => Ok(FileOutputMode::Directory), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} impl serde::Serialize for FileRange { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -5991,7 +6292,7 @@ impl<'de> serde::Deserialize<'de> for FileRange { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -6183,7 +6484,7 @@ impl<'de> serde::Deserialize<'de> for FileScanExecConf { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -6362,6 +6663,9 @@ impl serde::Serialize for FileSinkConfig { if !self.file_extension.is_empty() { len += 1; } + if self.file_output_mode != 0 { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.FileSinkConfig", len)?; if !self.object_store_url.is_empty() { struct_ser.serialize_field("objectStoreUrl", &self.object_store_url)?; @@ -6389,6 +6693,11 @@ impl serde::Serialize for FileSinkConfig { if !self.file_extension.is_empty() { struct_ser.serialize_field("fileExtension", &self.file_extension)?; } + if self.file_output_mode != 0 { + let v = FileOutputMode::try_from(self.file_output_mode) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.file_output_mode)))?; + struct_ser.serialize_field("fileOutputMode", &v)?; + } struct_ser.end() } } @@ -6415,6 +6724,8 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { "insertOp", "file_extension", "fileExtension", + "file_output_mode", + "fileOutputMode", ]; #[allow(clippy::enum_variant_names)] @@ -6427,6 +6738,7 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { KeepPartitionByColumns, InsertOp, FileExtension, + FileOutputMode, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -6435,7 +6747,7 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -6456,6 +6768,7 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { "keepPartitionByColumns" | "keep_partition_by_columns" => Ok(GeneratedField::KeepPartitionByColumns), "insertOp" | "insert_op" => Ok(GeneratedField::InsertOp), "fileExtension" | "file_extension" => Ok(GeneratedField::FileExtension), + "fileOutputMode" | "file_output_mode" => Ok(GeneratedField::FileOutputMode), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -6483,6 +6796,7 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { let mut keep_partition_by_columns__ = None; let mut insert_op__ = None; let mut file_extension__ = None; + let mut file_output_mode__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::ObjectStoreUrl => { @@ -6533,6 +6847,12 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { } file_extension__ = Some(map_.next_value()?); } + GeneratedField::FileOutputMode => { + if file_output_mode__.is_some() { + return Err(serde::de::Error::duplicate_field("fileOutputMode")); + } + file_output_mode__ = Some(map_.next_value::()? as i32); + } } } Ok(FileSinkConfig { @@ -6544,6 +6864,7 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { keep_partition_by_columns: keep_partition_by_columns__.unwrap_or_default(), insert_op: insert_op__.unwrap_or_default(), file_extension: file_extension__.unwrap_or_default(), + file_output_mode: file_output_mode__.unwrap_or_default(), }) } } @@ -6570,6 +6891,9 @@ impl serde::Serialize for FilterExecNode { if !self.projection.is_empty() { len += 1; } + if self.batch_size != 0 { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.FilterExecNode", len)?; if let Some(v) = self.input.as_ref() { struct_ser.serialize_field("input", v)?; @@ -6583,6 +6907,9 @@ impl serde::Serialize for FilterExecNode { if !self.projection.is_empty() { struct_ser.serialize_field("projection", &self.projection)?; } + if self.batch_size != 0 { + struct_ser.serialize_field("batchSize", &self.batch_size)?; + } struct_ser.end() } } @@ -6598,6 +6925,8 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { "default_filter_selectivity", "defaultFilterSelectivity", "projection", + "batch_size", + "batchSize", ]; #[allow(clippy::enum_variant_names)] @@ -6606,6 +6935,7 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { Expr, DefaultFilterSelectivity, Projection, + BatchSize, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -6614,7 +6944,7 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -6631,6 +6961,7 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { "expr" => Ok(GeneratedField::Expr), "defaultFilterSelectivity" | "default_filter_selectivity" => Ok(GeneratedField::DefaultFilterSelectivity), "projection" => Ok(GeneratedField::Projection), + "batchSize" | "batch_size" => Ok(GeneratedField::BatchSize), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -6654,6 +6985,7 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { let mut expr__ = None; let mut default_filter_selectivity__ = None; let mut projection__ = None; + let mut batch_size__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { @@ -6685,6 +7017,14 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { .into_iter().map(|x| x.0).collect()) ; } + GeneratedField::BatchSize => { + if batch_size__.is_some() { + return Err(serde::de::Error::duplicate_field("batchSize")); + } + batch_size__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } } } Ok(FilterExecNode { @@ -6692,6 +7032,7 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { expr: expr__, default_filter_selectivity: default_filter_selectivity__.unwrap_or_default(), projection: projection__.unwrap_or_default(), + batch_size: batch_size__.unwrap_or_default(), }) } } @@ -6737,7 +7078,7 @@ impl<'de> serde::Deserialize<'de> for FixedSizeBinary { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -6846,7 +7187,7 @@ impl<'de> serde::Deserialize<'de> for FullTableReference { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -6957,7 +7298,7 @@ impl<'de> serde::Deserialize<'de> for GenerateSeriesArgsContainsNull { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -7087,7 +7428,7 @@ impl<'de> serde::Deserialize<'de> for GenerateSeriesArgsDate { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -7259,7 +7600,7 @@ impl<'de> serde::Deserialize<'de> for GenerateSeriesArgsInt64 { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -7439,7 +7780,7 @@ impl<'de> serde::Deserialize<'de> for GenerateSeriesArgsTimestamp { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -7566,7 +7907,7 @@ impl<'de> serde::Deserialize<'de> for GenerateSeriesName { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GenerateSeriesName; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -7690,7 +8031,7 @@ impl<'de> serde::Deserialize<'de> for GenerateSeriesNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -7844,7 +8185,7 @@ impl<'de> serde::Deserialize<'de> for GlobalLimitExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -7957,7 +8298,7 @@ impl<'de> serde::Deserialize<'de> for GroupingSetNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8041,6 +8382,9 @@ impl serde::Serialize for HashJoinExecNode { if !self.projection.is_empty() { len += 1; } + if self.null_aware { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.HashJoinExecNode", len)?; if let Some(v) = self.left.as_ref() { struct_ser.serialize_field("left", v)?; @@ -8072,6 +8416,9 @@ impl serde::Serialize for HashJoinExecNode { if !self.projection.is_empty() { struct_ser.serialize_field("projection", &self.projection)?; } + if self.null_aware { + struct_ser.serialize_field("nullAware", &self.null_aware)?; + } struct_ser.end() } } @@ -8093,6 +8440,8 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { "nullEquality", "filter", "projection", + "null_aware", + "nullAware", ]; #[allow(clippy::enum_variant_names)] @@ -8105,6 +8454,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { NullEquality, Filter, Projection, + NullAware, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -8113,7 +8463,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8134,6 +8484,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { "nullEquality" | "null_equality" => Ok(GeneratedField::NullEquality), "filter" => Ok(GeneratedField::Filter), "projection" => Ok(GeneratedField::Projection), + "nullAware" | "null_aware" => Ok(GeneratedField::NullAware), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -8161,6 +8512,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { let mut null_equality__ = None; let mut filter__ = None; let mut projection__ = None; + let mut null_aware__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Left => { @@ -8214,6 +8566,12 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { .into_iter().map(|x| x.0).collect()) ; } + GeneratedField::NullAware => { + if null_aware__.is_some() { + return Err(serde::de::Error::duplicate_field("nullAware")); + } + null_aware__ = Some(map_.next_value()?); + } } } Ok(HashJoinExecNode { @@ -8225,6 +8583,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { null_equality: null_equality__.unwrap_or_default(), filter: filter__, projection: projection__.unwrap_or_default(), + null_aware: null_aware__.unwrap_or_default(), }) } } @@ -8282,7 +8641,7 @@ impl<'de> serde::Deserialize<'de> for HashRepartition { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8409,7 +8768,7 @@ impl<'de> serde::Deserialize<'de> for ILikeNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8543,7 +8902,7 @@ impl<'de> serde::Deserialize<'de> for InListNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8641,7 +9000,7 @@ impl<'de> serde::Deserialize<'de> for InsertOp { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = InsertOp; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8726,7 +9085,7 @@ impl<'de> serde::Deserialize<'de> for InterleaveExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8817,7 +9176,7 @@ impl<'de> serde::Deserialize<'de> for IsFalse { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8908,7 +9267,7 @@ impl<'de> serde::Deserialize<'de> for IsNotFalse { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -8999,7 +9358,7 @@ impl<'de> serde::Deserialize<'de> for IsNotNull { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -9090,7 +9449,7 @@ impl<'de> serde::Deserialize<'de> for IsNotTrue { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -9181,7 +9540,7 @@ impl<'de> serde::Deserialize<'de> for IsNotUnknown { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -9272,7 +9631,7 @@ impl<'de> serde::Deserialize<'de> for IsNull { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -9363,7 +9722,7 @@ impl<'de> serde::Deserialize<'de> for IsTrue { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -9454,7 +9813,7 @@ impl<'de> serde::Deserialize<'de> for IsUnknown { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -9562,7 +9921,7 @@ impl<'de> serde::Deserialize<'de> for JoinFilter { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -9738,7 +10097,7 @@ impl<'de> serde::Deserialize<'de> for JoinNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -9900,7 +10259,7 @@ impl<'de> serde::Deserialize<'de> for JoinOn { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -10001,7 +10360,7 @@ impl<'de> serde::Deserialize<'de> for JsonScanExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -10101,7 +10460,7 @@ impl<'de> serde::Deserialize<'de> for JsonSink { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -10227,7 +10586,7 @@ impl<'de> serde::Deserialize<'de> for JsonSinkExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -10370,7 +10729,7 @@ impl<'de> serde::Deserialize<'de> for LikeNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -10508,7 +10867,7 @@ impl<'de> serde::Deserialize<'de> for LimitNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -10621,7 +10980,7 @@ impl<'de> serde::Deserialize<'de> for ListIndex { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -10728,7 +11087,7 @@ impl<'de> serde::Deserialize<'de> for ListRange { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -10846,7 +11205,7 @@ impl<'de> serde::Deserialize<'de> for ListUnnest { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -11060,7 +11419,7 @@ impl<'de> serde::Deserialize<'de> for ListingTableScanNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -11284,7 +11643,7 @@ impl<'de> serde::Deserialize<'de> for LocalLimitExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -11386,7 +11745,7 @@ impl<'de> serde::Deserialize<'de> for LogicalExprList { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -11649,7 +12008,7 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -11982,7 +12341,7 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNodeCollection { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -12083,7 +12442,7 @@ impl<'de> serde::Deserialize<'de> for LogicalExtensionNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -12359,7 +12718,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -12699,7 +13058,7 @@ impl<'de> serde::Deserialize<'de> for MaybeFilter { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -12791,7 +13150,7 @@ impl<'de> serde::Deserialize<'de> for MaybePhysicalSortExprs { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -12924,7 +13283,7 @@ impl<'de> serde::Deserialize<'de> for MemoryScanExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -13068,7 +13427,7 @@ impl<'de> serde::Deserialize<'de> for NamedStructField { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -13159,7 +13518,7 @@ impl<'de> serde::Deserialize<'de> for NegativeNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -13285,7 +13644,7 @@ impl<'de> serde::Deserialize<'de> for NestedLoopJoinExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -13415,7 +13774,7 @@ impl<'de> serde::Deserialize<'de> for Not { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -13493,7 +13852,7 @@ impl<'de> serde::Deserialize<'de> for NullTreatment { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = NullTreatment; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -13578,7 +13937,7 @@ impl<'de> serde::Deserialize<'de> for OptimizedLogicalPlanType { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -13670,7 +14029,7 @@ impl<'de> serde::Deserialize<'de> for OptimizedPhysicalPlanType { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -13779,7 +14138,7 @@ impl<'de> serde::Deserialize<'de> for ParquetScanExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -13897,7 +14256,7 @@ impl<'de> serde::Deserialize<'de> for ParquetSink { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -14023,7 +14382,7 @@ impl<'de> serde::Deserialize<'de> for ParquetSinkExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -14149,7 +14508,7 @@ impl<'de> serde::Deserialize<'de> for PartialTableReference { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -14249,7 +14608,7 @@ impl<'de> serde::Deserialize<'de> for PartiallySortedInputOrderMode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -14352,7 +14711,7 @@ impl<'de> serde::Deserialize<'de> for PartitionColumn { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -14441,7 +14800,7 @@ impl<'de> serde::Deserialize<'de> for PartitionMode { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = PartitionMode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -14560,7 +14919,7 @@ impl<'de> serde::Deserialize<'de> for PartitionStats { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -14730,7 +15089,7 @@ impl<'de> serde::Deserialize<'de> for PartitionedFile { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -14889,7 +15248,7 @@ impl<'de> serde::Deserialize<'de> for Partitioning { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -15054,7 +15413,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -15209,7 +15568,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalAliasNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -15325,7 +15684,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalBinaryExprNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -15452,7 +15811,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalCaseNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -15570,7 +15929,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalCastNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -15678,7 +16037,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalColumn { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -15796,7 +16155,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalDateTimeIntervalExprNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -15874,10 +16233,18 @@ impl serde::Serialize for PhysicalExprNode { { use serde::ser::SerializeStruct; let mut len = 0; + if self.expr_id.is_some() { + len += 1; + } if self.expr_type.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalExprNode", len)?; + if let Some(v) = self.expr_id.as_ref() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("exprId", ToString::to_string(&v).as_str())?; + } if let Some(v) = self.expr_type.as_ref() { match v { physical_expr_node::ExprType::Column(v) => { @@ -15949,6 +16316,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ + "expr_id", + "exprId", "column", "literal", "binary_expr", @@ -15985,6 +16354,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { #[allow(clippy::enum_variant_names)] enum GeneratedField { + ExprId, Column, Literal, BinaryExpr, @@ -16012,7 +16382,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -16025,6 +16395,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { E: serde::de::Error, { match value { + "exprId" | "expr_id" => Ok(GeneratedField::ExprId), "column" => Ok(GeneratedField::Column), "literal" => Ok(GeneratedField::Literal), "binaryExpr" | "binary_expr" => Ok(GeneratedField::BinaryExpr), @@ -16063,9 +16434,18 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { where V: serde::de::MapAccess<'de>, { + let mut expr_id__ = None; let mut expr_type__ = None; while let Some(k) = map_.next_key()? { match k { + GeneratedField::ExprId => { + if expr_id__.is_some() { + return Err(serde::de::Error::duplicate_field("exprId")); + } + expr_id__ = + map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| x.0) + ; + } GeneratedField::Column => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("column")); @@ -16202,6 +16582,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { } } Ok(PhysicalExprNode { + expr_id: expr_id__, expr_type: expr_type__, }) } @@ -16258,7 +16639,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExtensionExprNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -16370,7 +16751,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExtensionNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -16521,7 +16902,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalHashExprNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -16677,7 +17058,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalHashRepartition { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -16795,7 +17176,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalInListNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -16904,7 +17285,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalIsNotNull { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -16995,7 +17376,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalIsNull { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -17111,7 +17492,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalLikeExprNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -17229,7 +17610,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalNegativeNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -17320,7 +17701,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalNot { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -17491,6 +17872,12 @@ impl serde::Serialize for PhysicalPlanNode { physical_plan_node::PhysicalPlanType::AsyncFunc(v) => { struct_ser.serialize_field("asyncFunc", v)?; } + physical_plan_node::PhysicalPlanType::Buffer(v) => { + struct_ser.serialize_field("buffer", v)?; + } + physical_plan_node::PhysicalPlanType::ArrowScan(v) => { + struct_ser.serialize_field("arrowScan", v)?; + } } } struct_ser.end() @@ -17558,6 +17945,9 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "memoryScan", "async_func", "asyncFunc", + "buffer", + "arrow_scan", + "arrowScan", ]; #[allow(clippy::enum_variant_names)] @@ -17597,6 +17987,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { SortMergeJoin, MemoryScan, AsyncFunc, + Buffer, + ArrowScan, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -17605,7 +17997,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -17653,6 +18045,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "sortMergeJoin" | "sort_merge_join" => Ok(GeneratedField::SortMergeJoin), "memoryScan" | "memory_scan" => Ok(GeneratedField::MemoryScan), "asyncFunc" | "async_func" => Ok(GeneratedField::AsyncFunc), + "buffer" => Ok(GeneratedField::Buffer), + "arrowScan" | "arrow_scan" => Ok(GeneratedField::ArrowScan), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -17918,6 +18312,20 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { return Err(serde::de::Error::duplicate_field("asyncFunc")); } physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::AsyncFunc) +; + } + GeneratedField::Buffer => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("buffer")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Buffer) +; + } + GeneratedField::ArrowScan => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("arrowScan")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::ArrowScan) ; } } @@ -18014,7 +18422,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -18169,7 +18577,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalSortExprNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -18279,7 +18687,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalSortExprNodeCollection { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -18379,7 +18787,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalTryCastNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -18489,7 +18897,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalWhenThen { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -18671,7 +19079,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -18868,7 +19276,7 @@ impl<'de> serde::Deserialize<'de> for PlaceholderNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -18988,7 +19396,7 @@ impl<'de> serde::Deserialize<'de> for PlaceholderRowExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -19143,7 +19551,7 @@ impl<'de> serde::Deserialize<'de> for PlanType { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -19356,7 +19764,7 @@ impl<'de> serde::Deserialize<'de> for PrepareNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -19474,7 +19882,7 @@ impl<'de> serde::Deserialize<'de> for ProjectionColumns { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -19582,7 +19990,7 @@ impl<'de> serde::Deserialize<'de> for ProjectionExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -19699,7 +20107,7 @@ impl<'de> serde::Deserialize<'de> for ProjectionExpr { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -19799,7 +20207,7 @@ impl<'de> serde::Deserialize<'de> for ProjectionExprs { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -19910,7 +20318,7 @@ impl<'de> serde::Deserialize<'de> for ProjectionNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -20037,7 +20445,7 @@ impl<'de> serde::Deserialize<'de> for RecursionUnnestOption { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -20175,7 +20583,7 @@ impl<'de> serde::Deserialize<'de> for RecursiveQueryNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -20301,7 +20709,7 @@ impl<'de> serde::Deserialize<'de> for RepartitionExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -20421,7 +20829,7 @@ impl<'de> serde::Deserialize<'de> for RepartitionNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -20529,7 +20937,7 @@ impl<'de> serde::Deserialize<'de> for RollupNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -20640,7 +21048,7 @@ impl<'de> serde::Deserialize<'de> for ScalarUdfExprNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -20751,7 +21159,7 @@ impl<'de> serde::Deserialize<'de> for ScanLimit { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -20844,7 +21252,7 @@ impl<'de> serde::Deserialize<'de> for SelectionExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -20943,7 +21351,7 @@ impl<'de> serde::Deserialize<'de> for SelectionNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -21068,7 +21476,7 @@ impl<'de> serde::Deserialize<'de> for SimilarToNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -21213,7 +21621,7 @@ impl<'de> serde::Deserialize<'de> for SortExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -21350,7 +21758,7 @@ impl<'de> serde::Deserialize<'de> for SortExprNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -21460,7 +21868,7 @@ impl<'de> serde::Deserialize<'de> for SortExprNodeCollection { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -21606,7 +22014,7 @@ impl<'de> serde::Deserialize<'de> for SortMergeJoinExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -21769,7 +22177,7 @@ impl<'de> serde::Deserialize<'de> for SortNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -21898,7 +22306,7 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -21996,7 +22404,7 @@ impl<'de> serde::Deserialize<'de> for StreamPartitionMode { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = StreamPartitionMode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -22089,7 +22497,7 @@ impl<'de> serde::Deserialize<'de> for StringifiedPlan { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -22197,7 +22605,7 @@ impl<'de> serde::Deserialize<'de> for SubqueryAliasNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -22372,7 +22780,7 @@ impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -22549,7 +22957,7 @@ impl<'de> serde::Deserialize<'de> for TableReference { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -22666,7 +23074,7 @@ impl<'de> serde::Deserialize<'de> for TryCastNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -22766,7 +23174,7 @@ impl<'de> serde::Deserialize<'de> for UnionExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -22857,7 +23265,7 @@ impl<'de> serde::Deserialize<'de> for UnionNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -22948,7 +23356,7 @@ impl<'de> serde::Deserialize<'de> for UnknownColumn { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -23039,7 +23447,7 @@ impl<'de> serde::Deserialize<'de> for Unnest { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -23164,7 +23572,7 @@ impl<'de> serde::Deserialize<'de> for UnnestExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -23346,7 +23754,7 @@ impl<'de> serde::Deserialize<'de> for UnnestNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -23506,7 +23914,7 @@ impl<'de> serde::Deserialize<'de> for UnnestOptions { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -23618,7 +24026,7 @@ impl<'de> serde::Deserialize<'de> for ValuesNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -23753,7 +24161,7 @@ impl<'de> serde::Deserialize<'de> for ViewTableScanNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -23890,7 +24298,7 @@ impl<'de> serde::Deserialize<'de> for WhenThen { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -23990,7 +24398,7 @@ impl<'de> serde::Deserialize<'de> for Wildcard { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -24122,7 +24530,7 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -24339,7 +24747,7 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -24535,7 +24943,7 @@ impl<'de> serde::Deserialize<'de> for WindowFrame { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -24657,7 +25065,7 @@ impl<'de> serde::Deserialize<'de> for WindowFrameBound { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -24746,7 +25154,7 @@ impl<'de> serde::Deserialize<'de> for WindowFrameBoundType { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = WindowFrameBoundType; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -24820,7 +25228,7 @@ impl<'de> serde::Deserialize<'de> for WindowFrameUnits { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = WindowFrameUnits; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -24914,7 +25322,7 @@ impl<'de> serde::Deserialize<'de> for WindowNode { { struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + impl serde::de::Visitor<'_> for GeneratedVisitor { type Value = GeneratedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index cf343e0258d0b..d9602665c284a 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -444,6 +444,7 @@ pub mod dml_node { InsertAppend = 3, InsertOverwrite = 4, InsertReplace = 5, + Truncate = 6, } impl Type { /// String value of the enum field names used in the ProtoBuf definition. @@ -458,6 +459,7 @@ pub mod dml_node { Self::InsertAppend => "INSERT_APPEND", Self::InsertOverwrite => "INSERT_OVERWRITE", Self::InsertReplace => "INSERT_REPLACE", + Self::Truncate => "TRUNCATE", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -469,6 +471,7 @@ pub mod dml_node { "INSERT_APPEND" => Some(Self::InsertAppend), "INSERT_OVERWRITE" => Some(Self::InsertOverwrite), "INSERT_REPLACE" => Some(Self::InsertReplace), + "TRUNCATE" => Some(Self::Truncate), _ => None, } } @@ -1076,7 +1079,7 @@ pub mod table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38" )] pub physical_plan_type: ::core::option::Option, } @@ -1156,6 +1159,10 @@ pub mod physical_plan_node { MemoryScan(super::MemoryScanExecNode), #[prost(message, tag = "36")] AsyncFunc(::prost::alloc::boxed::Box), + #[prost(message, tag = "37")] + Buffer(::prost::alloc::boxed::Box), + #[prost(message, tag = "38")] + ArrowScan(super::ArrowScanExecNode), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -1183,6 +1190,9 @@ pub struct FileSinkConfig { pub insert_op: i32, #[prost(string, tag = "11")] pub file_extension: ::prost::alloc::string::String, + /// Determines how the output path is interpreted. + #[prost(enumeration = "FileOutputMode", tag = "12")] + pub file_output_mode: i32, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct JsonSink { @@ -1274,6 +1284,14 @@ pub struct PhysicalExtensionNode { /// physical expressions #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalExprNode { + /// Unique identifier for this expression to do deduplication during deserialization. + /// When serializing, this is set to a unique identifier for each combination of + /// expression, process and serialization run. + /// When deserializing, if this ID has been seen before, the cached Arc is returned + /// instead of creating a new one, enabling reconstruction of referential integrity + /// across serde roundtrips. + #[prost(uint64, optional, tag = "30")] + pub expr_id: ::core::option::Option, #[prost( oneof = "physical_expr_node::ExprType", tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19, 20, 21" @@ -1543,6 +1561,8 @@ pub struct FilterExecNode { pub default_filter_selectivity: u32, #[prost(uint32, repeated, tag = "9")] pub projection: ::prost::alloc::vec::Vec, + #[prost(uint32, tag = "10")] + pub batch_size: u32, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct FileGroup { @@ -1651,6 +1671,11 @@ pub struct AvroScanExecNode { pub base_conf: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] +pub struct ArrowScanExecNode { + #[prost(message, optional, tag = "1")] + pub base_conf: ::core::option::Option, +} +#[derive(Clone, PartialEq, ::prost::Message)] pub struct MemoryScanExecNode { #[prost(bytes = "vec", repeated, tag = "1")] pub partitions: ::prost::alloc::vec::Vec<::prost::alloc::vec::Vec>, @@ -1688,6 +1713,8 @@ pub struct HashJoinExecNode { pub filter: ::core::option::Option, #[prost(uint32, repeated, tag = "9")] pub projection: ::prost::alloc::vec::Vec, + #[prost(bool, tag = "10")] + pub null_aware: bool, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct SymmetricHashJoinExecNode { @@ -1830,6 +1857,9 @@ pub struct AggLimit { /// wrap into a message to make it optional #[prost(uint64, tag = "1")] pub limit: u64, + /// Optional ordering direction for TopK aggregation (true = descending, false = ascending) + #[prost(bool, optional, tag = "2")] + pub descending: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct AggregateExecNode { @@ -2134,6 +2164,13 @@ pub struct AsyncFuncExecNode { #[prost(string, repeated, tag = "3")] pub async_expr_names: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, } +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct BufferExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(uint64, tag = "2")] + pub capacity: u64, +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum WindowFrameUnits { @@ -2244,6 +2281,39 @@ impl DateUnit { } } } +/// Determines how file sink output paths are interpreted. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum FileOutputMode { + /// Infer output mode from the URL (extension/trailing `/` heuristic). + Automatic = 0, + /// Write to a single file at the exact output path. + SingleFile = 1, + /// Write to a directory with generated filenames. + Directory = 2, +} +impl FileOutputMode { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Automatic => "FILE_OUTPUT_MODE_AUTOMATIC", + Self::SingleFile => "FILE_OUTPUT_MODE_SINGLE_FILE", + Self::Directory => "FILE_OUTPUT_MODE_DIRECTORY", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "FILE_OUTPUT_MODE_AUTOMATIC" => Some(Self::Automatic), + "FILE_OUTPUT_MODE_SINGLE_FILE" => Some(Self::SingleFile), + "FILE_OUTPUT_MODE_DIRECTORY" => Some(Self::Directory), + _ => None, + } + } +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum InsertOp { @@ -2336,6 +2406,7 @@ pub enum AggregateMode { FinalPartitioned = 2, Single = 3, SinglePartitioned = 4, + PartialReduce = 5, } impl AggregateMode { /// String value of the enum field names used in the ProtoBuf definition. @@ -2349,6 +2420,7 @@ impl AggregateMode { Self::FinalPartitioned => "FINAL_PARTITIONED", Self::Single => "SINGLE", Self::SinglePartitioned => "SINGLE_PARTITIONED", + Self::PartialReduce => "PARTIAL_REDUCE", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2359,6 +2431,7 @@ impl AggregateMode { "FINAL_PARTITIONED" => Some(Self::FinalPartitioned), "SINGLE" => Some(Self::Single), "SINGLE_PARTITIONED" => Some(Self::SinglePartitioned), + "PARTIAL_REDUCE" => Some(Self::PartialReduce), _ => None, } } diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index e30d2a22348cd..7ddc930fa257e 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -23,7 +23,6 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] -#![deny(clippy::allow_attributes)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] //! Serialize / Deserialize DataFusion Plans to bytes diff --git a/datafusion/proto/src/logical_plan/file_formats.rs b/datafusion/proto/src/logical_plan/file_formats.rs index 436a06493766d..08f42b0af7290 100644 --- a/datafusion/proto/src/logical_plan/file_formats.rs +++ b/datafusion/proto/src/logical_plan/file_formats.rs @@ -241,6 +241,7 @@ impl JsonOptionsProto { compression: options.compression as i32, schema_infer_max_rec: options.schema_infer_max_rec.map(|v| v as u64), compression_level: options.compression_level, + newline_delimited: Some(options.newline_delimited), } } else { JsonOptionsProto::default() @@ -260,6 +261,7 @@ impl From<&JsonOptionsProto> for JsonOptions { }, schema_infer_max_rec: proto.schema_infer_max_rec.map(|v| v as usize), compression_level: proto.compression_level, + newline_delimited: proto.newline_delimited.unwrap_or(true), } } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 179fe8bb7d7fe..a653f517b7275 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -239,6 +239,7 @@ impl From for WriteOp { } protobuf::dml_node::Type::InsertReplace => WriteOp::Insert(InsertOp::Replace), protobuf::dml_node::Type::Ctas => WriteOp::Ctas, + protobuf::dml_node::Type::Truncate => WriteOp::Truncate, } } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 6e4e5d0b6eea4..fe63fce6ee260 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -578,6 +578,7 @@ pub fn serialize_expr( Expr::ScalarSubquery(_) | Expr::InSubquery(_) | Expr::Exists { .. } + | Expr::SetComparison(_) | Expr::OuterReferenceColumn { .. } => { // we would need to add logical plan operators to datafusion.proto to support this // see discussion in https://github.com/apache/datafusion/issues/2565 @@ -728,6 +729,7 @@ impl From<&WriteOp> for protobuf::dml_node::Type { WriteOp::Delete => protobuf::dml_node::Type::Delete, WriteOp::Update => protobuf::dml_node::Type::Update, WriteOp::Ctas => protobuf::dml_node::Type::Ctas, + WriteOp::Truncate => protobuf::dml_node::Type::Truncate, } } } diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 3cfc796700dae..e424be162648b 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -21,14 +21,9 @@ use std::sync::Arc; use arrow::array::RecordBatch; use arrow::compute::SortOptions; -use arrow::datatypes::Field; +use arrow::datatypes::{Field, Schema}; use arrow::ipc::reader::StreamReader; use chrono::{TimeZone, Utc}; -use datafusion_expr::dml::InsertOp; -use object_store::ObjectMeta; -use object_store::path::Path; - -use arrow::datatypes::Schema; use datafusion_common::{DataFusionError, Result, internal_datafusion_err, not_impl_err}; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_groups::FileGroup; @@ -42,6 +37,7 @@ use datafusion_datasource_parquet::file_format::ParquetSink; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{FunctionRegistry, TaskContext}; use datafusion_expr::WindowFunctionDefinition; +use datafusion_expr::dml::InsertOp; use datafusion_physical_expr::projection::{ProjectionExpr, ProjectionExprs}; use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr, ScalarFunctionExpr}; use datafusion_physical_plan::expressions::{ @@ -52,13 +48,16 @@ use datafusion_physical_plan::joins::{HashExpr, SeededRandomState}; use datafusion_physical_plan::windows::{create_window_expr, schema_add_window_field}; use datafusion_physical_plan::{Partitioning, PhysicalExpr, WindowExpr}; use datafusion_proto_common::common::proto_error; +use object_store::ObjectMeta; +use object_store::path::Path; -use crate::convert_required; +use super::{ + DefaultPhysicalProtoConverter, PhysicalExtensionCodec, + PhysicalProtoConverterExtension, +}; use crate::logical_plan::{self}; -use crate::protobuf; use crate::protobuf::physical_expr_node::ExprType; - -use super::PhysicalExtensionCodec; +use crate::{convert_required, protobuf}; impl From<&protobuf::PhysicalColumn> for Column { fn from(c: &protobuf::PhysicalColumn) -> Column { @@ -80,9 +79,15 @@ pub fn parse_physical_sort_expr( ctx: &TaskContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { if let Some(expr) = &proto.expr { - let expr = parse_physical_expr(expr.as_ref(), ctx, input_schema, codec)?; + let expr = proto_converter.proto_to_physical_expr( + expr.as_ref(), + ctx, + input_schema, + codec, + )?; let options = SortOptions { descending: !proto.asc, nulls_first: proto.nulls_first, @@ -107,10 +112,13 @@ pub fn parse_physical_sort_exprs( ctx: &TaskContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { proto .iter() - .map(|sort_expr| parse_physical_sort_expr(sort_expr, ctx, input_schema, codec)) + .map(|sort_expr| { + parse_physical_sort_expr(sort_expr, ctx, input_schema, codec, proto_converter) + }) .collect() } @@ -129,12 +137,25 @@ pub fn parse_physical_window_expr( ctx: &TaskContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let window_node_expr = parse_physical_exprs(&proto.args, ctx, input_schema, codec)?; - let partition_by = - parse_physical_exprs(&proto.partition_by, ctx, input_schema, codec)?; - - let order_by = parse_physical_sort_exprs(&proto.order_by, ctx, input_schema, codec)?; + let window_node_expr = + parse_physical_exprs(&proto.args, ctx, input_schema, codec, proto_converter)?; + let partition_by = parse_physical_exprs( + &proto.partition_by, + ctx, + input_schema, + codec, + proto_converter, + )?; + + let order_by = parse_physical_sort_exprs( + &proto.order_by, + ctx, + input_schema, + codec, + proto_converter, + )?; let window_frame = proto .window_frame @@ -188,13 +209,14 @@ pub fn parse_physical_exprs<'a, I>( ctx: &TaskContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result>> where I: IntoIterator, { protos .into_iter() - .map(|p| parse_physical_expr(p, ctx, input_schema, codec)) + .map(|p| proto_converter.proto_to_physical_expr(p, ctx, input_schema, codec)) .collect::>>() } @@ -212,6 +234,32 @@ pub fn parse_physical_expr( ctx: &TaskContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, +) -> Result> { + parse_physical_expr_with_converter( + proto, + ctx, + input_schema, + codec, + &DefaultPhysicalProtoConverter {}, + ) +} + +/// Parses a physical expression from a protobuf. +/// +/// # Arguments +/// +/// * `proto` - Input proto with physical expression node +/// * `registry` - A registry knows how to build logical expressions out of user-defined function names +/// * `input_schema` - The Arrow schema for the input, used for determining expression data types +/// when performing type coercion. +/// * `codec` - An extension codec used to decode custom UDFs. +/// * `proto_converter` - Conversion functions for physical plans and expressions +pub fn parse_physical_expr_with_converter( + proto: &protobuf::PhysicalExprNode, + ctx: &TaskContext, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let expr_type = proto .expr_type @@ -232,6 +280,7 @@ pub fn parse_physical_expr( "left", input_schema, codec, + proto_converter, )?, logical_plan::from_proto::from_proto_binary_op(&binary_expr.op)?, parse_required_physical_expr( @@ -240,6 +289,7 @@ pub fn parse_physical_expr( "right", input_schema, codec, + proto_converter, )?, )), ExprType::AggregateExpr(_) => { @@ -262,6 +312,7 @@ pub fn parse_physical_expr( "expr", input_schema, codec, + proto_converter, )?)) } ExprType::IsNotNullExpr(e) => { @@ -271,6 +322,7 @@ pub fn parse_physical_expr( "expr", input_schema, codec, + proto_converter, )?)) } ExprType::NotExpr(e) => Arc::new(NotExpr::new(parse_required_physical_expr( @@ -279,6 +331,7 @@ pub fn parse_physical_expr( "expr", input_schema, codec, + proto_converter, )?)), ExprType::Negative(e) => { Arc::new(NegativeExpr::new(parse_required_physical_expr( @@ -287,6 +340,7 @@ pub fn parse_physical_expr( "expr", input_schema, codec, + proto_converter, )?)) } ExprType::InList(e) => in_list( @@ -296,15 +350,23 @@ pub fn parse_physical_expr( "expr", input_schema, codec, + proto_converter, )?, - parse_physical_exprs(&e.list, ctx, input_schema, codec)?, + parse_physical_exprs(&e.list, ctx, input_schema, codec, proto_converter)?, &e.negated, input_schema, )?, ExprType::Case(e) => Arc::new(CaseExpr::try_new( e.expr .as_ref() - .map(|e| parse_physical_expr(e.as_ref(), ctx, input_schema, codec)) + .map(|e| { + proto_converter.proto_to_physical_expr( + e.as_ref(), + ctx, + input_schema, + codec, + ) + }) .transpose()?, e.when_then_expr .iter() @@ -316,6 +378,7 @@ pub fn parse_physical_expr( "when_expr", input_schema, codec, + proto_converter, )?, parse_required_physical_expr( e.then_expr.as_ref(), @@ -323,13 +386,21 @@ pub fn parse_physical_expr( "then_expr", input_schema, codec, + proto_converter, )?, )) }) .collect::>>()?, e.else_expr .as_ref() - .map(|e| parse_physical_expr(e.as_ref(), ctx, input_schema, codec)) + .map(|e| { + proto_converter.proto_to_physical_expr( + e.as_ref(), + ctx, + input_schema, + codec, + ) + }) .transpose()?, )?), ExprType::Cast(e) => Arc::new(CastExpr::new( @@ -339,6 +410,7 @@ pub fn parse_physical_expr( "expr", input_schema, codec, + proto_converter, )?, convert_required!(e.arrow_type)?, None, @@ -350,6 +422,7 @@ pub fn parse_physical_expr( "expr", input_schema, codec, + proto_converter, )?, convert_required!(e.arrow_type)?, )), @@ -362,7 +435,8 @@ pub fn parse_physical_expr( }; let scalar_fun_def = Arc::clone(&udf); - let args = parse_physical_exprs(&e.args, ctx, input_schema, codec)?; + let args = + parse_physical_exprs(&e.args, ctx, input_schema, codec, proto_converter)?; let config_options = Arc::clone(ctx.session_config().options()); @@ -391,6 +465,7 @@ pub fn parse_physical_expr( "expr", input_schema, codec, + proto_converter, )?, parse_required_physical_expr( like_expr.pattern.as_deref(), @@ -398,11 +473,17 @@ pub fn parse_physical_expr( "pattern", input_schema, codec, + proto_converter, )?, )), ExprType::HashExpr(hash_expr) => { - let on_columns = - parse_physical_exprs(&hash_expr.on_columns, ctx, input_schema, codec)?; + let on_columns = parse_physical_exprs( + &hash_expr.on_columns, + ctx, + input_schema, + codec, + proto_converter, + )?; Arc::new(HashExpr::new( on_columns, SeededRandomState::with_seeds( @@ -418,9 +499,11 @@ pub fn parse_physical_expr( let inputs: Vec> = extension .inputs .iter() - .map(|e| parse_physical_expr(e, ctx, input_schema, codec)) + .map(|e| { + proto_converter.proto_to_physical_expr(e, ctx, input_schema, codec) + }) .collect::>()?; - (codec.try_decode_expr(extension.expr.as_slice(), &inputs)?) as _ + codec.try_decode_expr(extension.expr.as_slice(), &inputs)? as _ } }; @@ -433,8 +516,9 @@ fn parse_required_physical_expr( field: &str, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - expr.map(|e| parse_physical_expr(e, ctx, input_schema, codec)) + expr.map(|e| proto_converter.proto_to_physical_expr(e, ctx, input_schema, codec)) .transpose()? .ok_or_else(|| internal_datafusion_err!("Missing required field {field:?}")) } @@ -444,11 +528,17 @@ pub fn parse_protobuf_hash_partitioning( ctx: &TaskContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { match partitioning { Some(hash_part) => { - let expr = - parse_physical_exprs(&hash_part.hash_expr, ctx, input_schema, codec)?; + let expr = parse_physical_exprs( + &hash_part.hash_expr, + ctx, + input_schema, + codec, + proto_converter, + )?; Ok(Some(Partitioning::Hash( expr, @@ -464,6 +554,7 @@ pub fn parse_protobuf_partitioning( ctx: &TaskContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { match partitioning { Some(protobuf::Partitioning { partition_method }) => match partition_method { @@ -478,6 +569,7 @@ pub fn parse_protobuf_partitioning( ctx, input_schema, codec, + proto_converter, ) } Some(protobuf::partitioning::PartitionMethod::Unknown(partition_count)) => { @@ -532,6 +624,7 @@ pub fn parse_protobuf_file_scan_config( proto: &protobuf::FileScanExecConf, ctx: &TaskContext, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, file_source: Arc, ) -> Result { let schema: Arc = parse_protobuf_file_scan_schema(proto)?; @@ -557,6 +650,7 @@ pub fn parse_protobuf_file_scan_config( ctx, &schema, codec, + proto_converter, )?; output_ordering.extend(LexOrdering::new(sort_exprs)); } @@ -567,7 +661,7 @@ pub fn parse_protobuf_file_scan_config( .projections .iter() .map(|proto_expr| { - let expr = parse_physical_expr( + let expr = proto_converter.proto_to_physical_expr( proto_expr.expr.as_ref().ok_or_else(|| { internal_datafusion_err!("ProjectionExpr missing expr field") })?, @@ -727,6 +821,17 @@ impl TryFrom<&protobuf::FileSinkConfig> for FileSinkConfig { protobuf::InsertOp::Overwrite => InsertOp::Overwrite, protobuf::InsertOp::Replace => InsertOp::Replace, }; + let file_output_mode = match conf.file_output_mode() { + protobuf::FileOutputMode::Automatic => { + datafusion_datasource::file_sink_config::FileOutputMode::Automatic + } + protobuf::FileOutputMode::SingleFile => { + datafusion_datasource::file_sink_config::FileOutputMode::SingleFile + } + protobuf::FileOutputMode::Directory => { + datafusion_datasource::file_sink_config::FileOutputMode::Directory + } + }; Ok(Self { original_url: String::default(), object_store_url: ObjectStoreUrl::parse(&conf.object_store_url)?, @@ -737,18 +842,20 @@ impl TryFrom<&protobuf::FileSinkConfig> for FileSinkConfig { insert_op, keep_partition_by_columns: conf.keep_partition_by_columns, file_extension: conf.file_extension.clone(), + file_output_mode, }) } } #[cfg(test)] mod tests { - use super::*; use chrono::{TimeZone, Utc}; use datafusion_datasource::PartitionedFile; use object_store::ObjectMeta; use object_store::path::Path; + use super::*; + #[test] fn partitioned_file_path_roundtrip_percent_encoded() { let path_str = "foo/foo%2Fbar/baz%252Fqux"; diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 0666fc2979b38..bfba715b91249 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -15,33 +15,14 @@ // specific language governing permissions and limitations // under the License. +use std::cell::RefCell; +use std::collections::HashMap; use std::fmt::Debug; +use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; -use self::from_proto::parse_protobuf_partitioning; -use self::to_proto::{serialize_partitioning, serialize_physical_expr}; -use crate::common::{byte_to_string, str_to_byte}; -use crate::physical_plan::from_proto::{ - parse_physical_expr, parse_physical_sort_expr, parse_physical_sort_exprs, - parse_physical_window_expr, parse_protobuf_file_scan_config, parse_record_batches, - parse_table_schema_from_proto, -}; -use crate::physical_plan::to_proto::{ - serialize_file_scan_config, serialize_maybe_filter, serialize_physical_aggr_expr, - serialize_physical_sort_exprs, serialize_physical_window_expr, - serialize_record_batches, -}; -use crate::protobuf::physical_aggregate_expr_node::AggregateFunction; -use crate::protobuf::physical_expr_node::ExprType; -use crate::protobuf::physical_plan_node::PhysicalPlanType; -use crate::protobuf::{ - self, ListUnnest as ProtoListUnnest, SortExprNode, SortMergeJoinExecNode, - proto_error, window_agg_exec_node, -}; -use crate::{convert_required, into_required}; - use arrow::compute::SortOptions; -use arrow::datatypes::{IntervalMonthDayNanoType, SchemaRef}; +use arrow::datatypes::{IntervalMonthDayNanoType, Schema, SchemaRef}; use datafusion_catalog::memory::MemorySourceConfig; use datafusion_common::config::CsvOptions; use datafusion_common::{ @@ -53,6 +34,7 @@ use datafusion_datasource::file_compression_type::FileCompressionType; use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; use datafusion_datasource::sink::DataSinkExec; use datafusion_datasource::source::{DataSource, DataSourceExec}; +use datafusion_datasource_arrow::source::ArrowSource; #[cfg(feature = "avro")] use datafusion_datasource_avro::source::AvroSource; use datafusion_datasource_csv::file_format::CsvSink; @@ -68,12 +50,15 @@ use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; use datafusion_functions_table::generate_series::{ Empty, GenSeriesArgs, GenerateSeriesTable, GenericSeriesState, TimestampValue, }; -use datafusion_physical_expr::aggregate::AggregateExprBuilder; -use datafusion_physical_expr::aggregate::AggregateFunctionExpr; +use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; +use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr; use datafusion_physical_expr::{LexOrdering, LexRequirement, PhysicalExprRef}; -use datafusion_physical_plan::aggregates::AggregateMode; -use datafusion_physical_plan::aggregates::{AggregateExec, PhysicalGroupBy}; +use datafusion_physical_plan::aggregates::{ + AggregateExec, AggregateMode, LimitOptions, PhysicalGroupBy, +}; use datafusion_physical_plan::analyze::AnalyzeExec; +use datafusion_physical_plan::async_func::AsyncFuncExec; +use datafusion_physical_plan::buffer::BufferExec; #[expect(deprecated)] use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; @@ -81,13 +66,12 @@ use datafusion_physical_plan::coop::CooperativeExec; use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::explain::ExplainExec; use datafusion_physical_plan::expressions::PhysicalSortExpr; -use datafusion_physical_plan::filter::FilterExec; +use datafusion_physical_plan::filter::{FilterExec, FilterExecBuilder}; use datafusion_physical_plan::joins::utils::{ColumnIndex, JoinFilter}; use datafusion_physical_plan::joins::{ - CrossJoinExec, NestedLoopJoinExec, SortMergeJoinExec, StreamJoinPartitionMode, - SymmetricHashJoinExec, + CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec, + StreamJoinPartitionMode, SymmetricHashJoinExec, }; -use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::memory::LazyMemoryExec; use datafusion_physical_plan::metrics::MetricType; @@ -100,12 +84,31 @@ use datafusion_physical_plan::union::{InterleaveExec, UnionExec}; use datafusion_physical_plan::unnest::{ListUnnest, UnnestExec}; use datafusion_physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use datafusion_physical_plan::{ExecutionPlan, InputOrderMode, PhysicalExpr, WindowExpr}; - -use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr; -use datafusion_physical_plan::async_func::AsyncFuncExec; use prost::Message; use prost::bytes::BufMut; +use self::from_proto::parse_protobuf_partitioning; +use self::to_proto::serialize_partitioning; +use crate::common::{byte_to_string, str_to_byte}; +use crate::physical_plan::from_proto::{ + parse_physical_expr_with_converter, parse_physical_sort_expr, + parse_physical_sort_exprs, parse_physical_window_expr, + parse_protobuf_file_scan_config, parse_record_batches, parse_table_schema_from_proto, +}; +use crate::physical_plan::to_proto::{ + serialize_file_scan_config, serialize_maybe_filter, serialize_physical_aggr_expr, + serialize_physical_expr_with_converter, serialize_physical_sort_exprs, + serialize_physical_window_expr, serialize_record_batches, +}; +use crate::protobuf::physical_aggregate_expr_node::AggregateFunction; +use crate::protobuf::physical_expr_node::ExprType; +use crate::protobuf::physical_plan_node::PhysicalPlanType; +use crate::protobuf::{ + self, ListUnnest as ProtoListUnnest, SortExprNode, SortMergeJoinExecNode, + proto_error, window_agg_exec_node, +}; +use crate::{convert_required, into_required}; + pub mod from_proto; pub mod to_proto; @@ -132,8 +135,37 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { fn try_into_physical_plan( &self, ctx: &TaskContext, + codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + self.try_into_physical_plan_with_converter( + ctx, + codec, + &DefaultPhysicalProtoConverter {}, + ) + } - extension_codec: &dyn PhysicalExtensionCodec, + fn try_from_physical_plan( + plan: Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result + where + Self: Sized, + { + Self::try_from_physical_plan_with_converter( + plan, + codec, + &DefaultPhysicalProtoConverter {}, + ) + } +} + +impl protobuf::PhysicalPlanNode { + pub fn try_into_physical_plan_with_converter( + &self, + ctx: &TaskContext, + + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let plan = self.physical_plan_type.as_ref().ok_or_else(|| { proto_error(format!( @@ -142,125 +174,155 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { })?; match plan { PhysicalPlanType::Explain(explain) => { - self.try_into_explain_physical_plan(explain, ctx, extension_codec) - } - PhysicalPlanType::Projection(projection) => { - self.try_into_projection_physical_plan(projection, ctx, extension_codec) + self.try_into_explain_physical_plan(explain, ctx, codec, proto_converter) } + PhysicalPlanType::Projection(projection) => self + .try_into_projection_physical_plan( + projection, + ctx, + codec, + proto_converter, + ), PhysicalPlanType::Filter(filter) => { - self.try_into_filter_physical_plan(filter, ctx, extension_codec) + self.try_into_filter_physical_plan(filter, ctx, codec, proto_converter) } PhysicalPlanType::CsvScan(scan) => { - self.try_into_csv_scan_physical_plan(scan, ctx, extension_codec) + self.try_into_csv_scan_physical_plan(scan, ctx, codec, proto_converter) } PhysicalPlanType::JsonScan(scan) => { - self.try_into_json_scan_physical_plan(scan, ctx, extension_codec) - } - PhysicalPlanType::ParquetScan(scan) => { - self.try_into_parquet_scan_physical_plan(scan, ctx, extension_codec) + self.try_into_json_scan_physical_plan(scan, ctx, codec, proto_converter) } + PhysicalPlanType::ParquetScan(scan) => self + .try_into_parquet_scan_physical_plan(scan, ctx, codec, proto_converter), PhysicalPlanType::AvroScan(scan) => { - self.try_into_avro_scan_physical_plan(scan, ctx, extension_codec) + self.try_into_avro_scan_physical_plan(scan, ctx, codec, proto_converter) } PhysicalPlanType::MemoryScan(scan) => { - self.try_into_memory_scan_physical_plan(scan, ctx, extension_codec) + self.try_into_memory_scan_physical_plan(scan, ctx, codec, proto_converter) + } + PhysicalPlanType::ArrowScan(scan) => { + self.try_into_arrow_scan_physical_plan(scan, ctx, codec, proto_converter) } PhysicalPlanType::CoalesceBatches(coalesce_batches) => self .try_into_coalesce_batches_physical_plan( coalesce_batches, ctx, - extension_codec, + codec, + proto_converter, ), PhysicalPlanType::Merge(merge) => { - self.try_into_merge_physical_plan(merge, ctx, extension_codec) - } - PhysicalPlanType::Repartition(repart) => { - self.try_into_repartition_physical_plan(repart, ctx, extension_codec) - } - PhysicalPlanType::GlobalLimit(limit) => { - self.try_into_global_limit_physical_plan(limit, ctx, extension_codec) - } - PhysicalPlanType::LocalLimit(limit) => { - self.try_into_local_limit_physical_plan(limit, ctx, extension_codec) - } - PhysicalPlanType::Window(window_agg) => { - self.try_into_window_physical_plan(window_agg, ctx, extension_codec) - } - PhysicalPlanType::Aggregate(hash_agg) => { - self.try_into_aggregate_physical_plan(hash_agg, ctx, extension_codec) - } - PhysicalPlanType::HashJoin(hashjoin) => { - self.try_into_hash_join_physical_plan(hashjoin, ctx, extension_codec) + self.try_into_merge_physical_plan(merge, ctx, codec, proto_converter) } + PhysicalPlanType::Repartition(repart) => self + .try_into_repartition_physical_plan(repart, ctx, codec, proto_converter), + PhysicalPlanType::GlobalLimit(limit) => self + .try_into_global_limit_physical_plan(limit, ctx, codec, proto_converter), + PhysicalPlanType::LocalLimit(limit) => self + .try_into_local_limit_physical_plan(limit, ctx, codec, proto_converter), + PhysicalPlanType::Window(window_agg) => self.try_into_window_physical_plan( + window_agg, + ctx, + codec, + proto_converter, + ), + PhysicalPlanType::Aggregate(hash_agg) => self + .try_into_aggregate_physical_plan(hash_agg, ctx, codec, proto_converter), + PhysicalPlanType::HashJoin(hashjoin) => self + .try_into_hash_join_physical_plan(hashjoin, ctx, codec, proto_converter), PhysicalPlanType::SymmetricHashJoin(sym_join) => self .try_into_symmetric_hash_join_physical_plan( sym_join, ctx, - extension_codec, + codec, + proto_converter, ), PhysicalPlanType::Union(union) => { - self.try_into_union_physical_plan(union, ctx, extension_codec) - } - PhysicalPlanType::Interleave(interleave) => { - self.try_into_interleave_physical_plan(interleave, ctx, extension_codec) - } - PhysicalPlanType::CrossJoin(crossjoin) => { - self.try_into_cross_join_physical_plan(crossjoin, ctx, extension_codec) + self.try_into_union_physical_plan(union, ctx, codec, proto_converter) } - PhysicalPlanType::Empty(empty) => { - self.try_into_empty_physical_plan(empty, ctx, extension_codec) - } - PhysicalPlanType::PlaceholderRow(placeholder) => self - .try_into_placeholder_row_physical_plan( - placeholder, + PhysicalPlanType::Interleave(interleave) => self + .try_into_interleave_physical_plan( + interleave, ctx, - extension_codec, + codec, + proto_converter, ), - PhysicalPlanType::Sort(sort) => { - self.try_into_sort_physical_plan(sort, ctx, extension_codec) + PhysicalPlanType::CrossJoin(crossjoin) => self + .try_into_cross_join_physical_plan( + crossjoin, + ctx, + codec, + proto_converter, + ), + PhysicalPlanType::Empty(empty) => { + self.try_into_empty_physical_plan(empty, ctx, codec, proto_converter) } - PhysicalPlanType::SortPreservingMerge(sort) => self - .try_into_sort_preserving_merge_physical_plan(sort, ctx, extension_codec), - PhysicalPlanType::Extension(extension) => { - self.try_into_extension_physical_plan(extension, ctx, extension_codec) + PhysicalPlanType::PlaceholderRow(placeholder) => { + self.try_into_placeholder_row_physical_plan(placeholder, ctx, codec) } - PhysicalPlanType::NestedLoopJoin(join) => { - self.try_into_nested_loop_join_physical_plan(join, ctx, extension_codec) + PhysicalPlanType::Sort(sort) => { + self.try_into_sort_physical_plan(sort, ctx, codec, proto_converter) } + PhysicalPlanType::SortPreservingMerge(sort) => self + .try_into_sort_preserving_merge_physical_plan( + sort, + ctx, + codec, + proto_converter, + ), + PhysicalPlanType::Extension(extension) => self + .try_into_extension_physical_plan(extension, ctx, codec, proto_converter), + PhysicalPlanType::NestedLoopJoin(join) => self + .try_into_nested_loop_join_physical_plan( + join, + ctx, + codec, + proto_converter, + ), PhysicalPlanType::Analyze(analyze) => { - self.try_into_analyze_physical_plan(analyze, ctx, extension_codec) + self.try_into_analyze_physical_plan(analyze, ctx, codec, proto_converter) } PhysicalPlanType::JsonSink(sink) => { - self.try_into_json_sink_physical_plan(sink, ctx, extension_codec) + self.try_into_json_sink_physical_plan(sink, ctx, codec, proto_converter) } PhysicalPlanType::CsvSink(sink) => { - self.try_into_csv_sink_physical_plan(sink, ctx, extension_codec) + self.try_into_csv_sink_physical_plan(sink, ctx, codec, proto_converter) } #[cfg_attr(not(feature = "parquet"), allow(unused_variables))] - PhysicalPlanType::ParquetSink(sink) => { - self.try_into_parquet_sink_physical_plan(sink, ctx, extension_codec) - } + PhysicalPlanType::ParquetSink(sink) => self + .try_into_parquet_sink_physical_plan(sink, ctx, codec, proto_converter), PhysicalPlanType::Unnest(unnest) => { - self.try_into_unnest_physical_plan(unnest, ctx, extension_codec) - } - PhysicalPlanType::Cooperative(cooperative) => { - self.try_into_cooperative_physical_plan(cooperative, ctx, extension_codec) + self.try_into_unnest_physical_plan(unnest, ctx, codec, proto_converter) } + PhysicalPlanType::Cooperative(cooperative) => self + .try_into_cooperative_physical_plan( + cooperative, + ctx, + codec, + proto_converter, + ), PhysicalPlanType::GenerateSeries(generate_series) => { self.try_into_generate_series_physical_plan(generate_series) } PhysicalPlanType::SortMergeJoin(sort_join) => { - self.try_into_sort_join(sort_join, ctx, extension_codec) + self.try_into_sort_join(sort_join, ctx, codec, proto_converter) } - PhysicalPlanType::AsyncFunc(async_func) => { - self.try_into_async_func_physical_plan(async_func, ctx, extension_codec) + PhysicalPlanType::AsyncFunc(async_func) => self + .try_into_async_func_physical_plan( + async_func, + ctx, + codec, + proto_converter, + ), + PhysicalPlanType::Buffer(buffer) => { + self.try_into_buffer_physical_plan(buffer, ctx, codec, proto_converter) } } } - fn try_from_physical_plan( + pub fn try_from_physical_plan_with_converter( plan: Arc, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result where Self: Sized, @@ -269,93 +331,96 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { let plan = plan.as_any(); if let Some(exec) = plan.downcast_ref::() { - return protobuf::PhysicalPlanNode::try_from_explain_exec( - exec, - extension_codec, - ); + return protobuf::PhysicalPlanNode::try_from_explain_exec(exec, codec); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_projection_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_analyze_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_filter_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(limit) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_global_limit_exec( limit, - extension_codec, + codec, + proto_converter, ); } if let Some(limit) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_local_limit_exec( limit, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_hash_join_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_symmetric_hash_join_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_sort_merge_join_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_cross_join_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_aggregate_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(empty) = plan.downcast_ref::() { - return protobuf::PhysicalPlanNode::try_from_empty_exec( - empty, - extension_codec, - ); + return protobuf::PhysicalPlanNode::try_from_empty_exec(empty, codec); } if let Some(empty) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_placeholder_row_exec( - empty, - extension_codec, + empty, codec, ); } @@ -363,14 +428,16 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { if let Some(coalesce_batches) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_coalesce_batches_exec( coalesce_batches, - extension_codec, + codec, + proto_converter, ); } if let Some(data_source_exec) = plan.downcast_ref::() && let Some(node) = protobuf::PhysicalPlanNode::try_from_data_source_exec( data_source_exec, - extension_codec, + codec, + proto_converter, )? { return Ok(node); @@ -379,67 +446,80 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_coalesce_partitions_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_repartition_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { - return protobuf::PhysicalPlanNode::try_from_sort_exec(exec, extension_codec); + return protobuf::PhysicalPlanNode::try_from_sort_exec( + exec, + codec, + proto_converter, + ); } if let Some(union) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_union_exec( union, - extension_codec, + codec, + proto_converter, ); } if let Some(interleave) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_interleave_exec( interleave, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_sort_preserving_merge_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_nested_loop_join_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_window_agg_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_bounded_window_agg_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() && let Some(node) = protobuf::PhysicalPlanNode::try_from_data_sink_exec( exec, - extension_codec, + codec, + proto_converter, )? { return Ok(node); @@ -448,14 +528,16 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_unnest_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_cooperative_exec( exec, - extension_codec, + codec, + proto_converter, ); } @@ -469,21 +551,31 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_async_func_exec( exec, - extension_codec, + codec, + proto_converter, + ); + } + + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_buffer_exec( + exec, + codec, + proto_converter, ); } let mut buf: Vec = vec![]; - match extension_codec.try_encode(Arc::clone(&plan_clone), &mut buf) { + match codec.try_encode(Arc::clone(&plan_clone), &mut buf) { Ok(_) => { let inputs: Vec = plan_clone .children() .into_iter() .cloned() .map(|i| { - protobuf::PhysicalPlanNode::try_from_physical_plan( + protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( i, - extension_codec, + codec, + proto_converter, ) }) .collect::>()?; @@ -507,7 +599,8 @@ impl protobuf::PhysicalPlanNode { explain: &protobuf::ExplainExecNode, _ctx: &TaskContext, - _extension_codec: &dyn PhysicalExtensionCodec, + _codec: &dyn PhysicalExtensionCodec, + _proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { Ok(Arc::new(ExplainExec::new( Arc::new(explain.schema.as_ref().unwrap().try_into()?), @@ -525,21 +618,22 @@ impl protobuf::PhysicalPlanNode { projection: &protobuf::ProjectionExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&projection.input, ctx, extension_codec)?; + into_physical_plan(&projection.input, ctx, codec, proto_converter)?; let exprs = projection .expr .iter() .zip(projection.expr_name.iter()) .map(|(expr, name)| { Ok(( - parse_physical_expr( + proto_converter.proto_to_physical_expr( expr, ctx, input.schema().as_ref(), - extension_codec, + codec, )?, name.to_string(), )) @@ -557,16 +651,22 @@ impl protobuf::PhysicalPlanNode { filter: &protobuf::FilterExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&filter.input, ctx, extension_codec)?; + into_physical_plan(&filter.input, ctx, codec, proto_converter)?; let predicate = filter .expr .as_ref() .map(|expr| { - parse_physical_expr(expr, ctx, input.schema().as_ref(), extension_codec) + proto_converter.proto_to_physical_expr( + expr, + ctx, + input.schema().as_ref(), + codec, + ) }) .transpose()? .ok_or_else(|| { @@ -588,8 +688,10 @@ impl protobuf::PhysicalPlanNode { None }; - let filter = - FilterExec::try_new(predicate, input)?.with_projection(projection)?; + let filter = FilterExecBuilder::new(predicate, input) + .apply_projection(projection)? + .with_batch_size(filter.batch_size as usize) + .build()?; match filter_selectivity { Ok(filter_selectivity) => Ok(Arc::new( filter.with_default_selectivity(filter_selectivity)?, @@ -605,7 +707,8 @@ impl protobuf::PhysicalPlanNode { scan: &protobuf::CsvScanExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let escape = if let Some(protobuf::csv_scan_exec_node::OptionalEscape::Escape(escape)) = @@ -646,7 +749,8 @@ impl protobuf::PhysicalPlanNode { let conf = FileScanConfigBuilder::from(parse_protobuf_file_scan_config( scan.base_conf.as_ref().unwrap(), ctx, - extension_codec, + codec, + proto_converter, source, )?) .with_file_compression_type(FileCompressionType::UNCOMPRESSED) @@ -659,25 +763,49 @@ impl protobuf::PhysicalPlanNode { scan: &protobuf::JsonScanExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let base_conf = scan.base_conf.as_ref().unwrap(); let table_schema = parse_table_schema_from_proto(base_conf)?; let scan_conf = parse_protobuf_file_scan_config( base_conf, ctx, - extension_codec, + codec, + proto_converter, Arc::new(JsonSource::new(table_schema)), )?; Ok(DataSourceExec::from_data_source(scan_conf)) } + fn try_into_arrow_scan_physical_plan( + &self, + scan: &protobuf::ArrowScanExecNode, + ctx: &TaskContext, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result> { + let base_conf = scan.base_conf.as_ref().ok_or_else(|| { + internal_datafusion_err!("base_conf in ArrowScanExecNode is missing.") + })?; + let table_schema = parse_table_schema_from_proto(base_conf)?; + let scan_conf = parse_protobuf_file_scan_config( + base_conf, + ctx, + codec, + proto_converter, + Arc::new(ArrowSource::new_file_source(table_schema)), + )?; + Ok(DataSourceExec::from_data_source(scan_conf)) + } + #[cfg_attr(not(feature = "parquet"), expect(unused_variables))] fn try_into_parquet_scan_physical_plan( &self, scan: &protobuf::ParquetScanExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { #[cfg(feature = "parquet")] { @@ -694,7 +822,7 @@ impl protobuf::PhysicalPlanNode { .iter() .map(|&i| schema.field(i as usize).clone()) .collect(); - Arc::new(arrow::datatypes::Schema::new(projected_fields)) + Arc::new(Schema::new(projected_fields)) } else { schema }; @@ -703,11 +831,11 @@ impl protobuf::PhysicalPlanNode { .predicate .as_ref() .map(|expr| { - parse_physical_expr( + proto_converter.proto_to_physical_expr( expr, ctx, predicate_schema.as_ref(), - extension_codec, + codec, ) }) .transpose()?; @@ -729,7 +857,8 @@ impl protobuf::PhysicalPlanNode { let base_config = parse_protobuf_file_scan_config( base_conf, ctx, - extension_codec, + codec, + proto_converter, Arc::new(source), )?; Ok(DataSourceExec::from_data_source(base_config)) @@ -745,7 +874,8 @@ impl protobuf::PhysicalPlanNode { &self, scan: &protobuf::AvroScanExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { #[cfg(feature = "avro")] { @@ -754,7 +884,8 @@ impl protobuf::PhysicalPlanNode { let conf = parse_protobuf_file_scan_config( scan.base_conf.as_ref().unwrap(), ctx, - extension_codec, + codec, + proto_converter, Arc::new(AvroSource::new(table_schema)), )?; Ok(DataSourceExec::from_data_source(conf)) @@ -769,7 +900,8 @@ impl protobuf::PhysicalPlanNode { scan: &protobuf::MemoryScanExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let partitions = scan .partitions @@ -799,7 +931,8 @@ impl protobuf::PhysicalPlanNode { &ordering.physical_sort_expr_nodes, ctx, &schema, - extension_codec, + codec, + proto_converter, )?; sort_information.extend(LexOrdering::new(sort_exprs)); } @@ -818,10 +951,11 @@ impl protobuf::PhysicalPlanNode { coalesce_batches: &protobuf::CoalesceBatchesExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&coalesce_batches.input, ctx, extension_codec)?; + into_physical_plan(&coalesce_batches.input, ctx, codec, proto_converter)?; Ok(Arc::new( #[expect(deprecated)] CoalesceBatchesExec::new(input, coalesce_batches.target_batch_size as usize) @@ -834,10 +968,11 @@ impl protobuf::PhysicalPlanNode { merge: &protobuf::CoalescePartitionsExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&merge.input, ctx, extension_codec)?; + into_physical_plan(&merge.input, ctx, codec, proto_converter)?; Ok(Arc::new( CoalescePartitionsExec::new(input) .with_fetch(merge.fetch.map(|f| f as usize)), @@ -849,15 +984,17 @@ impl protobuf::PhysicalPlanNode { repart: &protobuf::RepartitionExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&repart.input, ctx, extension_codec)?; + into_physical_plan(&repart.input, ctx, codec, proto_converter)?; let partitioning = parse_protobuf_partitioning( repart.partitioning.as_ref(), ctx, input.schema().as_ref(), - extension_codec, + codec, + proto_converter, )?; Ok(Arc::new(RepartitionExec::try_new( input, @@ -870,10 +1007,11 @@ impl protobuf::PhysicalPlanNode { limit: &protobuf::GlobalLimitExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&limit.input, ctx, extension_codec)?; + into_physical_plan(&limit.input, ctx, codec, proto_converter)?; let fetch = if limit.fetch >= 0 { Some(limit.fetch as usize) } else { @@ -891,10 +1029,11 @@ impl protobuf::PhysicalPlanNode { limit: &protobuf::LocalLimitExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&limit.input, ctx, extension_codec)?; + into_physical_plan(&limit.input, ctx, codec, proto_converter)?; Ok(Arc::new(LocalLimitExec::new(input, limit.fetch as usize))) } @@ -903,10 +1042,11 @@ impl protobuf::PhysicalPlanNode { window_agg: &protobuf::WindowAggExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&window_agg.input, ctx, extension_codec)?; + into_physical_plan(&window_agg.input, ctx, codec, proto_converter)?; let input_schema = input.schema(); let physical_window_expr: Vec> = window_agg @@ -917,7 +1057,8 @@ impl protobuf::PhysicalPlanNode { window_expr, ctx, input_schema.as_ref(), - extension_codec, + codec, + proto_converter, ) }) .collect::, _>>()?; @@ -926,7 +1067,12 @@ impl protobuf::PhysicalPlanNode { .partition_keys .iter() .map(|expr| { - parse_physical_expr(expr, ctx, input.schema().as_ref(), extension_codec) + proto_converter.proto_to_physical_expr( + expr, + ctx, + input.schema().as_ref(), + codec, + ) }) .collect::>>>()?; @@ -961,10 +1107,11 @@ impl protobuf::PhysicalPlanNode { hash_agg: &protobuf::AggregateExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&hash_agg.input, ctx, extension_codec)?; + into_physical_plan(&hash_agg.input, ctx, codec, proto_converter)?; let mode = protobuf::AggregateMode::try_from(hash_agg.mode).map_err(|_| { proto_error(format!( "Received a AggregateNode message with unknown AggregateMode {}", @@ -979,6 +1126,7 @@ impl protobuf::PhysicalPlanNode { protobuf::AggregateMode::SinglePartitioned => { AggregateMode::SinglePartitioned } + protobuf::AggregateMode::PartialReduce => AggregateMode::PartialReduce, }; let num_expr = hash_agg.group_expr.len(); @@ -988,7 +1136,8 @@ impl protobuf::PhysicalPlanNode { .iter() .zip(hash_agg.group_expr_name.iter()) .map(|(expr, name)| { - parse_physical_expr(expr, ctx, input.schema().as_ref(), extension_codec) + proto_converter + .proto_to_physical_expr(expr, ctx, input.schema().as_ref(), codec) .map(|expr| (expr, name.to_string())) }) .collect::, _>>()?; @@ -998,7 +1147,8 @@ impl protobuf::PhysicalPlanNode { .iter() .zip(hash_agg.group_expr_name.iter()) .map(|(expr, name)| { - parse_physical_expr(expr, ctx, input.schema().as_ref(), extension_codec) + proto_converter + .proto_to_physical_expr(expr, ctx, input.schema().as_ref(), codec) .map(|expr| (expr, name.to_string())) }) .collect::, _>>()?; @@ -1027,7 +1177,12 @@ impl protobuf::PhysicalPlanNode { expr.expr .as_ref() .map(|e| { - parse_physical_expr(e, ctx, &physical_schema, extension_codec) + proto_converter.proto_to_physical_expr( + e, + ctx, + &physical_schema, + codec, + ) }) .transpose() }) @@ -1048,11 +1203,11 @@ impl protobuf::PhysicalPlanNode { .expr .iter() .map(|e| { - parse_physical_expr( + proto_converter.proto_to_physical_expr( e, ctx, &physical_schema, - extension_codec, + codec, ) }) .collect::>>()?; @@ -1064,7 +1219,8 @@ impl protobuf::PhysicalPlanNode { e, ctx, &physical_schema, - extension_codec, + codec, + proto_converter, ) }) .collect::>()?; @@ -1074,11 +1230,11 @@ impl protobuf::PhysicalPlanNode { .map(|func| match func { AggregateFunction::UserDefinedAggrFunction(udaf_name) => { let agg_udf = match &agg_node.fun_definition { - Some(buf) => extension_codec - .try_decode_udaf(udaf_name, buf)?, + Some(buf) => { + codec.try_decode_udaf(udaf_name, buf)? + } None => ctx.udaf(udaf_name).or_else(|_| { - extension_codec - .try_decode_udaf(udaf_name, &[]) + codec.try_decode_udaf(udaf_name, &[]) })?, }; @@ -1105,11 +1261,6 @@ impl protobuf::PhysicalPlanNode { }) .collect::, _>>()?; - let limit = hash_agg - .limit - .as_ref() - .map(|lit_value| lit_value.limit as usize); - let agg = AggregateExec::try_new( agg_mode, PhysicalGroupBy::new(group_expr, null_expr, groups, has_grouping_set), @@ -1119,7 +1270,16 @@ impl protobuf::PhysicalPlanNode { physical_schema, )?; - let agg = agg.with_limit(limit); + let agg = if let Some(limit_proto) = &hash_agg.limit { + let limit = limit_proto.limit as usize; + let limit_options = match limit_proto.descending { + Some(descending) => LimitOptions::new_with_order(limit, descending), + None => LimitOptions::new(limit), + }; + agg.with_limit_options(Some(limit_options)) + } else { + agg + }; Ok(Arc::new(agg)) } @@ -1129,29 +1289,30 @@ impl protobuf::PhysicalPlanNode { hashjoin: &protobuf::HashJoinExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let left: Arc = - into_physical_plan(&hashjoin.left, ctx, extension_codec)?; + into_physical_plan(&hashjoin.left, ctx, codec, proto_converter)?; let right: Arc = - into_physical_plan(&hashjoin.right, ctx, extension_codec)?; + into_physical_plan(&hashjoin.right, ctx, codec, proto_converter)?; let left_schema = left.schema(); let right_schema = right.schema(); let on: Vec<(PhysicalExprRef, PhysicalExprRef)> = hashjoin .on .iter() .map(|col| { - let left = parse_physical_expr( + let left = proto_converter.proto_to_physical_expr( &col.left.clone().unwrap(), ctx, left_schema.as_ref(), - extension_codec, + codec, )?; - let right = parse_physical_expr( + let right = proto_converter.proto_to_physical_expr( &col.right.clone().unwrap(), ctx, right_schema.as_ref(), - extension_codec, + codec, )?; Ok((left, right)) }) @@ -1180,12 +1341,12 @@ impl protobuf::PhysicalPlanNode { .ok_or_else(|| proto_error("Missing JoinFilter schema"))? .try_into()?; - let expression = parse_physical_expr( + let expression = proto_converter.proto_to_physical_expr( f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, ctx, &schema, - extension_codec, + codec, )?; let column_indices = f.column_indices .iter() @@ -1239,6 +1400,7 @@ impl protobuf::PhysicalPlanNode { projection, partition_mode, null_equality.into(), + hashjoin.null_aware, )?)) } @@ -1247,27 +1409,28 @@ impl protobuf::PhysicalPlanNode { sym_join: &protobuf::SymmetricHashJoinExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let left = into_physical_plan(&sym_join.left, ctx, extension_codec)?; - let right = into_physical_plan(&sym_join.right, ctx, extension_codec)?; + let left = into_physical_plan(&sym_join.left, ctx, codec, proto_converter)?; + let right = into_physical_plan(&sym_join.right, ctx, codec, proto_converter)?; let left_schema = left.schema(); let right_schema = right.schema(); let on = sym_join .on .iter() .map(|col| { - let left = parse_physical_expr( + let left = proto_converter.proto_to_physical_expr( &col.left.clone().unwrap(), ctx, left_schema.as_ref(), - extension_codec, + codec, )?; - let right = parse_physical_expr( + let right = proto_converter.proto_to_physical_expr( &col.right.clone().unwrap(), ctx, right_schema.as_ref(), - extension_codec, + codec, )?; Ok((left, right)) }) @@ -1296,12 +1459,12 @@ impl protobuf::PhysicalPlanNode { .ok_or_else(|| proto_error("Missing JoinFilter schema"))? .try_into()?; - let expression = parse_physical_expr( + let expression = proto_converter.proto_to_physical_expr( f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, ctx, &schema, - extension_codec, + codec, )?; let column_indices = f.column_indices .iter() @@ -1327,7 +1490,8 @@ impl protobuf::PhysicalPlanNode { &sym_join.left_sort_exprs, ctx, &left_schema, - extension_codec, + codec, + proto_converter, )?; let left_sort_exprs = LexOrdering::new(left_sort_exprs); @@ -1335,7 +1499,8 @@ impl protobuf::PhysicalPlanNode { &sym_join.right_sort_exprs, ctx, &right_schema, - extension_codec, + codec, + proto_converter, )?; let right_sort_exprs = LexOrdering::new(right_sort_exprs); @@ -1375,11 +1540,12 @@ impl protobuf::PhysicalPlanNode { union: &protobuf::UnionExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let mut inputs: Vec> = vec![]; for input in &union.inputs { - inputs.push(input.try_into_physical_plan(ctx, extension_codec)?); + inputs.push(proto_converter.proto_to_execution_plan(ctx, codec, input)?); } UnionExec::try_new(inputs) } @@ -1389,11 +1555,12 @@ impl protobuf::PhysicalPlanNode { interleave: &protobuf::InterleaveExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let mut inputs: Vec> = vec![]; for input in &interleave.inputs { - inputs.push(input.try_into_physical_plan(ctx, extension_codec)?); + inputs.push(proto_converter.proto_to_execution_plan(ctx, codec, input)?); } Ok(Arc::new(InterleaveExec::try_new(inputs)?)) } @@ -1403,12 +1570,13 @@ impl protobuf::PhysicalPlanNode { crossjoin: &protobuf::CrossJoinExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let left: Arc = - into_physical_plan(&crossjoin.left, ctx, extension_codec)?; + into_physical_plan(&crossjoin.left, ctx, codec, proto_converter)?; let right: Arc = - into_physical_plan(&crossjoin.right, ctx, extension_codec)?; + into_physical_plan(&crossjoin.right, ctx, codec, proto_converter)?; Ok(Arc::new(CrossJoinExec::new(left, right))) } @@ -1417,7 +1585,8 @@ impl protobuf::PhysicalPlanNode { empty: &protobuf::EmptyExecNode, _ctx: &TaskContext, - _extension_codec: &dyn PhysicalExtensionCodec, + _codec: &dyn PhysicalExtensionCodec, + _proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let schema = Arc::new(convert_required!(empty.schema)?); Ok(Arc::new(EmptyExec::new(schema))) @@ -1428,7 +1597,7 @@ impl protobuf::PhysicalPlanNode { placeholder: &protobuf::PlaceholderRowExecNode, _ctx: &TaskContext, - _extension_codec: &dyn PhysicalExtensionCodec, + _codec: &dyn PhysicalExtensionCodec, ) -> Result> { let schema = Arc::new(convert_required!(placeholder.schema)?); Ok(Arc::new(PlaceholderRowExec::new(schema))) @@ -1439,9 +1608,10 @@ impl protobuf::PhysicalPlanNode { sort: &protobuf::SortExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let input = into_physical_plan(&sort.input, ctx, extension_codec)?; + let input = into_physical_plan(&sort.input, ctx, codec, proto_converter)?; let exprs = sort .expr .iter() @@ -1462,7 +1632,7 @@ impl protobuf::PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr, ctx, input.schema().as_ref(), extension_codec)?, + expr: proto_converter.proto_to_physical_expr(expr, ctx, input.schema().as_ref(), codec)?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -1491,9 +1661,10 @@ impl protobuf::PhysicalPlanNode { sort: &protobuf::SortPreservingMergeExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let input = into_physical_plan(&sort.input, ctx, extension_codec)?; + let input = into_physical_plan(&sort.input, ctx, codec, proto_converter)?; let exprs = sort .expr .iter() @@ -1514,11 +1685,11 @@ impl protobuf::PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr( + expr: proto_converter.proto_to_physical_expr( expr, ctx, input.schema().as_ref(), - extension_codec, + codec, )?, options: SortOptions { descending: !sort_expr.asc, @@ -1544,16 +1715,16 @@ impl protobuf::PhysicalPlanNode { extension: &protobuf::PhysicalExtensionNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let inputs: Vec> = extension .inputs .iter() - .map(|i| i.try_into_physical_plan(ctx, extension_codec)) + .map(|i| proto_converter.proto_to_execution_plan(ctx, codec, i)) .collect::>()?; - let extension_node = - extension_codec.try_decode(extension.node.as_slice(), &inputs, ctx)?; + let extension_node = codec.try_decode(extension.node.as_slice(), &inputs, ctx)?; Ok(extension_node) } @@ -1563,12 +1734,13 @@ impl protobuf::PhysicalPlanNode { join: &protobuf::NestedLoopJoinExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let left: Arc = - into_physical_plan(&join.left, ctx, extension_codec)?; + into_physical_plan(&join.left, ctx, codec, proto_converter)?; let right: Arc = - into_physical_plan(&join.right, ctx, extension_codec)?; + into_physical_plan(&join.right, ctx, codec, proto_converter)?; let join_type = protobuf::JoinType::try_from(join.join_type).map_err(|_| { proto_error(format!( "Received a NestedLoopJoinExecNode message with unknown JoinType {}", @@ -1585,12 +1757,12 @@ impl protobuf::PhysicalPlanNode { .ok_or_else(|| proto_error("Missing JoinFilter schema"))? .try_into()?; - let expression = parse_physical_expr( + let expression = proto_converter.proto_to_physical_expr( f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, ctx, &schema, - extension_codec, + codec, )?; let column_indices = f.column_indices .iter() @@ -1637,10 +1809,11 @@ impl protobuf::PhysicalPlanNode { analyze: &protobuf::AnalyzeExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&analyze.input, ctx, extension_codec)?; + into_physical_plan(&analyze.input, ctx, codec, proto_converter)?; Ok(Arc::new(AnalyzeExec::new( analyze.verbose, analyze.show_statistics, @@ -1655,9 +1828,10 @@ impl protobuf::PhysicalPlanNode { sink: &protobuf::JsonSinkExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let input = into_physical_plan(&sink.input, ctx, extension_codec)?; + let input = into_physical_plan(&sink.input, ctx, codec, proto_converter)?; let data_sink: JsonSink = sink .sink @@ -1673,7 +1847,8 @@ impl protobuf::PhysicalPlanNode { &collection.physical_sort_expr_nodes, ctx, &sink_schema, - extension_codec, + codec, + proto_converter, ) .map(|sort_exprs| { LexRequirement::new(sort_exprs.into_iter().map(Into::into)) @@ -1693,9 +1868,10 @@ impl protobuf::PhysicalPlanNode { sink: &protobuf::CsvSinkExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let input = into_physical_plan(&sink.input, ctx, extension_codec)?; + let input = into_physical_plan(&sink.input, ctx, codec, proto_converter)?; let data_sink: CsvSink = sink .sink @@ -1711,7 +1887,8 @@ impl protobuf::PhysicalPlanNode { &collection.physical_sort_expr_nodes, ctx, &sink_schema, - extension_codec, + codec, + proto_converter, ) .map(|sort_exprs| { LexRequirement::new(sort_exprs.into_iter().map(Into::into)) @@ -1732,11 +1909,12 @@ impl protobuf::PhysicalPlanNode { sink: &protobuf::ParquetSinkExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { #[cfg(feature = "parquet")] { - let input = into_physical_plan(&sink.input, ctx, extension_codec)?; + let input = into_physical_plan(&sink.input, ctx, codec, proto_converter)?; let data_sink: ParquetSink = sink .sink @@ -1752,7 +1930,8 @@ impl protobuf::PhysicalPlanNode { &collection.physical_sort_expr_nodes, ctx, &sink_schema, - extension_codec, + codec, + proto_converter, ) .map(|sort_exprs| { LexRequirement::new(sort_exprs.into_iter().map(Into::into)) @@ -1775,9 +1954,10 @@ impl protobuf::PhysicalPlanNode { unnest: &protobuf::UnnestExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let input = into_physical_plan(&unnest.input, ctx, extension_codec)?; + let input = into_physical_plan(&unnest.input, ctx, codec, proto_converter)?; Ok(Arc::new(UnnestExec::new( input, @@ -1806,11 +1986,12 @@ impl protobuf::PhysicalPlanNode { sort_join: &SortMergeJoinExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let left = into_physical_plan(&sort_join.left, ctx, extension_codec)?; + let left = into_physical_plan(&sort_join.left, ctx, codec, proto_converter)?; let left_schema = left.schema(); - let right = into_physical_plan(&sort_join.right, ctx, extension_codec)?; + let right = into_physical_plan(&sort_join.right, ctx, codec, proto_converter)?; let right_schema = right.schema(); let filter = sort_join @@ -1823,13 +2004,13 @@ impl protobuf::PhysicalPlanNode { .ok_or_else(|| proto_error("Missing JoinFilter schema"))? .try_into()?; - let expression = parse_physical_expr( + let expression = proto_converter.proto_to_physical_expr( f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, ctx, &schema, - extension_codec, + codec, )?; let column_indices = f .column_indices @@ -1886,17 +2067,17 @@ impl protobuf::PhysicalPlanNode { .on .iter() .map(|col| { - let left = parse_physical_expr( + let left = proto_converter.proto_to_physical_expr( &col.left.clone().unwrap(), ctx, left_schema.as_ref(), - extension_codec, + codec, )?; - let right = parse_physical_expr( + let right = proto_converter.proto_to_physical_expr( &col.right.clone().unwrap(), ctx, right_schema.as_ref(), - extension_codec, + codec, )?; Ok((left, right)) }) @@ -1983,9 +2164,10 @@ impl protobuf::PhysicalPlanNode { field_stream: &protobuf::CooperativeExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let input = into_physical_plan(&field_stream.input, ctx, extension_codec)?; + let input = into_physical_plan(&field_stream.input, ctx, codec, proto_converter)?; Ok(Arc::new(CooperativeExec::new(input))) } @@ -1993,10 +2175,11 @@ impl protobuf::PhysicalPlanNode { &self, async_func: &protobuf::AsyncFuncExecNode, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&async_func.input, ctx, extension_codec)?; + into_physical_plan(&async_func.input, ctx, codec, proto_converter)?; if async_func.async_exprs.len() != async_func.async_expr_names.len() { return internal_err!( @@ -2009,11 +2192,11 @@ impl protobuf::PhysicalPlanNode { .iter() .zip(async_func.async_expr_names.iter()) .map(|(expr, name)| { - let physical_expr = parse_physical_expr( + let physical_expr = proto_converter.proto_to_physical_expr( expr, ctx, input.schema().as_ref(), - extension_codec, + codec, )?; Ok(Arc::new(AsyncFuncExpr::try_new( @@ -2027,9 +2210,22 @@ impl protobuf::PhysicalPlanNode { Ok(Arc::new(AsyncFuncExec::try_new(async_exprs, input)?)) } + fn try_into_buffer_physical_plan( + &self, + buffer: &protobuf::BufferExecNode, + ctx: &TaskContext, + extension_codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result> { + let input: Arc = + into_physical_plan(&buffer.input, ctx, extension_codec, proto_converter)?; + + Ok(Arc::new(BufferExec::new(input, buffer.capacity as usize))) + } + fn try_from_explain_exec( exec: &ExplainExec, - _extension_codec: &dyn PhysicalExtensionCodec, + _codec: &dyn PhysicalExtensionCodec, ) -> Result { Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Explain( @@ -2048,16 +2244,20 @@ impl protobuf::PhysicalPlanNode { fn try_from_projection_exec( exec: &ProjectionExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; let expr = exec .expr() .iter() - .map(|proj_expr| serialize_physical_expr(&proj_expr.expr, extension_codec)) + .map(|proj_expr| { + proto_converter.physical_expr_to_proto(&proj_expr.expr, codec) + }) .collect::>>()?; let expr_name = exec .expr() @@ -2077,11 +2277,13 @@ impl protobuf::PhysicalPlanNode { fn try_from_analyze_exec( exec: &AnalyzeExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Analyze(Box::new( @@ -2097,24 +2299,27 @@ impl protobuf::PhysicalPlanNode { fn try_from_filter_exec( exec: &FilterExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Filter(Box::new( protobuf::FilterExecNode { input: Some(Box::new(input)), - expr: Some(serialize_physical_expr( - exec.predicate(), - extension_codec, - )?), + expr: Some( + proto_converter + .physical_expr_to_proto(exec.predicate(), codec)?, + ), default_filter_selectivity: exec.default_selectivity() as u32, projection: exec.projection().as_ref().map_or_else(Vec::new, |v| { v.iter().map(|x| *x as u32).collect::>() }), + batch_size: exec.batch_size() as u32, }, ))), }) @@ -2122,11 +2327,13 @@ impl protobuf::PhysicalPlanNode { fn try_from_global_limit_exec( limit: &GlobalLimitExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( limit.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { @@ -2145,11 +2352,13 @@ impl protobuf::PhysicalPlanNode { fn try_from_local_limit_exec( limit: &LocalLimitExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( limit.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::LocalLimit(Box::new( @@ -2163,22 +2372,25 @@ impl protobuf::PhysicalPlanNode { fn try_from_hash_join_exec( exec: &HashJoinExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + let left = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.left().to_owned(), - extension_codec, + codec, + proto_converter, )?; - let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + let right = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.right().to_owned(), - extension_codec, + codec, + proto_converter, )?; let on: Vec = exec .on() .iter() .map(|tuple| { - let l = serialize_physical_expr(&tuple.0, extension_codec)?; - let r = serialize_physical_expr(&tuple.1, extension_codec)?; + let l = proto_converter.physical_expr_to_proto(&tuple.0, codec)?; + let r = proto_converter.physical_expr_to_proto(&tuple.1, codec)?; Ok::<_, DataFusionError>(protobuf::JoinOn { left: Some(l), right: Some(r), @@ -2192,7 +2404,7 @@ impl protobuf::PhysicalPlanNode { .as_ref() .map(|f| { let expression = - serialize_physical_expr(f.expression(), extension_codec)?; + proto_converter.physical_expr_to_proto(f.expression(), codec)?; let column_indices = f .column_indices() .iter() @@ -2232,6 +2444,7 @@ impl protobuf::PhysicalPlanNode { projection: exec.projection.as_ref().map_or_else(Vec::new, |v| { v.iter().map(|x| *x as u32).collect::>() }), + null_aware: exec.null_aware, }, ))), }) @@ -2239,22 +2452,25 @@ impl protobuf::PhysicalPlanNode { fn try_from_symmetric_hash_join_exec( exec: &SymmetricHashJoinExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + let left = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.left().to_owned(), - extension_codec, + codec, + proto_converter, )?; - let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + let right = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.right().to_owned(), - extension_codec, + codec, + proto_converter, )?; let on = exec .on() .iter() .map(|tuple| { - let l = serialize_physical_expr(&tuple.0, extension_codec)?; - let r = serialize_physical_expr(&tuple.1, extension_codec)?; + let l = proto_converter.physical_expr_to_proto(&tuple.0, codec)?; + let r = proto_converter.physical_expr_to_proto(&tuple.1, codec)?; Ok::<_, DataFusionError>(protobuf::JoinOn { left: Some(l), right: Some(r), @@ -2268,7 +2484,7 @@ impl protobuf::PhysicalPlanNode { .as_ref() .map(|f| { let expression = - serialize_physical_expr(f.expression(), extension_codec)?; + proto_converter.physical_expr_to_proto(f.expression(), codec)?; let column_indices = f .column_indices() .iter() @@ -2305,10 +2521,10 @@ impl protobuf::PhysicalPlanNode { .iter() .map(|expr| { Ok(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(serialize_physical_expr( - &expr.expr, - extension_codec, - )?)), + expr: Some(Box::new( + proto_converter + .physical_expr_to_proto(&expr.expr, codec)?, + )), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }) @@ -2325,10 +2541,10 @@ impl protobuf::PhysicalPlanNode { .iter() .map(|expr| { Ok(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(serialize_physical_expr( - &expr.expr, - extension_codec, - )?)), + expr: Some(Box::new( + proto_converter + .physical_expr_to_proto(&expr.expr, codec)?, + )), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }) @@ -2357,22 +2573,25 @@ impl protobuf::PhysicalPlanNode { fn try_from_sort_merge_join_exec( exec: &SortMergeJoinExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + let left = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.left().to_owned(), - extension_codec, + codec, + proto_converter, )?; - let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + let right = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.right().to_owned(), - extension_codec, + codec, + proto_converter, )?; let on = exec .on() .iter() .map(|tuple| { - let l = serialize_physical_expr(&tuple.0, extension_codec)?; - let r = serialize_physical_expr(&tuple.1, extension_codec)?; + let l = proto_converter.physical_expr_to_proto(&tuple.0, codec)?; + let r = proto_converter.physical_expr_to_proto(&tuple.1, codec)?; Ok::<_, DataFusionError>(protobuf::JoinOn { left: Some(l), right: Some(r), @@ -2386,7 +2605,7 @@ impl protobuf::PhysicalPlanNode { .as_ref() .map(|f| { let expression = - serialize_physical_expr(f.expression(), extension_codec)?; + proto_converter.physical_expr_to_proto(f.expression(), codec)?; let column_indices = f .column_indices() .iter() @@ -2426,7 +2645,7 @@ impl protobuf::PhysicalPlanNode { Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::SortMergeJoin(Box::new( - protobuf::SortMergeJoinExecNode { + SortMergeJoinExecNode { left: Some(Box::new(left)), right: Some(Box::new(right)), on, @@ -2441,15 +2660,18 @@ impl protobuf::PhysicalPlanNode { fn try_from_cross_join_exec( exec: &CrossJoinExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + let left = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.left().to_owned(), - extension_codec, + codec, + proto_converter, )?; - let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + let right = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.right().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::CrossJoin(Box::new( @@ -2463,7 +2685,8 @@ impl protobuf::PhysicalPlanNode { fn try_from_aggregate_exec( exec: &AggregateExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { let groups: Vec = exec .group_expr() @@ -2483,13 +2706,15 @@ impl protobuf::PhysicalPlanNode { let filter = exec .filter_expr() .iter() - .map(|expr| serialize_maybe_filter(expr.to_owned(), extension_codec)) + .map(|expr| serialize_maybe_filter(expr.to_owned(), codec, proto_converter)) .collect::>>()?; let agg = exec .aggr_expr() .iter() - .map(|expr| serialize_physical_aggr_expr(expr.to_owned(), extension_codec)) + .map(|expr| { + serialize_physical_aggr_expr(expr.to_owned(), codec, proto_converter) + }) .collect::>>()?; let agg_names = exec @@ -2506,29 +2731,32 @@ impl protobuf::PhysicalPlanNode { AggregateMode::SinglePartitioned => { protobuf::AggregateMode::SinglePartitioned } + AggregateMode::PartialReduce => protobuf::AggregateMode::PartialReduce, }; let input_schema = exec.input_schema(); - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; let null_expr = exec .group_expr() .null_expr() .iter() - .map(|expr| serialize_physical_expr(&expr.0, extension_codec)) + .map(|expr| proto_converter.physical_expr_to_proto(&expr.0, codec)) .collect::>>()?; let group_expr = exec .group_expr() .expr() .iter() - .map(|expr| serialize_physical_expr(&expr.0, extension_codec)) + .map(|expr| proto_converter.physical_expr_to_proto(&expr.0, codec)) .collect::>>()?; - let limit = exec.limit().map(|value| protobuf::AggLimit { - limit: value as u64, + let limit = exec.limit_options().map(|config| protobuf::AggLimit { + limit: config.limit() as u64, + descending: config.descending(), }); Ok(protobuf::PhysicalPlanNode { @@ -2553,7 +2781,7 @@ impl protobuf::PhysicalPlanNode { fn try_from_empty_exec( empty: &EmptyExec, - _extension_codec: &dyn PhysicalExtensionCodec, + _codec: &dyn PhysicalExtensionCodec, ) -> Result { let schema = empty.schema().as_ref().try_into()?; Ok(protobuf::PhysicalPlanNode { @@ -2565,7 +2793,7 @@ impl protobuf::PhysicalPlanNode { fn try_from_placeholder_row_exec( empty: &PlaceholderRowExec, - _extension_codec: &dyn PhysicalExtensionCodec, + _codec: &dyn PhysicalExtensionCodec, ) -> Result { let schema = empty.schema().as_ref().try_into()?; Ok(protobuf::PhysicalPlanNode { @@ -2580,11 +2808,13 @@ impl protobuf::PhysicalPlanNode { #[expect(deprecated)] fn try_from_coalesce_batches_exec( coalesce_batches: &CoalesceBatchesExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( coalesce_batches.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::CoalesceBatches(Box::new( @@ -2599,7 +2829,8 @@ impl protobuf::PhysicalPlanNode { fn try_from_data_source_exec( data_source_exec: &DataSourceExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let data_source = data_source_exec.data_source(); if let Some(maybe_csv) = data_source.as_any().downcast_ref::() { @@ -2610,7 +2841,8 @@ impl protobuf::PhysicalPlanNode { protobuf::CsvScanExecNode { base_conf: Some(serialize_file_scan_config( maybe_csv, - extension_codec, + codec, + proto_converter, )?), has_header: csv_config.has_header(), delimiter: byte_to_string( @@ -2651,7 +2883,25 @@ impl protobuf::PhysicalPlanNode { protobuf::JsonScanExecNode { base_conf: Some(serialize_file_scan_config( scan_conf, - extension_codec, + codec, + proto_converter, + )?), + }, + )), + })); + } + } + + if let Some(scan_conf) = data_source.as_any().downcast_ref::() { + let source = scan_conf.file_source(); + if let Some(_arrow_source) = source.as_any().downcast_ref::() { + return Ok(Some(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::ArrowScan( + protobuf::ArrowScanExecNode { + base_conf: Some(serialize_file_scan_config( + scan_conf, + codec, + proto_converter, )?), }, )), @@ -2665,14 +2915,15 @@ impl protobuf::PhysicalPlanNode { { let predicate = conf .filter() - .map(|pred| serialize_physical_expr(&pred, extension_codec)) + .map(|pred| proto_converter.physical_expr_to_proto(&pred, codec)) .transpose()?; return Ok(Some(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::ParquetScan( protobuf::ParquetScanExecNode { base_conf: Some(serialize_file_scan_config( maybe_parquet, - extension_codec, + codec, + proto_converter, )?), predicate, parquet_options: Some(conf.table_parquet_options().try_into()?), @@ -2690,7 +2941,8 @@ impl protobuf::PhysicalPlanNode { protobuf::AvroScanExecNode { base_conf: Some(serialize_file_scan_config( maybe_avro, - extension_codec, + codec, + proto_converter, )?), }, )), @@ -2723,7 +2975,8 @@ impl protobuf::PhysicalPlanNode { .map(|ordering| { let sort_exprs = serialize_physical_sort_exprs( ordering.to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok::<_, DataFusionError>(protobuf::PhysicalSortExprNodeCollection { physical_sort_expr_nodes: sort_exprs, @@ -2750,11 +3003,13 @@ impl protobuf::PhysicalPlanNode { fn try_from_coalesce_partitions_exec( exec: &CoalescePartitionsExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Merge(Box::new( @@ -2768,15 +3023,17 @@ impl protobuf::PhysicalPlanNode { fn try_from_repartition_exec( exec: &RepartitionExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; let pb_partitioning = - serialize_partitioning(exec.partitioning(), extension_codec)?; + serialize_partitioning(exec.partitioning(), codec, proto_converter)?; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Repartition(Box::new( @@ -2790,25 +3047,23 @@ impl protobuf::PhysicalPlanNode { fn try_from_sort_exec( exec: &SortExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( - exec.input().to_owned(), - extension_codec, - )?; + let input = proto_converter.execution_plan_to_proto(exec.input(), codec)?; let expr = exec .expr() .iter() .map(|expr| { let sort_expr = Box::new(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(serialize_physical_expr( - &expr.expr, - extension_codec, - )?)), + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(&expr.expr, codec)?, + )), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }); Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(ExprType::Sort(sort_expr)), }) }) @@ -2830,14 +3085,18 @@ impl protobuf::PhysicalPlanNode { fn try_from_union_exec( union: &UnionExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { let mut inputs: Vec = vec![]; for input in union.inputs() { - inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan( - input.to_owned(), - extension_codec, - )?); + inputs.push( + protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + input.to_owned(), + codec, + proto_converter, + )?, + ); } Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Union(protobuf::UnionExecNode { @@ -2848,14 +3107,18 @@ impl protobuf::PhysicalPlanNode { fn try_from_interleave_exec( interleave: &InterleaveExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { let mut inputs: Vec = vec![]; for input in interleave.inputs() { - inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan( - input.to_owned(), - extension_codec, - )?); + inputs.push( + protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + input.to_owned(), + codec, + proto_converter, + )?, + ); } Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Interleave( @@ -2866,25 +3129,27 @@ impl protobuf::PhysicalPlanNode { fn try_from_sort_preserving_merge_exec( exec: &SortPreservingMergeExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; let expr = exec .expr() .iter() .map(|expr| { let sort_expr = Box::new(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(serialize_physical_expr( - &expr.expr, - extension_codec, - )?)), + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(&expr.expr, codec)?, + )), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }); Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(ExprType::Sort(sort_expr)), }) }) @@ -2902,15 +3167,18 @@ impl protobuf::PhysicalPlanNode { fn try_from_nested_loop_join_exec( exec: &NestedLoopJoinExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + let left = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.left().to_owned(), - extension_codec, + codec, + proto_converter, )?; - let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + let right = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.right().to_owned(), - extension_codec, + codec, + proto_converter, )?; let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); @@ -2919,7 +3187,7 @@ impl protobuf::PhysicalPlanNode { .as_ref() .map(|f| { let expression = - serialize_physical_expr(f.expression(), extension_codec)?; + proto_converter.physical_expr_to_proto(f.expression(), codec)?; let column_indices = f .column_indices() .iter() @@ -2947,7 +3215,7 @@ impl protobuf::PhysicalPlanNode { right: Some(Box::new(right)), join_type: join_type.into(), filter, - projection: exec.projection().map_or_else(Vec::new, |v| { + projection: exec.projection().as_ref().map_or_else(Vec::new, |v| { v.iter().map(|x| *x as u32).collect::>() }), }, @@ -2957,23 +3225,25 @@ impl protobuf::PhysicalPlanNode { fn try_from_window_agg_exec( exec: &WindowAggExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; let window_expr = exec .window_expr() .iter() - .map(|e| serialize_physical_window_expr(e, extension_codec)) + .map(|e| serialize_physical_window_expr(e, codec, proto_converter)) .collect::>>()?; let partition_keys = exec .partition_keys() .iter() - .map(|e| serialize_physical_expr(e, extension_codec)) + .map(|e| proto_converter.physical_expr_to_proto(e, codec)) .collect::>>()?; Ok(protobuf::PhysicalPlanNode { @@ -2990,23 +3260,25 @@ impl protobuf::PhysicalPlanNode { fn try_from_bounded_window_agg_exec( exec: &BoundedWindowAggExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; let window_expr = exec .window_expr() .iter() - .map(|e| serialize_physical_window_expr(e, extension_codec)) + .map(|e| serialize_physical_window_expr(e, codec, proto_converter)) .collect::>>()?; let partition_keys = exec .partition_keys() .iter() - .map(|e| serialize_physical_expr(e, extension_codec)) + .map(|e| proto_converter.physical_expr_to_proto(e, codec)) .collect::>>()?; let input_order_mode = match &exec.input_order_mode { @@ -3039,12 +3311,14 @@ impl protobuf::PhysicalPlanNode { fn try_from_data_sink_exec( exec: &DataSinkExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: protobuf::PhysicalPlanNode = - protobuf::PhysicalPlanNode::try_from_physical_plan( + protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; let sort_order = match exec.sort_order() { Some(requirements) => { @@ -3053,10 +3327,10 @@ impl protobuf::PhysicalPlanNode { .map(|requirement| { let expr: PhysicalSortExpr = requirement.to_owned().into(); let sort_expr = protobuf::PhysicalSortExprNode { - expr: Some(Box::new(serialize_physical_expr( - &expr.expr, - extension_codec, - )?)), + expr: Some(Box::new( + proto_converter + .physical_expr_to_proto(&expr.expr, codec)?, + )), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }; @@ -3116,11 +3390,13 @@ impl protobuf::PhysicalPlanNode { fn try_from_unnest_exec( exec: &UnnestExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { @@ -3149,11 +3425,13 @@ impl protobuf::PhysicalPlanNode { fn try_from_cooperative_exec( exec: &CooperativeExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { @@ -3282,18 +3560,21 @@ impl protobuf::PhysicalPlanNode { fn try_from_async_func_exec( exec: &AsyncFuncExec, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( Arc::clone(exec.input()), - extension_codec, + codec, + proto_converter, )?; let mut async_exprs = vec![]; let mut async_expr_names = vec![]; for async_expr in exec.async_exprs() { - async_exprs.push(serialize_physical_expr(&async_expr.func, extension_codec)?); + async_exprs + .push(proto_converter.physical_expr_to_proto(&async_expr.func, codec)?); async_expr_names.push(async_expr.name.clone()) } @@ -3307,6 +3588,27 @@ impl protobuf::PhysicalPlanNode { ))), }) } + + fn try_from_buffer_exec( + exec: &BufferExec, + extension_codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(exec.input()), + extension_codec, + proto_converter, + )?; + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Buffer(Box::new( + protobuf::BufferExecNode { + input: Some(Box::new(input)), + capacity: exec.capacity() as u64, + }, + ))), + }) + } } pub trait AsExecutionPlan: Debug + Send + Sync + Clone { @@ -3323,12 +3625,12 @@ pub trait AsExecutionPlan: Debug + Send + Sync + Clone { &self, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, ) -> Result>; fn try_from_physical_plan( plan: Arc, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, ) -> Result where Self: Sized; @@ -3409,6 +3711,38 @@ impl PhysicalExtensionCodec for DefaultPhysicalExtensionCodec { } } +/// Controls the conversion of physical plans and expressions to and from their +/// Protobuf variants. Using this trait, users can perform optimizations on the +/// conversion process or collect performance metrics. +pub trait PhysicalProtoConverterExtension { + fn proto_to_execution_plan( + &self, + ctx: &TaskContext, + codec: &dyn PhysicalExtensionCodec, + proto: &protobuf::PhysicalPlanNode, + ) -> Result>; + + fn execution_plan_to_proto( + &self, + plan: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result; + + fn proto_to_physical_expr( + &self, + proto: &protobuf::PhysicalExprNode, + ctx: &TaskContext, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, + ) -> Result>; + + fn physical_expr_to_proto( + &self, + expr: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result; +} + /// DataEncoderTuple captures the position of the encoder /// in the codec list that was used to encode the data and actual encoded data #[derive(Clone, PartialEq, prost::Message)] @@ -3422,6 +3756,266 @@ struct DataEncoderTuple { pub blob: Vec, } +pub struct DefaultPhysicalProtoConverter; +impl PhysicalProtoConverterExtension for DefaultPhysicalProtoConverter { + fn proto_to_execution_plan( + &self, + ctx: &TaskContext, + codec: &dyn PhysicalExtensionCodec, + proto: &protobuf::PhysicalPlanNode, + ) -> Result> { + proto.try_into_physical_plan_with_converter(ctx, codec, self) + } + + fn execution_plan_to_proto( + &self, + plan: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result + where + Self: Sized, + { + protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(plan), + codec, + self, + ) + } + + fn proto_to_physical_expr( + &self, + proto: &protobuf::PhysicalExprNode, + ctx: &TaskContext, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, + ) -> Result> + where + Self: Sized, + { + // Default implementation calls the free function + parse_physical_expr_with_converter(proto, ctx, input_schema, codec, self) + } + + fn physical_expr_to_proto( + &self, + expr: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result { + serialize_physical_expr_with_converter(expr, codec, self) + } +} + +/// Internal serializer that adds expr_id to expressions. +/// Created fresh for each serialization operation. +struct DeduplicatingSerializer { + /// Random salt combined with pointer addresses and process ID to create globally unique expr_ids. + session_id: u64, +} + +impl DeduplicatingSerializer { + fn new() -> Self { + Self { + session_id: rand::random(), + } + } +} + +impl PhysicalProtoConverterExtension for DeduplicatingSerializer { + fn proto_to_execution_plan( + &self, + _ctx: &TaskContext, + _codec: &dyn PhysicalExtensionCodec, + _proto: &protobuf::PhysicalPlanNode, + ) -> Result> { + internal_err!("DeduplicatingSerializer cannot deserialize execution plans") + } + + fn execution_plan_to_proto( + &self, + plan: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result + where + Self: Sized, + { + protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(plan), + codec, + self, + ) + } + + fn proto_to_physical_expr( + &self, + _proto: &protobuf::PhysicalExprNode, + _ctx: &TaskContext, + _input_schema: &Schema, + _codec: &dyn PhysicalExtensionCodec, + ) -> Result> + where + Self: Sized, + { + internal_err!("DeduplicatingSerializer cannot deserialize physical expressions") + } + + fn physical_expr_to_proto( + &self, + expr: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let mut proto = serialize_physical_expr_with_converter(expr, codec, self)?; + + // Hash session_id, pointer address, and process ID together to create expr_id. + // - session_id: random per serializer, prevents collisions when merging serializations + // - ptr: unique address per Arc within a process + // - pid: prevents collisions if serializer is shared across processes + let mut hasher = DefaultHasher::new(); + self.session_id.hash(&mut hasher); + (Arc::as_ptr(expr) as *const () as u64).hash(&mut hasher); + std::process::id().hash(&mut hasher); + proto.expr_id = Some(hasher.finish()); + + Ok(proto) + } +} + +/// Internal deserializer that caches expressions by expr_id. +/// Created fresh for each deserialization operation. +#[derive(Default)] +struct DeduplicatingDeserializer { + /// Cache mapping expr_id to deserialized expressions. + cache: RefCell>>, +} + +impl PhysicalProtoConverterExtension for DeduplicatingDeserializer { + fn proto_to_execution_plan( + &self, + ctx: &TaskContext, + codec: &dyn PhysicalExtensionCodec, + proto: &protobuf::PhysicalPlanNode, + ) -> Result> { + proto.try_into_physical_plan_with_converter(ctx, codec, self) + } + + fn execution_plan_to_proto( + &self, + _plan: &Arc, + _codec: &dyn PhysicalExtensionCodec, + ) -> Result + where + Self: Sized, + { + internal_err!("DeduplicatingDeserializer cannot serialize execution plans") + } + + fn proto_to_physical_expr( + &self, + proto: &protobuf::PhysicalExprNode, + ctx: &TaskContext, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, + ) -> Result> + where + Self: Sized, + { + if let Some(expr_id) = proto.expr_id { + // Check cache first + if let Some(cached) = self.cache.borrow().get(&expr_id) { + return Ok(Arc::clone(cached)); + } + // Deserialize and cache + let expr = parse_physical_expr_with_converter( + proto, + ctx, + input_schema, + codec, + self, + )?; + self.cache.borrow_mut().insert(expr_id, Arc::clone(&expr)); + Ok(expr) + } else { + parse_physical_expr_with_converter(proto, ctx, input_schema, codec, self) + } + } + + fn physical_expr_to_proto( + &self, + _expr: &Arc, + _codec: &dyn PhysicalExtensionCodec, + ) -> Result { + internal_err!("DeduplicatingDeserializer cannot serialize physical expressions") + } +} + +/// A proto converter that adds expression deduplication during serialization +/// and deserialization. +/// +/// During serialization, each expression's Arc pointer address is XORed with a +/// random session_id to create a salted `expr_id`. This prevents cross-process +/// collisions when serialized plans are merged. +/// +/// During deserialization, expressions with the same `expr_id` share the same +/// Arc, reducing memory usage for plans with duplicate expressions (e.g., large +/// IN lists) and supporting correctly linking [`DynamicFilterPhysicalExpr`] instances. +/// +/// This converter is stateless - it creates internal serializers/deserializers +/// on demand for each operation. +/// +/// [`DynamicFilterPhysicalExpr`]: https://docs.rs/datafusion-physical-expr/latest/datafusion_physical_expr/expressions/struct.DynamicFilterPhysicalExpr.html +#[derive(Debug, Default, Clone, Copy)] +pub struct DeduplicatingProtoConverter {} + +impl PhysicalProtoConverterExtension for DeduplicatingProtoConverter { + fn proto_to_execution_plan( + &self, + ctx: &TaskContext, + codec: &dyn PhysicalExtensionCodec, + proto: &protobuf::PhysicalPlanNode, + ) -> Result> { + let deserializer = DeduplicatingDeserializer::default(); + proto.try_into_physical_plan_with_converter(ctx, codec, &deserializer) + } + + fn execution_plan_to_proto( + &self, + plan: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result + where + Self: Sized, + { + let serializer = DeduplicatingSerializer::new(); + protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(plan), + codec, + &serializer, + ) + } + + fn proto_to_physical_expr( + &self, + proto: &protobuf::PhysicalExprNode, + ctx: &TaskContext, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, + ) -> Result> + where + Self: Sized, + { + let deserializer = DeduplicatingDeserializer::default(); + deserializer.proto_to_physical_expr(proto, ctx, input_schema, codec) + } + + fn physical_expr_to_proto( + &self, + expr: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let serializer = DeduplicatingSerializer::new(); + serializer.physical_expr_to_proto(expr, codec) + } +} + /// A PhysicalExtensionCodec that tries one of multiple inner codecs /// until one works #[derive(Debug)] @@ -3524,10 +4118,11 @@ impl PhysicalExtensionCodec for ComposedPhysicalExtensionCodec { fn into_physical_plan( node: &Option>, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { if let Some(field) = node { - field.try_into_physical_plan(ctx, extension_codec) + proto_converter.proto_to_execution_plan(ctx, codec, field) } else { Err(proto_error("Missing required field in protobuf")) } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 9558effb8a2a6..a38e59acdab26 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -24,8 +24,7 @@ use datafusion_common::{ DataFusionError, Result, internal_datafusion_err, internal_err, not_impl_err, }; use datafusion_datasource::file_scan_config::FileScanConfig; -use datafusion_datasource::file_sink_config::FileSink; -use datafusion_datasource::file_sink_config::FileSinkConfig; +use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig}; use datafusion_datasource::{FileRange, PartitionedFile}; use datafusion_datasource_csv::file_format::CsvSink; use datafusion_datasource_json::file_format::JsonSink; @@ -36,36 +35,43 @@ use datafusion_physical_expr::ScalarFunctionExpr; use datafusion_physical_expr::window::{SlidingAggregateWindowExpr, StandardWindowExpr}; use datafusion_physical_expr_common::physical_expr::snapshot_physical_expr; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; -use datafusion_physical_plan::expressions::LikeExpr; use datafusion_physical_plan::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr, IsNullExpr, - Literal, NegativeExpr, NotExpr, TryCastExpr, UnKnownColumn, + LikeExpr, Literal, NegativeExpr, NotExpr, TryCastExpr, UnKnownColumn, }; use datafusion_physical_plan::joins::{HashExpr, HashTableLookupExpr}; use datafusion_physical_plan::udaf::AggregateFunctionExpr; use datafusion_physical_plan::windows::{PlainAggregateWindowExpr, WindowUDFExpr}; use datafusion_physical_plan::{Partitioning, PhysicalExpr, WindowExpr}; +use super::{ + DefaultPhysicalProtoConverter, PhysicalExtensionCodec, + PhysicalProtoConverterExtension, +}; use crate::protobuf::{ self, PhysicalSortExprNode, PhysicalSortExprNodeCollection, physical_aggregate_expr_node, physical_window_expr_node, }; -use super::PhysicalExtensionCodec; - #[expect(clippy::needless_pass_by_value)] pub fn serialize_physical_aggr_expr( aggr_expr: Arc, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let expressions = serialize_physical_exprs(&aggr_expr.expressions(), codec)?; - let order_bys = - serialize_physical_sort_exprs(aggr_expr.order_bys().iter().cloned(), codec)?; + let expressions = + serialize_physical_exprs(&aggr_expr.expressions(), codec, proto_converter)?; + let order_bys = serialize_physical_sort_exprs( + aggr_expr.order_bys().iter().cloned(), + codec, + proto_converter, + )?; let name = aggr_expr.fun().name().to_string(); let mut buf = Vec::new(); codec.try_encode_udaf(aggr_expr.fun(), &mut buf)?; Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( protobuf::PhysicalAggregateExprNode { aggregate_function: Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(name)), @@ -100,6 +106,7 @@ fn serialize_physical_window_aggr_expr( pub fn serialize_physical_window_expr( window_expr: &Arc, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { let expr = window_expr.as_any(); let args = window_expr.expressions().to_vec(); @@ -155,9 +162,14 @@ pub fn serialize_physical_window_expr( return not_impl_err!("WindowExpr not supported: {window_expr:?}"); }; - let args = serialize_physical_exprs(&args, codec)?; - let partition_by = serialize_physical_exprs(window_expr.partition_by(), codec)?; - let order_by = serialize_physical_sort_exprs(window_expr.order_by().to_vec(), codec)?; + let args = serialize_physical_exprs(&args, codec, proto_converter)?; + let partition_by = + serialize_physical_exprs(window_expr.partition_by(), codec, proto_converter)?; + let order_by = serialize_physical_sort_exprs( + window_expr.order_by().to_vec(), + codec, + proto_converter, + )?; let window_frame: protobuf::WindowFrame = window_frame .as_ref() .try_into() @@ -179,22 +191,24 @@ pub fn serialize_physical_window_expr( pub fn serialize_physical_sort_exprs( sort_exprs: I, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> where I: IntoIterator, { sort_exprs .into_iter() - .map(|sort_expr| serialize_physical_sort_expr(sort_expr, codec)) + .map(|sort_expr| serialize_physical_sort_expr(sort_expr, codec, proto_converter)) .collect() } pub fn serialize_physical_sort_expr( sort_expr: PhysicalSortExpr, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { let PhysicalSortExpr { expr, options } = sort_expr; - let expr = serialize_physical_expr(&expr, codec)?; + let expr = proto_converter.physical_expr_to_proto(&expr, codec)?; Ok(PhysicalSortExprNode { expr: Some(Box::new(expr)), asc: !options.descending, @@ -205,13 +219,14 @@ pub fn serialize_physical_sort_expr( pub fn serialize_physical_exprs<'a, I>( values: I, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> where I: IntoIterator>, { values .into_iter() - .map(|value| serialize_physical_expr(value, codec)) + .map(|value| proto_converter.physical_expr_to_proto(value, codec)) .collect() } @@ -222,6 +237,24 @@ where pub fn serialize_physical_expr( value: &Arc, codec: &dyn PhysicalExtensionCodec, +) -> Result { + serialize_physical_expr_with_converter( + value, + codec, + &DefaultPhysicalProtoConverter {}, + ) +} + +/// Serialize a `PhysicalExpr` to default protobuf representation. +/// +/// If required, a [`PhysicalExtensionCodec`] can be provided which can handle +/// serialization of udfs requiring specialized serialization (see [`PhysicalExtensionCodec::try_encode_udf`]). +/// A [`PhysicalProtoConverterExtension`] can be provided to handle the +/// conversion process (see [`PhysicalProtoConverterExtension::physical_expr_to_proto`]). +pub fn serialize_physical_expr_with_converter( + value: &Arc, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { // Snapshot the expr in case it has dynamic predicate state so // it can be serialized @@ -248,12 +281,14 @@ pub fn serialize_physical_expr( )), }; return Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::Literal(value)), }); } if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::Column( protobuf::PhysicalColumn { name: expr.name().to_string(), @@ -263,6 +298,7 @@ pub fn serialize_physical_expr( }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::UnknownColumn( protobuf::UnknownColumn { name: expr.name().to_string(), @@ -271,18 +307,24 @@ pub fn serialize_physical_expr( }) } else if let Some(expr) = expr.downcast_ref::() { let binary_expr = Box::new(protobuf::PhysicalBinaryExprNode { - l: Some(Box::new(serialize_physical_expr(expr.left(), codec)?)), - r: Some(Box::new(serialize_physical_expr(expr.right(), codec)?)), + l: Some(Box::new( + proto_converter.physical_expr_to_proto(expr.left(), codec)?, + )), + r: Some(Box::new( + proto_converter.physical_expr_to_proto(expr.right(), codec)?, + )), op: format!("{:?}", expr.op()), }); Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::BinaryExpr( binary_expr, )), }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some( protobuf::physical_expr_node::ExprType::Case( Box::new( @@ -290,14 +332,21 @@ pub fn serialize_physical_expr( expr: expr .expr() .map(|exp| { - serialize_physical_expr(exp, codec).map(Box::new) + proto_converter + .physical_expr_to_proto(exp, codec) + .map(Box::new) }) .transpose()?, when_then_expr: expr .when_then_expr() .iter() .map(|(when_expr, then_expr)| { - serialize_when_then_expr(when_expr, then_expr, codec) + serialize_when_then_expr( + when_expr, + then_expr, + codec, + proto_converter, + ) }) .collect::, @@ -305,7 +354,11 @@ pub fn serialize_physical_expr( >>()?, else_expr: expr .else_expr() - .map(|a| serialize_physical_expr(a, codec).map(Box::new)) + .map(|a| { + proto_converter + .physical_expr_to_proto(a, codec) + .map(Box::new) + }) .transpose()?, }, ), @@ -314,66 +367,88 @@ pub fn serialize_physical_expr( }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::NotExpr(Box::new( protobuf::PhysicalNot { - expr: Some(Box::new(serialize_physical_expr(expr.arg(), codec)?)), + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(expr.arg(), codec)?, + )), }, ))), }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::IsNullExpr( Box::new(protobuf::PhysicalIsNull { - expr: Some(Box::new(serialize_physical_expr(expr.arg(), codec)?)), + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(expr.arg(), codec)?, + )), }), )), }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::IsNotNullExpr( Box::new(protobuf::PhysicalIsNotNull { - expr: Some(Box::new(serialize_physical_expr(expr.arg(), codec)?)), + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(expr.arg(), codec)?, + )), }), )), }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::InList(Box::new( protobuf::PhysicalInListNode { - expr: Some(Box::new(serialize_physical_expr(expr.expr(), codec)?)), - list: serialize_physical_exprs(expr.list(), codec)?, + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(expr.expr(), codec)?, + )), + list: serialize_physical_exprs(expr.list(), codec, proto_converter)?, negated: expr.negated(), }, ))), }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::Negative(Box::new( protobuf::PhysicalNegativeNode { - expr: Some(Box::new(serialize_physical_expr(expr.arg(), codec)?)), + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(expr.arg(), codec)?, + )), }, ))), }) } else if let Some(lit) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::Literal( lit.value().try_into()?, )), }) } else if let Some(cast) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::Cast(Box::new( protobuf::PhysicalCastNode { - expr: Some(Box::new(serialize_physical_expr(cast.expr(), codec)?)), + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(cast.expr(), codec)?, + )), arrow_type: Some(cast.cast_type().try_into()?), }, ))), }) } else if let Some(cast) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast(Box::new( protobuf::PhysicalTryCastNode { - expr: Some(Box::new(serialize_physical_expr(cast.expr(), codec)?)), + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(cast.expr(), codec)?, + )), arrow_type: Some(cast.cast_type().try_into()?), }, ))), @@ -382,10 +457,11 @@ pub fn serialize_physical_expr( let mut buf = Vec::new(); codec.try_encode_udf(expr.fun(), &mut buf)?; Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarUdf( protobuf::PhysicalScalarUdfNode { name: expr.name().to_string(), - args: serialize_physical_exprs(expr.args(), codec)?, + args: serialize_physical_exprs(expr.args(), codec, proto_converter)?, fun_definition: (!buf.is_empty()).then_some(buf), return_type: Some(expr.return_type().try_into()?), nullable: expr.nullable(), @@ -398,24 +474,31 @@ pub fn serialize_physical_expr( }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::LikeExpr(Box::new( protobuf::PhysicalLikeExprNode { negated: expr.negated(), case_insensitive: expr.case_insensitive(), - expr: Some(Box::new(serialize_physical_expr(expr.expr(), codec)?)), - pattern: Some(Box::new(serialize_physical_expr( - expr.pattern(), - codec, - )?)), + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(expr.expr(), codec)?, + )), + pattern: Some(Box::new( + proto_converter.physical_expr_to_proto(expr.pattern(), codec)?, + )), }, ))), }) } else if let Some(expr) = expr.downcast_ref::() { let (s0, s1, s2, s3) = expr.seeds(); Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::HashExpr( protobuf::PhysicalHashExprNode { - on_columns: serialize_physical_exprs(expr.on_columns(), codec)?, + on_columns: serialize_physical_exprs( + expr.on_columns(), + codec, + proto_converter, + )?, seed0: s0, seed1: s1, seed2: s2, @@ -431,9 +514,10 @@ pub fn serialize_physical_expr( let inputs: Vec = value .children() .into_iter() - .map(|e| serialize_physical_expr(e, codec)) + .map(|e| proto_converter.physical_expr_to_proto(e, codec)) .collect::>()?; Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::Extension( protobuf::PhysicalExtensionExprNode { expr: buf, inputs }, )), @@ -449,6 +533,7 @@ pub fn serialize_physical_expr( pub fn serialize_partitioning( partitioning: &Partitioning, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { let serialized_partitioning = match partitioning { Partitioning::RoundRobinBatch(partition_count) => protobuf::Partitioning { @@ -457,7 +542,8 @@ pub fn serialize_partitioning( )), }, Partitioning::Hash(exprs, partition_count) => { - let serialized_exprs = serialize_physical_exprs(exprs, codec)?; + let serialized_exprs = + serialize_physical_exprs(exprs, codec, proto_converter)?; protobuf::Partitioning { partition_method: Some(protobuf::partitioning::PartitionMethod::Hash( protobuf::PhysicalHashRepartition { @@ -480,10 +566,11 @@ fn serialize_when_then_expr( when_expr: &Arc, then_expr: &Arc, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { Ok(protobuf::PhysicalWhenThen { - when_expr: Some(serialize_physical_expr(when_expr, codec)?), - then_expr: Some(serialize_physical_expr(then_expr, codec)?), + when_expr: Some(proto_converter.physical_expr_to_proto(when_expr, codec)?), + then_expr: Some(proto_converter.physical_expr_to_proto(then_expr, codec)?), }) } @@ -539,6 +626,7 @@ impl TryFrom<&[PartitionedFile]> for protobuf::FileGroup { pub fn serialize_file_scan_config( conf: &FileScanConfig, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { let file_groups = conf .file_groups @@ -548,7 +636,8 @@ pub fn serialize_file_scan_config( let mut output_orderings = vec![]; for order in &conf.output_ordering { - let ordering = serialize_physical_sort_exprs(order.to_vec(), codec)?; + let ordering = + serialize_physical_sort_exprs(order.to_vec(), codec, proto_converter)?; output_orderings.push(ordering) } @@ -563,8 +652,7 @@ pub fn serialize_file_scan_config( fields.extend(conf.table_partition_cols().iter().cloned()); let schema = Arc::new( - arrow::datatypes::Schema::new(fields.clone()) - .with_metadata(conf.file_schema().metadata.clone()), + Schema::new(fields.clone()).with_metadata(conf.file_schema().metadata.clone()), ); let projection_exprs = conf @@ -579,7 +667,10 @@ pub fn serialize_file_scan_config( .map(|expr| { Ok(protobuf::ProjectionExpr { alias: expr.alias.to_string(), - expr: Some(serialize_physical_expr(&expr.expr, codec)?), + expr: Some( + proto_converter + .physical_expr_to_proto(&expr.expr, codec)?, + ), }) }) .collect::>>()?, @@ -614,11 +705,12 @@ pub fn serialize_file_scan_config( pub fn serialize_maybe_filter( expr: Option>, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { match expr { None => Ok(protobuf::MaybeFilter { expr: None }), Some(expr) => Ok(protobuf::MaybeFilter { - expr: Some(serialize_physical_expr(&expr, codec)?), + expr: Some(proto_converter.physical_expr_to_proto(&expr, codec)?), }), } } @@ -695,6 +787,17 @@ impl TryFrom<&FileSinkConfig> for protobuf::FileSinkConfig { }) }) .collect::>>()?; + let file_output_mode = match conf.file_output_mode { + datafusion_datasource::file_sink_config::FileOutputMode::Automatic => { + protobuf::FileOutputMode::Automatic + } + datafusion_datasource::file_sink_config::FileOutputMode::SingleFile => { + protobuf::FileOutputMode::SingleFile + } + datafusion_datasource::file_sink_config::FileOutputMode::Directory => { + protobuf::FileOutputMode::Directory + } + }; Ok(Self { object_store_url: conf.object_store_url.to_string(), file_groups, @@ -704,6 +807,7 @@ impl TryFrom<&FileSinkConfig> for protobuf::FileSinkConfig { keep_partition_by_columns: conf.keep_partition_by_columns, insert_op: conf.insert_op as i32, file_extension: conf.file_extension.to_string(), + file_output_mode: file_output_mode.into(), }) } } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index bcfda648b53e5..9407cbf9a0749 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -28,7 +28,7 @@ use datafusion::datasource::file_format::json::{JsonFormat, JsonFormatFactory}; use datafusion::datasource::listing::{ ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, }; -use datafusion::execution::options::ArrowReadOptions; +use datafusion::execution::options::{ArrowReadOptions, JsonReadOptions}; use datafusion::optimizer::Optimizer; use datafusion::optimizer::optimize_unions::OptimizeUnions; use datafusion_common::parquet_config::DFParquetWriterVersion; @@ -413,6 +413,7 @@ async fn roundtrip_logical_plan_dml() -> Result<()> { "DELETE FROM T1", "UPDATE T1 SET a = 1", "CREATE TABLE T2 AS SELECT * FROM T1", + "TRUNCATE TABLE T1", ]; for query in queries { let plan = ctx.sql(query).await?.into_optimized_plan()?; @@ -754,7 +755,7 @@ async fn create_json_scan(ctx: &SessionContext) -> Result) -> Result<()> { let ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; - roundtrip_test_and_return(exec_plan, &ctx, &codec)?; + let proto_converter = DefaultPhysicalProtoConverter {}; + roundtrip_test_and_return(exec_plan, &ctx, &codec, &proto_converter)?; Ok(()) } @@ -142,13 +150,19 @@ fn roundtrip_test_and_return( exec_plan: Arc, ctx: &SessionContext, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let proto: protobuf::PhysicalPlanNode = - protobuf::PhysicalPlanNode::try_from_physical_plan(exec_plan.clone(), codec) - .expect("to proto"); - let result_exec_plan: Arc = proto - .try_into_physical_plan(&ctx.task_ctx(), codec) - .expect("from proto"); + let bytes = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&exec_plan), + codec, + proto_converter, + )?; + let result_exec_plan = physical_plan_from_bytes_with_proto_converter( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + codec, + proto_converter, + )?; pretty_assertions::assert_eq!( format!("{exec_plan:?}"), @@ -168,7 +182,8 @@ fn roundtrip_test_with_context( ctx: &SessionContext, ) -> Result<()> { let codec = DefaultPhysicalExtensionCodec {}; - roundtrip_test_and_return(exec_plan, ctx, &codec)?; + let proto_converter = DefaultPhysicalProtoConverter {}; + roundtrip_test_and_return(exec_plan, ctx, &codec, &proto_converter)?; Ok(()) } @@ -176,9 +191,10 @@ fn roundtrip_test_with_context( /// query results are identical. async fn roundtrip_test_sql_with_context(sql: &str, ctx: &SessionContext) -> Result<()> { let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DefaultPhysicalProtoConverter {}; let initial_plan = ctx.sql(sql).await?.create_physical_plan().await?; - roundtrip_test_and_return(initial_plan, ctx, &codec)?; + roundtrip_test_and_return(initial_plan, ctx, &codec, &proto_converter)?; Ok(()) } @@ -285,6 +301,7 @@ fn roundtrip_hash_join() -> Result<()> { None, *partition_mode, NullEquality::NullEqualsNothing, + false, )?))?; } } @@ -615,7 +632,7 @@ fn roundtrip_aggregate_with_limit() -> Result<()> { Arc::new(EmptyExec::new(schema.clone())), schema, )?; - let agg = agg.with_limit(Some(12)); + let agg = agg.with_limit_options(Some(LimitOptions::new_with_order(12, false))); roundtrip_test(Arc::new(agg)) } @@ -912,6 +929,30 @@ fn roundtrip_parquet_exec_with_pruning_predicate() -> Result<()> { roundtrip_test(DataSourceExec::from_data_source(scan_config)) } +#[test] +fn roundtrip_arrow_scan() -> Result<()> { + let file_schema = + Arc::new(Schema::new(vec![Field::new("col", DataType::Utf8, false)])); + + let table_schema = TableSchema::new(file_schema.clone(), vec![]); + let file_source = Arc::new(ArrowSource::new_file_source(table_schema)); + + let scan_config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), file_source) + .with_file_groups(vec![FileGroup::new(vec![PartitionedFile::new( + "/path/to/file.arrow".to_string(), + 1024, + )])]) + .with_statistics(Statistics { + num_rows: Precision::Inexact(100), + total_byte_size: Precision::Inexact(1024), + column_statistics: Statistics::unknown_column(&file_schema), + }) + .build(); + + roundtrip_test(DataSourceExec::from_data_source(scan_config)) +} + #[tokio::test] async fn roundtrip_parquet_exec_with_table_partition_cols() -> Result<()> { let mut file_group = @@ -987,7 +1028,7 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { } impl Display for CustomPredicateExpr { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "CustomPredicateExpr") } } @@ -1080,7 +1121,12 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { let exec_plan = DataSourceExec::from_data_source(scan_config); let ctx = SessionContext::new(); - roundtrip_test_and_return(exec_plan, &ctx, &CustomPhysicalExtensionCodec {})?; + roundtrip_test_and_return( + exec_plan, + &ctx, + &CustomPhysicalExtensionCodec {}, + &DefaultPhysicalProtoConverter {}, + )?; Ok(()) } @@ -1286,7 +1332,8 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { )?); let ctx = SessionContext::new(); - roundtrip_test_and_return(aggregate, &ctx, &UDFExtensionCodec)?; + let proto_converter = DefaultPhysicalProtoConverter {}; + roundtrip_test_and_return(aggregate, &ctx, &UDFExtensionCodec, &proto_converter)?; Ok(()) } @@ -1333,7 +1380,8 @@ fn roundtrip_udwf_extension_codec() -> Result<()> { )?); let ctx = SessionContext::new(); - roundtrip_test_and_return(window, &ctx, &UDFExtensionCodec)?; + let proto_converter = DefaultPhysicalProtoConverter {}; + roundtrip_test_and_return(window, &ctx, &UDFExtensionCodec, &proto_converter)?; Ok(()) } @@ -1404,7 +1452,8 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { )?); let ctx = SessionContext::new(); - roundtrip_test_and_return(aggregate, &ctx, &UDFExtensionCodec)?; + let proto_converter = DefaultPhysicalProtoConverter {}; + roundtrip_test_and_return(aggregate, &ctx, &UDFExtensionCodec, &proto_converter)?; Ok(()) } @@ -1474,6 +1523,7 @@ fn roundtrip_json_sink() -> Result<()> { insert_op: InsertOp::Overwrite, keep_partition_by_columns: true, file_extension: "json".into(), + file_output_mode: FileOutputMode::SingleFile, }; let data_sink = Arc::new(JsonSink::new( file_sink_config, @@ -1512,6 +1562,7 @@ fn roundtrip_csv_sink() -> Result<()> { insert_op: InsertOp::Overwrite, keep_partition_by_columns: true, file_extension: "csv".into(), + file_output_mode: FileOutputMode::Directory, }; let data_sink = Arc::new(CsvSink::new( file_sink_config, @@ -1528,12 +1579,14 @@ fn roundtrip_csv_sink() -> Result<()> { let ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DefaultPhysicalProtoConverter {}; + let roundtrip_plan = roundtrip_test_and_return( Arc::new(DataSinkExec::new(input, data_sink, Some(sort_order))), &ctx, &codec, - ) - .unwrap(); + &proto_converter, + )?; let roundtrip_plan = roundtrip_plan .as_any() @@ -1569,6 +1622,7 @@ fn roundtrip_parquet_sink() -> Result<()> { insert_op: InsertOp::Overwrite, keep_partition_by_columns: true, file_extension: "parquet".into(), + file_output_mode: FileOutputMode::Automatic, }; let data_sink = Arc::new(ParquetSink::new( file_sink_config, @@ -1820,11 +1874,12 @@ async fn roundtrip_projection_source() -> Result<()> { .build(); let filter = Arc::new( - FilterExec::try_new( + FilterExecBuilder::new( Arc::new(BinaryExpr::new(col("c", &schema)?, Operator::Eq, lit(1))), DataSourceExec::from_data_source(scan_config), - )? - .with_projection(Some(vec![0, 1]))?, + ) + .apply_projection(Some(vec![0, 1]))? + .build()?, ); roundtrip_test(filter) @@ -1974,6 +2029,7 @@ async fn test_serialize_deserialize_tpch_queries() -> Result<()> { // serialize the physical plan let codec = DefaultPhysicalExtensionCodec {}; + let proto = PhysicalPlanNode::try_from_physical_plan(physical_plan.clone(), &codec)?; @@ -2095,6 +2151,7 @@ async fn test_tpch_part_in_list_query_with_real_parquet_data() -> Result<()> { // Serialize the physical plan - bug may happen here already but not necessarily manifests let codec = DefaultPhysicalExtensionCodec {}; + let proto = PhysicalPlanNode::try_from_physical_plan(physical_plan.clone(), &codec)?; // This will fail with the bug, but should succeed when fixed @@ -2336,9 +2393,8 @@ async fn roundtrip_async_func_exec() -> Result<()> { /// it's a performance optimization filter, not a correctness requirement. #[test] fn roundtrip_hash_table_lookup_expr_to_lit() -> Result<()> { - use datafusion::physical_plan::joins::HashTableLookupExpr; - use datafusion::physical_plan::joins::Map; use datafusion::physical_plan::joins::join_hash_map::JoinHashMapU32; + use datafusion::physical_plan::joins::{HashTableLookupExpr, Map}; // Create a simple schema and input plan let schema = Arc::new(Schema::new(vec![Field::new("col", DataType::Int64, false)])); @@ -2360,8 +2416,9 @@ fn roundtrip_hash_table_lookup_expr_to_lit() -> Result<()> { // Serialize let ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; - let proto: protobuf::PhysicalPlanNode = - protobuf::PhysicalPlanNode::try_from_physical_plan(filter.clone(), &codec) + + let proto: PhysicalPlanNode = + PhysicalPlanNode::try_from_physical_plan(filter.clone(), &codec) .expect("serialization should succeed"); // Deserialize @@ -2411,3 +2468,590 @@ fn roundtrip_hash_expr() -> Result<()> { ); roundtrip_test(filter) } + +#[test] +fn custom_proto_converter_intercepts() -> Result<()> { + #[derive(Default)] + struct CustomConverterInterceptor { + num_proto_plans: RwLock, + num_physical_plans: RwLock, + num_proto_exprs: RwLock, + num_physical_exprs: RwLock, + } + + impl PhysicalProtoConverterExtension for CustomConverterInterceptor { + fn proto_to_execution_plan( + &self, + ctx: &TaskContext, + codec: &dyn PhysicalExtensionCodec, + proto: &protobuf::PhysicalPlanNode, + ) -> Result> { + { + let mut counter = self + .num_proto_plans + .write() + .map_err(|err| exec_datafusion_err!("{err}"))?; + *counter += 1; + } + proto.try_into_physical_plan_with_converter(ctx, codec, self) + } + + fn execution_plan_to_proto( + &self, + plan: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result + where + Self: Sized, + { + { + let mut counter = self + .num_physical_plans + .write() + .map_err(|err| exec_datafusion_err!("{err}"))?; + *counter += 1; + } + PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(plan), + codec, + self, + ) + } + + fn proto_to_physical_expr( + &self, + proto: &PhysicalExprNode, + ctx: &TaskContext, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, + ) -> Result> + where + Self: Sized, + { + { + let mut counter = self + .num_proto_exprs + .write() + .map_err(|err| exec_datafusion_err!("{err}"))?; + *counter += 1; + } + parse_physical_expr_with_converter(proto, ctx, input_schema, codec, self) + } + + fn physical_expr_to_proto( + &self, + expr: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result { + { + let mut counter = self + .num_physical_exprs + .write() + .map_err(|err| exec_datafusion_err!("{err}"))?; + *counter += 1; + } + serialize_physical_expr_with_converter(expr, codec, self) + } + } + + let field_a = Field::new("a", DataType::Boolean, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let sort_exprs = [ + PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions { + descending: true, + nulls_first: false, + }, + }, + PhysicalSortExpr { + expr: col("b", &schema)?, + options: SortOptions { + descending: false, + nulls_first: true, + }, + }, + ] + .into(); + + let exec_plan = Arc::new(SortExec::new(sort_exprs, Arc::new(EmptyExec::new(schema)))); + + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = CustomConverterInterceptor::default(); + roundtrip_test_and_return(exec_plan, &ctx, &codec, &proto_converter)?; + + assert_eq!(*proto_converter.num_proto_exprs.read().unwrap(), 2); + assert_eq!(*proto_converter.num_physical_exprs.read().unwrap(), 2); + assert_eq!(*proto_converter.num_proto_plans.read().unwrap(), 2); + assert_eq!(*proto_converter.num_physical_plans.read().unwrap(), 2); + + Ok(()) +} + +#[test] +fn roundtrip_call_null_scalar_struct_dict() -> Result<()> { + let data_type = DataType::Struct(Fields::from(vec![Field::new( + "item", + DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)), + true, + )])); + + let schema = Arc::new(Schema::new(vec![Field::new("a", data_type.clone(), true)])); + let scan = Arc::new(EmptyExec::new(Arc::clone(&schema))); + let scalar = lit(ScalarValue::try_from(data_type)?); + let filter = Arc::new(FilterExec::try_new( + Arc::new(BinaryExpr::new(scalar, Operator::Eq, col("a", &schema)?)), + scan, + )?); + + roundtrip_test(filter) +} + +/// Test that expression deduplication works during deserialization. +/// When the same expression Arc is serialized multiple times, it should be +/// deduplicated on deserialization (sharing the same Arc). +#[test] +fn test_expression_deduplication() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a])); + + // Create a shared expression that will be used multiple times + let shared_col: Arc = Arc::new(Column::new("a", 0)); + + // Create an InList expression that uses the same column Arc multiple times + // This simulates a real-world scenario where expressions are shared + let in_list_expr = in_list( + Arc::clone(&shared_col), + vec![lit(1i64), lit(2i64), lit(3i64)], + &false, + &schema, + )?; + + // Create a binary expression that uses the shared column and the in_list result + let binary_expr: Arc = Arc::new(BinaryExpr::new( + Arc::clone(&shared_col), + Operator::Eq, + lit(42i64), + )); + + // Create a plan that has both expressions (they share the `shared_col` Arc) + let input = Arc::new(EmptyExec::new(schema.clone())); + let filter = FilterExecBuilder::new(in_list_expr, input).build()?; + let projection_exprs = vec![ProjectionExpr { + expr: binary_expr, + alias: "result".to_string(), + }]; + let exec_plan = + Arc::new(ProjectionExec::try_new(projection_exprs, Arc::new(filter))?); + + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DeduplicatingProtoConverter {}; + + // Perform roundtrip + let bytes = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&exec_plan) as Arc, + &codec, + &proto_converter, + )?; + + // Create a new converter for deserialization (fresh cache) + let deser_converter = DeduplicatingProtoConverter {}; + let result_plan = physical_plan_from_bytes_with_proto_converter( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + &codec, + &deser_converter, + )?; + + // Verify the plan structure is correct + pretty_assertions::assert_eq!(format!("{exec_plan:?}"), format!("{result_plan:?}")); + + Ok(()) +} + +/// Test that expression deduplication correctly shares Arcs for identical expressions. +/// This test verifies the core deduplication behavior. +#[test] +fn test_expression_deduplication_arc_sharing() -> Result<()> { + use datafusion_proto::bytes::{ + physical_plan_from_bytes_with_proto_converter, + physical_plan_to_bytes_with_proto_converter, + }; + + let field_a = Field::new("a", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a])); + + // Create a column expression + let col_expr: Arc = Arc::new(Column::new("a", 0)); + + // Create a projection that uses the SAME Arc twice + // After roundtrip, both should point to the same Arc + let projection_exprs = vec![ + ProjectionExpr { + expr: Arc::clone(&col_expr), + alias: "a1".to_string(), + }, + ProjectionExpr { + expr: Arc::clone(&col_expr), // Same Arc! + alias: "a2".to_string(), + }, + ]; + + let input = Arc::new(EmptyExec::new(schema)); + let exec_plan = Arc::new(ProjectionExec::try_new(projection_exprs, input)?); + + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DeduplicatingProtoConverter {}; + + // Serialize + let bytes = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&exec_plan) as Arc, + &codec, + &proto_converter, + )?; + + // Deserialize with a fresh converter + let deser_converter = DeduplicatingProtoConverter {}; + let result_plan = physical_plan_from_bytes_with_proto_converter( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + &codec, + &deser_converter, + )?; + + // Get the projection from the result + let projection = result_plan + .as_any() + .downcast_ref::() + .expect("Expected ProjectionExec"); + + let exprs: Vec<_> = projection.expr().iter().collect(); + assert_eq!(exprs.len(), 2); + + // The key test: both expressions should point to the same Arc after deduplication + // This is because they were the same Arc before serialization + assert!( + Arc::ptr_eq(&exprs[0].expr, &exprs[1].expr), + "Expected both expressions to share the same Arc after deduplication" + ); + + Ok(()) +} + +/// Test backward compatibility: protos without expr_id should still deserialize correctly. +#[test] +fn test_backward_compatibility_no_expr_id() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a])); + + // Manually create a proto without expr_id set + let proto = PhysicalExprNode { + expr_id: None, // Simulating old proto without this field + expr_type: Some( + datafusion_proto::protobuf::physical_expr_node::ExprType::Column( + datafusion_proto::protobuf::PhysicalColumn { + name: "a".to_string(), + index: 0, + }, + ), + ), + }; + + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DefaultPhysicalProtoConverter {}; + + // Should deserialize without error + let result = proto_converter.proto_to_physical_expr( + &proto, + ctx.task_ctx().as_ref(), + &schema, + &codec, + )?; + + // Verify the result is correct + let col = result + .as_any() + .downcast_ref::() + .expect("Expected Column"); + assert_eq!(col.name(), "a"); + assert_eq!(col.index(), 0); + + Ok(()) +} + +/// Test that deduplication works within a single plan deserialization and that +/// separate deserializations produce independent expressions (no cross-operation sharing). +#[test] +fn test_deduplication_within_plan_deserialization() -> Result<()> { + use datafusion_proto::bytes::{ + physical_plan_from_bytes_with_proto_converter, + physical_plan_to_bytes_with_proto_converter, + }; + + let field_a = Field::new("a", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a])); + + // Create a plan with expressions that will be deduplicated + let col_expr: Arc = Arc::new(Column::new("a", 0)); + let projection_exprs = vec![ + ProjectionExpr { + expr: Arc::clone(&col_expr), + alias: "a1".to_string(), + }, + ProjectionExpr { + expr: Arc::clone(&col_expr), // Same Arc - will be deduplicated + alias: "a2".to_string(), + }, + ]; + let exec_plan = Arc::new(ProjectionExec::try_new( + projection_exprs, + Arc::new(EmptyExec::new(schema)), + )?); + + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DeduplicatingProtoConverter {}; + + // Serialize + let bytes = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&exec_plan) as Arc, + &codec, + &proto_converter, + )?; + + // First deserialization + let plan1 = physical_plan_from_bytes_with_proto_converter( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + &codec, + &proto_converter, + )?; + + // Check that the plan was deserialized correctly with deduplication + let projection1 = plan1 + .as_any() + .downcast_ref::() + .expect("Expected ProjectionExec"); + let exprs1: Vec<_> = projection1.expr().iter().collect(); + assert_eq!(exprs1.len(), 2); + assert!( + Arc::ptr_eq(&exprs1[0].expr, &exprs1[1].expr), + "Expected both expressions to share the same Arc after deduplication" + ); + + // Second deserialization + let plan2 = physical_plan_from_bytes_with_proto_converter( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + &codec, + &proto_converter, + )?; + + // Check that the second plan was also deserialized correctly + let projection2 = plan2 + .as_any() + .downcast_ref::() + .expect("Expected ProjectionExec"); + let exprs2: Vec<_> = projection2.expr().iter().collect(); + assert_eq!(exprs2.len(), 2); + assert!( + Arc::ptr_eq(&exprs2[0].expr, &exprs2[1].expr), + "Expected both expressions to share the same Arc after deduplication" + ); + + // Check that there was no deduplication across deserializations + assert!( + !Arc::ptr_eq(&exprs1[0].expr, &exprs2[0].expr), + "Expected expressions from different deserializations to be different Arcs" + ); + assert!( + !Arc::ptr_eq(&exprs1[1].expr, &exprs2[1].expr), + "Expected expressions from different deserializations to be different Arcs" + ); + + Ok(()) +} + +/// Test that deduplication works within direct expression deserialization and that +/// separate deserializations produce independent expressions (no cross-operation sharing). +#[test] +fn test_deduplication_within_expr_deserialization() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a])); + + // Create a binary expression where both sides are the same Arc + // This allows us to test deduplication within a single deserialization + let col_expr: Arc = Arc::new(Column::new("a", 0)); + let binary_expr: Arc = Arc::new(BinaryExpr::new( + Arc::clone(&col_expr), + Operator::Plus, + Arc::clone(&col_expr), // Same Arc - will be deduplicated + )); + + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DeduplicatingProtoConverter {}; + + // Serialize the expression + let proto = proto_converter.physical_expr_to_proto(&binary_expr, &codec)?; + + // First expression deserialization + let expr1 = proto_converter.proto_to_physical_expr( + &proto, + ctx.task_ctx().as_ref(), + &schema, + &codec, + )?; + + // Check that deduplication worked within the deserialization + let binary1 = expr1 + .as_any() + .downcast_ref::() + .expect("Expected BinaryExpr"); + assert!( + Arc::ptr_eq(binary1.left(), binary1.right()), + "Expected both sides to share the same Arc after deduplication" + ); + + // Second expression deserialization + let expr2 = proto_converter.proto_to_physical_expr( + &proto, + ctx.task_ctx().as_ref(), + &schema, + &codec, + )?; + + // Check that the second expression was also deserialized correctly + let binary2 = expr2 + .as_any() + .downcast_ref::() + .expect("Expected BinaryExpr"); + assert!( + Arc::ptr_eq(binary2.left(), binary2.right()), + "Expected both sides to share the same Arc after deduplication" + ); + + // Check that there was no deduplication across deserializations + assert!( + !Arc::ptr_eq(binary1.left(), binary2.left()), + "Expected expressions from different deserializations to be different Arcs" + ); + assert!( + !Arc::ptr_eq(binary1.right(), binary2.right()), + "Expected expressions from different deserializations to be different Arcs" + ); + + Ok(()) +} + +/// Test that session_id rotates between top-level serialization operations. +/// This verifies that each top-level serialization gets a fresh session_id, +/// which prevents cross-process collisions when serialized plans are merged. +#[test] +fn test_session_id_rotation_between_serializations() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let _schema = Arc::new(Schema::new(vec![field_a])); + + // Create a simple expression + let col_expr: Arc = Arc::new(Column::new("a", 0)); + + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DeduplicatingProtoConverter {}; + + // First serialization + let proto1 = proto_converter.physical_expr_to_proto(&col_expr, &codec)?; + let expr_id1 = proto1.expr_id.expect("Expected expr_id to be set"); + + // Second serialization with the same converter + // The session_id should have rotated, so the expr_id should be different + // even though we're serializing the same expression (same pointer address) + let proto2 = proto_converter.physical_expr_to_proto(&col_expr, &codec)?; + let expr_id2 = proto2.expr_id.expect("Expected expr_id to be set"); + + // The expr_ids should be different because session_id rotated + assert_ne!( + expr_id1, expr_id2, + "Expected different expr_ids due to session_id rotation between serializations" + ); + + // Also test that serializing the same expression multiple times within + // the same top-level operation would give the same expr_id (not testable + // here directly since each physical_expr_to_proto is a top-level operation, + // but the deduplication tests verify this indirectly) + + Ok(()) +} + +/// Test that session_id rotation works correctly with execution plans. +/// This verifies the end-to-end behavior with plan serialization. +#[test] +fn test_session_id_rotation_with_execution_plans() -> Result<()> { + use datafusion_proto::bytes::physical_plan_to_bytes_with_proto_converter; + + let field_a = Field::new("a", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a])); + + // Create a simple plan + let col_expr: Arc = Arc::new(Column::new("a", 0)); + let projection_exprs = vec![ProjectionExpr { + expr: Arc::clone(&col_expr), + alias: "a1".to_string(), + }]; + let exec_plan = Arc::new(ProjectionExec::try_new( + projection_exprs.clone(), + Arc::new(EmptyExec::new(Arc::clone(&schema))), + )?); + + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DeduplicatingProtoConverter {}; + + // First serialization + let bytes1 = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&exec_plan) as Arc, + &codec, + &proto_converter, + )?; + + // Second serialization with the same converter + let bytes2 = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&exec_plan) as Arc, + &codec, + &proto_converter, + )?; + + // The serialized bytes should be different due to different session_ids + // (specifically, the expr_id values embedded in the protobuf will differ) + assert_ne!( + bytes1.as_ref(), + bytes2.as_ref(), + "Expected different serialized bytes due to session_id rotation" + ); + + // But both should deserialize correctly + let ctx = SessionContext::new(); + let deser_converter = DeduplicatingProtoConverter {}; + + let plan1 = datafusion_proto::bytes::physical_plan_from_bytes_with_proto_converter( + bytes1.as_ref(), + ctx.task_ctx().as_ref(), + &codec, + &deser_converter, + )?; + + let plan2 = datafusion_proto::bytes::physical_plan_from_bytes_with_proto_converter( + bytes2.as_ref(), + ctx.task_ctx().as_ref(), + &codec, + &deser_converter, + )?; + + // Verify both plans have the expected structure + assert_eq!(plan1.schema(), plan2.schema()); + + Ok(()) +} diff --git a/datafusion/pruning/LICENSE.txt b/datafusion/pruning/LICENSE.txt new file mode 120000 index 0000000000000..1ef648f64b34f --- /dev/null +++ b/datafusion/pruning/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/pruning/NOTICE.txt b/datafusion/pruning/NOTICE.txt new file mode 120000 index 0000000000000..fb051c92b10b2 --- /dev/null +++ b/datafusion/pruning/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/pruning/src/lib.rs b/datafusion/pruning/src/lib.rs index 9f8142447ba69..be17f29eaafa0 100644 --- a/datafusion/pruning/src/lib.rs +++ b/datafusion/pruning/src/lib.rs @@ -16,7 +16,6 @@ // under the License. #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -#![deny(clippy::allow_attributes)] mod file_pruner; mod pruning_predicate; diff --git a/datafusion/pruning/src/pruning_predicate.rs b/datafusion/pruning/src/pruning_predicate.rs index b5b8267d7f93f..d0cb0674424ba 100644 --- a/datafusion/pruning/src/pruning_predicate.rs +++ b/datafusion/pruning/src/pruning_predicate.rs @@ -492,7 +492,6 @@ impl PruningPredicate { // Simplify the newly created predicate to get rid of redundant casts, comparisons, etc. let predicate_expr = PhysicalExprSimplifier::new(&predicate_schema).simplify(predicate_expr)?; - let literal_guarantees = LiteralGuarantee::analyze(&expr); Ok(Self { @@ -1206,13 +1205,6 @@ fn is_compare_op(op: Operator) -> bool { ) } -fn is_string_type(data_type: &DataType) -> bool { - matches!( - data_type, - DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View - ) -} - // The pruning logic is based on the comparing the min/max bounds. // Must make sure the two type has order. // For example, casts from string to numbers is not correct. @@ -1234,7 +1226,7 @@ fn verify_support_type_for_prune(from_type: &DataType, to_type: &DataType) -> Re // If both types are strings or both are not strings (number, timestamp, etc) // then we can compare them. // PruningPredicate does not support casting of strings to numbers and such. - if is_string_type(from_type) == is_string_type(to_type) { + if from_type.is_string() == to_type.is_string() { Ok(()) } else { plan_err!( @@ -4682,7 +4674,7 @@ mod tests { true, // s1 ["AB", "A\u{10ffff}\u{10ffff}\u{10ffff}"] ==> some rows could pass (must keep) true, - // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> no row match. (min, max) maybe truncate + // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> no row match. (min, max) maybe truncate // original (min, max) maybe ("A\u{10ffff}\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}\u{10ffff}\u{10ffff}") true, ]; diff --git a/datafusion/session/src/lib.rs b/datafusion/session/src/lib.rs index 3d3cb541b5a5e..11f734e757452 100644 --- a/datafusion/session/src/lib.rs +++ b/datafusion/session/src/lib.rs @@ -16,7 +16,6 @@ // under the License. #![cfg_attr(test, allow(clippy::needless_pass_by_value))] -#![deny(clippy::allow_attributes)] //! Session management for DataFusion query execution environment //! diff --git a/datafusion/spark/Cargo.toml b/datafusion/spark/Cargo.toml index 0dc35f4a87776..8a5c68a5d4e4b 100644 --- a/datafusion/spark/Cargo.toml +++ b/datafusion/spark/Cargo.toml @@ -29,6 +29,10 @@ edition = { workspace = true } [package.metadata.docs.rs] all-features = true +[features] +default = [] +core = ["datafusion"] + # Note: add additional linter rules in lib.rs. # Rust does not support workspace + new linter rules in subcrates yet # https://github.com/rust-lang/cargo/issues/13157 @@ -43,6 +47,8 @@ arrow = { workspace = true } bigdecimal = { workspace = true } chrono = { workspace = true } crc32fast = "1.4" +# Optional dependency for SessionStateBuilderSpark extension trait +datafusion = { workspace = true, optional = true, default-features = false } datafusion-catalog = { workspace = true } datafusion-common = { workspace = true } datafusion-execution = { workspace = true } @@ -54,10 +60,14 @@ log = { workspace = true } percent-encoding = "2.3.2" rand = { workspace = true } sha1 = "0.10" +sha2 = { workspace = true } url = { workspace = true } [dev-dependencies] +arrow = { workspace = true, features = ["test_utils"] } criterion = { workspace = true } +# for SessionStateBuilderSpark tests +datafusion = { workspace = true, default-features = false } [[bench]] harness = false @@ -66,3 +76,23 @@ name = "char" [[bench]] harness = false name = "space" + +[[bench]] +harness = false +name = "hex" + +[[bench]] +harness = false +name = "slice" + +[[bench]] +harness = false +name = "substring" + +[[bench]] +harness = false +name = "unhex" + +[[bench]] +harness = false +name = "sha2" diff --git a/datafusion/spark/benches/char.rs b/datafusion/spark/benches/char.rs index b5f87857ae9c6..38d9ebdeb4f5f 100644 --- a/datafusion/spark/benches/char.rs +++ b/datafusion/spark/benches/char.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::{DataType, Field}; use arrow::{array::PrimitiveArray, datatypes::Int64Type}; use criterion::{Criterion, criterion_group, criterion_main}; diff --git a/datafusion/spark/benches/hex.rs b/datafusion/spark/benches/hex.rs new file mode 100644 index 0000000000000..9785371cc5827 --- /dev/null +++ b/datafusion/spark/benches/hex.rs @@ -0,0 +1,150 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::*; +use arrow::datatypes::*; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_spark::function::math::hex::SparkHex; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +fn seedable_rng() -> StdRng { + StdRng::seed_from_u64(42) +} + +fn generate_int64_data(size: usize, null_density: f32) -> PrimitiveArray { + let mut rng = seedable_rng(); + (0..size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(rng.random_range::(-999_999_999_999..999_999_999_999)) + } + }) + .collect() +} + +fn generate_utf8_data(size: usize, null_density: f32) -> StringArray { + let mut rng = seedable_rng(); + let mut builder = StringBuilder::new(); + for _ in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + let len = rng.random_range::(1..=100); + let s: String = + std::iter::repeat_with(|| rng.random_range(b'a'..=b'z') as char) + .take(len) + .collect(); + builder.append_value(&s); + } + } + builder.finish() +} + +fn generate_binary_data(size: usize, null_density: f32) -> BinaryArray { + let mut rng = seedable_rng(); + let mut builder = BinaryBuilder::new(); + for _ in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + let len = rng.random_range::(1..=100); + let bytes: Vec = (0..len).map(|_| rng.random()).collect(); + builder.append_value(&bytes); + } + } + builder.finish() +} + +fn generate_int64_dict_data( + size: usize, + null_density: f32, +) -> DictionaryArray { + let mut rng = seedable_rng(); + let mut builder = PrimitiveDictionaryBuilder::::new(); + for _ in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + builder.append_value( + rng.random_range::(-999_999_999_999..999_999_999_999), + ); + } + } + builder.finish() +} + +fn run_benchmark(c: &mut Criterion, name: &str, size: usize, array: Arc) { + let hex_func = SparkHex::new(); + let args = vec![ColumnarValue::Array(array)]; + let arg_fields: Vec<_> = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function(&format!("{name}/size={size}"), |b| { + b.iter(|| { + black_box( + hex_func + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Arc::new(Field::new("f", DataType::Utf8, true)), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + let sizes = vec![1024, 4096, 8192]; + let null_density = 0.1; + + for &size in &sizes { + let data = generate_int64_data(size, null_density); + run_benchmark(c, "hex_int64", size, Arc::new(data)); + } + + for &size in &sizes { + let data = generate_utf8_data(size, null_density); + run_benchmark(c, "hex_utf8", size, Arc::new(data)); + } + + for &size in &sizes { + let data = generate_binary_data(size, null_density); + run_benchmark(c, "hex_binary", size, Arc::new(data)); + } + + for &size in &sizes { + let data = generate_int64_dict_data(size, null_density); + run_benchmark(c, "hex_int64_dict", size, Arc::new(data)); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/spark/benches/sha2.rs b/datafusion/spark/benches/sha2.rs new file mode 100644 index 0000000000000..6e835984703f0 --- /dev/null +++ b/datafusion/spark/benches/sha2.rs @@ -0,0 +1,105 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::*; +use arrow::datatypes::*; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_spark::function::hash::sha2::SparkSha2; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +fn seedable_rng() -> StdRng { + StdRng::seed_from_u64(42) +} + +fn generate_binary_data(size: usize, null_density: f32) -> BinaryArray { + let mut rng = seedable_rng(); + let mut builder = BinaryBuilder::new(); + for _ in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + let len = rng.random_range::(1..=100); + let bytes: Vec = (0..len).map(|_| rng.random()).collect(); + builder.append_value(&bytes); + } + } + builder.finish() +} + +fn run_benchmark(c: &mut Criterion, name: &str, size: usize, args: &[ColumnarValue]) { + let sha2_func = SparkSha2::new(); + let arg_fields: Vec<_> = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function(&format!("{name}/size={size}"), |b| { + b.iter(|| { + black_box( + sha2_func + .invoke_with_args(ScalarFunctionArgs { + args: args.to_vec(), + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Arc::new(Field::new("f", DataType::Utf8, true)), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + // Scalar benchmark (avoid array expansion) + let scalar_args = vec![ + ColumnarValue::Scalar(ScalarValue::Binary(Some(b"Spark".to_vec()))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(256))), + ]; + run_benchmark(c, "sha2/scalar", 1, &scalar_args); + + let sizes = vec![1024, 4096, 8192]; + let null_density = 0.1; + + for &size in &sizes { + let values: ArrayRef = Arc::new(generate_binary_data(size, null_density)); + let bit_lengths: ArrayRef = Arc::new(Int32Array::from(vec![256; size])); + + let array_args = vec![ + ColumnarValue::Array(Arc::clone(&values)), + ColumnarValue::Array(Arc::clone(&bit_lengths)), + ]; + run_benchmark(c, "sha2/array_binary_256", size, &array_args); + + let array_scalar_args = vec![ + ColumnarValue::Array(Arc::clone(&values)), + ColumnarValue::Scalar(ScalarValue::Int32(Some(256))), + ]; + run_benchmark(c, "sha2/array_scalar_binary_256", size, &array_scalar_args); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/spark/benches/slice.rs b/datafusion/spark/benches/slice.rs new file mode 100644 index 0000000000000..da392dc042f92 --- /dev/null +++ b/datafusion/spark/benches/slice.rs @@ -0,0 +1,185 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Int64Array, ListArray, ListViewArray, NullBufferBuilder, PrimitiveArray, +}; +use arrow::buffer::{OffsetBuffer, ScalarBuffer}; +use arrow::datatypes::{DataType, Field, Int64Type}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_spark::function::array::slice; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +fn create_inputs( + rng: &mut StdRng, + size: usize, + child_array_size: usize, + null_density: f32, +) -> (ListArray, ListViewArray) { + let mut nulls_builder = NullBufferBuilder::new(size); + let mut sizes = Vec::with_capacity(size); + + for _ in 0..size { + if rng.random::() < null_density { + nulls_builder.append_null(); + } else { + nulls_builder.append_non_null(); + } + sizes.push(rng.random_range(1..child_array_size)); + } + let nulls = nulls_builder.finish(); + + let length = sizes.iter().sum(); + let values: PrimitiveArray = + (0..length).map(|_| Some(rng.random())).collect(); + let values = Arc::new(values); + + let offsets = OffsetBuffer::from_lengths(sizes.clone()); + let list_array = ListArray::new( + Arc::new(Field::new_list_field(DataType::Int64, true)), + offsets.clone(), + values.clone(), + nulls.clone(), + ); + + let offsets = ScalarBuffer::from(offsets.slice(0, size - 1)); + let sizes = ScalarBuffer::from_iter(sizes.into_iter().map(|v| v as i32)); + let list_view_array = ListViewArray::new( + Arc::new(Field::new_list_field(DataType::Int64, true)), + offsets, + sizes, + values, + nulls, + ); + + (list_array, list_view_array) +} + +fn random_from_to( + rng: &mut StdRng, + size: i64, + null_density: f32, +) -> (Option, Option) { + let from = if rng.random::() < null_density { + None + } else { + Some(rng.random_range(1..=size)) + }; + + let to = if rng.random::() < null_density { + None + } else { + match from { + Some(from) => Some(rng.random_range(from..=size)), + None => Some(rng.random_range(1..=size)), + } + }; + + (from, to) +} + +fn array_slice_benchmark( + name: &str, + input: ColumnarValue, + mut args: Vec, + c: &mut Criterion, + size: usize, +) { + args.insert(0, input); + + let array_slice = slice(); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + >::from(Field::new(format!("arg_{idx}"), arg.data_type(), true)) + }) + .collect::>(); + c.bench_function(name, |b| { + b.iter(|| { + black_box( + array_slice + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new_list_field(args[0].data_type(), true) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + let rng = &mut StdRng::seed_from_u64(42); + let size = 1_000_000; + let child_array_size = 100; + let null_density = 0.1; + + let (list_array, list_view_array) = + create_inputs(rng, size, child_array_size, null_density); + + let mut array_from = Vec::with_capacity(size); + let mut array_to = Vec::with_capacity(size); + for child_array_size in list_array.offsets().lengths() { + let (from, to) = random_from_to(rng, child_array_size as i64, null_density); + array_from.push(from); + array_to.push(to); + } + + // input + let list_array = ColumnarValue::Array(Arc::new(list_array)); + let list_view_array = ColumnarValue::Array(Arc::new(list_view_array)); + + // args + let array_from = ColumnarValue::Array(Arc::new(Int64Array::from(array_from))); + let array_to = ColumnarValue::Array(Arc::new(Int64Array::from(array_to))); + let scalar_from = ColumnarValue::Scalar(ScalarValue::from(1i64)); + let scalar_to = ColumnarValue::Scalar(ScalarValue::from(child_array_size as i64 / 2)); + + for input in [list_array, list_view_array] { + let input_type = input.data_type().to_string(); + + array_slice_benchmark( + &format!("slice: input {input_type}, array args, no stride"), + input.clone(), + vec![array_from.clone(), array_to.clone()], + c, + size, + ); + + array_slice_benchmark( + &format!("slice: input {input_type}, scalar args, no stride"), + input.clone(), + vec![scalar_from.clone(), scalar_to.clone()], + c, + size, + ); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/spark/benches/space.rs b/datafusion/spark/benches/space.rs index 8ace7219a1dcc..bd9d370ca37fe 100644 --- a/datafusion/spark/benches/space.rs +++ b/datafusion/spark/benches/space.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::PrimitiveArray; use arrow::datatypes::{DataType, Field, Int32Type}; use criterion::{Criterion, criterion_group, criterion_main}; diff --git a/datafusion/spark/benches/substring.rs b/datafusion/spark/benches/substring.rs new file mode 100644 index 0000000000000..d6eac817c322f --- /dev/null +++ b/datafusion/spark/benches/substring.rs @@ -0,0 +1,205 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, Int64Array, OffsetSizeTrait}; +use arrow::datatypes::{DataType, Field}; +use arrow::util::bench_util::{ + create_string_array_with_len, create_string_view_array_with_len, +}; +use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; +use datafusion_common::DataFusionError; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_spark::function::string::substring; +use std::hint::black_box; +use std::sync::Arc; + +fn create_args_without_count( + size: usize, + str_len: usize, + start_half_way: bool, + force_view_types: bool, +) -> Vec { + let start_array = Arc::new(Int64Array::from( + (0..size) + .map(|_| { + if start_half_way { + (str_len / 2) as i64 + } else { + 1i64 + } + }) + .collect::>(), + )); + + if force_view_types { + let string_array = + Arc::new(create_string_view_array_with_len(size, 0.1, str_len, false)); + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(start_array), + ] + } else { + let string_array = + Arc::new(create_string_array_with_len::(size, 0.1, str_len)); + + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(Arc::clone(&start_array) as ArrayRef), + ] + } +} + +fn create_args_with_count( + size: usize, + str_len: usize, + count_max: usize, + force_view_types: bool, +) -> Vec { + let start_array = + Arc::new(Int64Array::from((0..size).map(|_| 1).collect::>())); + let count = count_max.min(str_len) as i64; + let count_array = Arc::new(Int64Array::from( + (0..size).map(|_| count).collect::>(), + )); + + if force_view_types { + let string_array = + Arc::new(create_string_view_array_with_len(size, 0.1, str_len, false)); + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(start_array), + ColumnarValue::Array(count_array), + ] + } else { + let string_array = + Arc::new(create_string_array_with_len::(size, 0.1, str_len)); + + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(Arc::clone(&start_array) as ArrayRef), + ColumnarValue::Array(Arc::clone(&count_array) as ArrayRef), + ] + } +} + +#[expect(clippy::needless_pass_by_value)] +fn invoke_substr_with_args( + args: Vec, + number_rows: usize, +) -> Result { + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + + substring().invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields, + number_rows, + return_field: Field::new("f", DataType::Utf8View, true).into(), + config_options: Arc::clone(&config_options), + }) +} + +fn criterion_benchmark(c: &mut Criterion) { + for size in [1024, 4096] { + // string_len = 12, substring_len=6 (see `create_args_without_count`) + let len = 12; + let mut group = c.benchmark_group("SHORTER THAN 12"); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + + let args = create_args_without_count::(size, len, true, true); + group.bench_function( + format!("substr_string_view [size={size}, strlen={len}]"), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + let args = create_args_without_count::(size, len, false, false); + group.bench_function(format!("substr_string [size={size}, strlen={len}]"), |b| { + b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))) + }); + + let args = create_args_without_count::(size, len, true, false); + group.bench_function( + format!("substr_large_string [size={size}, strlen={len}]"), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + group.finish(); + + // string_len = 128, start=1, count=64, substring_len=64 + let len = 128; + let count = 64; + let mut group = c.benchmark_group("LONGER THAN 12"); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + + let args = create_args_with_count::(size, len, count, true); + group.bench_function( + format!("substr_string_view [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + let args = create_args_with_count::(size, len, count, false); + group.bench_function( + format!("substr_string [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + let args = create_args_with_count::(size, len, count, false); + group.bench_function( + format!("substr_large_string [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + group.finish(); + + // string_len = 128, start=1, count=6, substring_len=6 + let len = 128; + let count = 6; + let mut group = c.benchmark_group("SRC_LEN > 12, SUB_LEN < 12"); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + + let args = create_args_with_count::(size, len, count, true); + group.bench_function( + format!("substr_string_view [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + let args = create_args_with_count::(size, len, count, false); + group.bench_function( + format!("substr_string [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + let args = create_args_with_count::(size, len, count, false); + group.bench_function( + format!("substr_large_string [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + group.finish(); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/spark/benches/unhex.rs b/datafusion/spark/benches/unhex.rs new file mode 100644 index 0000000000000..7dce683485bc7 --- /dev/null +++ b/datafusion/spark/benches/unhex.rs @@ -0,0 +1,146 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, LargeStringArray, LargeStringBuilder, StringArray, StringBuilder, + StringViewArray, StringViewBuilder, +}; +use arrow::datatypes::{DataType, Field}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_spark::function::math::unhex::SparkUnhex; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +fn generate_hex_string_data(size: usize, null_density: f32) -> StringArray { + let mut rng = StdRng::seed_from_u64(42); + let mut builder = StringBuilder::with_capacity(size, 0); + let hex_chars = b"0123456789abcdefABCDEF"; + + for _ in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + let len = rng.random_range::(2..=100); + let s: String = std::iter::repeat_with(|| { + hex_chars[rng.random_range(0..hex_chars.len())] as char + }) + .take(len) + .collect(); + builder.append_value(&s); + } + } + builder.finish() +} + +fn generate_hex_large_string_data(size: usize, null_density: f32) -> LargeStringArray { + let mut rng = StdRng::seed_from_u64(42); + let mut builder = LargeStringBuilder::with_capacity(size, 0); + let hex_chars = b"0123456789abcdefABCDEF"; + + for _ in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + let len = rng.random_range::(2..=100); + let s: String = std::iter::repeat_with(|| { + hex_chars[rng.random_range(0..hex_chars.len())] as char + }) + .take(len) + .collect(); + builder.append_value(&s); + } + } + builder.finish() +} + +fn generate_hex_utf8view_data(size: usize, null_density: f32) -> StringViewArray { + let mut rng = StdRng::seed_from_u64(42); + let mut builder = StringViewBuilder::with_capacity(size); + let hex_chars = b"0123456789abcdefABCDEF"; + + for _ in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + let len = rng.random_range::(2..=100); + let s: String = std::iter::repeat_with(|| { + hex_chars[rng.random_range(0..hex_chars.len())] as char + }) + .take(len) + .collect(); + builder.append_value(&s); + } + } + builder.finish() +} + +fn run_benchmark(c: &mut Criterion, name: &str, size: usize, array: Arc) { + let unhex_func = SparkUnhex::new(); + let args = vec![ColumnarValue::Array(array)]; + let arg_fields: Vec<_> = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function(&format!("{name}/size={size}"), |b| { + b.iter(|| { + black_box( + unhex_func + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Arc::new(Field::new("f", DataType::Binary, true)), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + let sizes = vec![1024, 4096, 8192]; + let null_density = 0.1; + + // Benchmark with hex string + for &size in &sizes { + let data = generate_hex_string_data(size, null_density); + run_benchmark(c, "unhex_utf8", size, Arc::new(data)); + } + + // Benchmark with hex large string + for &size in &sizes { + let data = generate_hex_large_string_data(size, null_density); + run_benchmark(c, "unhex_large_utf8", size, Arc::new(data)); + } + + // Benchmark with hex Utf8View + for &size in &sizes { + let data = generate_hex_utf8view_data(size, null_density); + run_benchmark(c, "unhex_utf8view", size, Arc::new(data)); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/spark/src/function/array/mod.rs b/datafusion/spark/src/function/array/mod.rs index 7140653510e09..0d4cd40d99329 100644 --- a/datafusion/spark/src/function/array/mod.rs +++ b/datafusion/spark/src/function/array/mod.rs @@ -17,6 +17,7 @@ pub mod repeat; pub mod shuffle; +pub mod slice; pub mod spark_array; use datafusion_expr::ScalarUDF; @@ -26,6 +27,7 @@ use std::sync::Arc; make_udf_function!(spark_array::SparkArray, array); make_udf_function!(shuffle::SparkShuffle, shuffle); make_udf_function!(repeat::SparkArrayRepeat, array_repeat); +make_udf_function!(slice::SparkSlice, slice); pub mod expr_fn { use datafusion_functions::export_functions; @@ -41,8 +43,13 @@ pub mod expr_fn { "returns an array containing element count times.", element count )); + export_functions!(( + slice, + "Returns a slice of the array from the start index with the given length.", + array start length + )); } pub fn functions() -> Vec> { - vec![array(), shuffle(), array_repeat()] + vec![array(), shuffle(), array_repeat(), slice()] } diff --git a/datafusion/spark/src/function/array/slice.rs b/datafusion/spark/src/function/array/slice.rs new file mode 100644 index 0000000000000..6c168a4f491b5 --- /dev/null +++ b/datafusion/spark/src/function/array/slice.rs @@ -0,0 +1,172 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, Int64Builder}; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::cast::{as_int64_array, as_list_array}; +use datafusion_common::utils::ListCoercion; +use datafusion_common::{Result, exec_err, internal_err, utils::take_function_args}; +use datafusion_expr::{ + ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ReturnFieldArgs, + ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use datafusion_functions_nested::extract::array_slice_udf; +use std::any::Any; +use std::sync::Arc; + +/// Spark slice function implementation +/// Main difference from DataFusion's array_slice is that the third argument is the length of the slice and not the end index. +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkSlice { + signature: Signature, +} + +impl Default for SparkSlice { + fn default() -> Self { + Self::new() + } +} + +impl SparkSlice { + pub fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Index, + ArrayFunctionArgument::Index, + ], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }, + ), + volatility: Volatility::Immutable, + parameter_names: None, + }, + } + } +} + +impl ScalarUDFImpl for SparkSlice { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "slice" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + + Ok(Arc::new(Field::new( + "slice", + args.arg_fields[0].data_type().clone(), + nullable, + ))) + } + + fn invoke_with_args( + &self, + mut func_args: ScalarFunctionArgs, + ) -> Result { + let array_len = func_args + .args + .iter() + .find_map(|arg| match arg { + ColumnarValue::Array(array) => Some(array.len()), + _ => None, + }) + .unwrap_or(func_args.number_rows); + + let arrays = func_args + .args + .iter() + .map(|arg| match arg { + ColumnarValue::Array(array) => Ok(Arc::clone(array)), + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(array_len), + }) + .collect::>>()?; + + let (start, end) = calculate_start_end(&arrays)?; + + array_slice_udf().invoke_with_args(ScalarFunctionArgs { + args: vec![ + func_args.args.swap_remove(0), + ColumnarValue::Array(start), + ColumnarValue::Array(end), + ], + arg_fields: func_args.arg_fields, + number_rows: func_args.number_rows, + return_field: func_args.return_field, + config_options: func_args.config_options, + }) + } +} + +fn calculate_start_end(args: &[ArrayRef]) -> Result<(ArrayRef, ArrayRef)> { + let [values, start, length] = take_function_args("slice", args)?; + + let values_len = values.len(); + + let start = as_int64_array(&start)?; + let length = as_int64_array(&length)?; + + let values = as_list_array(values)?; + + let mut adjusted_start = Int64Builder::with_capacity(values_len); + let mut end = Int64Builder::with_capacity(values_len); + + for row in 0..values_len { + if values.is_null(row) || start.is_null(row) || length.is_null(row) { + adjusted_start.append_null(); + end.append_null(); + continue; + } + let start = start.value(row); + let length = length.value(row); + let value_length = values.value(row).len() as i64; + + if start == 0 { + return exec_err!("Start index must not be zero"); + } + if length < 0 { + return exec_err!("Length must be non-negative, but got {}", length); + } + + let adjusted_start_value = if start < 0 { + start + value_length + 1 + } else { + start + }; + + adjusted_start.append_value(adjusted_start_value); + end.append_value(adjusted_start_value + (length - 1)); + } + + Ok((Arc::new(adjusted_start.finish()), Arc::new(end.finish()))) +} diff --git a/datafusion/spark/src/function/array/spark_array.rs b/datafusion/spark/src/function/array/spark_array.rs index 6d9f9a1695e1b..1ad0a394b8ca6 100644 --- a/datafusion/spark/src/function/array/spark_array.rs +++ b/datafusion/spark/src/function/array/spark_array.rs @@ -23,7 +23,7 @@ use datafusion_common::utils::SingleRowListArrayBuilder; use datafusion_common::{Result, internal_err}; use datafusion_expr::{ ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, - TypeSignature, Volatility, + Volatility, }; use datafusion_functions_nested::make_array::{array_array, coerce_types_inner}; @@ -45,10 +45,7 @@ impl Default for SparkArray { impl SparkArray { pub fn new() -> Self { Self { - signature: Signature::one_of( - vec![TypeSignature::UserDefined, TypeSignature::Nullary], - Volatility::Immutable, - ), + signature: Signature::user_defined(Volatility::Immutable), } } } @@ -104,12 +101,12 @@ impl ScalarUDFImpl for SparkArray { make_scalar_function(make_array_inner)(args.as_slice()) } - fn aliases(&self) -> &[String] { - &[] - } - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - coerce_types_inner(arg_types, self.name()) + if arg_types.is_empty() { + Ok(vec![]) + } else { + coerce_types_inner(arg_types, self.name()) + } } } diff --git a/datafusion/spark/src/function/bitmap/bitmap_bit_position.rs b/datafusion/spark/src/function/bitmap/bitmap_bit_position.rs new file mode 100644 index 0000000000000..3871d00cc91d8 --- /dev/null +++ b/datafusion/spark/src/function/bitmap/bitmap_bit_position.rs @@ -0,0 +1,141 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, AsArray, Int64Array}; +use arrow::datatypes::Field; +use arrow::datatypes::{DataType, FieldRef, Int8Type, Int16Type, Int32Type, Int64Type}; +use datafusion::logical_expr::{ColumnarValue, Signature, TypeSignature, Volatility}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, internal_err}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_functions::utils::make_scalar_function; +use std::any::Any; +use std::sync::Arc; + +/// Spark-compatible `bitmap_bit_position` expression +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct BitmapBitPosition { + signature: Signature, +} + +impl Default for BitmapBitPosition { + fn default() -> Self { + Self::new() + } +} + +impl BitmapBitPosition { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Int8]), + TypeSignature::Exact(vec![DataType::Int16]), + TypeSignature::Exact(vec![DataType::Int32]), + TypeSignature::Exact(vec![DataType::Int64]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for BitmapBitPosition { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "bitmap_bit_position" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args( + &self, + args: datafusion_expr::ReturnFieldArgs, + ) -> Result { + Ok(Arc::new(Field::new( + self.name(), + DataType::Int64, + args.arg_fields[0].is_nullable(), + ))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(bitmap_bit_position_inner, vec![])(&args.args) + } +} + +pub fn bitmap_bit_position_inner(arg: &[ArrayRef]) -> Result { + let [array] = take_function_args("bitmap_bit_position", arg)?; + match &array.data_type() { + DataType::Int8 => { + let result: Int64Array = array + .as_primitive::() + .iter() + .map(|opt| opt.map(|value| bitmap_bit_position(value.into()))) + .collect(); + Ok(Arc::new(result)) + } + DataType::Int16 => { + let result: Int64Array = array + .as_primitive::() + .iter() + .map(|opt| opt.map(|value| bitmap_bit_position(value.into()))) + .collect(); + Ok(Arc::new(result)) + } + DataType::Int32 => { + let result: Int64Array = array + .as_primitive::() + .iter() + .map(|opt| opt.map(|value| bitmap_bit_position(value.into()))) + .collect(); + Ok(Arc::new(result)) + } + DataType::Int64 => { + let result: Int64Array = array + .as_primitive::() + .iter() + .map(|opt| opt.map(bitmap_bit_position)) + .collect(); + Ok(Arc::new(result)) + } + data_type => { + internal_err!("bitmap_bit_position does not support {data_type}") + } + } +} + +const NUM_BYTES: i64 = 4 * 1024; +const NUM_BITS: i64 = NUM_BYTES * 8; + +fn bitmap_bit_position(value: i64) -> i64 { + if value > 0 { + (value - 1) % NUM_BITS + } else { + (value.wrapping_neg()) % NUM_BITS + } +} diff --git a/datafusion/spark/src/function/bitmap/mod.rs b/datafusion/spark/src/function/bitmap/mod.rs index 8532c32ac9c5f..1a7dce02db3a3 100644 --- a/datafusion/spark/src/function/bitmap/mod.rs +++ b/datafusion/spark/src/function/bitmap/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +pub mod bitmap_bit_position; pub mod bitmap_count; use datafusion_expr::ScalarUDF; @@ -22,6 +23,7 @@ use datafusion_functions::make_udf_function; use std::sync::Arc; make_udf_function!(bitmap_count::BitmapCount, bitmap_count); +make_udf_function!(bitmap_bit_position::BitmapBitPosition, bitmap_bit_position); pub mod expr_fn { use datafusion_functions::export_functions; @@ -31,8 +33,13 @@ pub mod expr_fn { "Returns the number of set bits in the input bitmap.", arg )); + export_functions!(( + bitmap_bit_position, + "Returns the bit position for the given input child expression.", + arg + )); } pub fn functions() -> Vec> { - vec![bitmap_count()] + vec![bitmap_count(), bitmap_bit_position()] } diff --git a/datafusion/spark/src/function/bitwise/bitwise_not.rs b/datafusion/spark/src/function/bitwise/bitwise_not.rs index 5f8cf36911f43..e7285d4804950 100644 --- a/datafusion/spark/src/function/bitwise/bitwise_not.rs +++ b/datafusion/spark/src/function/bitwise/bitwise_not.rs @@ -73,25 +73,11 @@ impl ScalarUDFImpl for SparkBitwiseNot { } fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { - if args.arg_fields.len() != 1 { - return plan_err!("bitwise_not expects exactly 1 argument"); - } - - let input_field = &args.arg_fields[0]; - - let out_dt = input_field.data_type().clone(); - let mut out_nullable = input_field.is_nullable(); - - let scalar_null_present = args - .scalar_arguments - .iter() - .any(|opt_s| opt_s.is_some_and(|sv| sv.is_null())); - - if scalar_null_present { - out_nullable = true; - } - - Ok(Arc::new(Field::new(self.name(), out_dt, out_nullable))) + Ok(Arc::new(Field::new( + self.name(), + args.arg_fields[0].data_type().clone(), + args.arg_fields[0].is_nullable(), + ))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -196,32 +182,4 @@ mod tests { assert!(out_i64_null.is_nullable()); assert_eq!(out_i64_null.data_type(), &DataType::Int64); } - - #[test] - fn test_bitwise_not_nullability_with_null_scalar() -> Result<()> { - use arrow::datatypes::{DataType, Field}; - use datafusion_common::ScalarValue; - use std::sync::Arc; - - let func = SparkBitwiseNot::new(); - - let non_nullable: FieldRef = Arc::new(Field::new("col", DataType::Int32, false)); - - let out = func.return_field_from_args(ReturnFieldArgs { - arg_fields: &[Arc::clone(&non_nullable)], - scalar_arguments: &[None], - })?; - assert!(!out.is_nullable()); - assert_eq!(out.data_type(), &DataType::Int32); - - let null_scalar = ScalarValue::Int32(None); - let out_with_null_scalar = func.return_field_from_args(ReturnFieldArgs { - arg_fields: &[Arc::clone(&non_nullable)], - scalar_arguments: &[Some(&null_scalar)], - })?; - assert!(out_with_null_scalar.is_nullable()); - assert_eq!(out_with_null_scalar.data_type(), &DataType::Int32); - - Ok(()) - } } diff --git a/datafusion/spark/src/function/collection/mod.rs b/datafusion/spark/src/function/collection/mod.rs index a87df9a2c87a0..6871e3aba6469 100644 --- a/datafusion/spark/src/function/collection/mod.rs +++ b/datafusion/spark/src/function/collection/mod.rs @@ -15,11 +15,20 @@ // specific language governing permissions and limitations // under the License. +pub mod size; + use datafusion_expr::ScalarUDF; +use datafusion_functions::make_udf_function; use std::sync::Arc; -pub mod expr_fn {} +make_udf_function!(size::SparkSize, size); + +pub mod expr_fn { + use datafusion_functions::export_functions; + + export_functions!((size, "Return the size of an array or map.", arg)); +} pub fn functions() -> Vec> { - vec![] + vec![size()] } diff --git a/datafusion/spark/src/function/collection/size.rs b/datafusion/spark/src/function/collection/size.rs new file mode 100644 index 0000000000000..05b8ba315675c --- /dev/null +++ b/datafusion/spark/src/function/collection/size.rs @@ -0,0 +1,162 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, AsArray, Int32Array}; +use arrow::compute::kernels::length::length as arrow_length; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::{Result, plan_err}; +use datafusion_expr::{ + ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ReturnFieldArgs, + ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; +use std::any::Any; +use std::sync::Arc; + +/// Spark-compatible `size` function. +/// +/// Returns the number of elements in an array or the number of key-value pairs in a map. +/// Returns -1 for null input (Spark behavior). +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkSize { + signature: Signature, +} + +impl Default for SparkSize { + fn default() -> Self { + Self::new() + } +} + +impl SparkSize { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + // Array Type + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array], + array_coercion: None, + }), + // Map Type + TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkSize { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "size" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { + // nullable=false for legacy behavior (NULL -> -1); set to input nullability for null-on-null + Ok(Arc::new(Field::new(self.name(), DataType::Int32, false))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_size_inner, vec![])(&args.args) + } +} + +fn spark_size_inner(args: &[ArrayRef]) -> Result { + let array = &args[0]; + + match array.data_type() { + DataType::List(_) => { + if array.null_count() == 0 { + Ok(arrow_length(array)?) + } else { + let list_array = array.as_list::(); + let lengths: Vec = list_array + .offsets() + .lengths() + .enumerate() + .map(|(i, len)| if array.is_null(i) { -1 } else { len as i32 }) + .collect(); + Ok(Arc::new(Int32Array::from(lengths))) + } + } + DataType::FixedSizeList(_, size) => { + if array.null_count() == 0 { + Ok(arrow_length(array)?) + } else { + let length: Vec = (0..array.len()) + .map(|i| if array.is_null(i) { -1 } else { *size }) + .collect(); + Ok(Arc::new(Int32Array::from(length))) + } + } + DataType::LargeList(_) => { + // Arrow length kernel returns Int64 for LargeList + let list_array = array.as_list::(); + if array.null_count() == 0 { + let lengths: Vec = list_array + .offsets() + .lengths() + .map(|len| len as i32) + .collect(); + Ok(Arc::new(Int32Array::from(lengths))) + } else { + let lengths: Vec = list_array + .offsets() + .lengths() + .enumerate() + .map(|(i, len)| if array.is_null(i) { -1 } else { len as i32 }) + .collect(); + Ok(Arc::new(Int32Array::from(lengths))) + } + } + DataType::Map(_, _) => { + let map_array = array.as_map(); + let length: Vec = if array.null_count() == 0 { + map_array + .offsets() + .lengths() + .map(|len| len as i32) + .collect() + } else { + map_array + .offsets() + .lengths() + .enumerate() + .map(|(i, len)| if array.is_null(i) { -1 } else { len as i32 }) + .collect() + }; + Ok(Arc::new(Int32Array::from(length))) + } + DataType::Null => Ok(Arc::new(Int32Array::from(vec![-1; array.len()]))), + dt => { + plan_err!("size function does not support type: {}", dt) + } + } +} diff --git a/datafusion/spark/src/function/datetime/add_months.rs b/datafusion/spark/src/function/datetime/add_months.rs new file mode 100644 index 0000000000000..fa9f6fa8db945 --- /dev/null +++ b/datafusion/spark/src/function/datetime/add_months.rs @@ -0,0 +1,95 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::ops::Add; +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, FieldRef, IntervalUnit}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, internal_err}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + ColumnarValue, Expr, ExprSchemable, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, Volatility, +}; + +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkAddMonths { + signature: Signature, +} + +impl Default for SparkAddMonths { + fn default() -> Self { + Self::new() + } +} + +impl SparkAddMonths { + pub fn new() -> Self { + Self { + signature: Signature::exact( + vec![DataType::Date32, DataType::Int32], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkAddMonths { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "add_months" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + + Ok(Arc::new(Field::new( + self.name(), + DataType::Date32, + nullable, + ))) + } + + fn simplify( + &self, + args: Vec, + info: &SimplifyContext, + ) -> Result { + let [date_arg, months_arg] = take_function_args("add_months", args)?; + let interval = months_arg + .cast_to(&DataType::Interval(IntervalUnit::YearMonth), info.schema())?; + Ok(ExprSimplifyResult::Simplified(date_arg.add(interval))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("invoke should not be called on a simplified add_months() function") + } +} diff --git a/datafusion/spark/src/function/datetime/date_add.rs b/datafusion/spark/src/function/datetime/date_add.rs index 78b9c904cee37..3745f77969f22 100644 --- a/datafusion/spark/src/function/datetime/date_add.rs +++ b/datafusion/spark/src/function/datetime/date_add.rs @@ -82,12 +82,7 @@ impl ScalarUDFImpl for SparkDateAdd { } fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { - let nullable = args.arg_fields.iter().any(|f| f.is_nullable()) - || args - .scalar_arguments - .iter() - .any(|arg| matches!(arg, Some(sv) if sv.is_null())); - + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); Ok(Arc::new(Field::new( self.name(), DataType::Date32, @@ -142,7 +137,6 @@ fn spark_date_add(args: &[ArrayRef]) -> Result { mod tests { use super::*; use arrow::datatypes::Field; - use datafusion_common::ScalarValue; #[test] fn test_date_add_non_nullable_inputs() { @@ -181,25 +175,4 @@ mod tests { assert_eq!(ret_field.data_type(), &DataType::Date32); assert!(ret_field.is_nullable()); } - - #[test] - fn test_date_add_null_scalar() { - let func = SparkDateAdd::new(); - let args = &[ - Arc::new(Field::new("date", DataType::Date32, false)), - Arc::new(Field::new("num", DataType::Int32, false)), - ]; - - let null_scalar = ScalarValue::Int32(None); - - let ret_field = func - .return_field_from_args(ReturnFieldArgs { - arg_fields: args, - scalar_arguments: &[None, Some(&null_scalar)], - }) - .unwrap(); - - assert_eq!(ret_field.data_type(), &DataType::Date32); - assert!(ret_field.is_nullable()); - } } diff --git a/datafusion/spark/src/function/datetime/date_diff.rs b/datafusion/spark/src/function/datetime/date_diff.rs new file mode 100644 index 0000000000000..094c35eec56b5 --- /dev/null +++ b/datafusion/spark/src/function/datetime/date_diff.rs @@ -0,0 +1,119 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::types::{NativeType, logical_date, logical_string}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, internal_err}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + Coercion, ColumnarValue, Expr, ExprSchemable, Operator, ReturnFieldArgs, + ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, + binary_expr, +}; + +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkDateDiff { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkDateDiff { + fn default() -> Self { + Self::new() + } +} + +impl SparkDateDiff { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![ + Coercion::new_implicit( + TypeSignatureClass::Native(logical_date()), + vec![ + TypeSignatureClass::Native(logical_string()), + TypeSignatureClass::Timestamp, + ], + NativeType::Date, + ), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_date()), + vec![ + TypeSignatureClass::Native(logical_string()), + TypeSignatureClass::Timestamp, + ], + NativeType::Date, + ), + ], + Volatility::Immutable, + ), + aliases: vec!["datediff".to_string()], + } + } +} + +impl ScalarUDFImpl for SparkDateDiff { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "date_diff" + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + Ok(Arc::new(Field::new(self.name(), DataType::Int32, nullable))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!( + "Apache Spark `date_diff` should have been simplified to standard subtraction" + ) + } + + fn simplify( + &self, + args: Vec, + info: &SimplifyContext, + ) -> Result { + let [end, start] = take_function_args(self.name(), args)?; + let end = end.cast_to(&DataType::Date32, info.schema())?; + let start = start.cast_to(&DataType::Date32, info.schema())?; + Ok(ExprSimplifyResult::Simplified( + binary_expr(end, Operator::Minus, start) + .cast_to(&DataType::Int32, info.schema())?, + )) + } +} diff --git a/datafusion/spark/src/function/datetime/date_part.rs b/datafusion/spark/src/function/datetime/date_part.rs new file mode 100644 index 0000000000000..e30a162ef42db --- /dev/null +++ b/datafusion/spark/src/function/datetime/date_part.rs @@ -0,0 +1,139 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::types::logical_date; +use datafusion_common::{ + Result, ScalarValue, internal_err, types::logical_string, utils::take_function_args, +}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + Coercion, ColumnarValue, Expr, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, + Signature, TypeSignature, TypeSignatureClass, Volatility, +}; +use std::{any::Any, sync::Arc}; + +/// Wrapper around datafusion date_part function to handle +/// Spark behavior returning day of the week 1-indexed instead of 0-indexed and different part aliases. +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkDatePart { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkDatePart { + fn default() -> Self { + Self::new() + } +} + +impl SparkDatePart { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Timestamp), + ]), + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Native(logical_date())), + ]), + ], + Volatility::Immutable, + ), + aliases: vec![String::from("datepart")], + } + } +} + +impl ScalarUDFImpl for SparkDatePart { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "date_part" + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("Use return_field_from_args in this case instead.") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + + Ok(Arc::new(Field::new(self.name(), DataType::Int32, nullable))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("spark date_part should have been simplified to standard date_part") + } + + fn simplify( + &self, + args: Vec, + _info: &SimplifyContext, + ) -> Result { + let [part_expr, date_expr] = take_function_args(self.name(), args)?; + + let part = match part_expr.as_literal() { + Some(ScalarValue::Utf8(Some(v))) + | Some(ScalarValue::Utf8View(Some(v))) + | Some(ScalarValue::LargeUtf8(Some(v))) => v.to_lowercase(), + _ => { + return internal_err!( + "First argument of `DATE_PART` must be non-null scalar Utf8" + ); + } + }; + + // Map Spark-specific date part aliases to datafusion ones + let part = match part.as_str() { + "yearofweek" | "year_iso" => "isoyear", + "dayofweek" => "dow", + "dayofweek_iso" | "dow_iso" => "isodow", + other => other, + }; + + let part_expr = Expr::Literal(ScalarValue::new_utf8(part), None); + + let date_part_expr = Expr::ScalarFunction(ScalarFunction::new_udf( + datafusion_functions::datetime::date_part(), + vec![part_expr, date_expr], + )); + + match part { + // Add 1 for day-of-week parts to convert 0-indexed to 1-indexed + "dow" | "isodow" => Ok(ExprSimplifyResult::Simplified( + date_part_expr + Expr::Literal(ScalarValue::Int32(Some(1)), None), + )), + _ => Ok(ExprSimplifyResult::Simplified(date_part_expr)), + } + } +} diff --git a/datafusion/spark/src/function/datetime/date_sub.rs b/datafusion/spark/src/function/datetime/date_sub.rs index 34894317f67d3..af1b8d5a4e91e 100644 --- a/datafusion/spark/src/function/datetime/date_sub.rs +++ b/datafusion/spark/src/function/datetime/date_sub.rs @@ -75,12 +75,7 @@ impl ScalarUDFImpl for SparkDateSub { } fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { - let nullable = args.arg_fields.iter().any(|f| f.is_nullable()) - || args - .scalar_arguments - .iter() - .any(|arg| matches!(arg, Some(sv) if sv.is_null())); - + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); Ok(Arc::new(Field::new( self.name(), DataType::Date32, @@ -139,7 +134,6 @@ fn spark_date_sub(args: &[ArrayRef]) -> Result { #[cfg(test)] mod tests { use super::*; - use datafusion_common::ScalarValue; #[test] fn test_date_sub_nullability_non_nullable_args() { @@ -174,22 +168,4 @@ mod tests { assert!(result.is_nullable()); assert_eq!(result.data_type(), &DataType::Date32); } - - #[test] - fn test_date_sub_nullability_scalar_null_argument() { - let udf = SparkDateSub::new(); - let date_field = Arc::new(Field::new("d", DataType::Date32, false)); - let days_field = Arc::new(Field::new("n", DataType::Int32, false)); - let null_scalar = ScalarValue::Int32(None); - - let result = udf - .return_field_from_args(ReturnFieldArgs { - arg_fields: &[date_field, days_field], - scalar_arguments: &[None, Some(&null_scalar)], - }) - .unwrap(); - - assert!(result.is_nullable()); - assert_eq!(result.data_type(), &DataType::Date32); - } } diff --git a/datafusion/spark/src/function/datetime/date_trunc.rs b/datafusion/spark/src/function/datetime/date_trunc.rs new file mode 100644 index 0000000000000..2199c90703b38 --- /dev/null +++ b/datafusion/spark/src/function/datetime/date_trunc.rs @@ -0,0 +1,172 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, FieldRef, TimeUnit}; +use datafusion_common::types::{NativeType, logical_string}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, internal_err, plan_err}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + Coercion, ColumnarValue, Expr, ExprSchemable, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, +}; + +/// Spark date_trunc supports extra format aliases. +/// It also handles timestamps with timezones by converting to session timezone first. +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkDateTrunc { + signature: Signature, +} + +impl Default for SparkDateTrunc { + fn default() -> Self { + Self::new() + } +} + +impl SparkDateTrunc { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_implicit( + TypeSignatureClass::Timestamp, + vec![TypeSignatureClass::Native(logical_string())], + NativeType::Timestamp(TimeUnit::Microsecond, None), + ), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkDateTrunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "date_trunc" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + + Ok(Arc::new(Field::new( + self.name(), + args.arg_fields[1].data_type().clone(), + nullable, + ))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!( + "spark date_trunc should have been simplified to standard date_trunc" + ) + } + + fn simplify( + &self, + args: Vec, + info: &SimplifyContext, + ) -> Result { + let [fmt_expr, ts_expr] = take_function_args(self.name(), args)?; + + let fmt = match fmt_expr.as_literal() { + Some(ScalarValue::Utf8(Some(v))) + | Some(ScalarValue::Utf8View(Some(v))) + | Some(ScalarValue::LargeUtf8(Some(v))) => v.to_lowercase(), + _ => { + return plan_err!( + "First argument of `DATE_TRUNC` must be non-null scalar Utf8" + ); + } + }; + + // Map Spark-specific fmt aliases to datafusion ones + let fmt = match fmt.as_str() { + "yy" | "yyyy" => "year", + "mm" | "mon" => "month", + "dd" => "day", + other => other, + }; + + let session_tz = info.config_options().execution.time_zone.clone(); + let ts_type = ts_expr.get_type(info.schema())?; + + // Spark interprets timestamps in the session timezone before truncating, + // then returns a timestamp at microsecond precision. + // See: https://github.com/apache/spark/blob/f310f4fcc95580a6824bc7d22b76006f79b8804a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala#L492 + // + // For sub-second truncations (second, millisecond, microsecond), timezone + // adjustment is unnecessary since timezone offsets are whole seconds. + let ts_expr = match (&ts_type, fmt) { + // Sub-second truncations don't need timezone adjustment + (_, "second" | "millisecond" | "microsecond") => ts_expr, + + // convert to session timezone, strip timezone and convert back to original timezone + (DataType::Timestamp(unit, tz), _) => { + let ts_expr = match &session_tz { + Some(session_tz) => ts_expr.cast_to( + &DataType::Timestamp( + TimeUnit::Microsecond, + Some(Arc::from(session_tz.as_str())), + ), + info.schema(), + )?, + None => ts_expr, + }; + Expr::ScalarFunction(ScalarFunction::new_udf( + datafusion_functions::datetime::to_local_time(), + vec![ts_expr], + )) + .cast_to(&DataType::Timestamp(*unit, tz.clone()), info.schema())? + } + + _ => { + return plan_err!( + "Second argument of `DATE_TRUNC` must be Timestamp, got {}", + ts_type + ); + } + }; + + let fmt_expr = Expr::Literal(ScalarValue::new_utf8(fmt), None); + + Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction( + ScalarFunction::new_udf( + datafusion_functions::datetime::date_trunc(), + vec![fmt_expr, ts_expr], + ), + ))) + } +} diff --git a/datafusion/spark/src/function/datetime/from_utc_timestamp.rs b/datafusion/spark/src/function/datetime/from_utc_timestamp.rs new file mode 100644 index 0000000000000..77e033af98a5c --- /dev/null +++ b/datafusion/spark/src/function/datetime/from_utc_timestamp.rs @@ -0,0 +1,195 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::timezone::Tz; +use arrow::array::{Array, ArrayRef, AsArray, PrimitiveBuilder, StringArrayType}; +use arrow::datatypes::TimeUnit; +use arrow::datatypes::{ + ArrowTimestampType, DataType, Field, FieldRef, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, +}; +use datafusion_common::types::{NativeType, logical_string}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, exec_datafusion_err, exec_err, internal_err}; +use datafusion_expr::{ + Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, + Signature, TypeSignatureClass, Volatility, +}; +use datafusion_functions::datetime::common::adjust_to_local_time; +use datafusion_functions::utils::make_scalar_function; + +/// Apache Spark `from_utc_timestamp` function. +/// +/// Interprets the given timestamp as UTC and converts it to the given timezone. +/// +/// Timestamp in Apache Spark represents number of microseconds from the Unix epoch, which is not +/// timezone-agnostic. So in Apache Spark this function just shift the timestamp value from UTC timezone to +/// the given timezone. +/// +/// See +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkFromUtcTimestamp { + signature: Signature, +} + +impl Default for SparkFromUtcTimestamp { + fn default() -> Self { + Self::new() + } +} + +impl SparkFromUtcTimestamp { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![ + Coercion::new_implicit( + TypeSignatureClass::Timestamp, + vec![TypeSignatureClass::Native(logical_string())], + NativeType::Timestamp(TimeUnit::Microsecond, None), + ), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkFromUtcTimestamp { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "from_utc_timestamp" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + + Ok(Arc::new(Field::new( + self.name(), + args.arg_fields[0].data_type().clone(), + nullable, + ))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_from_utc_timestamp, vec![])(&args.args) + } +} + +fn spark_from_utc_timestamp(args: &[ArrayRef]) -> Result { + let [timestamp, timezone] = take_function_args("from_utc_timestamp", args)?; + + match timestamp.data_type() { + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { + process_timestamp_with_tz_array::( + timestamp, + timezone, + tz_opt.clone(), + ) + } + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { + process_timestamp_with_tz_array::( + timestamp, + timezone, + tz_opt.clone(), + ) + } + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { + process_timestamp_with_tz_array::( + timestamp, + timezone, + tz_opt.clone(), + ) + } + DataType::Timestamp(TimeUnit::Second, tz_opt) => { + process_timestamp_with_tz_array::( + timestamp, + timezone, + tz_opt.clone(), + ) + } + ts_type => { + exec_err!("`from_utc_timestamp`: unsupported argument types: {ts_type}") + } + } +} + +fn process_timestamp_with_tz_array( + ts_array: &ArrayRef, + tz_array: &ArrayRef, + tz_opt: Option>, +) -> Result { + match tz_array.data_type() { + DataType::Utf8 => { + process_arrays::(tz_opt, ts_array, tz_array.as_string::()) + } + DataType::LargeUtf8 => { + process_arrays::(tz_opt, ts_array, tz_array.as_string::()) + } + DataType::Utf8View => { + process_arrays::(tz_opt, ts_array, tz_array.as_string_view()) + } + other => { + exec_err!("`from_utc_timestamp`: timezone must be a string type, got {other}") + } + } +} + +fn process_arrays<'a, T: ArrowTimestampType, S>( + return_tz_opt: Option>, + ts_array: &ArrayRef, + tz_array: &'a S, +) -> Result +where + &'a S: StringArrayType<'a>, +{ + let ts_primitive = ts_array.as_primitive::(); + let mut builder = PrimitiveBuilder::::with_capacity(ts_array.len()); + + for (ts_opt, tz_opt) in ts_primitive.iter().zip(tz_array.iter()) { + match (ts_opt, tz_opt) { + (Some(ts), Some(tz_str)) => { + let tz: Tz = tz_str.parse().map_err(|e| { + exec_datafusion_err!( + "`from_utc_timestamp`: invalid timezone '{tz_str}': {e}" + ) + })?; + let val = adjust_to_local_time::(ts, tz)?; + builder.append_value(val); + } + _ => builder.append_null(), + } + } + + builder = builder.with_timezone_opt(return_tz_opt); + Ok(Arc::new(builder.finish())) +} diff --git a/datafusion/spark/src/function/datetime/mod.rs b/datafusion/spark/src/function/datetime/mod.rs index 849aa20895990..3133ed7337f25 100644 --- a/datafusion/spark/src/function/datetime/mod.rs +++ b/datafusion/spark/src/function/datetime/mod.rs @@ -15,20 +15,37 @@ // specific language governing permissions and limitations // under the License. +pub mod add_months; pub mod date_add; +pub mod date_diff; +pub mod date_part; pub mod date_sub; +pub mod date_trunc; pub mod extract; +pub mod from_utc_timestamp; pub mod last_day; pub mod make_dt_interval; pub mod make_interval; pub mod next_day; +pub mod time_trunc; +pub mod to_utc_timestamp; +pub mod trunc; +pub mod unix; use datafusion_expr::ScalarUDF; use datafusion_functions::make_udf_function; use std::sync::Arc; +make_udf_function!(add_months::SparkAddMonths, add_months); make_udf_function!(date_add::SparkDateAdd, date_add); +make_udf_function!(date_diff::SparkDateDiff, date_diff); +make_udf_function!(date_part::SparkDatePart, date_part); make_udf_function!(date_sub::SparkDateSub, date_sub); +make_udf_function!(date_trunc::SparkDateTrunc, date_trunc); +make_udf_function!( + from_utc_timestamp::SparkFromUtcTimestamp, + from_utc_timestamp +); make_udf_function!(extract::SparkHour, hour); make_udf_function!(extract::SparkMinute, minute); make_udf_function!(extract::SparkSecond, second); @@ -36,10 +53,34 @@ make_udf_function!(last_day::SparkLastDay, last_day); make_udf_function!(make_dt_interval::SparkMakeDtInterval, make_dt_interval); make_udf_function!(make_interval::SparkMakeInterval, make_interval); make_udf_function!(next_day::SparkNextDay, next_day); +make_udf_function!(time_trunc::SparkTimeTrunc, time_trunc); +make_udf_function!(to_utc_timestamp::SparkToUtcTimestamp, to_utc_timestamp); +make_udf_function!(trunc::SparkTrunc, trunc); +make_udf_function!(unix::SparkUnixDate, unix_date); +make_udf_function!( + unix::SparkUnixTimestamp, + unix_micros, + unix::SparkUnixTimestamp::microseconds +); +make_udf_function!( + unix::SparkUnixTimestamp, + unix_millis, + unix::SparkUnixTimestamp::milliseconds +); +make_udf_function!( + unix::SparkUnixTimestamp, + unix_seconds, + unix::SparkUnixTimestamp::seconds +); pub mod expr_fn { use datafusion_functions::export_functions; + export_functions!(( + add_months, + "Returns the date that is months months after start. The function returns NULL if at least one of the input parameters is NULL.", + arg1 arg2 + )); export_functions!(( date_add, "Returns the date that is days days after start. The function returns NULL if at least one of the input parameters is NULL.", @@ -83,18 +124,85 @@ pub mod expr_fn { "Returns the first date which is later than start_date and named as indicated. The function returns NULL if at least one of the input parameters is NULL.", arg1 arg2 )); + export_functions!(( + date_diff, + "Returns the number of days from start `start` to end `end`.", + end start + )); + export_functions!(( + date_trunc, + "Truncates a timestamp `ts` to the unit specified by the format `fmt`.", + fmt ts + )); + export_functions!(( + time_trunc, + "Truncates a time `t` to the unit specified by the format `fmt`.", + fmt t + )); + export_functions!(( + trunc, + "Truncates a date `dt` to the unit specified by the format `fmt`.", + dt fmt + )); + export_functions!(( + date_part, + "Extracts a part of the date or time from a date, time, or timestamp expression.", + arg1 arg2 + )); + export_functions!(( + from_utc_timestamp, + "Interpret a given timestamp `ts` in UTC timezone and then convert it to timezone `tz`.", + ts tz + )); + export_functions!(( + to_utc_timestamp, + "Interpret a given timestamp `ts` in timezone `tz` and then convert it to UTC timezone.", + ts tz + )); + export_functions!(( + unix_date, + "Returns the number of days since epoch (1970-01-01) for the given date `dt`.", + dt + )); + export_functions!(( + unix_micros, + "Returns the number of microseconds since epoch (1970-01-01 00:00:00 UTC) for the given timestamp `ts`.", + ts + )); + export_functions!(( + unix_millis, + "Returns the number of milliseconds since epoch (1970-01-01 00:00:00 UTC) for the given timestamp `ts`.", + ts + )); + export_functions!(( + unix_seconds, + "Returns the number of seconds since epoch (1970-01-01 00:00:00 UTC) for the given timestamp `ts`.", + ts + )); } pub fn functions() -> Vec> { vec![ + add_months(), date_add(), + date_diff(), + date_part(), date_sub(), + date_trunc(), + from_utc_timestamp(), hour(), - minute(), - second(), last_day(), make_dt_interval(), make_interval(), + minute(), next_day(), + second(), + time_trunc(), + to_utc_timestamp(), + trunc(), + unix_date(), + unix_micros(), + unix_millis(), + unix_seconds(), ] } diff --git a/datafusion/spark/src/function/datetime/time_trunc.rs b/datafusion/spark/src/function/datetime/time_trunc.rs new file mode 100644 index 0000000000000..718502a05ee6d --- /dev/null +++ b/datafusion/spark/src/function/datetime/time_trunc.rs @@ -0,0 +1,122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::types::logical_string; +use datafusion_common::{Result, ScalarValue, internal_err, plan_err}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + Coercion, ColumnarValue, Expr, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, + Signature, TypeSignatureClass, Volatility, +}; + +/// Spark time_trunc function only handles time inputs. +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkTimeTrunc { + signature: Signature, +} + +impl Default for SparkTimeTrunc { + fn default() -> Self { + Self::new() + } +} + +impl SparkTimeTrunc { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Time), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkTimeTrunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "time_trunc" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + + Ok(Arc::new(Field::new( + self.name(), + args.arg_fields[1].data_type().clone(), + nullable, + ))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!( + "spark time_trunc should have been simplified to standard date_trunc" + ) + } + + fn simplify( + &self, + args: Vec, + _info: &SimplifyContext, + ) -> Result { + let fmt_expr = &args[0]; + + let fmt = match fmt_expr.as_literal() { + Some(ScalarValue::Utf8(Some(v))) + | Some(ScalarValue::Utf8View(Some(v))) + | Some(ScalarValue::LargeUtf8(Some(v))) => v.to_lowercase(), + _ => { + return plan_err!( + "First argument of `TIME_TRUNC` must be non-null scalar Utf8" + ); + } + }; + + if !matches!( + fmt.as_str(), + "hour" | "minute" | "second" | "millisecond" | "microsecond" + ) { + return plan_err!( + "The format argument of `TIME_TRUNC` must be one of: hour, minute, second, millisecond, microsecond" + ); + } + + Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction( + ScalarFunction::new_udf(datafusion_functions::datetime::date_trunc(), args), + ))) + } +} diff --git a/datafusion/spark/src/function/datetime/to_utc_timestamp.rs b/datafusion/spark/src/function/datetime/to_utc_timestamp.rs new file mode 100644 index 0000000000000..0e8c267a390e1 --- /dev/null +++ b/datafusion/spark/src/function/datetime/to_utc_timestamp.rs @@ -0,0 +1,225 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::timezone::Tz; +use arrow::array::{Array, ArrayRef, AsArray, PrimitiveBuilder, StringArrayType}; +use arrow::datatypes::TimeUnit; +use arrow::datatypes::{ + ArrowTimestampType, DataType, Field, FieldRef, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, +}; +use chrono::{DateTime, Offset, TimeZone}; +use datafusion_common::types::{NativeType, logical_string}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{ + Result, exec_datafusion_err, exec_err, internal_datafusion_err, internal_err, +}; +use datafusion_expr::{ + Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, + Signature, TypeSignatureClass, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; + +/// Apache Spark `to_utc_timestamp` function. +/// +/// Interprets the given timestamp in the provided timezone and then converts it to UTC. +/// +/// Timestamp in Apache Spark represents number of microseconds from the Unix epoch, which is not +/// timezone-agnostic. So in Apache Spark this function just shift the timestamp value from the given +/// timezone to UTC timezone. +/// +/// See +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkToUtcTimestamp { + signature: Signature, +} + +impl Default for SparkToUtcTimestamp { + fn default() -> Self { + Self::new() + } +} + +impl SparkToUtcTimestamp { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![ + Coercion::new_implicit( + TypeSignatureClass::Timestamp, + vec![TypeSignatureClass::Native(logical_string())], + NativeType::Timestamp(TimeUnit::Microsecond, None), + ), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkToUtcTimestamp { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "to_utc_timestamp" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + + Ok(Arc::new(Field::new( + self.name(), + args.arg_fields[0].data_type().clone(), + nullable, + ))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(to_utc_timestamp, vec![])(&args.args) + } +} + +fn to_utc_timestamp(args: &[ArrayRef]) -> Result { + let [timestamp, timezone] = take_function_args("to_utc_timestamp", args)?; + + match timestamp.data_type() { + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { + process_timestamp_with_tz_array::( + timestamp, + timezone, + tz_opt.clone(), + ) + } + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { + process_timestamp_with_tz_array::( + timestamp, + timezone, + tz_opt.clone(), + ) + } + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { + process_timestamp_with_tz_array::( + timestamp, + timezone, + tz_opt.clone(), + ) + } + DataType::Timestamp(TimeUnit::Second, tz_opt) => { + process_timestamp_with_tz_array::( + timestamp, + timezone, + tz_opt.clone(), + ) + } + ts_type => { + exec_err!("`to_utc_timestamp`: unsupported argument types: {ts_type}") + } + } +} + +fn process_timestamp_with_tz_array( + ts_array: &ArrayRef, + tz_array: &ArrayRef, + tz_opt: Option>, +) -> Result { + match tz_array.data_type() { + DataType::Utf8 => { + process_arrays::(tz_opt, ts_array, tz_array.as_string::()) + } + DataType::LargeUtf8 => { + process_arrays::(tz_opt, ts_array, tz_array.as_string::()) + } + DataType::Utf8View => { + process_arrays::(tz_opt, ts_array, tz_array.as_string_view()) + } + other => { + exec_err!("`to_utc_timestamp`: timezone must be a string type, got {other}") + } + } +} + +fn process_arrays<'a, T: ArrowTimestampType, S>( + return_tz_opt: Option>, + ts_array: &ArrayRef, + tz_array: &'a S, +) -> Result +where + &'a S: StringArrayType<'a>, +{ + let ts_primitive = ts_array.as_primitive::(); + let mut builder = PrimitiveBuilder::::with_capacity(ts_array.len()); + + for (ts_opt, tz_opt) in ts_primitive.iter().zip(tz_array.iter()) { + match (ts_opt, tz_opt) { + (Some(ts), Some(tz_str)) => { + let tz: Tz = tz_str.parse().map_err(|e| { + exec_datafusion_err!( + "`to_utc_timestamp`: invalid timezone '{tz_str}': {e}" + ) + })?; + let val = adjust_to_utc_time::(ts, tz)?; + builder.append_value(val); + } + _ => builder.append_null(), + } + } + + builder = builder.with_timezone_opt(return_tz_opt); + Ok(Arc::new(builder.finish())) +} + +fn adjust_to_utc_time(ts: i64, tz: Tz) -> Result { + let dt = match T::UNIT { + TimeUnit::Nanosecond => Some(DateTime::from_timestamp_nanos(ts)), + TimeUnit::Microsecond => DateTime::from_timestamp_micros(ts), + TimeUnit::Millisecond => DateTime::from_timestamp_millis(ts), + TimeUnit::Second => DateTime::from_timestamp(ts, 0), + } + .ok_or_else(|| internal_datafusion_err!("Invalid timestamp"))?; + let naive_dt = dt.naive_utc(); + + let offset_seconds = tz + .offset_from_utc_datetime(&naive_dt) + .fix() + .local_minus_utc() as i64; + + let offset_in_unit = match T::UNIT { + TimeUnit::Nanosecond => offset_seconds.checked_mul(1_000_000_000), + TimeUnit::Microsecond => offset_seconds.checked_mul(1_000_000), + TimeUnit::Millisecond => offset_seconds.checked_mul(1_000), + TimeUnit::Second => Some(offset_seconds), + } + .ok_or_else(|| internal_datafusion_err!("Offset overflow"))?; + + ts.checked_sub(offset_in_unit).ok_or_else(|| { + internal_datafusion_err!("Timestamp overflow during timezone adjustment") + }) +} diff --git a/datafusion/spark/src/function/datetime/trunc.rs b/datafusion/spark/src/function/datetime/trunc.rs new file mode 100644 index 0000000000000..b584cc9a70d44 --- /dev/null +++ b/datafusion/spark/src/function/datetime/trunc.rs @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, FieldRef, TimeUnit}; +use datafusion_common::types::{NativeType, logical_date, logical_string}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, internal_err, plan_err}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + Coercion, ColumnarValue, Expr, ExprSchemable, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, +}; + +/// Spark trunc supports date inputs only and extra format aliases. +/// Also spark trunc's argument order is (date, format). +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkTrunc { + signature: Signature, +} + +impl Default for SparkTrunc { + fn default() -> Self { + Self::new() + } +} + +impl SparkTrunc { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![ + Coercion::new_implicit( + TypeSignatureClass::Native(logical_date()), + vec![TypeSignatureClass::Native(logical_string())], + NativeType::Date, + ), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkTrunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "trunc" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + + Ok(Arc::new(Field::new( + self.name(), + args.arg_fields[0].data_type().clone(), + nullable, + ))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("spark trunc should have been simplified to standard date_trunc") + } + + fn simplify( + &self, + args: Vec, + info: &SimplifyContext, + ) -> Result { + let [dt_expr, fmt_expr] = take_function_args(self.name(), args)?; + + let fmt = match fmt_expr.as_literal() { + Some(ScalarValue::Utf8(Some(v))) + | Some(ScalarValue::Utf8View(Some(v))) + | Some(ScalarValue::LargeUtf8(Some(v))) => v.to_lowercase(), + _ => { + return plan_err!( + "Second argument of `TRUNC` must be non-null scalar Utf8" + ); + } + }; + + // Map Spark-specific fmt aliases to datafusion ones + let fmt = match fmt.as_str() { + "yy" | "yyyy" => "year", + "mm" | "mon" => "month", + "year" | "month" | "day" | "week" | "quarter" => fmt.as_str(), + _ => { + return plan_err!( + "The format argument of `TRUNC` must be one of: year, yy, yyyy, month, mm, mon, day, week, quarter." + ); + } + }; + let return_type = dt_expr.get_type(info.schema())?; + + let fmt_expr = Expr::Literal(ScalarValue::new_utf8(fmt), None); + + // Spark uses Dates so we need to cast to timestamp and back to work with datafusion's date_trunc + Ok(ExprSimplifyResult::Simplified( + Expr::ScalarFunction(ScalarFunction::new_udf( + datafusion_functions::datetime::date_trunc(), + vec![ + fmt_expr, + dt_expr.cast_to( + &DataType::Timestamp(TimeUnit::Nanosecond, None), + info.schema(), + )?, + ], + )) + .cast_to(&return_type, info.schema())?, + )) + } +} diff --git a/datafusion/spark/src/function/datetime/unix.rs b/datafusion/spark/src/function/datetime/unix.rs new file mode 100644 index 0000000000000..4254b2ed85d58 --- /dev/null +++ b/datafusion/spark/src/function/datetime/unix.rs @@ -0,0 +1,174 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, FieldRef, TimeUnit}; +use datafusion_common::types::logical_date; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, internal_err}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + Coercion, ColumnarValue, Expr, ExprSchemable, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, +}; + +/// Returns the number of days since epoch (1970-01-01) for the given date. +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkUnixDate { + signature: Signature, +} + +impl Default for SparkUnixDate { + fn default() -> Self { + Self::new() + } +} + +impl SparkUnixDate { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![Coercion::new_exact(TypeSignatureClass::Native( + logical_date(), + ))], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkUnixDate { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "unix_date" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields[0].is_nullable(); + Ok(Arc::new(Field::new(self.name(), DataType::Int32, nullable))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("invoke_with_args should not be called on SparkUnixDate") + } + + fn simplify( + &self, + args: Vec, + info: &SimplifyContext, + ) -> Result { + let [date] = take_function_args(self.name(), args)?; + Ok(ExprSimplifyResult::Simplified( + date.cast_to(&DataType::Date32, info.schema())? + .cast_to(&DataType::Int32, info.schema())?, + )) + } +} + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkUnixTimestamp { + time_unit: TimeUnit, + signature: Signature, + name: &'static str, +} + +impl SparkUnixTimestamp { + pub fn new(name: &'static str, time_unit: TimeUnit) -> Self { + Self { + signature: Signature::coercible( + vec![Coercion::new_exact(TypeSignatureClass::Timestamp)], + Volatility::Immutable, + ), + time_unit, + name, + } + } + + /// Returns the number of microseconds since epoch (1970-01-01 00:00:00 UTC) for the given timestamp. + /// + pub fn microseconds() -> Self { + Self::new("unix_micros", TimeUnit::Microsecond) + } + + /// Returns the number of milliseconds since epoch (1970-01-01 00:00:00 UTC) for the given timestamp. + /// + pub fn milliseconds() -> Self { + Self::new("unix_millis", TimeUnit::Millisecond) + } + + /// Returns the number of seconds since epoch (1970-01-01 00:00:00 UTC) for the given timestamp. + /// + pub fn seconds() -> Self { + Self::new("unix_seconds", TimeUnit::Second) + } +} + +impl ScalarUDFImpl for SparkUnixTimestamp { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields[0].is_nullable(); + Ok(Arc::new(Field::new(self.name(), DataType::Int64, nullable))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("invoke_with_args should not be called on `{}`", self.name()) + } + + fn simplify( + &self, + args: Vec, + info: &SimplifyContext, + ) -> Result { + let [ts] = take_function_args(self.name(), args)?; + Ok(ExprSimplifyResult::Simplified( + ts.cast_to( + &DataType::Timestamp(self.time_unit, Some("UTC".into())), + info.schema(), + )? + .cast_to(&DataType::Int64, info.schema())?, + )) + } +} diff --git a/datafusion/spark/src/function/hash/sha2.rs b/datafusion/spark/src/function/hash/sha2.rs index 1f17275062778..2f01854d37324 100644 --- a/datafusion/spark/src/function/hash/sha2.rs +++ b/datafusion/spark/src/function/hash/sha2.rs @@ -15,26 +15,31 @@ // specific language governing permissions and limitations // under the License. -extern crate datafusion_functions; - -use crate::function::error_utils::{ - invalid_arg_count_exec_err, unsupported_data_type_exec_err, +use arrow::array::{ + ArrayRef, AsArray, BinaryArrayType, Int32Array, StringArray, new_null_array, }; -use crate::function::math::hex::spark_sha2_hex; -use arrow::array::{ArrayRef, AsArray, StringArray}; use arrow::datatypes::{DataType, Int32Type}; -use datafusion_common::{Result, ScalarValue, exec_err, internal_datafusion_err}; -use datafusion_expr::Signature; -use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility}; -pub use datafusion_functions::crypto::basic::{sha224, sha256, sha384, sha512}; +use datafusion_common::types::{ + NativeType, logical_binary, logical_int32, logical_string, +}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, internal_err}; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignatureClass, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; +use sha2::{self, Digest}; use std::any::Any; use std::sync::Arc; +/// Differs from DataFusion version in allowing array input for bit lengths, and +/// also hex encoding the output. +/// /// #[derive(Debug, PartialEq, Eq, Hash)] pub struct SparkSha2 { signature: Signature, - aliases: Vec, } impl Default for SparkSha2 { @@ -46,8 +51,21 @@ impl Default for SparkSha2 { impl SparkSha2 { pub fn new() -> Self { Self { - signature: Signature::user_defined(Volatility::Immutable), - aliases: vec![], + signature: Signature::coercible( + vec![ + Coercion::new_implicit( + TypeSignatureClass::Native(logical_binary()), + vec![TypeSignatureClass::Native(logical_string())], + NativeType::Binary, + ), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_int32()), + vec![TypeSignatureClass::Integer], + NativeType::Int32, + ), + ], + Volatility::Immutable, + ), } } } @@ -65,163 +83,191 @@ impl ScalarUDFImpl for SparkSha2 { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types[1].is_null() { - return Ok(DataType::Null); - } - Ok(match arg_types[0] { - DataType::Utf8View - | DataType::LargeUtf8 - | DataType::Utf8 - | DataType::Binary - | DataType::BinaryView - | DataType::LargeBinary => DataType::Utf8, - DataType::Null => DataType::Null, - _ => { - return exec_err!( - "{} function can only accept strings or binary arrays.", - self.name() - ); - } - }) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let args: [ColumnarValue; 2] = args.args.try_into().map_err(|_| { - internal_datafusion_err!("Expected 2 arguments for function sha2") - })?; + let [values, bit_lengths] = take_function_args(self.name(), args.args.iter())?; - sha2(args) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } + match (values, bit_lengths) { + ( + ColumnarValue::Scalar(value_scalar), + ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length))), + ) => { + if value_scalar.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() != 2 { - return Err(invalid_arg_count_exec_err( - self.name(), - (2, 2), - arg_types.len(), - )); - } - let expr_type = match &arg_types[0] { - DataType::Utf8View - | DataType::LargeUtf8 - | DataType::Utf8 - | DataType::Binary - | DataType::BinaryView - | DataType::LargeBinary - | DataType::Null => Ok(arg_types[0].clone()), - _ => Err(unsupported_data_type_exec_err( - self.name(), - "String, Binary", - &arg_types[0], - )), - }?; - let bit_length_type = if arg_types[1].is_numeric() { - Ok(DataType::Int32) - } else if arg_types[1].is_null() { - Ok(DataType::Null) - } else { - Err(unsupported_data_type_exec_err( - self.name(), - "Numeric Type", - &arg_types[1], - )) - }?; - - Ok(vec![expr_type, bit_length_type]) - } -} + // Accept both Binary and Utf8 scalars (depending on coercion) + let bytes = match value_scalar { + ScalarValue::Binary(Some(b)) => b.as_slice(), + ScalarValue::LargeBinary(Some(b)) => b.as_slice(), + ScalarValue::BinaryView(Some(b)) => b.as_slice(), + ScalarValue::Utf8(Some(s)) + | ScalarValue::LargeUtf8(Some(s)) + | ScalarValue::Utf8View(Some(s)) => s.as_bytes(), + other => { + return internal_err!( + "Unsupported scalar datatype for sha2: {}", + other.data_type() + ); + } + }; -pub fn sha2(args: [ColumnarValue; 2]) -> Result { - match args { - [ - ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)), - ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg))), - ] => compute_sha2( - bit_length_arg, - &[ColumnarValue::from(ScalarValue::Utf8(expr_arg))], - ), - [ - ColumnarValue::Array(expr_arg), - ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg))), - ] => compute_sha2(bit_length_arg, &[ColumnarValue::from(expr_arg)]), - [ - ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)), - ColumnarValue::Array(bit_length_arg), - ] => { - let arr: StringArray = bit_length_arg - .as_primitive::() - .iter() - .map(|bit_length| { - match sha2([ - ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg.clone())), - ColumnarValue::Scalar(ScalarValue::Int32(bit_length)), - ]) - .unwrap() - { - ColumnarValue::Scalar(ScalarValue::Utf8(str)) => str, - ColumnarValue::Array(arr) => arr - .as_string::() - .iter() - .map(|str| str.unwrap().to_string()) - .next(), // first element - _ => unreachable!(), + let out = match bit_length { + 224 => { + let mut digest = sha2::Sha224::default(); + digest.update(bytes); + Some(hex_encode(digest.finalize())) } - }) - .collect(); - Ok(ColumnarValue::Array(Arc::new(arr) as ArrayRef)) - } - [ - ColumnarValue::Array(expr_arg), - ColumnarValue::Array(bit_length_arg), - ] => { - let expr_iter = expr_arg.as_string::().iter(); - let bit_length_iter = bit_length_arg.as_primitive::().iter(); - let arr: StringArray = expr_iter - .zip(bit_length_iter) - .map(|(expr, bit_length)| { - match sha2([ - ColumnarValue::Scalar(ScalarValue::Utf8(Some( - expr.unwrap().to_string(), - ))), - ColumnarValue::Scalar(ScalarValue::Int32(bit_length)), - ]) - .unwrap() - { - ColumnarValue::Scalar(ScalarValue::Utf8(str)) => str, - ColumnarValue::Array(arr) => arr - .as_string::() - .iter() - .map(|str| str.unwrap().to_string()) - .next(), // first element - _ => unreachable!(), + 0 | 256 => { + let mut digest = sha2::Sha256::default(); + digest.update(bytes); + Some(hex_encode(digest.finalize())) } - }) - .collect(); - Ok(ColumnarValue::Array(Arc::new(arr) as ArrayRef)) + 384 => { + let mut digest = sha2::Sha384::default(); + digest.update(bytes); + Some(hex_encode(digest.finalize())) + } + 512 => { + let mut digest = sha2::Sha512::default(); + digest.update(bytes); + Some(hex_encode(digest.finalize())) + } + _ => None, + }; + + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(out))) + } + // Array values + scalar bit length (common case: sha2(col, 256)) + ( + ColumnarValue::Array(values_array), + ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length))), + ) => { + let output: ArrayRef = match values_array.data_type() { + DataType::Binary => sha2_binary_scalar_bitlen( + &values_array.as_binary::(), + *bit_length, + ), + DataType::LargeBinary => sha2_binary_scalar_bitlen( + &values_array.as_binary::(), + *bit_length, + ), + DataType::BinaryView => sha2_binary_scalar_bitlen( + &values_array.as_binary_view(), + *bit_length, + ), + dt => return internal_err!("Unsupported datatype for sha2: {dt}"), + }; + Ok(ColumnarValue::Array(output)) + } + ( + ColumnarValue::Scalar(_), + ColumnarValue::Scalar(ScalarValue::Int32(None)), + ) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), + ( + ColumnarValue::Array(_), + ColumnarValue::Scalar(ScalarValue::Int32(None)), + ) => Ok(ColumnarValue::Array(new_null_array( + &DataType::Utf8, + args.number_rows, + ))), + _ => { + // Fallback to existing behavior for any array/mixed cases + make_scalar_function(sha2_impl, vec![])(&args.args) + } } - _ => exec_err!("Unsupported argument types for sha2 function"), } } -fn compute_sha2( - bit_length_arg: i32, - expr_arg: &[ColumnarValue], -) -> Result { - match bit_length_arg { - 0 | 256 => sha256(expr_arg), - 224 => sha224(expr_arg), - 384 => sha384(expr_arg), - 512 => sha512(expr_arg), - _ => { - // Return null for unsupported bit lengths instead of error, because spark sha2 does not - // error out for this. - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); +fn sha2_impl(args: &[ArrayRef]) -> Result { + let [values, bit_lengths] = take_function_args("sha2", args)?; + + let bit_lengths = bit_lengths.as_primitive::(); + let output = match values.data_type() { + DataType::Binary => sha2_binary_impl(&values.as_binary::(), bit_lengths), + DataType::LargeBinary => { + sha2_binary_impl(&values.as_binary::(), bit_lengths) } + DataType::BinaryView => sha2_binary_impl(&values.as_binary_view(), bit_lengths), + dt => return internal_err!("Unsupported datatype for sha2: {dt}"), + }; + Ok(output) +} + +fn sha2_binary_impl<'a, BinaryArrType>( + values: &BinaryArrType, + bit_lengths: &Int32Array, +) -> ArrayRef +where + BinaryArrType: BinaryArrayType<'a>, +{ + sha2_binary_bitlen_iter(values, bit_lengths.iter()) +} + +fn sha2_binary_scalar_bitlen<'a, BinaryArrType>( + values: &BinaryArrType, + bit_length: i32, +) -> ArrayRef +where + BinaryArrType: BinaryArrayType<'a>, +{ + sha2_binary_bitlen_iter(values, std::iter::repeat(Some(bit_length))) +} + +fn sha2_binary_bitlen_iter<'a, BinaryArrType, I>( + values: &BinaryArrType, + bit_lengths: I, +) -> ArrayRef +where + BinaryArrType: BinaryArrayType<'a>, + I: Iterator>, +{ + let array = values + .iter() + .zip(bit_lengths) + .map(|(value, bit_length)| match (value, bit_length) { + (Some(value), Some(224)) => { + let mut digest = sha2::Sha224::default(); + digest.update(value); + Some(hex_encode(digest.finalize())) + } + (Some(value), Some(0 | 256)) => { + let mut digest = sha2::Sha256::default(); + digest.update(value); + Some(hex_encode(digest.finalize())) + } + (Some(value), Some(384)) => { + let mut digest = sha2::Sha384::default(); + digest.update(value); + Some(hex_encode(digest.finalize())) + } + (Some(value), Some(512)) => { + let mut digest = sha2::Sha512::default(); + digest.update(value); + Some(hex_encode(digest.finalize())) + } + // Unknown bit-lengths go to null, same as in Spark + _ => None, + }) + .collect::(); + Arc::new(array) +} + +const HEX_CHARS: [u8; 16] = *b"0123456789abcdef"; + +#[inline] +fn hex_encode>(data: T) -> String { + let bytes = data.as_ref(); + let mut out = Vec::with_capacity(bytes.len() * 2); + for &b in bytes { + let hi = b >> 4; + let lo = b & 0x0F; + out.push(HEX_CHARS[hi as usize]); + out.push(HEX_CHARS[lo as usize]); } - .map(|hashed| spark_sha2_hex(&[hashed]).unwrap()) + // SAFETY: out contains only ASCII + unsafe { String::from_utf8_unchecked(out) } } diff --git a/datafusion/spark/src/function/map/mod.rs b/datafusion/spark/src/function/map/mod.rs index 2f596b19b422f..c9ebed6f612e1 100644 --- a/datafusion/spark/src/function/map/mod.rs +++ b/datafusion/spark/src/function/map/mod.rs @@ -17,6 +17,7 @@ pub mod map_from_arrays; pub mod map_from_entries; +pub mod str_to_map; mod utils; use datafusion_expr::ScalarUDF; @@ -25,6 +26,7 @@ use std::sync::Arc; make_udf_function!(map_from_arrays::MapFromArrays, map_from_arrays); make_udf_function!(map_from_entries::MapFromEntries, map_from_entries); +make_udf_function!(str_to_map::SparkStrToMap, str_to_map); pub mod expr_fn { use datafusion_functions::export_functions; @@ -40,8 +42,14 @@ pub mod expr_fn { "Creates a map from array>.", arg1 )); + + export_functions!(( + str_to_map, + "Creates a map after splitting the text into key/value pairs using delimiters.", + text pair_delim key_value_delim + )); } pub fn functions() -> Vec> { - vec![map_from_arrays(), map_from_entries()] + vec![map_from_arrays(), map_from_entries(), str_to_map()] } diff --git a/datafusion/spark/src/function/map/str_to_map.rs b/datafusion/spark/src/function/map/str_to_map.rs new file mode 100644 index 0000000000000..b722fb7abd6b2 --- /dev/null +++ b/datafusion/spark/src/function/map/str_to_map.rs @@ -0,0 +1,266 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::collections::HashSet; +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayRef, MapBuilder, MapFieldNames, StringArrayType, StringBuilder, +}; +use arrow::buffer::NullBuffer; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::cast::{ + as_large_string_array, as_string_array, as_string_view_array, +}; +use datafusion_common::{Result, exec_err, internal_err}; +use datafusion_expr::{ + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignature, Volatility, +}; + +use crate::function::map::utils::map_type_from_key_value_types; + +const DEFAULT_PAIR_DELIM: &str = ","; +const DEFAULT_KV_DELIM: &str = ":"; + +/// Spark-compatible `str_to_map` expression +/// +/// +/// Creates a map from a string by splitting on delimiters. +/// str_to_map(text[, pairDelim[, keyValueDelim]]) -> Map +/// +/// - text: The input string +/// - pairDelim: Delimiter between key-value pairs (default: ',') +/// - keyValueDelim: Delimiter between key and value (default: ':') +/// +/// # Duplicate Key Handling +/// Uses EXCEPTION behavior (Spark 3.0+ default): errors on duplicate keys. +/// See `spark.sql.mapKeyDedupPolicy`: +/// +/// +/// TODO: Support configurable `spark.sql.mapKeyDedupPolicy` (LAST_WIN) in a follow-up PR. +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkStrToMap { + signature: Signature, +} + +impl Default for SparkStrToMap { + fn default() -> Self { + Self::new() + } +} + +impl SparkStrToMap { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + // str_to_map(text) + TypeSignature::String(1), + // str_to_map(text, pairDelim) + TypeSignature::String(2), + // str_to_map(text, pairDelim, keyValueDelim) + TypeSignature::String(3), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkStrToMap { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "str_to_map" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + let map_type = map_type_from_key_value_types(&DataType::Utf8, &DataType::Utf8); + Ok(Arc::new(Field::new(self.name(), map_type, nullable))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let arrays: Vec = ColumnarValue::values_to_arrays(&args.args)?; + let result = str_to_map_inner(&arrays)?; + Ok(ColumnarValue::Array(result)) + } +} + +fn str_to_map_inner(args: &[ArrayRef]) -> Result { + match args.len() { + 1 => match args[0].data_type() { + DataType::Utf8 => str_to_map_impl(as_string_array(&args[0])?, None, None), + DataType::LargeUtf8 => { + str_to_map_impl(as_large_string_array(&args[0])?, None, None) + } + DataType::Utf8View => { + str_to_map_impl(as_string_view_array(&args[0])?, None, None) + } + other => exec_err!( + "Unsupported data type {other:?} for str_to_map, \ + expected Utf8, LargeUtf8, or Utf8View" + ), + }, + 2 => match (args[0].data_type(), args[1].data_type()) { + (DataType::Utf8, DataType::Utf8) => str_to_map_impl( + as_string_array(&args[0])?, + Some(as_string_array(&args[1])?), + None, + ), + (DataType::LargeUtf8, DataType::LargeUtf8) => str_to_map_impl( + as_large_string_array(&args[0])?, + Some(as_large_string_array(&args[1])?), + None, + ), + (DataType::Utf8View, DataType::Utf8View) => str_to_map_impl( + as_string_view_array(&args[0])?, + Some(as_string_view_array(&args[1])?), + None, + ), + (t1, t2) => exec_err!( + "Unsupported data types ({t1:?}, {t2:?}) for str_to_map, \ + expected matching Utf8, LargeUtf8, or Utf8View" + ), + }, + 3 => match ( + args[0].data_type(), + args[1].data_type(), + args[2].data_type(), + ) { + (DataType::Utf8, DataType::Utf8, DataType::Utf8) => str_to_map_impl( + as_string_array(&args[0])?, + Some(as_string_array(&args[1])?), + Some(as_string_array(&args[2])?), + ), + (DataType::LargeUtf8, DataType::LargeUtf8, DataType::LargeUtf8) => { + str_to_map_impl( + as_large_string_array(&args[0])?, + Some(as_large_string_array(&args[1])?), + Some(as_large_string_array(&args[2])?), + ) + } + (DataType::Utf8View, DataType::Utf8View, DataType::Utf8View) => { + str_to_map_impl( + as_string_view_array(&args[0])?, + Some(as_string_view_array(&args[1])?), + Some(as_string_view_array(&args[2])?), + ) + } + (t1, t2, t3) => exec_err!( + "Unsupported data types ({t1:?}, {t2:?}, {t3:?}) for str_to_map, \ + expected matching Utf8, LargeUtf8, or Utf8View" + ), + }, + n => exec_err!("str_to_map expects 1-3 arguments, got {n}"), + } +} + +fn str_to_map_impl<'a, V: StringArrayType<'a> + Copy>( + text_array: V, + pair_delim_array: Option, + kv_delim_array: Option, +) -> Result { + let num_rows = text_array.len(); + + // Precompute combined null buffer from all input arrays. + // NullBuffer::union performs a bitmap-level AND, which is more efficient + // than checking per-row nullability inline. + let text_nulls = text_array.nulls().cloned(); + let pair_nulls = pair_delim_array.and_then(|a| a.nulls().cloned()); + let kv_nulls = kv_delim_array.and_then(|a| a.nulls().cloned()); + let combined_nulls = [text_nulls.as_ref(), pair_nulls.as_ref(), kv_nulls.as_ref()] + .into_iter() + .fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls)); + + // Use field names matching map_type_from_key_value_types: "key" and "value" + let field_names = MapFieldNames { + entry: "entries".to_string(), + key: "key".to_string(), + value: "value".to_string(), + }; + let mut map_builder = MapBuilder::new( + Some(field_names), + StringBuilder::new(), + StringBuilder::new(), + ); + + let mut seen_keys = HashSet::new(); + for row_idx in 0..num_rows { + if combined_nulls.as_ref().is_some_and(|n| n.is_null(row_idx)) { + map_builder.append(false)?; + continue; + } + + // Per-row delimiter extraction + let pair_delim = + pair_delim_array.map_or(DEFAULT_PAIR_DELIM, |a| a.value(row_idx)); + let kv_delim = kv_delim_array.map_or(DEFAULT_KV_DELIM, |a| a.value(row_idx)); + + let text = text_array.value(row_idx); + if text.is_empty() { + // Empty string -> map with empty key and NULL value (Spark behavior) + map_builder.keys().append_value(""); + map_builder.values().append_null(); + map_builder.append(true)?; + continue; + } + + seen_keys.clear(); + for pair in text.split(pair_delim) { + if pair.is_empty() { + continue; + } + + let mut kv_iter = pair.splitn(2, kv_delim); + let key = kv_iter.next().unwrap_or(""); + let value = kv_iter.next(); + + // TODO: Support LAST_WIN policy via spark.sql.mapKeyDedupPolicy config + // EXCEPTION policy: error on duplicate keys (Spark 3.0+ default) + if !seen_keys.insert(key) { + return exec_err!( + "Duplicate map key '{key}' was found, please check the input data. \ + If you want to remove the duplicated keys, you can set \ + spark.sql.mapKeyDedupPolicy to \"LAST_WIN\" so that the key \ + inserted at last takes precedence." + ); + } + + map_builder.keys().append_value(key); + match value { + Some(v) => map_builder.values().append_value(v), + None => map_builder.values().append_null(), + } + } + map_builder.append(true)?; + } + + Ok(Arc::new(map_builder.finish())) +} diff --git a/datafusion/spark/src/function/math/abs.rs b/datafusion/spark/src/function/math/abs.rs index 101291ac5f66e..5edb40ae8ae9b 100644 --- a/datafusion/spark/src/function/math/abs.rs +++ b/datafusion/spark/src/function/math/abs.rs @@ -17,13 +17,15 @@ use arrow::array::*; use arrow::datatypes::{DataType, Field, FieldRef}; +use arrow::error::ArrowError; use datafusion_common::{DataFusionError, Result, ScalarValue, internal_err}; use datafusion_expr::{ ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; use datafusion_functions::{ - downcast_named_arg, make_abs_function, make_wrapping_abs_function, + downcast_named_arg, make_abs_function, make_try_abs_function, + make_wrapping_abs_function, }; use std::any::Any; use std::sync::Arc; @@ -34,8 +36,10 @@ use std::sync::Arc; /// Returns the absolute value of input /// Returns NULL if input is NULL, returns NaN if input is NaN. /// -/// TODOs: +/// Differences with DataFusion abs: /// - Spark's ANSI-compliant dialect, when off (i.e. `spark.sql.ansi.enabled=false`), taking absolute value on the minimal value of a signed integer returns the value as is. DataFusion's abs throws "DataFusion error: Arrow error: Compute error" on arithmetic overflow +/// +/// TODOs: /// - Spark's abs also supports ANSI interval types: YearMonthIntervalType and DayTimeIntervalType. DataFusion's abs doesn't. /// #[derive(Debug, PartialEq, Eq, Hash)] @@ -85,19 +89,39 @@ impl ScalarUDFImpl for SparkAbs { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - spark_abs(&args.args) + spark_abs(&args.args, args.config_options.execution.enable_ansi_mode) } } macro_rules! scalar_compute_op { - ($INPUT:ident, $SCALAR_TYPE:ident) => {{ - let result = $INPUT.wrapping_abs(); + ($ENABLE_ANSI_MODE:expr, $INPUT:ident, $SCALAR_TYPE:ident) => {{ + let result = if $ENABLE_ANSI_MODE { + $INPUT.checked_abs().ok_or_else(|| { + ArrowError::ComputeError(format!( + "{} overflow on abs({:?})", + stringify!($SCALAR_TYPE), + $INPUT + )) + })? + } else { + $INPUT.wrapping_abs() + }; Ok(ColumnarValue::Scalar(ScalarValue::$SCALAR_TYPE(Some( result, )))) }}; - ($INPUT:ident, $PRECISION:expr, $SCALE:expr, $SCALAR_TYPE:ident) => {{ - let result = $INPUT.wrapping_abs(); + ($ENABLE_ANSI_MODE:expr, $INPUT:ident, $PRECISION:expr, $SCALE:expr, $SCALAR_TYPE:ident) => {{ + let result = if $ENABLE_ANSI_MODE { + $INPUT.checked_abs().ok_or_else(|| { + ArrowError::ComputeError(format!( + "{} overflow on abs({:?})", + stringify!($SCALAR_TYPE), + $INPUT + )) + })? + } else { + $INPUT.wrapping_abs() + }; Ok(ColumnarValue::Scalar(ScalarValue::$SCALAR_TYPE( Some(result), $PRECISION, @@ -106,7 +130,10 @@ macro_rules! scalar_compute_op { }}; } -pub fn spark_abs(args: &[ColumnarValue]) -> Result { +pub fn spark_abs( + args: &[ColumnarValue], + enable_ansi_mode: bool, +) -> Result { if args.len() != 1 { return internal_err!("abs takes exactly 1 argument, but got: {}", args.len()); } @@ -119,19 +146,35 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result Ok(args[0].clone()), DataType::Int8 => { - let abs_fun = make_wrapping_abs_function!(Int8Array); + let abs_fun = if enable_ansi_mode { + make_try_abs_function!(Int8Array) + } else { + make_wrapping_abs_function!(Int8Array) + }; abs_fun(array).map(ColumnarValue::Array) } DataType::Int16 => { - let abs_fun = make_wrapping_abs_function!(Int16Array); + let abs_fun = if enable_ansi_mode { + make_try_abs_function!(Int16Array) + } else { + make_wrapping_abs_function!(Int16Array) + }; abs_fun(array).map(ColumnarValue::Array) } DataType::Int32 => { - let abs_fun = make_wrapping_abs_function!(Int32Array); + let abs_fun = if enable_ansi_mode { + make_try_abs_function!(Int32Array) + } else { + make_wrapping_abs_function!(Int32Array) + }; abs_fun(array).map(ColumnarValue::Array) } DataType::Int64 => { - let abs_fun = make_wrapping_abs_function!(Int64Array); + let abs_fun = if enable_ansi_mode { + make_try_abs_function!(Int64Array) + } else { + make_wrapping_abs_function!(Int64Array) + }; abs_fun(array).map(ColumnarValue::Array) } DataType::Float32 => { @@ -143,11 +186,19 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result { - let abs_fun = make_wrapping_abs_function!(Decimal128Array); + let abs_fun = if enable_ansi_mode { + make_try_abs_function!(Decimal128Array) + } else { + make_wrapping_abs_function!(Decimal128Array) + }; abs_fun(array).map(ColumnarValue::Array) } DataType::Decimal256(_, _) => { - let abs_fun = make_wrapping_abs_function!(Decimal256Array); + let abs_fun = if enable_ansi_mode { + make_try_abs_function!(Decimal256Array) + } else { + make_wrapping_abs_function!(Decimal256Array) + }; abs_fun(array).map(ColumnarValue::Array) } dt => internal_err!("Not supported datatype for Spark ABS: {dt}"), @@ -159,10 +210,10 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result Ok(args[0].clone()), sv if sv.is_null() => Ok(args[0].clone()), - ScalarValue::Int8(Some(v)) => scalar_compute_op!(v, Int8), - ScalarValue::Int16(Some(v)) => scalar_compute_op!(v, Int16), - ScalarValue::Int32(Some(v)) => scalar_compute_op!(v, Int32), - ScalarValue::Int64(Some(v)) => scalar_compute_op!(v, Int64), + ScalarValue::Int8(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int8), + ScalarValue::Int16(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int16), + ScalarValue::Int32(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int32), + ScalarValue::Int64(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int64), ScalarValue::Float32(Some(v)) => { Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(v.abs())))) } @@ -170,10 +221,10 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result { - scalar_compute_op!(v, *precision, *scale, Decimal128) + scalar_compute_op!(enable_ansi_mode, v, *precision, *scale, Decimal128) } ScalarValue::Decimal256(Some(v), precision, scale) => { - scalar_compute_op!(v, *precision, *scale, Decimal256) + scalar_compute_op!(enable_ansi_mode, v, *precision, *scale, Decimal256) } dt => internal_err!("Not supported datatype for Spark ABS: {dt}"), }, @@ -185,100 +236,12 @@ mod tests { use super::*; use arrow::datatypes::i256; - macro_rules! eval_legacy_mode { - ($TYPE:ident, $VAL:expr) => {{ - let args = ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL))); - match spark_abs(&[args]) { - Ok(ColumnarValue::Scalar(ScalarValue::$TYPE(Some(result)))) => { - assert_eq!(result, $VAL); - } - _ => unreachable!(), - } - }}; - ($TYPE:ident, $VAL:expr, $RESULT:expr) => {{ - let args = ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL))); - match spark_abs(&[args]) { - Ok(ColumnarValue::Scalar(ScalarValue::$TYPE(Some(result)))) => { - assert_eq!(result, $RESULT); - } - _ => unreachable!(), - } - }}; - ($TYPE:ident, $VAL:expr, $PRECISION:expr, $SCALE:expr) => {{ - let args = - ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL), $PRECISION, $SCALE)); - match spark_abs(&[args]) { - Ok(ColumnarValue::Scalar(ScalarValue::$TYPE( - Some(result), - precision, - scale, - ))) => { - assert_eq!(result, $VAL); - assert_eq!(precision, $PRECISION); - assert_eq!(scale, $SCALE); - } - _ => unreachable!(), - } - }}; - ($TYPE:ident, $VAL:expr, $PRECISION:expr, $SCALE:expr, $RESULT:expr) => {{ - let args = - ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL), $PRECISION, $SCALE)); - match spark_abs(&[args]) { - Ok(ColumnarValue::Scalar(ScalarValue::$TYPE( - Some(result), - precision, - scale, - ))) => { - assert_eq!(result, $RESULT); - assert_eq!(precision, $PRECISION); - assert_eq!(scale, $SCALE); - } - _ => unreachable!(), - } - }}; - } - - #[test] - fn test_abs_scalar_legacy_mode() { - // NumericType MIN - eval_legacy_mode!(UInt8, u8::MIN); - eval_legacy_mode!(UInt16, u16::MIN); - eval_legacy_mode!(UInt32, u32::MIN); - eval_legacy_mode!(UInt64, u64::MIN); - eval_legacy_mode!(Int8, i8::MIN); - eval_legacy_mode!(Int16, i16::MIN); - eval_legacy_mode!(Int32, i32::MIN); - eval_legacy_mode!(Int64, i64::MIN); - eval_legacy_mode!(Float32, f32::MIN, f32::MAX); - eval_legacy_mode!(Float64, f64::MIN, f64::MAX); - eval_legacy_mode!(Decimal128, i128::MIN, 18, 10); - eval_legacy_mode!(Decimal256, i256::MIN, 10, 2); - - // NumericType not MIN - eval_legacy_mode!(Int8, -1i8, 1i8); - eval_legacy_mode!(Int16, -1i16, 1i16); - eval_legacy_mode!(Int32, -1i32, 1i32); - eval_legacy_mode!(Int64, -1i64, 1i64); - eval_legacy_mode!(Decimal128, -1i128, 18, 10, 1i128); - eval_legacy_mode!(Decimal256, i256::from(-1i8), 10, 2, i256::from(1i8)); - - // Float32, Float64 - eval_legacy_mode!(Float32, f32::NEG_INFINITY, f32::INFINITY); - eval_legacy_mode!(Float32, f32::INFINITY, f32::INFINITY); - eval_legacy_mode!(Float32, 0.0f32, 0.0f32); - eval_legacy_mode!(Float32, -0.0f32, 0.0f32); - eval_legacy_mode!(Float64, f64::NEG_INFINITY, f64::INFINITY); - eval_legacy_mode!(Float64, f64::INFINITY, f64::INFINITY); - eval_legacy_mode!(Float64, 0.0f64, 0.0f64); - eval_legacy_mode!(Float64, -0.0f64, 0.0f64); - } - macro_rules! eval_array_legacy_mode { ($INPUT:expr, $OUTPUT:expr, $FUNC:ident) => {{ let input = $INPUT; let args = ColumnarValue::Array(Arc::new(input)); let expected = $OUTPUT; - match spark_abs(&[args]) { + match spark_abs(&[args], false) { Ok(ColumnarValue::Array(result)) => { let actual = datafusion_common::cast::$FUNC(&result).unwrap(); assert_eq!(actual, &expected); @@ -367,24 +330,187 @@ mod tests { ); eval_array_legacy_mode!( - Decimal128Array::from(vec![Some(i128::MIN), None]) + Decimal128Array::from(vec![Some(i128::MIN), Some(i128::MIN + 1), None]) .with_precision_and_scale(38, 37) .unwrap(), - Decimal128Array::from(vec![Some(i128::MIN), None]) + Decimal128Array::from(vec![Some(i128::MIN), Some(i128::MAX), None]) .with_precision_and_scale(38, 37) .unwrap(), as_decimal128_array ); eval_array_legacy_mode!( - Decimal256Array::from(vec![Some(i256::MIN), None]) - .with_precision_and_scale(5, 2) + Decimal256Array::from(vec![ + Some(i256::MIN), + Some(i256::MINUS_ONE), + Some(i256::MIN + i256::from(1)), + None + ]) + .with_precision_and_scale(5, 2) + .unwrap(), + Decimal256Array::from(vec![ + Some(i256::MIN), + Some(i256::ONE), + Some(i256::MAX), + None + ]) + .with_precision_and_scale(5, 2) + .unwrap(), + as_decimal256_array + ); + } + + macro_rules! eval_array_ansi_mode { + ($INPUT:expr) => {{ + let input = $INPUT; + let args = ColumnarValue::Array(Arc::new(input)); + match spark_abs(&[args], true) { + Err(e) => { + assert!( + e.to_string().contains("overflow on abs"), + "Error message did not match. Actual message: {e}" + ); + } + _ => unreachable!(), + } + }}; + ($INPUT:expr, $OUTPUT:expr, $FUNC:ident) => {{ + let input = $INPUT; + let args = ColumnarValue::Array(Arc::new(input)); + let expected = $OUTPUT; + match spark_abs(&[args], true) { + Ok(ColumnarValue::Array(result)) => { + let actual = datafusion_common::cast::$FUNC(&result).unwrap(); + assert_eq!(actual, &expected); + } + _ => unreachable!(), + } + }}; + } + #[test] + fn test_abs_array_ansi_mode() { + eval_array_ansi_mode!( + UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX), None]), + UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX), None]), + as_uint64_array + ); + + eval_array_ansi_mode!(Int8Array::from(vec![ + Some(-1), + Some(i8::MIN), + Some(i8::MAX), + None + ])); + eval_array_ansi_mode!(Int16Array::from(vec![ + Some(-1), + Some(i16::MIN), + Some(i16::MAX), + None + ])); + eval_array_ansi_mode!(Int32Array::from(vec![ + Some(-1), + Some(i32::MIN), + Some(i32::MAX), + None + ])); + eval_array_ansi_mode!(Int64Array::from(vec![ + Some(-1), + Some(i64::MIN), + Some(i64::MAX), + None + ])); + eval_array_ansi_mode!( + Float32Array::from(vec![ + Some(-1f32), + Some(f32::MIN), + Some(f32::MAX), + None, + Some(f32::NAN), + Some(f32::INFINITY), + Some(f32::NEG_INFINITY), + Some(0.0), + Some(-0.0), + ]), + Float32Array::from(vec![ + Some(1f32), + Some(f32::MAX), + Some(f32::MAX), + None, + Some(f32::NAN), + Some(f32::INFINITY), + Some(f32::INFINITY), + Some(0.0), + Some(0.0), + ]), + as_float32_array + ); + + eval_array_ansi_mode!( + Float64Array::from(vec![ + Some(-1f64), + Some(f64::MIN), + Some(f64::MAX), + None, + Some(f64::NAN), + Some(f64::INFINITY), + Some(f64::NEG_INFINITY), + Some(0.0), + Some(-0.0), + ]), + Float64Array::from(vec![ + Some(1f64), + Some(f64::MAX), + Some(f64::MAX), + None, + Some(f64::NAN), + Some(f64::INFINITY), + Some(f64::INFINITY), + Some(0.0), + Some(0.0), + ]), + as_float64_array + ); + + // decimal: no arithmetic overflow + eval_array_ansi_mode!( + Decimal128Array::from(vec![Some(-1), Some(-2), Some(i128::MIN + 1)]) + .with_precision_and_scale(38, 37) .unwrap(), - Decimal256Array::from(vec![Some(i256::MIN), None]) - .with_precision_and_scale(5, 2) + Decimal128Array::from(vec![Some(1), Some(2), Some(i128::MAX)]) + .with_precision_and_scale(38, 37) .unwrap(), + as_decimal128_array + ); + + eval_array_ansi_mode!( + Decimal256Array::from(vec![ + Some(i256::MINUS_ONE), + Some(i256::from(-2)), + Some(i256::MIN + i256::from(1)) + ]) + .with_precision_and_scale(18, 7) + .unwrap(), + Decimal256Array::from(vec![ + Some(i256::ONE), + Some(i256::from(2)), + Some(i256::MAX) + ]) + .with_precision_and_scale(18, 7) + .unwrap(), as_decimal256_array ); + + // decimal: arithmetic overflow + eval_array_ansi_mode!( + Decimal128Array::from(vec![Some(i128::MIN), None]) + .with_precision_and_scale(38, 37) + .unwrap() + ); + eval_array_ansi_mode!( + Decimal256Array::from(vec![Some(i256::MIN), None]) + .with_precision_and_scale(5, 2) + .unwrap() + ); } #[test] diff --git a/datafusion/spark/src/function/math/hex.rs b/datafusion/spark/src/function/math/hex.rs index ef62b08fb03d2..06c77f37021bf 100644 --- a/datafusion/spark/src/function/math/hex.rs +++ b/datafusion/spark/src/function/math/hex.rs @@ -16,9 +16,10 @@ // under the License. use std::any::Any; +use std::str::from_utf8_unchecked; use std::sync::Arc; -use arrow::array::{Array, StringArray}; +use arrow::array::{Array, ArrayRef, StringBuilder}; use arrow::datatypes::DataType; use arrow::{ array::{as_dictionary_array, as_largestring_array, as_string_array}, @@ -91,11 +92,13 @@ impl ScalarUDFImpl for SparkHex { &self.signature } - fn return_type( - &self, - _arg_types: &[DataType], - ) -> datafusion_common::Result { - Ok(DataType::Utf8) + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + Ok(match &arg_types[0] { + DataType::Dictionary(key_type, _) => { + DataType::Dictionary(key_type.clone(), Box::new(DataType::Utf8)) + } + _ => DataType::Utf8, + }) } fn invoke_with_args( @@ -110,37 +113,85 @@ impl ScalarUDFImpl for SparkHex { } } -fn hex_int64(num: i64) -> String { - format!("{num:X}") -} - /// Hex encoding lookup tables for fast byte-to-hex conversion const HEX_CHARS_LOWER: &[u8; 16] = b"0123456789abcdef"; const HEX_CHARS_UPPER: &[u8; 16] = b"0123456789ABCDEF"; #[inline] -fn hex_encode>(data: T, lower_case: bool) -> String { - let bytes = data.as_ref(); - let mut s = String::with_capacity(bytes.len() * 2); - let hex_chars = if lower_case { +fn hex_int64(num: i64, buffer: &mut [u8; 16]) -> &[u8] { + if num == 0 { + return b"0"; + } + + let mut n = num as u64; + let mut i = 16; + while n != 0 { + i -= 1; + buffer[i] = HEX_CHARS_UPPER[(n & 0xF) as usize]; + n >>= 4; + } + &buffer[i..] +} + +/// Generic hex encoding for byte array types +fn hex_encode_bytes<'a, I, T>( + iter: I, + lowercase: bool, + len: usize, +) -> Result +where + I: Iterator>, + T: AsRef<[u8]> + 'a, +{ + let mut builder = StringBuilder::with_capacity(len, len * 64); + let mut buffer = Vec::with_capacity(64); + let hex_chars = if lowercase { HEX_CHARS_LOWER } else { HEX_CHARS_UPPER }; - for &b in bytes { - s.push(hex_chars[(b >> 4) as usize] as char); - s.push(hex_chars[(b & 0x0f) as usize] as char); + + for v in iter { + if let Some(b) = v { + buffer.clear(); + let bytes = b.as_ref(); + for &byte in bytes { + buffer.push(hex_chars[(byte >> 4) as usize]); + buffer.push(hex_chars[(byte & 0x0f) as usize]); + } + // SAFETY: buffer contains only ASCII hex digests, which are valid UTF-8 + unsafe { + builder.append_value(from_utf8_unchecked(&buffer)); + } + } else { + builder.append_null(); + } } - s + + Ok(Arc::new(builder.finish())) } -#[inline(always)] -fn hex_bytes>( - bytes: T, - lowercase: bool, -) -> Result { - let hex_string = hex_encode(bytes, lowercase); - Ok(hex_string) +/// Generic hex encoding for int64 type +fn hex_encode_int64( + iter: impl Iterator>, + len: usize, +) -> Result { + let mut builder = StringBuilder::with_capacity(len, len * 16); + + for v in iter { + if let Some(num) = v { + let mut temp = [0u8; 16]; + let slice = hex_int64(num, &mut temp); + // SAFETY: slice contains only ASCII hex digests, which are valid UTF-8 + unsafe { + builder.append_value(from_utf8_unchecked(slice)); + } + } else { + builder.append_null(); + } + } + + Ok(Arc::new(builder.finish())) } /// Spark-compatible `hex` function @@ -166,103 +217,109 @@ pub fn compute_hex( ColumnarValue::Array(array) => match array.data_type() { DataType::Int64 => { let array = as_int64_array(array)?; - - let hexed_array: StringArray = - array.iter().map(|v| v.map(hex_int64)).collect(); - - Ok(ColumnarValue::Array(Arc::new(hexed_array))) + Ok(ColumnarValue::Array(hex_encode_int64( + array.iter(), + array.len(), + )?)) } DataType::Utf8 => { let array = as_string_array(array); - - let hexed: StringArray = array - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?; - - Ok(ColumnarValue::Array(Arc::new(hexed))) + Ok(ColumnarValue::Array(hex_encode_bytes( + array.iter(), + lowercase, + array.len(), + )?)) } DataType::Utf8View => { let array = as_string_view_array(array)?; - - let hexed: StringArray = array - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?; - - Ok(ColumnarValue::Array(Arc::new(hexed))) + Ok(ColumnarValue::Array(hex_encode_bytes( + array.iter(), + lowercase, + array.len(), + )?)) } DataType::LargeUtf8 => { let array = as_largestring_array(array); - - let hexed: StringArray = array - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?; - - Ok(ColumnarValue::Array(Arc::new(hexed))) + Ok(ColumnarValue::Array(hex_encode_bytes( + array.iter(), + lowercase, + array.len(), + )?)) } DataType::Binary => { let array = as_binary_array(array)?; - - let hexed: StringArray = array - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?; - - Ok(ColumnarValue::Array(Arc::new(hexed))) + Ok(ColumnarValue::Array(hex_encode_bytes( + array.iter(), + lowercase, + array.len(), + )?)) } DataType::LargeBinary => { let array = as_large_binary_array(array)?; - - let hexed: StringArray = array - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?; - - Ok(ColumnarValue::Array(Arc::new(hexed))) + Ok(ColumnarValue::Array(hex_encode_bytes( + array.iter(), + lowercase, + array.len(), + )?)) } DataType::FixedSizeBinary(_) => { let array = as_fixed_size_binary_array(array)?; - - let hexed: StringArray = array - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?; - - Ok(ColumnarValue::Array(Arc::new(hexed))) + Ok(ColumnarValue::Array(hex_encode_bytes( + array.iter(), + lowercase, + array.len(), + )?)) } - DataType::Dictionary(_, value_type) => { - let dict = as_dictionary_array::(&array); + DataType::Dictionary(key_type, _) => { + if **key_type != DataType::Int32 { + return exec_err!( + "hex only supports Int32 dictionary keys, get: {}", + key_type + ); + } - let values = match **value_type { - DataType::Int64 => as_int64_array(dict.values())? - .iter() - .map(|v| v.map(hex_int64)) - .collect::>(), - DataType::Utf8 => as_string_array(dict.values()) - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?, - DataType::Binary => as_binary_array(dict.values())? - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?, - _ => exec_err!( - "hex got an unexpected argument type: {}", - array.data_type() - )?, + let dict = as_dictionary_array::(&array); + let dict_values = dict.values(); + + let encoded_values = match dict_values.data_type() { + DataType::Int64 => { + let arr = as_int64_array(dict_values)?; + hex_encode_int64(arr.iter(), arr.len())? + } + DataType::Utf8 => { + let arr = as_string_array(dict_values); + hex_encode_bytes(arr.iter(), lowercase, arr.len())? + } + DataType::LargeUtf8 => { + let arr = as_largestring_array(dict_values); + hex_encode_bytes(arr.iter(), lowercase, arr.len())? + } + DataType::Utf8View => { + let arr = as_string_view_array(dict_values)?; + hex_encode_bytes(arr.iter(), lowercase, arr.len())? + } + DataType::Binary => { + let arr = as_binary_array(dict_values)?; + hex_encode_bytes(arr.iter(), lowercase, arr.len())? + } + DataType::LargeBinary => { + let arr = as_large_binary_array(dict_values)?; + hex_encode_bytes(arr.iter(), lowercase, arr.len())? + } + DataType::FixedSizeBinary(_) => { + let arr = as_fixed_size_binary_array(dict_values)?; + hex_encode_bytes(arr.iter(), lowercase, arr.len())? + } + _ => { + return exec_err!( + "hex got an unexpected argument type: {}", + dict_values.data_type() + ); + } }; - let new_values: Vec> = dict - .keys() - .iter() - .map(|key| key.map(|k| values[k as usize].clone()).unwrap_or(None)) - .collect(); - - let string_array_values = StringArray::from(new_values); - - Ok(ColumnarValue::Array(Arc::new(string_array_values))) + let new_dict = dict.with_values(encoded_values); + Ok(ColumnarValue::Array(Arc::new(new_dict))) } _ => exec_err!("hex got an unexpected argument type: {}", array.data_type()), }, @@ -272,16 +329,18 @@ pub fn compute_hex( #[cfg(test)] mod test { + use std::str::from_utf8_unchecked; use std::sync::Arc; - use arrow::array::{Int64Array, StringArray}; + use arrow::array::{DictionaryArray, Int32Array, Int64Array, StringArray}; use arrow::{ array::{ - BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringBuilder, - StringDictionaryBuilder, as_string_array, + BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringDictionaryBuilder, + as_string_array, }, datatypes::{Int32Type, Int64Type}, }; + use datafusion_common::cast::as_dictionary_array; use datafusion_expr::ColumnarValue; #[test] @@ -293,12 +352,12 @@ mod test { input_builder.append_value("rust"); let input = input_builder.finish(); - let mut string_builder = StringBuilder::new(); - string_builder.append_value("6869"); - string_builder.append_value("627965"); - string_builder.append_null(); - string_builder.append_value("72757374"); - let expected = string_builder.finish(); + let mut expected_builder = StringDictionaryBuilder::::new(); + expected_builder.append_value("6869"); + expected_builder.append_value("627965"); + expected_builder.append_null(); + expected_builder.append_value("72757374"); + let expected = expected_builder.finish(); let columnar_value = ColumnarValue::Array(Arc::new(input)); let result = super::spark_hex(&[columnar_value]).unwrap(); @@ -308,7 +367,7 @@ mod test { _ => panic!("Expected array"), }; - let result = as_string_array(&result); + let result = as_dictionary_array(&result).unwrap(); assert_eq!(result, &expected); } @@ -322,12 +381,12 @@ mod test { input_builder.append_value(3); let input = input_builder.finish(); - let mut string_builder = StringBuilder::new(); - string_builder.append_value("1"); - string_builder.append_value("2"); - string_builder.append_null(); - string_builder.append_value("3"); - let expected = string_builder.finish(); + let mut expected_builder = StringDictionaryBuilder::::new(); + expected_builder.append_value("1"); + expected_builder.append_value("2"); + expected_builder.append_null(); + expected_builder.append_value("3"); + let expected = expected_builder.finish(); let columnar_value = ColumnarValue::Array(Arc::new(input)); let result = super::spark_hex(&[columnar_value]).unwrap(); @@ -337,7 +396,7 @@ mod test { _ => panic!("Expected array"), }; - let result = as_string_array(&result); + let result = as_dictionary_array(&result).unwrap(); assert_eq!(result, &expected); } @@ -351,7 +410,7 @@ mod test { input_builder.append_value("3"); let input = input_builder.finish(); - let mut expected_builder = StringBuilder::new(); + let mut expected_builder = StringDictionaryBuilder::::new(); expected_builder.append_value("31"); expected_builder.append_value("6A"); expected_builder.append_null(); @@ -366,20 +425,24 @@ mod test { _ => panic!("Expected array"), }; - let result = as_string_array(&result); + let result = as_dictionary_array(&result).unwrap(); assert_eq!(result, &expected); } #[test] fn test_hex_int64() { - let num = 1234; - let hexed = super::hex_int64(num); - assert_eq!(hexed, "4D2".to_string()); + let test_cases = vec![(1234, "4D2"), (-1, "FFFFFFFFFFFFFFFF")]; + + for (num, expected) in test_cases { + let mut cache = [0u8; 16]; + let slice = super::hex_int64(num, &mut cache); - let num = -1; - let hexed = super::hex_int64(num); - assert_eq!(hexed, "FFFFFFFFFFFFFFFF".to_string()); + unsafe { + let result = from_utf8_unchecked(slice); + assert_eq!(expected, result); + } + } } #[test] @@ -403,4 +466,28 @@ mod test { assert_eq!(string_array, &expected_array); } + + #[test] + fn test_dict_values_null() { + let keys = Int32Array::from(vec![Some(0), None, Some(1)]); + let vals = Int64Array::from(vec![Some(32), None]); + // [32, null, null] + let dict = DictionaryArray::new(keys, Arc::new(vals)); + + let columnar_value = ColumnarValue::Array(Arc::new(dict)); + let result = super::spark_hex(&[columnar_value]).unwrap(); + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let result = as_dictionary_array(&result).unwrap(); + + let keys = Int32Array::from(vec![Some(0), None, Some(1)]); + let vals = StringArray::from(vec![Some("20"), None]); + let expected = DictionaryArray::new(keys, Arc::new(vals)); + + assert_eq!(&expected, result); + } } diff --git a/datafusion/spark/src/function/math/mod.rs b/datafusion/spark/src/function/math/mod.rs index 1422eb250d939..92d8e90ac372e 100644 --- a/datafusion/spark/src/function/math/mod.rs +++ b/datafusion/spark/src/function/math/mod.rs @@ -20,8 +20,10 @@ pub mod expm1; pub mod factorial; pub mod hex; pub mod modulus; +pub mod negative; pub mod rint; pub mod trigonometry; +pub mod unhex; pub mod width_bucket; use datafusion_expr::ScalarUDF; @@ -35,9 +37,11 @@ make_udf_function!(hex::SparkHex, hex); make_udf_function!(modulus::SparkMod, modulus); make_udf_function!(modulus::SparkPmod, pmod); make_udf_function!(rint::SparkRint, rint); +make_udf_function!(unhex::SparkUnhex, unhex); make_udf_function!(width_bucket::SparkWidthBucket, width_bucket); make_udf_function!(trigonometry::SparkCsc, csc); make_udf_function!(trigonometry::SparkSec, sec); +make_udf_function!(negative::SparkNegative, negative); pub mod expr_fn { use datafusion_functions::export_functions; @@ -57,9 +61,15 @@ pub mod expr_fn { "Returns the double value that is closest in value to the argument and is equal to a mathematical integer.", arg1 )); + export_functions!((unhex, "Converts hexadecimal string to binary.", arg1)); export_functions!((width_bucket, "Returns the bucket number into which the value of this expression would fall after being evaluated.", arg1 arg2 arg3 arg4)); export_functions!((csc, "Returns the cosecant of expr.", arg1)); export_functions!((sec, "Returns the secant of expr.", arg1)); + export_functions!(( + negative, + "Returns the negation of expr (unary minus).", + arg1 + )); } pub fn functions() -> Vec> { @@ -71,8 +81,10 @@ pub fn functions() -> Vec> { modulus(), pmod(), rint(), + unhex(), width_bucket(), csc(), sec(), + negative(), ] } diff --git a/datafusion/spark/src/function/math/negative.rs b/datafusion/spark/src/function/math/negative.rs new file mode 100644 index 0000000000000..2df71b709d8c4 --- /dev/null +++ b/datafusion/spark/src/function/math/negative.rs @@ -0,0 +1,477 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::types::*; +use arrow::array::*; +use arrow::datatypes::{DataType, IntervalDayTime, IntervalMonthDayNano, IntervalUnit}; +use bigdecimal::num_traits::WrappingNeg; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, exec_err, not_impl_err}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + Volatility, +}; +use std::any::Any; +use std::sync::Arc; + +/// Spark-compatible `negative` expression +/// +/// +/// Returns the negation of input (equivalent to unary minus) +/// Returns NULL if input is NULL, returns NaN if input is NaN. +/// +/// ANSI mode support: +/// - When ANSI mode is disabled (`spark.sql.ansi.enabled=false`), negating the minimal +/// value of a signed integer wraps around. For example: negative(i32::MIN) returns +/// i32::MIN (wraps instead of error). +/// - When ANSI mode is enabled (`spark.sql.ansi.enabled=true`), overflow conditions +/// throw an ARITHMETIC_OVERFLOW error instead of wrapping. +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkNegative { + signature: Signature, +} + +impl Default for SparkNegative { + fn default() -> Self { + Self::new() + } +} + +impl SparkNegative { + pub fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::OneOf(vec![ + // Numeric types: signed integers, float, decimals + TypeSignature::Numeric(1), + // Interval types: YearMonth, DayTime, MonthDayNano + TypeSignature::Uniform( + 1, + vec![ + DataType::Interval(IntervalUnit::YearMonth), + DataType::Interval(IntervalUnit::DayTime), + DataType::Interval(IntervalUnit::MonthDayNano), + ], + ), + ]), + volatility: Volatility::Immutable, + parameter_names: None, + }, + } + } +} + +impl ScalarUDFImpl for SparkNegative { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "negative" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_negative(&args.args, args.config_options.execution.enable_ansi_mode) + } +} + +/// Macro to implement negation for integer array types +macro_rules! impl_integer_array_negative { + ($array:expr, $type:ty, $type_name:expr, $enable_ansi_mode:expr) => {{ + let array = $array.as_primitive::<$type>(); + let result: PrimitiveArray<$type> = if $enable_ansi_mode { + array.try_unary(|x| { + x.checked_neg().ok_or_else(|| { + (exec_err!("{} overflow on negative({x})", $type_name) + as Result<(), _>) + .unwrap_err() + }) + })? + } else { + array.unary(|x| x.wrapping_neg()) + }; + Ok(ColumnarValue::Array(Arc::new(result))) + }}; +} + +/// Macro to implement negation for float array types +macro_rules! impl_float_array_negative { + ($array:expr, $type:ty) => {{ + let array = $array.as_primitive::<$type>(); + let result: PrimitiveArray<$type> = array.unary(|x| -x); + Ok(ColumnarValue::Array(Arc::new(result))) + }}; +} + +/// Macro to implement negation for decimal array types +macro_rules! impl_decimal_array_negative { + ($array:expr, $type:ty, $type_name:expr, $enable_ansi_mode:expr) => {{ + let array = $array.as_primitive::<$type>(); + let result: PrimitiveArray<$type> = if $enable_ansi_mode { + array + .try_unary(|x| { + x.checked_neg().ok_or_else(|| { + (exec_err!("{} overflow on negative({x})", $type_name) + as Result<(), _>) + .unwrap_err() + }) + })? + .with_data_type(array.data_type().clone()) + } else { + array.unary(|x| x.wrapping_neg()) + }; + Ok(ColumnarValue::Array(Arc::new(result))) + }}; +} + +/// Macro to implement negation for integer scalar types +macro_rules! impl_integer_scalar_negative { + ($v:expr, $type_name:expr, $variant:ident, $enable_ansi_mode:expr) => {{ + let result = if $enable_ansi_mode { + $v.checked_neg().ok_or_else(|| { + (exec_err!("{} overflow on negative({})", $type_name, $v) + as Result<(), _>) + .unwrap_err() + })? + } else { + $v.wrapping_neg() + }; + Ok(ColumnarValue::Scalar(ScalarValue::$variant(Some(result)))) + }}; +} + +/// Macro to implement negation for decimal scalar types +macro_rules! impl_decimal_scalar_negative { + ($v:expr, $precision:expr, $scale:expr, $type_name:expr, $variant:ident, $enable_ansi_mode:expr) => {{ + let result = if $enable_ansi_mode { + $v.checked_neg().ok_or_else(|| { + (exec_err!("{} overflow on negative({})", $type_name, $v) + as Result<(), _>) + .unwrap_err() + })? + } else { + $v.wrapping_neg() + }; + Ok(ColumnarValue::Scalar(ScalarValue::$variant( + Some(result), + *$precision, + *$scale, + ))) + }}; +} + +/// Core implementation of Spark's negative function +fn spark_negative( + args: &[ColumnarValue], + enable_ansi_mode: bool, +) -> Result { + let [arg] = take_function_args("negative", args)?; + + match arg { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Null => Ok(arg.clone()), + + // Signed integers - use checked negation in ANSI mode, wrapping in legacy mode + DataType::Int8 => { + impl_integer_array_negative!(array, Int8Type, "Int8", enable_ansi_mode) + } + DataType::Int16 => { + impl_integer_array_negative!(array, Int16Type, "Int16", enable_ansi_mode) + } + DataType::Int32 => { + impl_integer_array_negative!(array, Int32Type, "Int32", enable_ansi_mode) + } + DataType::Int64 => { + impl_integer_array_negative!(array, Int64Type, "Int64", enable_ansi_mode) + } + + // Floating point - simple negation (no overflow possible) + DataType::Float16 => impl_float_array_negative!(array, Float16Type), + DataType::Float32 => impl_float_array_negative!(array, Float32Type), + DataType::Float64 => impl_float_array_negative!(array, Float64Type), + + // Decimal types - use checked negation in ANSI mode, wrapping in legacy mode + DataType::Decimal32(_, _) => impl_decimal_array_negative!( + array, + Decimal32Type, + "Decimal32", + enable_ansi_mode + ), + DataType::Decimal64(_, _) => impl_decimal_array_negative!( + array, + Decimal64Type, + "Decimal64", + enable_ansi_mode + ), + DataType::Decimal128(_, _) => impl_decimal_array_negative!( + array, + Decimal128Type, + "Decimal128", + enable_ansi_mode + ), + DataType::Decimal256(_, _) => impl_decimal_array_negative!( + array, + Decimal256Type, + "Decimal256", + enable_ansi_mode + ), + + // interval type - use checked negation in ANSI mode, wrapping in legacy mode + DataType::Interval(IntervalUnit::YearMonth) => { + impl_integer_array_negative!( + array, + IntervalYearMonthType, + "IntervalYearMonth", + enable_ansi_mode + ) + } + DataType::Interval(IntervalUnit::DayTime) => { + let array = array.as_primitive::(); + let result: PrimitiveArray = if enable_ansi_mode { + array.try_unary(|x| { + let days = x.days.checked_neg().ok_or_else(|| { + (exec_err!( + "IntervalDayTime overflow on negative (days: {})", + x.days + ) as Result<(), _>) + .unwrap_err() + })?; + let milliseconds = + x.milliseconds.checked_neg().ok_or_else(|| { + (exec_err!( + "IntervalDayTime overflow on negative (milliseconds: {})", + x.milliseconds + ) as Result<(), _>) + .unwrap_err() + })?; + Ok::<_, arrow::error::ArrowError>(IntervalDayTime { + days, + milliseconds, + }) + })? + } else { + array.unary(|x| IntervalDayTime { + days: x.days.wrapping_neg(), + milliseconds: x.milliseconds.wrapping_neg(), + }) + }; + Ok(ColumnarValue::Array(Arc::new(result))) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + let array = array.as_primitive::(); + let result: PrimitiveArray = if enable_ansi_mode + { + array.try_unary(|x| { + let months = x.months.checked_neg().ok_or_else(|| { + (exec_err!( + "IntervalMonthDayNano overflow on negative (months: {})", + x.months + ) as Result<(), _>) + .unwrap_err() + })?; + let days = x.days.checked_neg().ok_or_else(|| { + (exec_err!( + "IntervalMonthDayNano overflow on negative (days: {})", + x.days + ) as Result<(), _>) + .unwrap_err() + })?; + let nanoseconds = x.nanoseconds.checked_neg().ok_or_else(|| { + (exec_err!( + "IntervalMonthDayNano overflow on negative (nanoseconds: {})", + x.nanoseconds + ) as Result<(), _>) + .unwrap_err() + })?; + Ok::<_, arrow::error::ArrowError>(IntervalMonthDayNano { + months, + days, + nanoseconds, + }) + })? + } else { + array.unary(|x| IntervalMonthDayNano { + months: x.months.wrapping_neg(), + days: x.days.wrapping_neg(), + nanoseconds: x.nanoseconds.wrapping_neg(), + }) + }; + Ok(ColumnarValue::Array(Arc::new(result))) + } + + dt => not_impl_err!("Not supported datatype for Spark negative(): {dt}"), + }, + ColumnarValue::Scalar(sv) => match sv { + ScalarValue::Null => Ok(arg.clone()), + _ if sv.is_null() => Ok(arg.clone()), + + // Signed integers - use checked negation in ANSI mode, wrapping in legacy mode + ScalarValue::Int8(Some(v)) => { + impl_integer_scalar_negative!(v, "Int8", Int8, enable_ansi_mode) + } + ScalarValue::Int16(Some(v)) => { + impl_integer_scalar_negative!(v, "Int16", Int16, enable_ansi_mode) + } + ScalarValue::Int32(Some(v)) => { + impl_integer_scalar_negative!(v, "Int32", Int32, enable_ansi_mode) + } + ScalarValue::Int64(Some(v)) => { + impl_integer_scalar_negative!(v, "Int64", Int64, enable_ansi_mode) + } + + // Floating point - simple negation + ScalarValue::Float16(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Float16(Some(-v)))) + } + ScalarValue::Float32(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(-v)))) + } + ScalarValue::Float64(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(-v)))) + } + + // Decimal types - use checked negation in ANSI mode, wrapping in legacy mode + ScalarValue::Decimal32(Some(v), precision, scale) => { + impl_decimal_scalar_negative!( + v, + precision, + scale, + "Decimal32", + Decimal32, + enable_ansi_mode + ) + } + ScalarValue::Decimal64(Some(v), precision, scale) => { + impl_decimal_scalar_negative!( + v, + precision, + scale, + "Decimal64", + Decimal64, + enable_ansi_mode + ) + } + ScalarValue::Decimal128(Some(v), precision, scale) => { + impl_decimal_scalar_negative!( + v, + precision, + scale, + "Decimal128", + Decimal128, + enable_ansi_mode + ) + } + ScalarValue::Decimal256(Some(v), precision, scale) => { + impl_decimal_scalar_negative!( + v, + precision, + scale, + "Decimal256", + Decimal256, + enable_ansi_mode + ) + } + + //interval type - use checked negation in ANSI mode, wrapping in legacy mode + ScalarValue::IntervalYearMonth(Some(v)) => { + impl_integer_scalar_negative!( + v, + "IntervalYearMonth", + IntervalYearMonth, + enable_ansi_mode + ) + } + ScalarValue::IntervalDayTime(Some(v)) => { + let result = if enable_ansi_mode { + let days = v.days.checked_neg().ok_or_else(|| { + (exec_err!( + "IntervalDayTime overflow on negative (days: {})", + v.days + ) as Result<(), _>) + .unwrap_err() + })?; + let milliseconds = v.milliseconds.checked_neg().ok_or_else(|| { + (exec_err!( + "IntervalDayTime overflow on negative (milliseconds: {})", + v.milliseconds + ) as Result<(), _>) + .unwrap_err() + })?; + IntervalDayTime { days, milliseconds } + } else { + IntervalDayTime { + days: v.days.wrapping_neg(), + milliseconds: v.milliseconds.wrapping_neg(), + } + }; + Ok(ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( + result, + )))) + } + ScalarValue::IntervalMonthDayNano(Some(v)) => { + let result = if enable_ansi_mode { + let months = v.months.checked_neg().ok_or_else(|| { + (exec_err!( + "IntervalMonthDayNano overflow on negative (months: {})", + v.months + ) as Result<(), _>) + .unwrap_err() + })?; + let days = v.days.checked_neg().ok_or_else(|| { + (exec_err!( + "IntervalMonthDayNano overflow on negative (days: {})", + v.days + ) as Result<(), _>) + .unwrap_err() + })?; + let nanoseconds = v.nanoseconds.checked_neg().ok_or_else(|| { + (exec_err!( + "IntervalMonthDayNano overflow on negative (nanoseconds: {})", + v.nanoseconds + ) as Result<(), _>) + .unwrap_err() + })?; + IntervalMonthDayNano { + months, + days, + nanoseconds, + } + } else { + IntervalMonthDayNano { + months: v.months.wrapping_neg(), + days: v.days.wrapping_neg(), + nanoseconds: v.nanoseconds.wrapping_neg(), + } + }; + Ok(ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano( + Some(result), + ))) + } + + dt => not_impl_err!("Not supported datatype for Spark negative(): {dt}"), + }, + } +} diff --git a/datafusion/spark/src/function/math/unhex.rs b/datafusion/spark/src/function/math/unhex.rs new file mode 100644 index 0000000000000..dee532d818f83 --- /dev/null +++ b/datafusion/spark/src/function/math/unhex.rs @@ -0,0 +1,214 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, BinaryBuilder}; +use arrow::datatypes::DataType; +use datafusion_common::cast::{ + as_large_string_array, as_string_array, as_string_view_array, +}; +use datafusion_common::types::logical_string; +use datafusion_common::utils::take_function_args; +use datafusion_common::{DataFusionError, Result, ScalarValue, exec_err}; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignatureClass, Volatility, +}; +use std::any::Any; +use std::sync::Arc; + +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkUnhex { + signature: Signature, +} + +impl Default for SparkUnhex { + fn default() -> Self { + Self::new() + } +} + +impl SparkUnhex { + pub fn new() -> Self { + let string = Coercion::new_exact(TypeSignatureClass::Native(logical_string())); + + Self { + signature: Signature::coercible(vec![string], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkUnhex { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "unhex" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Binary) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_unhex(&args.args) + } +} + +#[inline] +fn hex_nibble(c: u8) -> Option { + match c { + b'0'..=b'9' => Some(c - b'0'), + b'a'..=b'f' => Some(c - b'a' + 10), + b'A'..=b'F' => Some(c - b'A' + 10), + _ => None, + } +} + +/// Decodes a hex-encoded byte slice into binary data. +/// Returns `true` if decoding succeeded, `false` if the input contains invalid hex characters. +fn unhex_common(bytes: &[u8], out: &mut Vec) -> bool { + if bytes.is_empty() { + return true; + } + + let mut i = 0usize; + + // If the hex string length is odd, implicitly left-pad with '0'. + if (bytes.len() & 1) == 1 { + match hex_nibble(bytes[0]) { + // Equivalent to (0 << 4) | lo + Some(lo) => out.push(lo), + None => return false, + } + i = 1; + } + + while i + 1 < bytes.len() { + match (hex_nibble(bytes[i]), hex_nibble(bytes[i + 1])) { + (Some(hi), Some(lo)) => out.push((hi << 4) | lo), + _ => return false, + } + i += 2; + } + + true +} + +/// Converts an iterator of hex strings to a binary array. +fn unhex_array( + iter: I, + len: usize, + capacity: usize, +) -> Result +where + I: Iterator>, + T: AsRef, +{ + let mut builder = BinaryBuilder::with_capacity(len, capacity); + let mut buffer = Vec::new(); + + for v in iter { + if let Some(s) = v { + buffer.clear(); + buffer.reserve(s.as_ref().len().div_ceil(2)); + if unhex_common(s.as_ref().as_bytes(), &mut buffer) { + builder.append_value(&buffer); + } else { + builder.append_null(); + } + } else { + builder.append_null(); + } + } + + Ok(Arc::new(builder.finish())) +} + +/// Convert a single hex string to binary +fn unhex_scalar(s: &str) -> Option> { + let mut buffer = Vec::with_capacity(s.len().div_ceil(2)); + if unhex_common(s.as_bytes(), &mut buffer) { + Some(buffer) + } else { + None + } +} + +fn spark_unhex(args: &[ColumnarValue]) -> Result { + let [args] = take_function_args("unhex", args)?; + + match args { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Utf8 => { + let array = as_string_array(array)?; + let capacity = array.values().len().div_ceil(2); + Ok(ColumnarValue::Array(unhex_array( + array.iter(), + array.len(), + capacity, + )?)) + } + DataType::Utf8View => { + let array = as_string_view_array(array)?; + // Estimate capacity since StringViewArray data can be scattered or inlined. + let capacity = array.len() * 32; + Ok(ColumnarValue::Array(unhex_array( + array.iter(), + array.len(), + capacity, + )?)) + } + DataType::LargeUtf8 => { + let array = as_large_string_array(array)?; + let capacity = array.values().len().div_ceil(2); + Ok(ColumnarValue::Array(unhex_array( + array.iter(), + array.len(), + capacity, + )?)) + } + _ => exec_err!( + "unhex only supports string argument, but got: {}", + array.data_type() + ), + }, + ColumnarValue::Scalar(sv) => match sv { + ScalarValue::Utf8(None) + | ScalarValue::Utf8View(None) + | ScalarValue::LargeUtf8(None) => { + Ok(ColumnarValue::Scalar(ScalarValue::Binary(None))) + } + ScalarValue::Utf8(Some(s)) + | ScalarValue::Utf8View(Some(s)) + | ScalarValue::LargeUtf8(Some(s)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Binary(unhex_scalar(s)))) + } + _ => { + exec_err!( + "unhex only supports string argument, but got: {}", + sv.data_type() + ) + } + }, + } +} diff --git a/datafusion/spark/src/function/math/width_bucket.rs b/datafusion/spark/src/function/math/width_bucket.rs index 8d748439ad806..bd68c37edb517 100644 --- a/datafusion/spark/src/function/math/width_bucket.rs +++ b/datafusion/spark/src/function/math/width_bucket.rs @@ -26,11 +26,11 @@ use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Duration, Float64, Int32, Interval}; use arrow::datatypes::IntervalUnit::{MonthDayNano, YearMonth}; use datafusion_common::cast::{ - as_duration_microsecond_array, as_float64_array, as_int32_array, + as_duration_microsecond_array, as_float64_array, as_int64_array, as_interval_mdn_array, as_interval_ym_array, }; use datafusion_common::types::{ - NativeType, logical_duration_microsecond, logical_float64, logical_int32, + NativeType, logical_duration_microsecond, logical_float64, logical_int64, logical_interval_mdn, logical_interval_year_month, }; use datafusion_common::{Result, exec_err, internal_err}; @@ -41,7 +41,7 @@ use datafusion_expr::{ }; use datafusion_functions::utils::make_scalar_function; -use arrow::array::{Int32Array, Int32Builder}; +use arrow::array::{Int32Array, Int32Builder, Int64Array}; use arrow::datatypes::TimeUnit::Microsecond; use datafusion_expr::Coercion; use datafusion_expr::Volatility::Immutable; @@ -75,9 +75,9 @@ impl SparkWidthBucket { let interval_mdn = Coercion::new_exact(TypeSignatureClass::Native(logical_interval_mdn())); let bucket = Coercion::new_implicit( - TypeSignatureClass::Native(logical_int32()), + TypeSignatureClass::Native(logical_int64()), vec![TypeSignatureClass::Integer], - NativeType::Int32, + NativeType::Int64, ); let type_signature = Signature::one_of( vec![ @@ -160,28 +160,28 @@ fn width_bucket_kern(args: &[ArrayRef]) -> Result { let v = as_float64_array(v)?; let min = as_float64_array(minv)?; let max = as_float64_array(maxv)?; - let n_bucket = as_int32_array(nb)?; + let n_bucket = as_int64_array(nb)?; Ok(Arc::new(width_bucket_float64(v, min, max, n_bucket))) } Duration(Microsecond) => { let v = as_duration_microsecond_array(v)?; let min = as_duration_microsecond_array(minv)?; let max = as_duration_microsecond_array(maxv)?; - let n_bucket = as_int32_array(nb)?; + let n_bucket = as_int64_array(nb)?; Ok(Arc::new(width_bucket_i64_as_float(v, min, max, n_bucket))) } Interval(YearMonth) => { let v = as_interval_ym_array(v)?; let min = as_interval_ym_array(minv)?; let max = as_interval_ym_array(maxv)?; - let n_bucket = as_int32_array(nb)?; + let n_bucket = as_int64_array(nb)?; Ok(Arc::new(width_bucket_i32_as_float(v, min, max, n_bucket))) } Interval(MonthDayNano) => { let v = as_interval_mdn_array(v)?; let min = as_interval_mdn_array(minv)?; let max = as_interval_mdn_array(maxv)?; - let n_bucket = as_int32_array(nb)?; + let n_bucket = as_int64_array(nb)?; Ok(Arc::new(width_bucket_interval_mdn_exact( v, min, max, n_bucket, ))) @@ -203,7 +203,7 @@ macro_rules! width_bucket_kernel_impl { v: &$arr_ty, min: &$arr_ty, max: &$arr_ty, - n_bucket: &Int32Array, + n_bucket: &Int64Array, ) -> Int32Array { let len = v.len(); let mut b = Int32Builder::with_capacity(len); @@ -223,6 +223,7 @@ macro_rules! width_bucket_kernel_impl { b.append_null(); continue; } + let next_bucket = (buckets + 1) as i32; if $check_nan { if !x.is_finite() || !l.is_finite() || !h.is_finite() { b.append_null(); @@ -249,7 +250,7 @@ macro_rules! width_bucket_kernel_impl { continue; } if x >= h { - b.append_value(buckets + 1); + b.append_value(next_bucket); continue; } } else { @@ -258,7 +259,7 @@ macro_rules! width_bucket_kernel_impl { continue; } if x <= h { - b.append_value(buckets + 1); + b.append_value(next_bucket); continue; } } @@ -272,8 +273,8 @@ macro_rules! width_bucket_kernel_impl { if bucket < 1 { bucket = 1; } - if bucket > buckets + 1 { - bucket = buckets + 1; + if bucket > next_bucket { + bucket = next_bucket; } b.append_value(bucket); @@ -309,7 +310,7 @@ pub(crate) fn width_bucket_interval_mdn_exact( v: &IntervalMonthDayNanoArray, lo: &IntervalMonthDayNanoArray, hi: &IntervalMonthDayNanoArray, - n: &Int32Array, + n: &Int64Array, ) -> Int32Array { let len = v.len(); let mut b = Int32Builder::with_capacity(len); @@ -324,6 +325,7 @@ pub(crate) fn width_bucket_interval_mdn_exact( b.append_null(); continue; } + let next_bucket = (buckets + 1) as i32; let x = v.value(i); let l = lo.value(i); @@ -349,7 +351,7 @@ pub(crate) fn width_bucket_interval_mdn_exact( continue; } if x_m >= h_m { - b.append_value(buckets + 1); + b.append_value(next_bucket); continue; } } else { @@ -358,7 +360,7 @@ pub(crate) fn width_bucket_interval_mdn_exact( continue; } if x_m <= h_m { - b.append_value(buckets + 1); + b.append_value(next_bucket); continue; } } @@ -373,8 +375,8 @@ pub(crate) fn width_bucket_interval_mdn_exact( if bucket < 1 { bucket = 1; } - if bucket > buckets + 1 { - bucket = buckets + 1; + if bucket > next_bucket { + bucket = next_bucket; } b.append_value(bucket); continue; @@ -400,7 +402,7 @@ pub(crate) fn width_bucket_interval_mdn_exact( continue; } if x_f >= h_f { - b.append_value(buckets + 1); + b.append_value(next_bucket); continue; } } else { @@ -409,7 +411,7 @@ pub(crate) fn width_bucket_interval_mdn_exact( continue; } if x_f <= h_f { - b.append_value(buckets + 1); + b.append_value(next_bucket); continue; } } @@ -424,8 +426,8 @@ pub(crate) fn width_bucket_interval_mdn_exact( if bucket < 1 { bucket = 1; } - if bucket > buckets + 1 { - bucket = buckets + 1; + if bucket > next_bucket { + bucket = next_bucket; } b.append_value(bucket); continue; @@ -443,15 +445,15 @@ mod tests { use std::sync::Arc; use arrow::array::{ - ArrayRef, DurationMicrosecondArray, Float64Array, Int32Array, + ArrayRef, DurationMicrosecondArray, Float64Array, Int32Array, Int64Array, IntervalYearMonthArray, }; use arrow::datatypes::IntervalMonthDayNano; // --- Helpers ------------------------------------------------------------- - fn i32_array_all(len: usize, val: i32) -> Arc { - Arc::new(Int32Array::from(vec![val; len])) + fn i64_array_all(len: usize, val: i64) -> Arc { + Arc::new(Int64Array::from(vec![val; len])) } fn f64_array(vals: &[f64]) -> Arc { @@ -489,7 +491,7 @@ mod tests { let v = f64_array(&[0.5, 1.0, 9.9, -1.0, 10.0]); let lo = f64_array(&[0.0, 0.0, 0.0, 0.0, 0.0]); let hi = f64_array(&[10.0, 10.0, 10.0, 10.0, 10.0]); - let n = i32_array_all(5, 10); + let n = i64_array_all(5, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -501,7 +503,7 @@ mod tests { let v = f64_array(&[9.9, 10.0, 0.0, -0.1, 10.1]); let lo = f64_array(&[10.0; 5]); let hi = f64_array(&[0.0; 5]); - let n = i32_array_all(5, 10); + let n = i64_array_all(5, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -513,7 +515,7 @@ mod tests { let v = f64_array(&[0.0, 9.999999999, 10.0]); let lo = f64_array(&[0.0; 3]); let hi = f64_array(&[10.0; 3]); - let n = i32_array_all(3, 10); + let n = i64_array_all(3, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -525,7 +527,7 @@ mod tests { let v = f64_array(&[10.0, 0.0, -0.000001]); let lo = f64_array(&[10.0; 3]); let hi = f64_array(&[0.0; 3]); - let n = i32_array_all(3, 10); + let n = i64_array_all(3, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -537,7 +539,7 @@ mod tests { let v = f64_array(&[1.0, 5.0, 9.0]); let lo = f64_array(&[0.0, 0.0, 0.0]); let hi = f64_array(&[10.0, 10.0, 10.0]); - let n = Arc::new(Int32Array::from(vec![0, -1, 10])); + let n = Arc::new(Int64Array::from(vec![0, -1, 10])); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); assert!(out.is_null(0)); @@ -547,7 +549,7 @@ mod tests { let v = f64_array(&[1.0]); let lo = f64_array(&[5.0]); let hi = f64_array(&[5.0]); - let n = i32_array_all(1, 10); + let n = i64_array_all(1, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); assert!(out.is_null(0)); @@ -555,7 +557,7 @@ mod tests { let v = f64_array_opt(&[Some(f64::NAN)]); let lo = f64_array(&[0.0]); let hi = f64_array(&[10.0]); - let n = i32_array_all(1, 10); + let n = i64_array_all(1, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); assert!(out.is_null(0)); @@ -566,7 +568,7 @@ mod tests { let v = f64_array_opt(&[None, Some(1.0), Some(2.0), Some(3.0)]); let lo = f64_array(&[0.0; 4]); let hi = f64_array(&[10.0; 4]); - let n = i32_array_all(4, 10); + let n = i64_array_all(4, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -578,7 +580,7 @@ mod tests { let v = f64_array(&[1.0]); let lo = f64_array_opt(&[None]); let hi = f64_array(&[10.0]); - let n = i32_array_all(1, 10); + let n = i64_array_all(1, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); assert!(out.is_null(0)); @@ -591,7 +593,7 @@ mod tests { let v = dur_us_array(&[1_000_000, 0, -1]); let lo = dur_us_array(&[0, 0, 0]); let hi = dur_us_array(&[2_000_000, 2_000_000, 2_000_000]); - let n = i32_array_all(3, 2); + let n = i64_array_all(3, 2); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -603,7 +605,7 @@ mod tests { let v = dur_us_array(&[0]); let lo = dur_us_array(&[1]); let hi = dur_us_array(&[1]); - let n = i32_array_all(1, 10); + let n = i64_array_all(1, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); assert!(downcast_i32(&out).is_null(0)); } @@ -615,7 +617,7 @@ mod tests { let v = ym_array(&[0, 5, 11, 12, 13]); let lo = ym_array(&[0; 5]); let hi = ym_array(&[12; 5]); - let n = i32_array_all(5, 12); + let n = i64_array_all(5, 12); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -627,7 +629,7 @@ mod tests { let v = ym_array(&[11, 12, 0, -1, 13]); let lo = ym_array(&[12; 5]); let hi = ym_array(&[0; 5]); - let n = i32_array_all(5, 12); + let n = i64_array_all(5, 12); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -641,7 +643,7 @@ mod tests { let v = mdn_array(&[(0, 0, 0), (5, 0, 0), (11, 0, 0), (12, 0, 0), (13, 0, 0)]); let lo = mdn_array(&[(0, 0, 0); 5]); let hi = mdn_array(&[(12, 0, 0); 5]); - let n = i32_array_all(5, 12); + let n = i64_array_all(5, 12); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -653,7 +655,7 @@ mod tests { let v = mdn_array(&[(11, 0, 0), (12, 0, 0), (0, 0, 0), (-1, 0, 0), (13, 0, 0)]); let lo = mdn_array(&[(12, 0, 0); 5]); let hi = mdn_array(&[(0, 0, 0); 5]); - let n = i32_array_all(5, 12); + let n = i64_array_all(5, 12); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -673,7 +675,7 @@ mod tests { ]); let lo = mdn_array(&[(0, 0, 0); 6]); let hi = mdn_array(&[(0, 10, 0); 6]); - let n = i32_array_all(6, 10); + let n = i64_array_all(6, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -686,7 +688,7 @@ mod tests { let v = mdn_array(&[(0, 9, 0), (0, 10, 0), (0, 0, 0), (0, -1, 0), (0, 11, 0)]); let lo = mdn_array(&[(0, 10, 0); 5]); let hi = mdn_array(&[(0, 0, 0); 5]); - let n = i32_array_all(5, 10); + let n = i64_array_all(5, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -698,7 +700,7 @@ mod tests { let v = mdn_array(&[(0, 9, 1), (0, 10, 0), (0, 0, 0), (0, -1, 0), (0, 11, 0)]); let lo = mdn_array(&[(0, 10, 0); 5]); let hi = mdn_array(&[(0, 0, 0); 5]); - let n = i32_array_all(5, 10); + let n = i64_array_all(5, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -711,7 +713,7 @@ mod tests { let v = mdn_array(&[(0, 1, 0)]); let lo = mdn_array(&[(0, 0, 0)]); let hi = mdn_array(&[(1, 1, 0)]); - let n = i32_array_all(1, 4); + let n = i64_array_all(1, 4); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -723,7 +725,7 @@ mod tests { let v = mdn_array(&[(0, 0, 0)]); let lo = mdn_array(&[(1, 2, 3)]); let hi = mdn_array(&[(1, 2, 3)]); // lo == hi - let n = i32_array_all(1, 10); + let n = i64_array_all(1, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); assert!(downcast_i32(&out).is_null(0)); @@ -734,7 +736,7 @@ mod tests { let v = mdn_array(&[(0, 0, 0)]); let lo = mdn_array(&[(0, 0, 0)]); let hi = mdn_array(&[(0, 10, 0)]); - let n = Arc::new(Int32Array::from(vec![0])); // n <= 0 + let n = Arc::new(Int64Array::from(vec![0])); // n <= 0 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); assert!(downcast_i32(&out).is_null(0)); @@ -748,7 +750,7 @@ mod tests { ])); let lo = mdn_array(&[(0, 0, 0), (0, 0, 0)]); let hi = mdn_array(&[(0, 10, 0), (0, 10, 0)]); - let n = i32_array_all(2, 10); + let n = i64_array_all(2, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -773,7 +775,7 @@ mod tests { let v: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); let lo = f64_array(&[0.0, 0.0, 0.0]); let hi = f64_array(&[10.0, 10.0, 10.0]); - let n = i32_array_all(3, 10); + let n = i64_array_all(3, 10); let err = width_bucket_kern(&[v, lo, hi, n]).unwrap_err(); let msg = format!("{err}"); diff --git a/datafusion/spark/src/function/string/base64.rs b/datafusion/spark/src/function/string/base64.rs new file mode 100644 index 0000000000000..a171d4823b0fa --- /dev/null +++ b/datafusion/spark/src/function/string/base64.rs @@ -0,0 +1,183 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::{Field, FieldRef}; +use datafusion_common::types::{NativeType, logical_string}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, exec_err, internal_err}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{Coercion, Expr, ReturnFieldArgs, TypeSignatureClass, lit}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions::expr_fn::{decode, encode}; + +/// Apache Spark base64 uses padded base64 encoding. +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkBase64 { + signature: Signature, +} + +impl Default for SparkBase64 { + fn default() -> Self { + Self::new() + } +} + +impl SparkBase64 { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![Coercion::new_implicit( + TypeSignatureClass::Binary, + vec![TypeSignatureClass::Native(logical_string())], + NativeType::Binary, + )], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkBase64 { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "base64" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_type should not be called for {}", self.name()) + } + + fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result { + let [bin] = take_function_args(self.name(), args.arg_fields)?; + let return_type = match bin.data_type() { + DataType::LargeBinary => DataType::LargeUtf8, + _ => DataType::Utf8, + }; + Ok(Arc::new(Field::new( + self.name(), + return_type, + bin.is_nullable(), + ))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + exec_err!( + "invoke should not be called on a simplified {} function", + self.name() + ) + } + + fn simplify( + &self, + args: Vec, + _info: &SimplifyContext, + ) -> Result { + let [bin] = take_function_args(self.name(), args)?; + Ok(ExprSimplifyResult::Simplified(encode( + bin, + lit("base64pad"), + ))) + } +} + +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkUnBase64 { + signature: Signature, +} + +impl Default for SparkUnBase64 { + fn default() -> Self { + Self::new() + } +} + +impl SparkUnBase64 { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![Coercion::new_implicit( + TypeSignatureClass::Binary, + vec![TypeSignatureClass::Native(logical_string())], + NativeType::Binary, + )], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkUnBase64 { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "unbase64" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_type should not be called for {}", self.name()) + } + + fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result { + let [str] = take_function_args(self.name(), args.arg_fields)?; + let return_type = match str.data_type() { + DataType::LargeBinary => DataType::LargeBinary, + _ => DataType::Binary, + }; + Ok(Arc::new(Field::new( + self.name(), + return_type, + str.is_nullable(), + ))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + exec_err!("{} should have been simplified", self.name()) + } + + fn simplify( + &self, + args: Vec, + _info: &SimplifyContext, + ) -> Result { + let [bin] = take_function_args(self.name(), args)?; + Ok(ExprSimplifyResult::Simplified(decode( + bin, + lit("base64pad"), + ))) + } +} diff --git a/datafusion/spark/src/function/string/concat.rs b/datafusion/spark/src/function/string/concat.rs index f3dae22866c23..b2073690fc446 100644 --- a/datafusion/spark/src/function/string/concat.rs +++ b/datafusion/spark/src/function/string/concat.rs @@ -20,8 +20,7 @@ use datafusion_common::arrow::datatypes::FieldRef; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::ReturnFieldArgs; use datafusion_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, - Volatility, + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; use datafusion_functions::string::concat::ConcatFunc; use std::any::Any; @@ -54,10 +53,7 @@ impl Default for SparkConcat { impl SparkConcat { pub fn new() -> Self { Self { - signature: Signature::one_of( - vec![TypeSignature::UserDefined, TypeSignature::Nullary], - Volatility::Immutable, - ), + signature: Signature::user_defined(Volatility::Immutable), } } } @@ -89,10 +85,21 @@ impl ScalarUDFImpl for SparkConcat { ) } fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result { + use DataType::*; + // Spark semantics: concat returns NULL if ANY input is NULL let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); - Ok(Arc::new(Field::new("concat", DataType::Utf8, nullable))) + // Determine return type: Utf8View > LargeUtf8 > Utf8 + let mut dt = &Utf8; + for field in args.arg_fields { + let data_type = field.data_type(); + if data_type == &Utf8View || (data_type == &LargeUtf8 && dt != &Utf8View) { + dt = data_type; + } + } + + Ok(Arc::new(Field::new("concat", dt.clone(), nullable))) } } @@ -110,9 +117,18 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result { // Handle zero-argument case: return empty string if arg_values.is_empty() { - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8( - Some(String::new()), - ))); + let return_type = return_field.data_type(); + return match return_type { + DataType::Utf8View => Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + String::new(), + )))), + DataType::LargeUtf8 => Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8( + Some(String::new()), + ))), + _ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8( + Some(String::new()), + ))), + }; } // Step 1: Check for NULL mask in incoming args @@ -120,7 +136,14 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result { // If all scalars and any is NULL, return NULL immediately if matches!(null_mask, NullMaskResolution::ReturnNull) { - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + let return_type = return_field.data_type(); + return match return_type { + DataType::Utf8View => Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(None))), + DataType::LargeUtf8 => { + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(None))) + } + _ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), + }; } // Step 2: Delegate to DataFusion's concat @@ -181,6 +204,7 @@ mod tests { ); Ok(()) } + #[test] fn test_spark_concat_return_field_non_nullable() -> Result<()> { let func = SparkConcat::new(); diff --git a/datafusion/spark/src/function/string/format_string.rs b/datafusion/spark/src/function/string/format_string.rs index 73de985109b7c..8ab87196fdc61 100644 --- a/datafusion/spark/src/function/string/format_string.rs +++ b/datafusion/spark/src/function/string/format_string.rs @@ -1431,7 +1431,7 @@ impl ConversionSpecifier { let value = "null".to_string(); self.format_string(string, &value) } - _ => exec_err!("Invalid scalar value: {:?}", value), + _ => exec_err!("Invalid scalar value: {value}"), } } diff --git a/datafusion/spark/src/function/string/mod.rs b/datafusion/spark/src/function/string/mod.rs index 369d381a9c35b..8859beca77996 100644 --- a/datafusion/spark/src/function/string/mod.rs +++ b/datafusion/spark/src/function/string/mod.rs @@ -16,6 +16,7 @@ // under the License. pub mod ascii; +pub mod base64; pub mod char; pub mod concat; pub mod elt; @@ -25,12 +26,14 @@ pub mod length; pub mod like; pub mod luhn_check; pub mod space; +pub mod substring; use datafusion_expr::ScalarUDF; use datafusion_functions::make_udf_function; use std::sync::Arc; make_udf_function!(ascii::SparkAscii, ascii); +make_udf_function!(base64::SparkBase64, base64); make_udf_function!(char::CharFunc, char); make_udf_function!(concat::SparkConcat, concat); make_udf_function!(ilike::SparkILike, ilike); @@ -40,6 +43,8 @@ make_udf_function!(like::SparkLike, like); make_udf_function!(luhn_check::SparkLuhnCheck, luhn_check); make_udf_function!(format_string::FormatStringFunc, format_string); make_udf_function!(space::SparkSpace, space); +make_udf_function!(substring::SparkSubstring, substring); +make_udf_function!(base64::SparkUnBase64, unbase64); pub mod expr_fn { use datafusion_functions::export_functions; @@ -49,6 +54,11 @@ pub mod expr_fn { "Returns the ASCII code point of the first character of string.", arg1 )); + export_functions!(( + base64, + "Encodes the input binary `bin` into a base64 string.", + bin + )); export_functions!(( char, "Returns the ASCII character having the binary equivalent to col. If col is larger than 256 the result is equivalent to char(col % 256).", @@ -90,11 +100,22 @@ pub mod expr_fn { strfmt args )); export_functions!((space, "Returns a string consisting of n spaces.", arg1)); + export_functions!(( + substring, + "Returns the substring from string `str` starting at position `pos` with length `length.", + str pos length + )); + export_functions!(( + unbase64, + "Decodes the input string `str` from a base64 string into binary data.", + str + )); } pub fn functions() -> Vec> { vec![ ascii(), + base64(), char(), concat(), elt(), @@ -104,5 +125,7 @@ pub fn functions() -> Vec> { luhn_check(), format_string(), space(), + substring(), + unbase64(), ] } diff --git a/datafusion/spark/src/function/string/substring.rs b/datafusion/spark/src/function/string/substring.rs new file mode 100644 index 0000000000000..524262b12f193 --- /dev/null +++ b/datafusion/spark/src/function/string/substring.rs @@ -0,0 +1,258 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayBuilder, ArrayRef, AsArray, GenericStringBuilder, Int64Array, + OffsetSizeTrait, StringArrayType, StringViewBuilder, +}; +use arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::{Field, FieldRef}; +use datafusion_common::cast::as_int64_array; +use datafusion_common::types::{ + NativeType, logical_int32, logical_int64, logical_string, +}; +use datafusion_common::{Result, exec_err}; +use datafusion_expr::{Coercion, ReturnFieldArgs, TypeSignatureClass}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + Volatility, +}; +use datafusion_functions::unicode::substr::{enable_ascii_fast_path, get_true_start_end}; +use datafusion_functions::utils::make_scalar_function; +use std::any::Any; +use std::sync::Arc; + +/// Spark-compatible `substring` expression +/// +/// +/// Returns the substring from string starting at position pos with length len. +/// Position is 1-indexed. If pos is negative, it counts from the end of the string. +/// Returns NULL if any input is NULL. +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkSubstring { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkSubstring { + fn default() -> Self { + Self::new() + } +} + +impl SparkSubstring { + pub fn new() -> Self { + let string = Coercion::new_exact(TypeSignatureClass::Native(logical_string())); + let int64 = Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![TypeSignatureClass::Native(logical_int32())], + NativeType::Int64, + ); + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![string.clone(), int64.clone()]), + TypeSignature::Coercible(vec![ + string.clone(), + int64.clone(), + int64.clone(), + ]), + ], + Volatility::Immutable, + ) + .with_parameter_names(vec![ + "str".to_string(), + "pos".to_string(), + "length".to_string(), + ]) + .expect("valid parameter names"), + aliases: vec![String::from("substr")], + } + } +} + +impl ScalarUDFImpl for SparkSubstring { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "substring" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_substring, vec![])(&args.args) + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + datafusion_common::internal_err!( + "return_type should not be called for Spark substring" + ) + } + + fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result { + // Spark semantics: substring returns NULL if ANY input is NULL + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + + Ok(Arc::new(Field::new( + "substring", + args.arg_fields[0].data_type().clone(), + nullable, + ))) + } +} + +fn spark_substring(args: &[ArrayRef]) -> Result { + let start_array = as_int64_array(&args[1])?; + let length_array = if args.len() > 2 { + Some(as_int64_array(&args[2])?) + } else { + None + }; + + match args[0].data_type() { + DataType::Utf8 => spark_substring_impl( + &args[0].as_string::(), + start_array, + length_array, + GenericStringBuilder::::new(), + ), + DataType::LargeUtf8 => spark_substring_impl( + &args[0].as_string::(), + start_array, + length_array, + GenericStringBuilder::::new(), + ), + DataType::Utf8View => spark_substring_impl( + &args[0].as_string_view(), + start_array, + length_array, + StringViewBuilder::new(), + ), + other => exec_err!( + "Unsupported data type {other:?} for function spark_substring, expected Utf8View, Utf8 or LargeUtf8." + ), + } +} + +/// Convert Spark's start position to DataFusion's 1-based start position. +/// +/// Spark semantics: +/// - Positive start: 1-based index from beginning +/// - Zero start: treated as 1 +/// - Negative start: counts from end of string +/// +/// Returns the converted 1-based start position for use with `get_true_start_end`. +#[inline] +fn spark_start_to_datafusion_start(start: i64, len: usize) -> i64 { + if start >= 0 { + start.max(1) + } else { + let len_i64 = i64::try_from(len).unwrap_or(i64::MAX); + let start = start.saturating_add(len_i64).saturating_add(1); + start.max(1) + } +} + +trait StringArrayBuilder: ArrayBuilder { + fn append_value(&mut self, val: &str); + fn append_null(&mut self); +} + +impl StringArrayBuilder for GenericStringBuilder { + fn append_value(&mut self, val: &str) { + GenericStringBuilder::append_value(self, val); + } + fn append_null(&mut self) { + GenericStringBuilder::append_null(self); + } +} + +impl StringArrayBuilder for StringViewBuilder { + fn append_value(&mut self, val: &str) { + StringViewBuilder::append_value(self, val); + } + fn append_null(&mut self) { + StringViewBuilder::append_null(self); + } +} + +fn spark_substring_impl<'a, V, B>( + string_array: &V, + start_array: &Int64Array, + length_array: Option<&Int64Array>, + mut builder: B, +) -> Result +where + V: StringArrayType<'a>, + B: StringArrayBuilder, +{ + let is_ascii = enable_ascii_fast_path(string_array, start_array, length_array); + + for i in 0..string_array.len() { + if string_array.is_null(i) || start_array.is_null(i) { + builder.append_null(); + continue; + } + + if let Some(len_arr) = length_array + && len_arr.is_null(i) + { + builder.append_null(); + continue; + } + + let string = string_array.value(i); + let start = start_array.value(i); + let len_opt = length_array.map(|arr| arr.value(i)); + + // Spark: negative length returns empty string + if let Some(len) = len_opt + && len < 0 + { + builder.append_value(""); + continue; + } + + let string_len = if is_ascii { + string.len() + } else { + string.chars().count() + }; + + let adjusted_start = spark_start_to_datafusion_start(start, string_len); + + let (byte_start, byte_end) = get_true_start_end( + string, + adjusted_start, + len_opt.map(|l| l as u64), + is_ascii, + ); + let substr = &string[byte_start..byte_end]; + builder.append_value(substr); + } + + Ok(builder.finish()) +} diff --git a/datafusion/spark/src/function/url/parse_url.rs b/datafusion/spark/src/function/url/parse_url.rs index e82ef28045a33..7beb02f7750ff 100644 --- a/datafusion/spark/src/function/url/parse_url.rs +++ b/datafusion/spark/src/function/url/parse_url.rs @@ -217,7 +217,12 @@ pub fn spark_handled_parse_url( handler_err, ) } - _ => exec_err!("{} expects STRING arguments, got {:?}", "`parse_url`", args), + _ => exec_err!( + "`parse_url` expects STRING arguments, got ({}, {}, {})", + url.data_type(), + part.data_type(), + key.data_type() + ), } } else { // The 'key' argument is omitted, assume all values are null @@ -253,7 +258,11 @@ pub fn spark_handled_parse_url( handler_err, ) } - _ => exec_err!("{} expects STRING arguments, got {:?}", "`parse_url`", args), + _ => exec_err!( + "`parse_url` expects STRING arguments, got ({}, {})", + url.data_type(), + part.data_type() + ), } } } diff --git a/datafusion/spark/src/lib.rs b/datafusion/spark/src/lib.rs index aad3ceed68ce3..9575f560b8d0e 100644 --- a/datafusion/spark/src/lib.rs +++ b/datafusion/spark/src/lib.rs @@ -22,7 +22,6 @@ #![cfg_attr(docsrs, feature(doc_cfg))] // Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] -#![deny(clippy::allow_attributes)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] //! Spark Expression packages for [DataFusion]. @@ -92,9 +91,49 @@ //! let expr = sha2(col("my_data"), lit(256)); //! ``` //! +//! # Example: using the Spark expression planner +//! +//! The [`planner::SparkFunctionPlanner`] provides Spark-compatible expression +//! planning, such as mapping SQL `EXTRACT` expressions to Spark's `date_part` +//! function. To use it, register it with your session context: +//! +//! ```ignore +//! use std::sync::Arc; +//! use datafusion::prelude::SessionContext; +//! use datafusion_spark::planner::SparkFunctionPlanner; +//! +//! let mut ctx = SessionContext::new(); +//! // Register the Spark expression planner +//! ctx.register_expr_planner(Arc::new(SparkFunctionPlanner))?; +//! // Now EXTRACT expressions will use Spark semantics +//! let df = ctx.sql("SELECT EXTRACT(YEAR FROM timestamp_col) FROM my_table").await?; +//! ``` +//! //![`Expr`]: datafusion_expr::Expr +//! +//! # Example: enabling Apache Spark features with SessionStateBuilder +//! +//! The recommended way to enable Apache Spark compatibility is to use the +//! `SessionStateBuilderSpark` extension trait. This registers all +//! Apache Spark functions (scalar, aggregate, window, and table) as well as the Apache Spark +//! expression planner. +//! +//! Enable the `core` feature in your `Cargo.toml`: +//! ```toml +//! datafusion-spark = { version = "X", features = ["core"] } +//! ``` +//! +//! Then use the extension trait - see [`SessionStateBuilderSpark::with_spark_features`] +//! for an example. pub mod function; +pub mod planner; + +#[cfg(feature = "core")] +mod session_state; + +#[cfg(feature = "core")] +pub use session_state::SessionStateBuilderSpark; use datafusion_catalog::TableFunction; use datafusion_common::Result; diff --git a/datafusion/spark/src/planner.rs b/datafusion/spark/src/planner.rs new file mode 100644 index 0000000000000..2dafbb1f9a570 --- /dev/null +++ b/datafusion/spark/src/planner.rs @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::Expr; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::planner::{ExprPlanner, PlannerResult}; + +#[derive(Default, Debug)] +pub struct SparkFunctionPlanner; + +impl ExprPlanner for SparkFunctionPlanner { + fn plan_extract( + &self, + args: Vec, + ) -> datafusion_common::Result>> { + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf(crate::function::datetime::date_part(), args), + ))) + } + + fn plan_substring( + &self, + args: Vec, + ) -> datafusion_common::Result>> { + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf(crate::function::string::substring(), args), + ))) + } +} diff --git a/datafusion/spark/src/session_state.rs b/datafusion/spark/src/session_state.rs new file mode 100644 index 0000000000000..e39de3a5888ea --- /dev/null +++ b/datafusion/spark/src/session_state.rs @@ -0,0 +1,111 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::sync::Arc; + +use datafusion::execution::SessionStateBuilder; + +use crate::planner::SparkFunctionPlanner; +use crate::{ + all_default_aggregate_functions, all_default_scalar_functions, + all_default_table_functions, all_default_window_functions, +}; + +/// Extension trait for adding Apache Spark features to [`SessionStateBuilder`]. +/// +/// This trait provides a convenient way to register all Apache Spark-compatible +/// functions and planners with a DataFusion session. +/// +/// # Example +/// +/// ```rust +/// use datafusion::execution::SessionStateBuilder; +/// use datafusion_spark::SessionStateBuilderSpark; +/// +/// // Create a SessionState with Apache Spark features enabled +/// // note: the order matters here, `with_spark_features` should be +/// // called after `with_default_features` to overwrite any existing functions +/// let state = SessionStateBuilder::new() +/// .with_default_features() +/// .with_spark_features() +/// .build(); +/// ``` +pub trait SessionStateBuilderSpark { + /// Adds all expr_planners, scalar, aggregate, window and table functions + /// compatible with Apache Spark. + /// + /// Note: This overwrites any previously registered items with the same name. + fn with_spark_features(self) -> Self; +} + +impl SessionStateBuilderSpark for SessionStateBuilder { + fn with_spark_features(mut self) -> Self { + self.expr_planners() + .get_or_insert_with(Vec::new) + // planners are evaluated in order of insertion. Push Apache Spark function planner to the front + // to take precedence over others + .insert(0, Arc::new(SparkFunctionPlanner)); + + self.scalar_functions() + .get_or_insert_with(Vec::new) + .extend(all_default_scalar_functions()); + + self.aggregate_functions() + .get_or_insert_with(Vec::new) + .extend(all_default_aggregate_functions()); + + self.window_functions() + .get_or_insert_with(Vec::new) + .extend(all_default_window_functions()); + + self.table_functions() + .get_or_insert_with(HashMap::new) + .extend( + all_default_table_functions() + .into_iter() + .map(|f| (f.name().to_string(), f)), + ); + + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_session_state_with_spark_features() { + let state = SessionStateBuilder::new().with_spark_features().build(); + + assert!( + state.scalar_functions().contains_key("sha2"), + "Apache Spark scalar function 'sha2' should be registered" + ); + + assert!( + state.aggregate_functions().contains_key("try_sum"), + "Apache Spark aggregate function 'try_sum' should be registered" + ); + + assert!( + !state.expr_planners().is_empty(), + "Apache Spark expr planners should be registered" + ); + } +} diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index a814292a3d71d..b7338cb764d77 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -56,6 +56,7 @@ bigdecimal = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true, features = ["sql"] } datafusion-expr = { workspace = true, features = ["sql"] } +datafusion-functions-nested = { workspace = true, features = ["sql"] } indexmap = { workspace = true } log = { workspace = true } recursive = { workspace = true, optional = true } diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index 34fbe2edf8dd9..cca09df0db027 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -76,15 +76,16 @@ impl SqlToRel<'_, S> { } // Check the outer query schema - if let Some(outer) = planner_context.outer_query_schema() - && let Ok((qualifier, field)) = + for outer in planner_context.outer_schemas_iter() { + if let Ok((qualifier, field)) = outer.qualified_field_with_unqualified_name(normalize_ident.as_str()) - { - // Found an exact match on a qualified name in the outer plan schema, so this is an outer reference column - return Ok(Expr::OuterReferenceColumn( - Arc::clone(field), - Column::from((qualifier, field)), - )); + { + // Found an exact match on a qualified name in the outer plan schema, so this is an outer reference column + return Ok(Expr::OuterReferenceColumn( + Arc::clone(field), + Column::from((qualifier, field)), + )); + } } // Default case @@ -172,14 +173,14 @@ impl SqlToRel<'_, S> { not_impl_err!("compound identifier: {ids:?}") } else { // Check the outer_query_schema and try to find a match - if let Some(outer) = planner_context.outer_query_schema() { + for outer in planner_context.outer_schemas_iter() { let search_result = search_dfschema(&ids, outer); - match search_result { + let result = match search_result { // Found matching field with spare identifier(s) for nested field(s) in structure Some((field, qualifier, nested_names)) if !nested_names.is_empty() => { - // TODO: remove when can support nested identifiers for OuterReferenceColumn + // TODO: remove this when we have support for nested identifiers for OuterReferenceColumn not_impl_err!( "Nested identifiers are not yet supported for OuterReferenceColumn {}", Column::from((qualifier, field)) @@ -195,26 +196,20 @@ impl SqlToRel<'_, S> { )) } // Found no matching field, will return a default - None => { - let s = &ids[0..ids.len()]; - // safe unwrap as s can never be empty or exceed the bounds - let (relation, column_name) = - form_identifier(s).unwrap(); - Ok(Expr::Column(Column::new(relation, column_name))) - } - } - } else { - let s = &ids[0..ids.len()]; - // Safe unwrap as s can never be empty or exceed the bounds - let (relation, column_name) = form_identifier(s).unwrap(); - let mut column = Column::new(relation, column_name); - if self.options.collect_spans - && let Some(span) = ids_span - { - column.spans_mut().add_span(span); - } - Ok(Expr::Column(column)) + None => continue, + }; + return result; + } + // Safe unwrap as column name can never be empty or exceed the bounds + let (relation, column_name) = + form_identifier(&ids[0..ids.len()]).unwrap(); + let mut column = Column::new(relation, column_name); + if self.options.collect_spans + && let Some(span) = ids_span + { + column.spans_mut().add_span(span); } + Ok(Expr::Column(column)) } } } diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index fcd7d6376d21c..dbf2ce67732ec 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -32,6 +32,7 @@ use datafusion_common::{ }; use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::expr::SetQuantifier; use datafusion_expr::expr::{InList, WildcardOptions}; use datafusion_expr::{ Between, BinaryExpr, Cast, Expr, ExprSchemable, GetFieldAccess, Like, Literal, @@ -39,6 +40,7 @@ use datafusion_expr::{ }; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; +use datafusion_functions_nested::expr_fn::array_has; mod binary_op; mod function; @@ -594,32 +596,44 @@ impl SqlToRel<'_, S> { // ANY/SOME are equivalent, this field specifies which the user // specified but it doesn't affect the plan so ignore the field is_some: _, - } => { - let mut binary_expr = RawBinaryExpr { - op: compare_op, - left: self.sql_expr_to_logical_expr( - *left, - schema, - planner_context, - )?, - right: self.sql_expr_to_logical_expr( - *right, - schema, - planner_context, - )?, - }; - for planner in self.context_provider.get_expr_planners() { - match planner.plan_any(binary_expr)? { - PlannerResult::Planned(expr) => { - return Ok(expr); - } - PlannerResult::Original(expr) => { - binary_expr = expr; - } + } => match *right { + SQLExpr::Subquery(subquery) => self.parse_set_comparison_subquery( + *left, + *subquery, + &compare_op, + SetQuantifier::Any, + schema, + planner_context, + ), + _ => { + if compare_op != BinaryOperator::Eq { + plan_err!( + "Unsupported AnyOp: '{compare_op}', only '=' is supported" + ) + } else { + let left_expr = + self.sql_to_expr(*left, schema, planner_context)?; + let right_expr = + self.sql_to_expr(*right, schema, planner_context)?; + Ok(array_has(right_expr, left_expr)) } } - not_impl_err!("AnyOp not supported by ExprPlanner: {binary_expr:?}") - } + }, + SQLExpr::AllOp { + left, + compare_op, + right, + } => match *right { + SQLExpr::Subquery(subquery) => self.parse_set_comparison_subquery( + *left, + *subquery, + &compare_op, + SetQuantifier::All, + schema, + planner_context, + ), + _ => not_impl_err!("ALL only supports subquery comparison currently"), + }, #[expect(deprecated)] SQLExpr::Wildcard(_token) => Ok(Expr::Wildcard { qualifier: None, diff --git a/datafusion/sql/src/expr/subquery.rs b/datafusion/sql/src/expr/subquery.rs index ec34ff3d53426..662c44f6f2620 100644 --- a/datafusion/sql/src/expr/subquery.rs +++ b/datafusion/sql/src/expr/subquery.rs @@ -17,10 +17,10 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{DFSchema, Diagnostic, Result, Span, Spans, plan_err}; -use datafusion_expr::expr::{Exists, InSubquery}; +use datafusion_expr::expr::{Exists, InSubquery, SetComparison, SetQuantifier}; use datafusion_expr::{Expr, LogicalPlan, Subquery}; use sqlparser::ast::Expr as SQLExpr; -use sqlparser::ast::{Query, SelectItem, SetExpr}; +use sqlparser::ast::{BinaryOperator, Query, SelectItem, SetExpr}; use std::sync::Arc; impl SqlToRel<'_, S> { @@ -31,11 +31,10 @@ impl SqlToRel<'_, S> { input_schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let old_outer_query_schema = - planner_context.set_outer_query_schema(Some(input_schema.clone().into())); + planner_context.append_outer_query_schema(input_schema.clone().into()); let sub_plan = self.query_to_plan(subquery, planner_context)?; let outer_ref_columns = sub_plan.all_out_ref_exprs(); - planner_context.set_outer_query_schema(old_outer_query_schema); + planner_context.pop_outer_query_schema(); Ok(Expr::Exists(Exists { subquery: Subquery { subquery: Arc::new(sub_plan), @@ -54,8 +53,7 @@ impl SqlToRel<'_, S> { input_schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let old_outer_query_schema = - planner_context.set_outer_query_schema(Some(input_schema.clone().into())); + planner_context.append_outer_query_schema(Arc::new(input_schema.clone())); let mut spans = Spans::new(); if let SetExpr::Select(select) = &subquery.body.as_ref() { @@ -70,7 +68,7 @@ impl SqlToRel<'_, S> { let sub_plan = self.query_to_plan(subquery, planner_context)?; let outer_ref_columns = sub_plan.all_out_ref_exprs(); - planner_context.set_outer_query_schema(old_outer_query_schema); + planner_context.pop_outer_query_schema(); self.validate_single_column( &sub_plan, @@ -98,8 +96,7 @@ impl SqlToRel<'_, S> { input_schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let old_outer_query_schema = - planner_context.set_outer_query_schema(Some(input_schema.clone().into())); + planner_context.append_outer_query_schema(Arc::new(input_schema.clone())); let mut spans = Spans::new(); if let SetExpr::Select(select) = subquery.body.as_ref() { for item in &select.projection { @@ -112,7 +109,7 @@ impl SqlToRel<'_, S> { } let sub_plan = self.query_to_plan(subquery, planner_context)?; let outer_ref_columns = sub_plan.all_out_ref_exprs(); - planner_context.set_outer_query_schema(old_outer_query_schema); + planner_context.pop_outer_query_schema(); self.validate_single_column( &sub_plan, @@ -162,4 +159,50 @@ impl SqlToRel<'_, S> { diagnostic.add_help(help_message, None); diagnostic } + + pub(super) fn parse_set_comparison_subquery( + &self, + left_expr: SQLExpr, + subquery: Query, + compare_op: &BinaryOperator, + quantifier: SetQuantifier, + input_schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + planner_context.append_outer_query_schema(Arc::new(input_schema.clone())); + + let mut spans = Spans::new(); + if let SetExpr::Select(select) = subquery.body.as_ref() { + for item in &select.projection { + if let SelectItem::ExprWithAlias { alias, .. } = item + && let Some(span) = Span::try_from_sqlparser_span(alias.span) + { + spans.add_span(span); + } + } + } + + let sub_plan = self.query_to_plan(subquery, planner_context)?; + let outer_ref_columns = sub_plan.all_out_ref_exprs(); + planner_context.pop_outer_query_schema(); + + self.validate_single_column( + &sub_plan, + &spans, + "Too many columns! The subquery should only return one column", + "Select only one column in the subquery", + )?; + + let expr_obj = self.sql_to_expr(left_expr, input_schema, planner_context)?; + Ok(Expr::SetComparison(SetComparison::new( + Box::new(expr_obj), + Subquery { + subquery: Arc::new(sub_plan), + outer_ref_columns, + spans, + }, + self.parse_sql_binary_op(compare_op)?, + quantifier, + ))) + } } diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs index b21eb52920ab5..7fef670933f9a 100644 --- a/datafusion/sql/src/lib.rs +++ b/datafusion/sql/src/lib.rs @@ -23,7 +23,6 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] -#![deny(clippy::allow_attributes)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] //! This crate provides: diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index 27db2b0f97579..1ecf90b7947c3 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -363,28 +363,49 @@ const DEFAULT_DIALECT: GenericDialect = GenericDialect {}; /// # Ok(()) /// # } /// ``` -pub struct DFParserBuilder<'a> { - /// The SQL string to parse - sql: &'a str, +pub struct DFParserBuilder<'a, 'b> { + /// Parser input: either raw SQL or tokens + input: ParserInput<'a>, /// The Dialect to use (defaults to [`GenericDialect`] - dialect: &'a dyn Dialect, + dialect: &'b dyn Dialect, /// The recursion limit while parsing recursion_limit: usize, } -impl<'a> DFParserBuilder<'a> { +/// Describes a possible input for parser +pub enum ParserInput<'a> { + /// Raw SQL. Tokenization will be performed automatically as a + /// part of [`DFParserBuilder::build`] + Sql(&'a str), + /// Tokens + Tokens(Vec), +} + +impl<'a> From<&'a str> for ParserInput<'a> { + fn from(sql: &'a str) -> Self { + Self::Sql(sql) + } +} + +impl From> for ParserInput<'static> { + fn from(tokens: Vec) -> Self { + Self::Tokens(tokens) + } +} + +impl<'a, 'b> DFParserBuilder<'a, 'b> { /// Create a new parser builder for the specified tokens using the /// [`GenericDialect`]. - pub fn new(sql: &'a str) -> Self { + pub fn new(input: impl Into>) -> Self { Self { - sql, + input: input.into(), dialect: &DEFAULT_DIALECT, recursion_limit: DEFAULT_RECURSION_LIMIT, } } /// Adjust the parser builder's dialect. Defaults to [`GenericDialect`] - pub fn with_dialect(mut self, dialect: &'a dyn Dialect) -> Self { + pub fn with_dialect(mut self, dialect: &'b dyn Dialect) -> Self { self.dialect = dialect; self } @@ -395,12 +416,18 @@ impl<'a> DFParserBuilder<'a> { self } - pub fn build(self) -> Result, DataFusionError> { - let mut tokenizer = Tokenizer::new(self.dialect, self.sql); - // Convert TokenizerError -> ParserError - let tokens = tokenizer - .tokenize_with_location() - .map_err(ParserError::from)?; + /// Build resulting parser + pub fn build(self) -> Result, DataFusionError> { + let tokens = match self.input { + ParserInput::Tokens(tokens) => tokens, + ParserInput::Sql(sql) => { + let mut tokenizer = Tokenizer::new(self.dialect, sql); + // Convert TokenizerError -> ParserError + tokenizer + .tokenize_with_location() + .map_err(ParserError::from)? + } + }; Ok(DFParser { parser: Parser::new(self.dialect) @@ -658,7 +685,7 @@ impl<'a> DFParser<'a> { } } } else { - let token = self.parser.next_token(); + let token = self.parser.peek_token(); if token == Token::EOF || token == Token::SemiColon { break; } else { @@ -1079,7 +1106,7 @@ impl<'a> DFParser<'a> { } } } else { - let token = self.parser.next_token(); + let token = self.parser.peek_token(); if token == Token::EOF || token == Token::SemiColon { break; } else { @@ -1162,7 +1189,7 @@ mod tests { BinaryOperator, DataType, ExactNumberInfo, Expr, Ident, ValueWithSpan, }; use sqlparser::dialect::SnowflakeDialect; - use sqlparser::tokenizer::Span; + use sqlparser::tokenizer::{Location, Span, Whitespace}; fn expect_parse_ok(sql: &str, expected: Statement) -> Result<(), DataFusionError> { let statements = DFParser::parse_sql(sql)?; @@ -2026,6 +2053,78 @@ mod tests { ); } + #[test] + fn test_multistatement() { + let sql = "COPY foo TO bar STORED AS CSV; \ + CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv'; \ + RESET var;"; + let statements = DFParser::parse_sql(sql).unwrap(); + assert_eq!( + statements, + vec![ + Statement::CopyTo(CopyToStatement { + source: object_name("foo"), + target: "bar".to_string(), + partitioned_by: vec![], + stored_as: Some("CSV".to_owned()), + options: vec![], + }), + { + let name = ObjectName::from(vec![Ident::from("t")]); + let display = None; + Statement::CreateExternalTable(CreateExternalTable { + name: name.clone(), + columns: vec![make_column_def("c1", DataType::Int(display))], + file_type: "CSV".to_string(), + location: "foo.csv".into(), + table_partition_cols: vec![], + order_exprs: vec![], + if_not_exists: false, + or_replace: false, + temporary: false, + unbounded: false, + options: vec![], + constraints: vec![], + }) + }, + { + let name = ObjectName::from(vec![Ident::from("var")]); + Statement::Reset(ResetStatement::Variable(name)) + } + ] + ); + } + + #[test] + fn test_custom_tokens() { + // Span mock. + let span = Span { + start: Location { line: 0, column: 0 }, + end: Location { line: 0, column: 0 }, + }; + let tokens = vec![ + TokenWithSpan { + token: Token::make_keyword("SELECT"), + span, + }, + TokenWithSpan { + token: Token::Whitespace(Whitespace::Space), + span, + }, + TokenWithSpan { + token: Token::Placeholder("1".to_string()), + span, + }, + ]; + + let statements = DFParserBuilder::new(tokens) + .build() + .unwrap() + .parse_statements() + .unwrap(); + assert_eq!(statements.len(), 1); + } + fn expect_parse_expr_ok(sql: &str, expected: ExprWithAlias) { let expr = DFParser::parse_sql_into_expr(sql).unwrap(); assert_eq!(expr, expected, "actual:\n{expr:#?}"); diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index eb798b71e4558..dd63cfce5e4a2 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -261,8 +261,10 @@ pub struct PlannerContext { /// Map of CTE name to logical plan of the WITH clause. /// Use `Arc` to allow cheap cloning ctes: HashMap>, - /// The query schema of the outer query plan, used to resolve the columns in subquery - outer_query_schema: Option, + + /// The queries schemas of outer query relations, used to resolve the outer referenced + /// columns in subquery (recursive aware) + outer_queries_schemas_stack: Vec, /// The joined schemas of all FROM clauses planned so far. When planning LATERAL /// FROM clauses, this should become a suffix of the `outer_query_schema`. outer_from_schema: Option, @@ -282,7 +284,7 @@ impl PlannerContext { Self { prepare_param_data_types: Arc::new(vec![]), ctes: HashMap::new(), - outer_query_schema: None, + outer_queries_schemas_stack: vec![], outer_from_schema: None, create_table_schema: None, } @@ -297,19 +299,42 @@ impl PlannerContext { self } - // Return a reference to the outer query's schema - pub fn outer_query_schema(&self) -> Option<&DFSchema> { - self.outer_query_schema.as_ref().map(|s| s.as_ref()) + /// Return the stack of outer relations' schemas, the outer most + /// relation are at the first entry + pub fn outer_queries_schemas(&self) -> &[DFSchemaRef] { + &self.outer_queries_schemas_stack + } + + /// Return an iterator of the subquery relations' schemas, innermost + /// relation is returned first. + /// + /// This order corresponds to the order of resolution when looking up column + /// references in subqueries, which start from the innermost relation and + /// then look up the outer relations one by one until a match is found or no + /// more outer relation exist. + /// + /// NOTE this is *REVERSED* order of [`Self::outer_queries_schemas`] + /// + /// This is useful to resolve the column reference in the subquery by + /// looking up the outer query schemas one by one. + pub fn outer_schemas_iter(&self) -> impl Iterator { + self.outer_queries_schemas_stack.iter().rev() } /// Sets the outer query schema, returning the existing one, if /// any - pub fn set_outer_query_schema( - &mut self, - mut schema: Option, - ) -> Option { - std::mem::swap(&mut self.outer_query_schema, &mut schema); - schema + pub fn append_outer_query_schema(&mut self, schema: DFSchemaRef) { + self.outer_queries_schemas_stack.push(schema); + } + + /// The schema of the adjacent outer relation + pub fn latest_outer_query_schema(&self) -> Option<&DFSchemaRef> { + self.outer_queries_schemas_stack.last() + } + + /// Remove the schema of the adjacent outer relation + pub fn pop_outer_query_schema(&mut self) -> Option { + self.outer_queries_schemas_stack.pop() } pub fn set_table_schema( @@ -823,7 +848,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { | SQLDataType::HugeInt | SQLDataType::UHugeInt | SQLDataType::UBigInt - | SQLDataType::TimestampNtz + | SQLDataType::TimestampNtz{..} | SQLDataType::NamedTable { .. } | SQLDataType::TsVector | SQLDataType::TsQuery diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index eba48a2401c38..1b7bb856a592b 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -170,6 +170,7 @@ impl SqlToRel<'_, S> { name: alias, // Apply to all fields columns: vec![], + explicit: true, }, ), PipeOperator::Union { diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index 3115d8dfffbd2..6558763ca4e42 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -93,7 +93,7 @@ impl SqlToRel<'_, S> { match self.create_extension_relation(relation, planner_context)? { RelationPlanning::Planned(planned) => planned, RelationPlanning::Original(original) => { - self.create_default_relation(original, planner_context)? + Box::new(self.create_default_relation(*original, planner_context)?) } }; @@ -112,7 +112,7 @@ impl SqlToRel<'_, S> { ) -> Result { let planners = self.context_provider.get_relation_planners(); if planners.is_empty() { - return Ok(RelationPlanning::Original(relation)); + return Ok(RelationPlanning::Original(Box::new(relation))); } let mut current_relation = relation; @@ -127,12 +127,12 @@ impl SqlToRel<'_, S> { return Ok(RelationPlanning::Planned(planned)); } RelationPlanning::Original(original) => { - current_relation = original; + current_relation = *original; } } } - Ok(RelationPlanning::Original(current_relation)) + Ok(RelationPlanning::Original(Box::new(current_relation))) } fn create_default_relation( @@ -262,9 +262,10 @@ impl SqlToRel<'_, S> { } => { let tbl_func_ref = self.object_name_to_table_reference(name)?; let schema = planner_context - .outer_query_schema() + .outer_queries_schemas() + .last() .cloned() - .unwrap_or_else(DFSchema::empty); + .unwrap_or_else(|| Arc::new(DFSchema::empty())); let func_args = args .into_iter() .map(|arg| match arg { @@ -310,20 +311,24 @@ impl SqlToRel<'_, S> { let old_from_schema = planner_context .set_outer_from_schema(None) .unwrap_or_else(|| Arc::new(DFSchema::empty())); - let new_query_schema = match planner_context.outer_query_schema() { - Some(old_query_schema) => { + let outer_query_schema = planner_context.pop_outer_query_schema(); + let new_query_schema = match outer_query_schema { + Some(ref old_query_schema) => { let mut new_query_schema = old_from_schema.as_ref().clone(); - new_query_schema.merge(old_query_schema); - Some(Arc::new(new_query_schema)) + new_query_schema.merge(old_query_schema.as_ref()); + Arc::new(new_query_schema) } - None => Some(Arc::clone(&old_from_schema)), + None => Arc::clone(&old_from_schema), }; - let old_query_schema = planner_context.set_outer_query_schema(new_query_schema); + planner_context.append_outer_query_schema(new_query_schema); let plan = self.create_relation(subquery, planner_context)?; let outer_ref_columns = plan.all_out_ref_exprs(); - planner_context.set_outer_query_schema(old_query_schema); + planner_context.pop_outer_query_schema(); + if let Some(schema) = outer_query_schema { + planner_context.append_outer_query_schema(schema); + } planner_context.set_outer_from_schema(Some(old_from_schema)); // We can omit the subquery wrapper if there are no columns diff --git a/datafusion/sql/src/resolve.rs b/datafusion/sql/src/resolve.rs index 148e886161fcb..955dbb86602a3 100644 --- a/datafusion/sql/src/resolve.rs +++ b/datafusion/sql/src/resolve.rs @@ -15,10 +15,12 @@ // specific language governing permissions and limitations // under the License. -use crate::TableReference; use std::collections::BTreeSet; use std::ops::ControlFlow; +use datafusion_common::{DataFusionError, Result}; + +use crate::TableReference; use crate::parser::{CopyToSource, CopyToStatement, Statement as DFStatement}; use crate::planner::object_name_to_table_reference; use sqlparser::ast::*; @@ -45,27 +47,40 @@ const INFORMATION_SCHEMA_TABLES: &[&str] = &[ PARAMETERS, ]; +// Collect table/CTE references as `TableReference`s and normalize them during traversal. +// This avoids a second normalization/conversion pass after visiting the AST. struct RelationVisitor { - relations: BTreeSet, - all_ctes: BTreeSet, - ctes_in_scope: Vec, + relations: BTreeSet, + all_ctes: BTreeSet, + ctes_in_scope: Vec, + enable_ident_normalization: bool, } impl RelationVisitor { /// Record the reference to `relation`, if it's not a CTE reference. - fn insert_relation(&mut self, relation: &ObjectName) { - if !self.relations.contains(relation) && !self.ctes_in_scope.contains(relation) { - self.relations.insert(relation.clone()); + fn insert_relation(&mut self, relation: &ObjectName) -> ControlFlow { + match object_name_to_table_reference( + relation.clone(), + self.enable_ident_normalization, + ) { + Ok(relation) => { + if !self.relations.contains(&relation) + && !self.ctes_in_scope.contains(&relation) + { + self.relations.insert(relation); + } + ControlFlow::Continue(()) + } + Err(e) => ControlFlow::Break(e), } } } impl Visitor for RelationVisitor { - type Break = (); + type Break = DataFusionError; - fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<()> { - self.insert_relation(relation); - ControlFlow::Continue(()) + fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow { + self.insert_relation(relation) } fn pre_visit_query(&mut self, q: &Query) -> ControlFlow { @@ -78,10 +93,16 @@ impl Visitor for RelationVisitor { if !with.recursive { // This is a bit hackish as the CTE will be visited again as part of visiting `q`, // but thankfully `insert_relation` is idempotent. - let _ = cte.visit(self); + cte.visit(self)?; + } + let cte_name = ObjectName::from(vec![cte.alias.name.clone()]); + match object_name_to_table_reference( + cte_name, + self.enable_ident_normalization, + ) { + Ok(cte_ref) => self.ctes_in_scope.push(cte_ref), + Err(e) => return ControlFlow::Break(e), } - self.ctes_in_scope - .push(ObjectName::from(vec![cte.alias.name.clone()])); } } ControlFlow::Continue(()) @@ -97,13 +118,13 @@ impl Visitor for RelationVisitor { ControlFlow::Continue(()) } - fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<()> { + fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow { if let Statement::ShowCreate { obj_type: ShowCreateObject::Table | ShowCreateObject::View, obj_name, } = statement { - self.insert_relation(obj_name) + self.insert_relation(obj_name)?; } // SHOW statements will later be rewritten into a SELECT from the information_schema @@ -120,35 +141,53 @@ impl Visitor for RelationVisitor { ); if requires_information_schema { for s in INFORMATION_SCHEMA_TABLES { - self.relations.insert(ObjectName::from(vec![ + // Information schema references are synthesized here, so convert directly. + let obj = ObjectName::from(vec![ Ident::new(INFORMATION_SCHEMA), Ident::new(*s), - ])); + ]); + match object_name_to_table_reference(obj, self.enable_ident_normalization) + { + Ok(tbl_ref) => { + self.relations.insert(tbl_ref); + } + Err(e) => return ControlFlow::Break(e), + } } } ControlFlow::Continue(()) } } -fn visit_statement(statement: &DFStatement, visitor: &mut RelationVisitor) { +fn control_flow_to_result(flow: ControlFlow) -> Result<()> { + match flow { + ControlFlow::Continue(()) => Ok(()), + ControlFlow::Break(err) => Err(err), + } +} + +fn visit_statement(statement: &DFStatement, visitor: &mut RelationVisitor) -> Result<()> { match statement { DFStatement::Statement(s) => { - let _ = s.as_ref().visit(visitor); + control_flow_to_result(s.as_ref().visit(visitor))?; } DFStatement::CreateExternalTable(table) => { - visitor.relations.insert(table.name.clone()); + control_flow_to_result(visitor.insert_relation(&table.name))?; } DFStatement::CopyTo(CopyToStatement { source, .. }) => match source { CopyToSource::Relation(table_name) => { - visitor.insert_relation(table_name); + control_flow_to_result(visitor.insert_relation(table_name))?; } CopyToSource::Query(query) => { - let _ = query.visit(visitor); + control_flow_to_result(query.visit(visitor))?; } }, - DFStatement::Explain(explain) => visit_statement(&explain.statement, visitor), + DFStatement::Explain(explain) => { + visit_statement(&explain.statement, visitor)?; + } DFStatement::Reset(_) => {} } + Ok(()) } /// Collects all tables and views referenced in the SQL statement. CTEs are collected separately. @@ -188,26 +227,20 @@ fn visit_statement(statement: &DFStatement, visitor: &mut RelationVisitor) { pub fn resolve_table_references( statement: &crate::parser::Statement, enable_ident_normalization: bool, -) -> datafusion_common::Result<(Vec, Vec)> { +) -> Result<(Vec, Vec)> { let mut visitor = RelationVisitor { relations: BTreeSet::new(), all_ctes: BTreeSet::new(), ctes_in_scope: vec![], + enable_ident_normalization, }; - visit_statement(statement, &mut visitor); - - let table_refs = visitor - .relations - .into_iter() - .map(|x| object_name_to_table_reference(x, enable_ident_normalization)) - .collect::>()?; - let ctes = visitor - .all_ctes - .into_iter() - .map(|x| object_name_to_table_reference(x, enable_ident_normalization)) - .collect::>()?; - Ok((table_refs, ctes)) + visit_statement(statement, &mut visitor)?; + + Ok(( + visitor.relations.into_iter().collect(), + visitor.all_ctes.into_iter().collect(), + )) } #[cfg(test)] @@ -270,4 +303,57 @@ mod tests { assert_eq!(ctes.len(), 1); assert_eq!(ctes[0].to_string(), "nodes"); } + + #[test] + fn resolve_table_references_cte_with_quoted_reference() { + use crate::parser::DFParser; + + let query = r#"with barbaz as (select 1) select * from "barbaz""#; + let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap(); + let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap(); + assert_eq!(ctes.len(), 1); + assert_eq!(ctes[0].to_string(), "barbaz"); + // Quoted reference should still resolve to the CTE when normalization is on + assert_eq!(table_refs.len(), 0); + } + + #[test] + fn resolve_table_references_cte_with_quoted_reference_normalization_off() { + use crate::parser::DFParser; + + let query = r#"with barbaz as (select 1) select * from "barbaz""#; + let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap(); + let (table_refs, ctes) = resolve_table_references(&statement, false).unwrap(); + assert_eq!(ctes.len(), 1); + assert_eq!(ctes[0].to_string(), "barbaz"); + // Even with normalization off, quoted reference matches same-case CTE name + assert_eq!(table_refs.len(), 0); + } + + #[test] + fn resolve_table_references_cte_with_quoted_reference_uppercase_normalization_on() { + use crate::parser::DFParser; + + let query = r#"with FOObar as (select 1) select * from "FOObar""#; + let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap(); + let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap(); + // CTE name is normalized to lowercase, quoted reference preserves case, so they differ + assert_eq!(ctes.len(), 1); + assert_eq!(ctes[0].to_string(), "foobar"); + assert_eq!(table_refs.len(), 1); + assert_eq!(table_refs[0].to_string(), "FOObar"); + } + + #[test] + fn resolve_table_references_cte_with_quoted_reference_uppercase_normalization_off() { + use crate::parser::DFParser; + + let query = r#"with FOObar as (select 1) select * from "FOObar""#; + let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap(); + let (table_refs, ctes) = resolve_table_references(&statement, false).unwrap(); + // Without normalization, cases match exactly, so quoted reference resolves to the CTE + assert_eq!(ctes.len(), 1); + assert_eq!(ctes[0].to_string(), "FOObar"); + assert_eq!(table_refs.len(), 0); + } } diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 1d6ccde6be13a..28e7ac2f205b8 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -29,7 +29,7 @@ use crate::utils::{ use datafusion_common::error::DataFusionErrorBuilder; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; -use datafusion_common::{Column, Result, not_impl_err, plan_err}; +use datafusion_common::{Column, DFSchema, Result, not_impl_err, plan_err}; use datafusion_common::{RecursionUnnestOption, UnnestOptions}; use datafusion_expr::expr::{Alias, PlannedReplaceSelectItem, WildcardOptions}; use datafusion_expr::expr_rewriter::{ @@ -637,11 +637,6 @@ impl SqlToRel<'_, S> { match selection { Some(predicate_expr) => { let fallback_schemas = plan.fallback_normalize_schemas(); - let outer_query_schema = planner_context.outer_query_schema().cloned(); - let outer_query_schema_vec = outer_query_schema - .as_ref() - .map(|schema| vec![schema]) - .unwrap_or_else(Vec::new); let filter_expr = self.sql_to_expr(predicate_expr, plan.schema(), planner_context)?; @@ -657,9 +652,19 @@ impl SqlToRel<'_, S> { let mut using_columns = HashSet::new(); expr_to_columns(&filter_expr, &mut using_columns)?; + let mut schema_stack: Vec> = + vec![vec![plan.schema()], fallback_schemas]; + for sc in planner_context.outer_schemas_iter() { + schema_stack.push(vec![sc.as_ref()]); + } + let filter_expr = normalize_col_with_schemas_and_ambiguity_check( filter_expr, - &[&[plan.schema()], &fallback_schemas, &outer_query_schema_vec], + schema_stack + .iter() + .map(|sc| sc.as_slice()) + .collect::>() + .as_slice(), &[using_columns], )?; diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 1acbcc92dfe19..14ec64f874c31 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -55,9 +55,10 @@ use datafusion_expr::{ TransactionIsolationLevel, TransactionStart, Volatility, WriteOp, cast, col, }; use sqlparser::ast::{ - self, BeginTransactionKind, IndexColumn, IndexType, NullsDistinctOption, OrderByExpr, - OrderByOptions, Set, ShowStatementIn, ShowStatementOptions, SqliteOnConflict, - TableObject, UpdateTableFromKind, ValueWithSpan, + self, BeginTransactionKind, CheckConstraint, ForeignKeyConstraint, IndexColumn, + IndexType, NullsDistinctOption, OrderByExpr, OrderByOptions, PrimaryKeyConstraint, + Set, ShowStatementIn, ShowStatementOptions, SqliteOnConflict, TableObject, + UniqueConstraint, Update, UpdateTableFromKind, ValueWithSpan, }; use sqlparser::ast::{ Assignment, AssignmentTarget, ColumnDef, CreateIndex, CreateTable, @@ -102,38 +103,24 @@ fn get_schema_name(schema_name: &SchemaName) -> String { /// Construct `TableConstraint`(s) for the given columns by iterating over /// `columns` and extracting individual inline constraint definitions. fn calc_inline_constraints_from_columns(columns: &[ColumnDef]) -> Vec { - let mut constraints = vec![]; + let mut constraints: Vec = vec![]; for column in columns { for ast::ColumnOptionDef { name, option } in &column.options { match option { - ast::ColumnOption::Unique { - is_primary: false, + ast::ColumnOption::Unique(UniqueConstraint { characteristics, - } => constraints.push(TableConstraint::Unique { + name, + index_name: _index_name, + index_type_display: _index_type_display, + index_type: _index_type, + columns: _column, + index_options: _index_options, + nulls_distinct: _nulls_distinct, + }) => constraints.push(TableConstraint::Unique(UniqueConstraint { name: name.clone(), - columns: vec![IndexColumn { - column: OrderByExpr { - expr: SQLExpr::Identifier(column.name.clone()), - options: OrderByOptions { - asc: None, - nulls_first: None, - }, - with_fill: None, - }, - operator_class: None, - }], - characteristics: *characteristics, index_name: None, index_type_display: ast::KeyOrIndexDisplay::None, index_type: None, - index_options: vec![], - nulls_distinct: NullsDistinctOption::None, - }), - ast::ColumnOption::Unique { - is_primary: true, - characteristics, - } => constraints.push(TableConstraint::PrimaryKey { - name: name.clone(), columns: vec![IndexColumn { column: OrderByExpr { expr: SQLExpr::Identifier(column.name.clone()), @@ -145,35 +132,69 @@ fn calc_inline_constraints_from_columns(columns: &[ColumnDef]) -> Vec { + constraints.push(TableConstraint::PrimaryKey(PrimaryKeyConstraint { + name: name.clone(), + index_name: None, + index_type: None, + columns: vec![IndexColumn { + column: OrderByExpr { + expr: SQLExpr::Identifier(column.name.clone()), + options: OrderByOptions { + asc: None, + nulls_first: None, + }, + with_fill: None, + }, + operator_class: None, + }], + index_options: vec![], + characteristics: *characteristics, + })) + } + ast::ColumnOption::ForeignKey(ForeignKeyConstraint { foreign_table, referred_columns, on_delete, on_update, characteristics, - } => constraints.push(TableConstraint::ForeignKey { - name: name.clone(), - columns: vec![], - foreign_table: foreign_table.clone(), - referred_columns: referred_columns.to_vec(), - on_delete: *on_delete, - on_update: *on_update, - characteristics: *characteristics, - index_name: None, - }), - ast::ColumnOption::Check(expr) => { - constraints.push(TableConstraint::Check { + name: _name, + index_name: _index_name, + columns: _columns, + match_kind: _match_kind, + }) => { + constraints.push(TableConstraint::ForeignKey(ForeignKeyConstraint { name: name.clone(), - expr: Box::new(expr.clone()), - enforced: None, - }) - } - // Other options are not constraint related. + index_name: None, + columns: vec![], + foreign_table: foreign_table.clone(), + referred_columns: referred_columns.clone(), + on_delete: *on_delete, + on_update: *on_update, + match_kind: None, + characteristics: *characteristics, + })) + } + ast::ColumnOption::Check(CheckConstraint { + name, + expr, + enforced: _enforced, + }) => constraints.push(TableConstraint::Check(CheckConstraint { + name: name.clone(), + expr: expr.clone(), + enforced: None, + })), ast::ColumnOption::Default(_) | ast::ColumnOption::Null | ast::ColumnOption::NotNull @@ -191,7 +212,8 @@ fn calc_inline_constraints_from_columns(columns: &[ColumnDef]) -> Vec {} + | ast::ColumnOption::Collation(_) + | ast::ColumnOption::Invisible => {} } } } @@ -341,15 +363,17 @@ impl SqlToRel<'_, S> { "Hive distribution not supported: {hive_distribution:?}" )?; } - if !matches!( - hive_formats, - Some(ast::HiveFormat { - row_format: None, - serde_properties: None, - storage: None, - location: None, - }) - ) { + if hive_formats.is_some() + && !matches!( + hive_formats, + Some(ast::HiveFormat { + row_format: None, + serde_properties: None, + storage: None, + location: None, + }) + ) + { return not_impl_err!( "Hive formats not supported: {hive_formats:?}" )?; @@ -557,7 +581,7 @@ impl SqlToRel<'_, S> { } } } - Statement::CreateView { + Statement::CreateView(ast::CreateView { or_replace, materialized, name, @@ -574,7 +598,7 @@ impl SqlToRel<'_, S> { or_alter, secure, name_before_not_exists, - } => { + }) => { if materialized { return not_impl_err!("Materialized views not supported")?; } @@ -596,7 +620,7 @@ impl SqlToRel<'_, S> { // put the statement back together temporarily to get the SQL // string representation - let stmt = Statement::CreateView { + let stmt = Statement::CreateView(ast::CreateView { or_replace, materialized, name, @@ -613,16 +637,16 @@ impl SqlToRel<'_, S> { or_alter, secure, name_before_not_exists, - }; + }); let sql = stmt.to_string(); - let Statement::CreateView { + let Statement::CreateView(ast::CreateView { name, columns, query, or_replace, temporary, .. - } = stmt + }) = stmt else { return internal_err!("Unreachable code in create view"); }; @@ -965,6 +989,7 @@ impl SqlToRel<'_, S> { has_table_keyword, settings, format_clause, + insert_token: _insert_token, // record the location the `INSERT` token }) => { let table_name = match table { TableObject::TableName(table_name) => table_name, @@ -1025,7 +1050,7 @@ impl SqlToRel<'_, S> { let _ = has_table_keyword; self.insert_to_plan(table_name, columns, source, overwrite, replace_into) } - Statement::Update { + Statement::Update(Update { table, assignments, from, @@ -1033,7 +1058,8 @@ impl SqlToRel<'_, S> { returning, or, limit, - } => { + update_token: _, + }) => { let from_clauses = from.map(|update_table_from_kind| match update_table_from_kind { UpdateTableFromKind::BeforeSet(from_clauses) => from_clauses, @@ -1064,6 +1090,7 @@ impl SqlToRel<'_, S> { from, order_by, limit, + delete_token: _, }) => { if !tables.is_empty() { plan_err!("DELETE not supported")?; @@ -1081,12 +1108,8 @@ impl SqlToRel<'_, S> { plan_err!("Delete-order-by clause not yet supported")?; } - if limit.is_some() { - plan_err!("Delete-limit clause not yet supported")?; - } - let table_name = self.get_delete_target(from)?; - self.delete_to_plan(&table_name, selection) + self.delete_to_plan(&table_name, selection, limit) } Statement::StartTransaction { @@ -1295,7 +1318,8 @@ impl SqlToRel<'_, S> { let function_body = match function_body { Some(r) => Some(self.sql_to_expr( match r { - ast::CreateFunctionBody::AsBeforeOptions(expr) => expr, + // `link_symbol` indicates if the primary expression contains the name of shared library file. + ast::CreateFunctionBody::AsBeforeOptions{body: expr, link_symbol: _link_symbol} => expr, ast::CreateFunctionBody::AsAfterOptions(expr) => expr, ast::CreateFunctionBody::Return(expr) => expr, ast::CreateFunctionBody::AsBeginEnd(_) => { @@ -1338,11 +1362,11 @@ impl SqlToRel<'_, S> { Ok(LogicalPlan::Ddl(statement)) } - Statement::DropFunction { + Statement::DropFunction(ast::DropFunction { if_exists, func_desc, - .. - } => { + drop_behavior: _, + }) => { // According to postgresql documentation it can be only one function // specified in drop statement if let Some(desc) = func_desc.first() { @@ -1362,6 +1386,56 @@ impl SqlToRel<'_, S> { exec_err!("Function name not provided") } } + Statement::Truncate(ast::Truncate { + table_names, + partitions, + identity, + cascade, + on_cluster, + table, + }) => { + let _ = table; // Support TRUNCATE TABLE and TRUNCATE syntax + if table_names.len() != 1 { + return not_impl_err!( + "TRUNCATE with multiple tables is not supported" + ); + } + + let target = &table_names[0]; + if target.only { + return not_impl_err!("TRUNCATE with ONLY is not supported"); + } + if partitions.is_some() { + return not_impl_err!("TRUNCATE with PARTITION is not supported"); + } + if identity.is_some() { + return not_impl_err!( + "TRUNCATE with RESTART/CONTINUE IDENTITY is not supported" + ); + } + if cascade.is_some() { + return not_impl_err!( + "TRUNCATE with CASCADE/RESTRICT is not supported" + ); + } + if on_cluster.is_some() { + return not_impl_err!("TRUNCATE with ON CLUSTER is not supported"); + } + let table = self.object_name_to_table_reference(target.name.clone())?; + let source = self.context_provider.get_table_source(table.clone())?; + + // TRUNCATE does not operate on input rows. The EmptyRelation is a logical placeholder + // since the real operation is executed directly by the TableProvider's truncate() hook. + Ok(LogicalPlan::Dml(DmlStatement::new( + table.clone(), + source, + WriteOp::Truncate, + Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: DFSchemaRef::new(DFSchema::empty()), + })), + ))) + } Statement::CreateIndex(CreateIndex { name, table_name, @@ -1716,8 +1790,17 @@ impl SqlToRel<'_, S> { let constraints = constraints .iter() .map(|c: &TableConstraint| match c { - TableConstraint::Unique { name, columns, .. } => { - let constraint_name = match name { + TableConstraint::Unique(UniqueConstraint { + name, + index_name: _, + index_type_display: _, + index_type: _, + columns, + index_options: _, + characteristics: _, + nulls_distinct: _, + }) => { + let constraint_name = match &name { Some(name) => &format!("unique constraint with name '{name}'"), None => "unique constraint", }; @@ -1729,7 +1812,14 @@ impl SqlToRel<'_, S> { )?; Ok(Constraint::Unique(indices)) } - TableConstraint::PrimaryKey { columns, .. } => { + TableConstraint::PrimaryKey(PrimaryKeyConstraint { + name: _, + index_name: _, + index_type: _, + columns, + index_options: _, + characteristics: _, + }) => { // Get primary key indices in the schema let indices = self.get_constraint_column_indices( df_schema, @@ -1978,6 +2068,7 @@ impl SqlToRel<'_, S> { &self, table_name: &ObjectName, predicate_expr: Option, + limit: Option, ) -> Result { // Do a table lookup to verify the table exists let table_ref = self.object_name_to_table_reference(table_name.clone())?; @@ -1991,7 +2082,7 @@ impl SqlToRel<'_, S> { .build()?; let mut planner_context = PlannerContext::new(); - let source = match predicate_expr { + let mut source = match predicate_expr { None => scan, Some(predicate_expr) => { let filter_expr = @@ -2008,6 +2099,14 @@ impl SqlToRel<'_, S> { } }; + if let Some(limit) = limit { + let empty_schema = DFSchema::empty(); + let limit = self.sql_to_expr(limit, &empty_schema, &mut planner_context)?; + source = LogicalPlanBuilder::from(source) + .limit_by_expr(None, Some(limit))? + .build()? + } + let plan = LogicalPlan::Dml(DmlStatement::new( table_ref, table_source, diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 5746a568e712b..5f6612830ac1f 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -45,7 +45,7 @@ use datafusion_common::{ }; use datafusion_expr::{ Between, BinaryExpr, Case, Cast, Expr, GroupingSet, Like, Operator, TryCast, - expr::{Alias, Exists, InList, ScalarFunction, Sort, WindowFunction}, + expr::{Alias, Exists, InList, ScalarFunction, SetQuantifier, Sort, WindowFunction}, }; use sqlparser::ast::helpers::attached_token::AttachedToken; use sqlparser::tokenizer::Span; @@ -393,6 +393,33 @@ impl Unparser<'_> { negated: insubq.negated, }) } + Expr::SetComparison(set_cmp) => { + let left = Box::new(self.expr_to_sql_inner(set_cmp.expr.as_ref())?); + let sub_statement = + self.plan_to_sql(set_cmp.subquery.subquery.as_ref())?; + let sub_query = if let ast::Statement::Query(inner_query) = sub_statement + { + inner_query + } else { + return plan_err!( + "Subquery must be a Query, but found {sub_statement:?}" + ); + }; + let compare_op = self.op_to_sql(&set_cmp.op)?; + match set_cmp.quantifier { + SetQuantifier::Any => Ok(ast::Expr::AnyOp { + left, + compare_op, + right: Box::new(ast::Expr::Subquery(sub_query)), + is_some: false, + }), + SetQuantifier::All => Ok(ast::Expr::AllOp { + left, + compare_op, + right: Box::new(ast::Expr::Subquery(sub_query)), + }), + } + } Expr::Exists(Exists { subquery, negated }) => { let sub_statement = self.plan_to_sql(subquery.subquery.as_ref())?; let sub_query = if let ast::Statement::Query(inner_query) = sub_statement @@ -1414,6 +1441,7 @@ impl Unparser<'_> { ScalarValue::Map(_) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::Union(..) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::Dictionary(_k, v) => self.scalar_to_sql(v), + ScalarValue::RunEndEncoded(_, _, v) => self.scalar_to_sql(v), } } @@ -1763,6 +1791,9 @@ impl Unparser<'_> { not_impl_err!("Unsupported DataType: conversion: {data_type}") } DataType::Dictionary(_, val) => self.arrow_dtype_to_ast_dtype(val), + DataType::RunEndEncoded(_, val) => { + self.arrow_dtype_to_ast_dtype(val.data_type()) + } DataType::Decimal32(precision, scale) | DataType::Decimal64(precision, scale) | DataType::Decimal128(precision, scale) @@ -1784,9 +1815,6 @@ impl Unparser<'_> { DataType::Map(_, _) => { not_impl_err!("Unsupported DataType: conversion: {data_type}") } - DataType::RunEndEncoded(_, _) => { - not_impl_err!("Unsupported DataType: conversion: {data_type}") - } } } } @@ -2289,6 +2317,17 @@ mod tests { ), "'foo'", ), + ( + Expr::Literal( + ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int32, false).into(), + Field::new("values", DataType::Utf8, true).into(), + Box::new(ScalarValue::Utf8(Some("foo".into()))), + ), + None, + ), + "'foo'", + ), ( Expr::Literal( ScalarValue::List(Arc::new(ListArray::from_iter_primitive::< @@ -3158,6 +3197,22 @@ mod tests { Ok(()) } + #[test] + fn test_run_end_encoded_to_sql() -> Result<()> { + let dialect = CustomDialectBuilder::new().build(); + + let unparser = Unparser::new(&dialect); + + let ast_dtype = unparser.arrow_dtype_to_ast_dtype(&DataType::RunEndEncoded( + Field::new("run_ends", DataType::Int32, false).into(), + Field::new("values", DataType::Utf8, true).into(), + ))?; + + assert_eq!(ast_dtype, ast::DataType::Varchar(None)); + + Ok(()) + } + #[test] fn test_utf8_view_to_sql() -> Result<()> { let dialect = CustomDialectBuilder::new() diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 56bf887dbde43..9f770f9f45e1d 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -1395,6 +1395,7 @@ impl Unparser<'_> { ast::TableAlias { name: self.new_ident_quoted_if_needs(alias), columns, + explicit: true, } } diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 43fb98e54545c..9205336a52e4e 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -374,7 +374,7 @@ pub(crate) fn rewrite_recursive_unnests_bottom_up( pub const UNNEST_PLACEHOLDER: &str = "__unnest_placeholder"; /* -This is only usedful when used with transform down up +This is only useful when used with transform down up A full example of how the transformation works: */ struct RecursiveUnnestRewriter<'a> { @@ -496,7 +496,7 @@ impl TreeNodeRewriter for RecursiveUnnestRewriter<'_> { type Node = Expr; /// This downward traversal needs to keep track of: - /// - Whether or not some unnest expr has been visited from the top util the current node + /// - Whether or not some unnest expr has been visited from the top until the current node /// - If some unnest expr has been visited, maintain a stack of such information, this /// is used to detect if some recursive unnest expr exists (e.g **unnest(unnest(unnest(3d column))))** fn f_down(&mut self, expr: Expr) -> Result> { diff --git a/datafusion/sql/src/values.rs b/datafusion/sql/src/values.rs index dd8957c95470d..c8cdf1254f33f 100644 --- a/datafusion/sql/src/values.rs +++ b/datafusion/sql/src/values.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::{DFSchema, Result}; +use datafusion_common::{DFSchema, Result, not_impl_err}; use datafusion_expr::{LogicalPlan, LogicalPlanBuilder}; use sqlparser::ast::Values as SQLValues; @@ -31,7 +31,13 @@ impl SqlToRel<'_, S> { let SQLValues { explicit_row: _, rows, + value_keyword, } = values; + if value_keyword { + return not_impl_err!( + "`VALUE` keyword not supported. Did you mean `VALUES`?" + )?; + } let empty_schema = Arc::new(DFSchema::empty()); let values = rows diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 46a42ae534af0..4717b843abb53 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -286,7 +286,7 @@ fn roundtrip_crossjoin() -> Result<()> { plan_roundtrip, @r" Projection: j1.j1_id, j2.j2_string - Cross Join: + Cross Join: TableScan: j1 TableScan: j2 " diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index 9dc6b895e49ab..4b8667c3c0cbf 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -161,12 +161,26 @@ impl ContextProvider for MockContextProvider { ])), "orders" => Ok(Schema::new(vec![ Field::new("order_id", DataType::UInt32, false), + Field::new("o_orderkey", DataType::UInt32, false), + Field::new("o_custkey", DataType::UInt32, false), + Field::new("o_orderstatus", DataType::Utf8, false), Field::new("customer_id", DataType::UInt32, false), + Field::new("o_totalprice", DataType::Decimal128(15, 2), false), Field::new("o_item_id", DataType::Utf8, false), Field::new("qty", DataType::Int32, false), Field::new("price", DataType::Float64, false), Field::new("delivered", DataType::Boolean, false), ])), + "customer" => Ok(Schema::new(vec![ + Field::new("c_custkey", DataType::UInt32, false), + Field::new("c_name", DataType::Utf8, false), + Field::new("c_address", DataType::Utf8, false), + Field::new("c_nationkey", DataType::UInt32, false), + Field::new("c_phone", DataType::Utf8, false), + Field::new("c_acctbal", DataType::Float64, false), + Field::new("c_mktsegment", DataType::Utf8, false), + Field::new("c_comment", DataType::Utf8, false), + ])), "array" => Ok(Schema::new(vec![ Field::new( "left", @@ -186,8 +200,10 @@ impl ContextProvider for MockContextProvider { ), ])), "lineitem" => Ok(Schema::new(vec![ + Field::new("l_orderkey", DataType::UInt32, false), Field::new("l_item_id", DataType::UInt32, false), Field::new("l_description", DataType::Utf8, false), + Field::new("l_extendedprice", DataType::Decimal128(15, 2), false), Field::new("price", DataType::Float64, false), ])), "aggregate_test_100" => Ok(Schema::new(vec![ diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 491873b4afe02..444bdae73ac26 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -995,15 +995,15 @@ fn select_nested_with_filters() { #[test] fn table_with_column_alias() { - let sql = "SELECT a, b, c - FROM lineitem l (a, b, c)"; + let sql = "SELECT a, b, c, d, e + FROM lineitem l (a, b, c, d, e)"; let plan = logical_plan(sql).unwrap(); assert_snapshot!( plan, @r" - Projection: l.a, l.b, l.c + Projection: l.a, l.b, l.c, l.d, l.e SubqueryAlias: l - Projection: lineitem.l_item_id AS a, lineitem.l_description AS b, lineitem.price AS c + Projection: lineitem.l_orderkey AS a, lineitem.l_item_id AS b, lineitem.l_description AS c, lineitem.l_extendedprice AS d, lineitem.price AS e TableScan: lineitem " ); @@ -1017,7 +1017,7 @@ fn table_with_column_alias_number_cols() { assert_snapshot!( err.strip_backtrace(), - @"Error during planning: Source table contains 3 columns but only 2 names given as column alias" + @"Error during planning: Source table contains 5 columns but only 2 names given as column alias" ); } @@ -1058,7 +1058,7 @@ fn natural_left_join() { plan, @r" Projection: a.l_item_id - Left Join: Using a.l_item_id = b.l_item_id, a.l_description = b.l_description, a.price = b.price + Left Join: Using a.l_orderkey = b.l_orderkey, a.l_item_id = b.l_item_id, a.l_description = b.l_description, a.l_extendedprice = b.l_extendedprice, a.price = b.price SubqueryAlias: a TableScan: lineitem SubqueryAlias: b @@ -1075,7 +1075,7 @@ fn natural_right_join() { plan, @r" Projection: a.l_item_id - Right Join: Using a.l_item_id = b.l_item_id, a.l_description = b.l_description, a.price = b.price + Right Join: Using a.l_orderkey = b.l_orderkey, a.l_item_id = b.l_item_id, a.l_description = b.l_description, a.l_extendedprice = b.l_extendedprice, a.price = b.price SubqueryAlias: a TableScan: lineitem SubqueryAlias: b @@ -3395,8 +3395,8 @@ fn cross_join_not_to_inner_join() { @r" Projection: person.id Filter: person.id = person.age - Cross Join: - Cross Join: + Cross Join: + Cross Join: TableScan: person TableScan: orders TableScan: lineitem @@ -3530,11 +3530,11 @@ fn exists_subquery_schema_outer_schema_overlap() { Subquery: Projection: person.first_name Filter: person.id = p2.id AND person.last_name = outer_ref(p.last_name) AND person.state = outer_ref(p.state) - Cross Join: + Cross Join: TableScan: person SubqueryAlias: p2 TableScan: person - Cross Join: + Cross Join: TableScan: person SubqueryAlias: p TableScan: person @@ -3619,10 +3619,10 @@ fn scalar_subquery_reference_outer_field() { Projection: count(*) Aggregate: groupBy=[[]], aggr=[[count(*)]] Filter: outer_ref(j2.j2_id) = j1.j1_id AND j1.j1_id = j3.j3_id - Cross Join: + Cross Join: TableScan: j1 TableScan: j3 - Cross Join: + Cross Join: TableScan: j1 TableScan: j2 " @@ -4801,7 +4801,11 @@ fn test_using_join_wildcard_schema() { // Only columns from one join side should be present let expected_fields = vec![ "o1.order_id".to_string(), + "o1.o_orderkey".to_string(), + "o1.o_custkey".to_string(), + "o1.o_orderstatus".to_string(), "o1.customer_id".to_string(), + "o1.o_totalprice".to_string(), "o1.o_item_id".to_string(), "o1.qty".to_string(), "o1.price".to_string(), @@ -4855,3 +4859,70 @@ fn test_using_join_wildcard_schema() { ] ); } + +#[test] +fn test_2_nested_lateral_join_with_the_deepest_join_referencing_the_outer_most_relation() +{ + let sql = "SELECT * FROM j1 j1_outer, LATERAL ( + SELECT * FROM j1 j1_inner, LATERAL ( + SELECT * FROM j2 WHERE j1_inner.j1_id = j2_id and j1_outer.j1_id=j2_id + ) as j2 +) as j2"; + + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: j1_outer.j1_id, j1_outer.j1_string, j2.j1_id, j2.j1_string, j2.j2_id, j2.j2_string + Cross Join: + SubqueryAlias: j1_outer + TableScan: j1 + SubqueryAlias: j2 + Subquery: + Projection: j1_inner.j1_id, j1_inner.j1_string, j2.j2_id, j2.j2_string + Cross Join: + SubqueryAlias: j1_inner + TableScan: j1 + SubqueryAlias: j2 + Subquery: + Projection: j2.j2_id, j2.j2_string + Filter: outer_ref(j1_inner.j1_id) = j2.j2_id AND outer_ref(j1_outer.j1_id) = j2.j2_id + TableScan: j2 +"# + ); +} + +#[test] +fn test_correlated_recursive_scalar_subquery_with_level_3_scalar_subquery_referencing_level1_relation() + { + let sql = "select c_custkey from customer + where c_acctbal < ( + select sum(o_totalprice) from orders + where o_custkey = c_custkey + and o_totalprice < ( + select sum(l_extendedprice) as price from lineitem where l_orderkey = o_orderkey + and l_extendedprice < c_acctbal + ) + ) order by c_custkey"; + + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Sort: customer.c_custkey ASC NULLS LAST + Projection: customer.c_custkey + Filter: customer.c_acctbal < () + Subquery: + Projection: sum(orders.o_totalprice) + Aggregate: groupBy=[[]], aggr=[[sum(orders.o_totalprice)]] + Filter: orders.o_custkey = outer_ref(customer.c_custkey) AND orders.o_totalprice < () + Subquery: + Projection: sum(lineitem.l_extendedprice) AS price + Aggregate: groupBy=[[]], aggr=[[sum(lineitem.l_extendedprice)]] + Filter: lineitem.l_orderkey = outer_ref(orders.o_orderkey) AND lineitem.l_extendedprice < outer_ref(customer.c_acctbal) + TableScan: lineitem + TableScan: orders + TableScan: customer +"# + ); +} diff --git a/datafusion/sqllogictest/Cargo.toml b/datafusion/sqllogictest/Cargo.toml index 13ae6e6a57e01..e610739a0312e 100644 --- a/datafusion/sqllogictest/Cargo.toml +++ b/datafusion/sqllogictest/Cargo.toml @@ -45,9 +45,9 @@ async-trait = { workspace = true } bigdecimal = { workspace = true } bytes = { workspace = true, optional = true } chrono = { workspace = true, optional = true } -clap = { version = "4.5.53", features = ["derive", "env"] } +clap = { version = "4.5.59", features = ["derive", "env"] } datafusion = { workspace = true, default-features = true, features = ["avro"] } -datafusion-spark = { workspace = true, default-features = true } +datafusion-spark = { workspace = true, features = ["core"] } datafusion-substrait = { workspace = true, default-features = true } futures = { workspace = true } half = { workspace = true, default-features = true } @@ -55,16 +55,16 @@ indicatif = "0.18" itertools = { workspace = true } log = { workspace = true } object_store = { workspace = true } -postgres-types = { version = "0.2.11", features = ["derive", "with-chrono-0_4"], optional = true } +postgres-types = { version = "0.2.12", features = ["derive", "with-chrono-0_4"], optional = true } # When updating the following dependency verify that sqlite test file regeneration works correctly # by running the regenerate_sqlite_files.sh script. -sqllogictest = "0.29.0" +sqllogictest = "0.29.1" sqlparser = { workspace = true } tempfile = { workspace = true } testcontainers-modules = { workspace = true, features = ["postgres"], optional = true } -thiserror = "2.0.17" +thiserror = "2.0.18" tokio = { workspace = true } -tokio-postgres = { version = "0.7.14", optional = true } +tokio-postgres = { version = "0.7.16", optional = true } [features] avro = ["datafusion/avro"] diff --git a/datafusion/sqllogictest/bin/sqllogictests.rs b/datafusion/sqllogictest/bin/sqllogictests.rs index 8037532c09ac3..3571377354eb4 100644 --- a/datafusion/sqllogictest/bin/sqllogictests.rs +++ b/datafusion/sqllogictest/bin/sqllogictests.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use clap::Parser; +use clap::{ColorChoice, Parser}; use datafusion::common::instant::Instant; use datafusion::common::utils::get_available_parallelism; use datafusion::common::{DataFusionError, Result, exec_datafusion_err, exec_err}; @@ -44,7 +44,9 @@ use datafusion::common::runtime::SpawnedTask; use futures::FutureExt; use std::ffi::OsStr; use std::fs; +use std::io::{IsTerminal, stdout}; use std::path::{Path, PathBuf}; +use std::str::FromStr; #[cfg(feature = "postgres")] mod postgres_container; @@ -123,6 +125,8 @@ async fn run_tests() -> Result<()> { .unwrap() .progress_chars("##-"); + let colored_output = options.is_colored(); + let start = Instant::now(); let test_files = read_test_files(&options)?; @@ -176,6 +180,7 @@ async fn run_tests() -> Result<()> { m_style_clone, filters.as_ref(), currently_running_sql_tracker_clone, + colored_output, ) .await? } @@ -187,6 +192,7 @@ async fn run_tests() -> Result<()> { m_style_clone, filters.as_ref(), currently_running_sql_tracker_clone, + colored_output, ) .await? } @@ -294,6 +300,7 @@ async fn run_test_file_substrait_round_trip( mp_style: ProgressStyle, filters: &[Filter], currently_executing_sql_tracker: CurrentlyExecutingSqlTracker, + colored_output: bool, ) -> Result<()> { let TestFile { path, @@ -323,7 +330,7 @@ async fn run_test_file_substrait_round_trip( runner.with_column_validator(strict_column_validator); runner.with_normalizer(value_normalizer); runner.with_validator(validator); - let res = run_file_in_runner(path, runner, filters).await; + let res = run_file_in_runner(path, runner, filters, colored_output).await; pb.finish_and_clear(); res } @@ -335,6 +342,7 @@ async fn run_test_file( mp_style: ProgressStyle, filters: &[Filter], currently_executing_sql_tracker: CurrentlyExecutingSqlTracker, + colored_output: bool, ) -> Result<()> { let TestFile { path, @@ -364,7 +372,7 @@ async fn run_test_file( runner.with_column_validator(strict_column_validator); runner.with_normalizer(value_normalizer); runner.with_validator(validator); - let result = run_file_in_runner(path, runner, filters).await; + let result = run_file_in_runner(path, runner, filters, colored_output).await; pb.finish_and_clear(); result } @@ -373,6 +381,7 @@ async fn run_file_in_runner>( path: PathBuf, mut runner: sqllogictest::Runner, filters: &[Filter], + colored_output: bool, ) -> Result<()> { let path = path.canonicalize()?; let records = @@ -386,7 +395,11 @@ async fn run_file_in_runner>( continue; } if let Err(err) = runner.run_async(record).await { - errs.push(format!("{err}")); + if colored_output { + errs.push(format!("{}", err.display(true))); + } else { + errs.push(format!("{err}")); + } } } @@ -479,7 +492,7 @@ async fn run_test_file_with_postgres( runner.with_column_validator(strict_column_validator); runner.with_normalizer(value_normalizer); runner.with_validator(validator); - let result = run_file_in_runner(path, runner, filters).await; + let result = run_file_in_runner(path, runner, filters, false).await; pb.finish_and_clear(); result } @@ -772,6 +785,14 @@ struct Options { default_value_t = get_available_parallelism() )] test_threads: usize, + + #[clap( + long, + value_name = "MODE", + help = "Control colored output", + default_value_t = ColorChoice::Auto + )] + color: ColorChoice, } impl Options { @@ -813,6 +834,37 @@ impl Options { eprintln!("WARNING: Ignoring `--show-output` compatibility option"); } } + + /// Determine if colour output should be enabled, respecting --color, NO_COLOR, CARGO_TERM_COLOR, and terminal detection + fn is_colored(&self) -> bool { + // NO_COLOR takes precedence + if std::env::var_os("NO_COLOR").is_some() { + return false; + } + + match self.color { + ColorChoice::Always => true, + ColorChoice::Never => false, + ColorChoice::Auto => { + // CARGO_TERM_COLOR takes precedence over auto-detection + let cargo_term_color = ColorChoice::from_str( + &std::env::var("CARGO_TERM_COLOR") + .unwrap_or_else(|_| "auto".to_string()), + ) + .unwrap_or(ColorChoice::Auto); + match cargo_term_color { + ColorChoice::Always => true, + ColorChoice::Never => false, + ColorChoice::Auto => { + // Auto for both CLI argument and CARGO_TERM_COLOR, + // then use colors by default for non-dumb terminals + stdout().is_terminal() + && std::env::var("TERM").unwrap_or_default() != "dumb" + } + } + } + } + } } /// Performs scratch file check for all test files. diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index 9ec085b41eec0..8bd0cabcb05b0 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -21,6 +21,7 @@ use std::fs::File; use std::io::Write; use std::path::Path; use std::sync::Arc; +use std::vec; use arrow::array::{ Array, ArrayRef, BinaryArray, Float64Array, Int32Array, LargeBinaryArray, @@ -30,7 +31,7 @@ use arrow::buffer::ScalarBuffer; use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit, UnionFields}; use arrow::record_batch::RecordBatch; use datafusion::catalog::{ - CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, Session, + CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider, Session, }; use datafusion::common::{DataFusionError, Result, not_impl_err}; use datafusion::functions::math::abs; @@ -45,6 +46,7 @@ use datafusion::{ datasource::{MemTable, TableProvider, TableType}, prelude::{CsvReadOptions, SessionContext}, }; +use datafusion_spark::SessionStateBuilderSpark; use crate::is_spark_path; use async_trait::async_trait; @@ -80,22 +82,26 @@ impl TestContext { // hardcode target partitions so plans are deterministic .with_target_partitions(4); let runtime = Arc::new(RuntimeEnv::default()); - let mut state = SessionStateBuilder::new() + + let mut state_builder = SessionStateBuilder::new() .with_config(config) .with_runtime_env(runtime) - .with_default_features() - .build(); + .with_default_features(); if is_spark_path(relative_path) { - info!("Registering Spark functions"); - datafusion_spark::register_all(&mut state) - .expect("Can not register Spark functions"); + state_builder = state_builder.with_spark_features(); } + let state = state_builder.build(); + let mut test_ctx = TestContext::new(SessionContext::new_with_state(state)); let file_name = relative_path.file_name().unwrap().to_str().unwrap(); match file_name { + "cte_quoted_reference.slt" => { + info!("Registering strict catalog provider for CTE tests"); + register_strict_orders_catalog(test_ctx.session_ctx()); + } "information_schema_table_types.slt" => { info!("Registering local temporary table"); register_temp_table(test_ctx.session_ctx()).await; @@ -171,6 +177,104 @@ impl TestContext { } } +// ============================================================================== +// Strict Catalog / Schema Provider (sqllogictest-only) +// ============================================================================== +// +// The goal of `cte_quoted_reference.slt` is to exercise end-to-end query planning +// while detecting *unexpected* catalog lookups. +// +// Specifically, if DataFusion incorrectly treats a CTE reference (e.g. `"barbaz"`) +// as a real table reference, the planner will attempt to resolve it through the +// schema provider. The types below deliberately `panic!` on any lookup other than +// the one table we expect (`orders`). +// +// This makes the "extra provider lookup" bug observable in an end-to-end test, +// rather than being silently ignored by default providers that return `Ok(None)` +// for unknown tables. + +#[derive(Debug)] +struct StrictOrdersCatalog { + schema: Arc, +} + +impl CatalogProvider for StrictOrdersCatalog { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema_names(&self) -> Vec { + vec!["public".to_string()] + } + + fn schema(&self, name: &str) -> Option> { + (name == "public").then(|| Arc::clone(&self.schema)) + } +} + +#[derive(Debug)] +struct StrictOrdersSchema { + orders: Arc, +} + +#[async_trait] +impl SchemaProvider for StrictOrdersSchema { + fn as_any(&self) -> &dyn Any { + self + } + + fn table_names(&self) -> Vec { + vec!["orders".to_string()] + } + + async fn table( + &self, + name: &str, + ) -> Result>, DataFusionError> { + match name { + "orders" => Ok(Some(Arc::clone(&self.orders))), + other => panic!( + "unexpected table lookup: {other}. This maybe indicates a CTE reference was \ + incorrectly treated as a catalog table reference." + ), + } + } + + fn table_exist(&self, name: &str) -> bool { + name == "orders" + } +} + +fn register_strict_orders_catalog(ctx: &SessionContext) { + let schema = Arc::new(Schema::new(vec![Field::new( + "order_id", + DataType::Int32, + false, + )])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![1, 2]))], + ) + .expect("record batch should be valid"); + + let orders = + MemTable::try_new(schema, vec![vec![batch]]).expect("memtable should be valid"); + + let schema_provider: Arc = Arc::new(StrictOrdersSchema { + orders: Arc::new(orders), + }); + + // Override the default "datafusion" catalog for this test file so that any + // unexpected lookup is caught immediately. + ctx.register_catalog( + "datafusion", + Arc::new(StrictOrdersCatalog { + schema: schema_provider, + }), + ); +} + #[cfg(feature = "avro")] pub async fn register_avro_tables(ctx: &mut TestContext) { use datafusion::prelude::AvroReadOptions; @@ -436,14 +540,15 @@ fn create_example_udf() -> ScalarUDF { fn register_union_table(ctx: &SessionContext) { let union = UnionArray::try_new( - UnionFields::new( + UnionFields::try_new( // typeids: 3 for int, 1 for string vec![3, 1], vec![ Field::new("int", DataType::Int32, false), Field::new("string", DataType::Utf8, false), ], - ), + ) + .unwrap(), ScalarBuffer::from(vec![3, 1, 3]), None, vec![ diff --git a/datafusion/sqllogictest/src/util.rs b/datafusion/sqllogictest/src/util.rs index 6a3d3944e4e81..b0cf32266ea31 100644 --- a/datafusion/sqllogictest/src/util.rs +++ b/datafusion/sqllogictest/src/util.rs @@ -44,7 +44,7 @@ pub fn setup_scratch_dir(name: &Path) -> Result<()> { /// Trailing whitespace from lines in SLT will typically be removed, but do not fail if it is not /// If particular test wants to cover trailing whitespace on a value, /// it should project additional non-whitespace column on the right. -#[allow(clippy::ptr_arg)] +#[expect(clippy::ptr_arg)] pub fn value_normalizer(s: &String) -> String { s.trim_end().to_string() } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index a5f3ef04139f4..517467110fe6d 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -379,6 +379,59 @@ select array_sort(c1), array_sort(c2) from ( statement ok drop table array_agg_distinct_list_table; +# Test array_agg with DISTINCT and IGNORE NULLS (regression test for issue #19735) +query ? +SELECT array_sort(ARRAY_AGG(DISTINCT x IGNORE NULLS)) as result +FROM (VALUES (1), (2), (NULL), (2), (NULL), (1)) AS t(x); +---- +[1, 2] + +# Test that non-DISTINCT aggregates also preserve IGNORE NULLS when mixed with DISTINCT +# This tests the two-phase aggregation rewrite in SingleDistinctToGroupBy +query I? +SELECT + COUNT(DISTINCT x) as distinct_count, + array_sort(ARRAY_AGG(y IGNORE NULLS)) as y_agg +FROM (VALUES + (1, 10), + (1, 20), + (2, 30), + (3, NULL), + (3, 40), + (NULL, 50) +) AS t(x, y) +---- +3 [10, 20, 30, 40, 50] + +# Test that FILTER clause is preserved in two-phase aggregation rewrite +query II +SELECT + COUNT(DISTINCT x) as distinct_count, + SUM(y) FILTER (WHERE y > 15) as filtered_sum +FROM (VALUES + (1, 10), + (1, 20), + (2, 5), + (2, 30), + (3, 25) +) AS t(x, y) +---- +3 75 + +# Test that ORDER BY is preserved in two-phase aggregation rewrite +query I? +SELECT + COUNT(DISTINCT x) as distinct_count, + ARRAY_AGG(y ORDER BY y DESC) as ordered_agg +FROM (VALUES + (1, 10), + (1, 30), + (2, 20), + (2, 40) +) AS t(x, y) +---- +2 [40, 30, 20, 10] + statement error This feature is not implemented: Calling array_agg: LIMIT not supported in function arguments: 1 SELECT array_agg(c13 LIMIT 1) FROM aggregate_test_100 @@ -518,6 +571,16 @@ SELECT covar(c2, c12) FROM aggregate_test_100 ---- -0.079969012479 +query R +SELECT covar_pop(arrow_cast(c2, 'Float16'), arrow_cast(c12, 'Float16')) FROM aggregate_test_100 +---- +-0.079163311005 + +query R +SELECT covar(arrow_cast(c2, 'Float16'), arrow_cast(c12, 'Float16')) FROM aggregate_test_100 +---- +-0.079962940409 + # single_row_query_covar_1 query R select covar_samp(sq.column1, sq.column2) from (values (1.1, 2.2)) as sq @@ -1226,6 +1289,32 @@ ORDER BY tags, timestamp; 4 tag2 90 67.5 82.5 5 tag2 100 70 90 + +# Test distinct median non-sliding window +query ITRR +SELECT + timestamp, + tags, + value, + median(DISTINCT value) OVER ( + PARTITION BY tags + ORDER BY timestamp + ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + ) AS distinct_median +FROM median_window_test +ORDER BY tags, timestamp; +---- +1 tag1 10 10 +2 tag1 20 15 +3 tag1 30 20 +4 tag1 40 25 +5 tag1 50 30 +1 tag2 60 60 +2 tag2 70 65 +3 tag2 80 70 +4 tag2 90 75 +5 tag2 100 80 + statement ok DROP TABLE median_window_test; @@ -1234,6 +1323,24 @@ select approx_median(arrow_cast(col_f32, 'Float16')), arrow_typeof(approx_median ---- 2.75 Float16 +# This shouldn't be NaN, see: +# https://github.com/apache/datafusion/issues/18945 +query RT +select + percentile_cont(0.5) within group (order by arrow_cast(col_f32, 'Float16')), + arrow_typeof(percentile_cont(0.5) within group (order by arrow_cast(col_f32, 'Float16'))) +from median_table; +---- +2.75 Float16 + +query RT +select + approx_percentile_cont(0.5) within group (order by arrow_cast(col_f32, 'Float16')), + arrow_typeof(approx_percentile_cont(0.5) within group (order by arrow_cast(col_f32, 'Float16'))) +from median_table; +---- +2.75 Float16 + query ?T select approx_median(NULL), arrow_typeof(approx_median(NULL)) from median_table; ---- @@ -1950,11 +2057,12 @@ statement ok INSERT INTO t1 VALUES (TRUE); # ISSUE: https://github.com/apache/datafusion/issues/12716 -# This test verifies that approx_percentile_cont_with_weight does not panic when given 'NaN' and returns 'inf' +# This test verifies that approx_percentile_cont_with_weight does not panic when given 'NaN' +# With weight=0, the data point does not contribute, so result is NULL query R SELECT approx_percentile_cont_with_weight(0, 0) WITHIN GROUP (ORDER BY 'NaN'::DOUBLE) FROM t1 WHERE t1.v1; ---- -Infinity +NULL statement ok DROP TABLE t1; @@ -2273,21 +2381,21 @@ e 115 query TI SELECT c1, approx_percentile_cont_with_weight(c2, 0.95) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- -a 74 +a 65 b 68 -c 123 -d 124 -e 115 +c 122 +d 123 +e 110 # approx_percentile_cont_with_weight with centroids query TI SELECT c1, approx_percentile_cont_with_weight(c2, 0.95, 200) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- -a 74 +a 65 b 68 -c 123 -d 124 -e 115 +c 122 +d 123 +e 110 # csv_query_sum_crossjoin query TTI @@ -5422,10 +5530,10 @@ as values statement ok create table t as select - arrow_cast(column1, 'Timestamp(Nanosecond, None)') as nanos, - arrow_cast(column1, 'Timestamp(Microsecond, None)') as micros, - arrow_cast(column1, 'Timestamp(Millisecond, None)') as millis, - arrow_cast(column1, 'Timestamp(Second, None)') as secs, + arrow_cast(column1, 'Timestamp(ns)') as nanos, + arrow_cast(column1, 'Timestamp(µs)') as micros, + arrow_cast(column1, 'Timestamp(ms)') as millis, + arrow_cast(column1, 'Timestamp(s)') as secs, arrow_cast(column1, 'Timestamp(Nanosecond, Some("UTC"))') as nanos_utc, arrow_cast(column1, 'Timestamp(Microsecond, Some("UTC"))') as micros_utc, arrow_cast(column1, 'Timestamp(Millisecond, Some("UTC"))') as millis_utc, @@ -5508,7 +5616,7 @@ SELECT tag, avg(nanos), avg(micros), avg(millis), avg(secs) FROM t GROUP BY tag # aggregate_duration_array_agg query T? -SELECT tag, array_agg(millis - arrow_cast(secs, 'Timestamp(Millisecond, None)')) FROM t GROUP BY tag ORDER BY tag; +SELECT tag, array_agg(millis - arrow_cast(secs, 'Timestamp(ms)')) FROM t GROUP BY tag ORDER BY tag; ---- X [0 days 0 hours 0 mins 0.011 secs, 0 days 0 hours 0 mins 0.123 secs] Y [NULL, 0 days 0 hours 0 mins 0.432 secs] @@ -6639,7 +6747,12 @@ from aggregate_test_100; ---- 0.051534002628 0.48427355347 100 0.001929150558 0.479274948239 0.508972509913 6.707779292571 9.234223721582 0.345678715695 - +query R +select + regr_slope(arrow_cast(c12, 'Float16'), arrow_cast(c11, 'Float16')) +from aggregate_test_100; +---- +0.051477733249 # regr_*() functions ignore NULLs query RRIRRRRRR @@ -7872,8 +7985,9 @@ logical_plan 02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] 03)----TableScan: t projection=[] physical_plan -01)ProjectionExec: expr=[2 as count(Int64(1)), 2 as count()] -02)--PlaceholderRowExec +01)ProjectionExec: expr=[count(Int64(1))@0 as count(Int64(1)), count(Int64(1))@0 as count()] +02)--ProjectionExec: expr=[2 as count(Int64(1))] +03)----PlaceholderRowExec query II select count(1), count(*) from t; @@ -7888,8 +8002,9 @@ logical_plan 02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] 03)----TableScan: t projection=[] physical_plan -01)ProjectionExec: expr=[2 as count(Int64(1)), 2 as count(*)] -02)--PlaceholderRowExec +01)ProjectionExec: expr=[count(Int64(1))@0 as count(Int64(1)), count(Int64(1))@0 as count(*)] +02)--ProjectionExec: expr=[2 as count(Int64(1))] +03)----PlaceholderRowExec query II select count(), count(*) from t; @@ -7904,8 +8019,9 @@ logical_plan 02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] 03)----TableScan: t projection=[] physical_plan -01)ProjectionExec: expr=[2 as count(), 2 as count(*)] -02)--PlaceholderRowExec +01)ProjectionExec: expr=[count(Int64(1))@0 as count(), count(Int64(1))@0 as count(*)] +02)--ProjectionExec: expr=[2 as count(Int64(1))] +03)----PlaceholderRowExec query TT explain select count(1) * count(2) from t; diff --git a/datafusion/sqllogictest/test_files/aggregates_topk.slt b/datafusion/sqllogictest/test_files/aggregates_topk.slt index 05f3e02bbc1b3..19ead8965ed01 100644 --- a/datafusion/sqllogictest/test_files/aggregates_topk.slt +++ b/datafusion/sqllogictest/test_files/aggregates_topk.slt @@ -344,5 +344,123 @@ physical_plan 06)----------DataSourceExec: partitions=1, partition_sizes=[1] +## Test GROUP BY with ORDER BY on the same column (no aggregate functions) +statement ok +CREATE TABLE ids(id int, value int) AS VALUES +(1, 10), +(2, 20), +(3, 30), +(4, 40), +(1, 50), +(2, 60), +(5, 70); + +query TT +explain select id from ids group by id order by id desc limit 3; +---- +logical_plan +01)Sort: ids.id DESC NULLS FIRST, fetch=3 +02)--Aggregate: groupBy=[[ids.id]], aggr=[[]] +03)----TableScan: ids projection=[id] +physical_plan +01)SortPreservingMergeExec: [id@0 DESC], fetch=3 +02)--SortExec: TopK(fetch=3), expr=[id@0 DESC], preserve_partitioning=[true] +03)----AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[], lim=[3] +04)------RepartitionExec: partitioning=Hash([id@0], 4), input_partitions=1 +05)--------AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[], lim=[3] +06)----------DataSourceExec: partitions=1, partition_sizes=[1] + +query I +select id from ids group by id order by id desc limit 3; +---- +5 +4 +3 + +query TT +explain select id from ids group by id order by id asc limit 2; +---- +logical_plan +01)Sort: ids.id ASC NULLS LAST, fetch=2 +02)--Aggregate: groupBy=[[ids.id]], aggr=[[]] +03)----TableScan: ids projection=[id] +physical_plan +01)SortPreservingMergeExec: [id@0 ASC NULLS LAST], fetch=2 +02)--SortExec: TopK(fetch=2), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[], lim=[2] +04)------RepartitionExec: partitioning=Hash([id@0], 4), input_partitions=1 +05)--------AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[], lim=[2] +06)----------DataSourceExec: partitions=1, partition_sizes=[1] + +query I +select id from ids group by id order by id asc limit 2; +---- +1 +2 + +# Test with larger limit than distinct values +query I +select id from ids group by id order by id desc limit 100; +---- +5 +4 +3 +2 +1 + +# Test with bigint group by +statement ok +CREATE TABLE values_table (value INT, category BIGINT) AS VALUES +(10, 100), +(20, 200), +(30, 300), +(40, 400), +(50, 500), +(20, 200), +(10, 100), +(40, 400); + +query TT +explain select category from values_table group by category order by category desc limit 3; +---- +logical_plan +01)Sort: values_table.category DESC NULLS FIRST, fetch=3 +02)--Aggregate: groupBy=[[values_table.category]], aggr=[[]] +03)----TableScan: values_table projection=[category] +physical_plan +01)SortPreservingMergeExec: [category@0 DESC], fetch=3 +02)--SortExec: TopK(fetch=3), expr=[category@0 DESC], preserve_partitioning=[true] +03)----AggregateExec: mode=FinalPartitioned, gby=[category@0 as category], aggr=[], lim=[3] +04)------RepartitionExec: partitioning=Hash([category@0], 4), input_partitions=1 +05)--------AggregateExec: mode=Partial, gby=[category@0 as category], aggr=[], lim=[3] +06)----------DataSourceExec: partitions=1, partition_sizes=[1] + +query I +select category from values_table group by category order by category desc limit 3; +---- +500 +400 +300 + +# Test with integer group by +query I +select value from values_table group by value order by value asc limit 3; +---- +10 +20 +30 + +# Test DISTINCT semantics are preserved +query I +select count(*) from (select category from values_table group by category order by category desc limit 3); +---- +3 + +statement ok +drop table values_table; + +statement ok +drop table ids; + statement ok drop table traces; diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index c31f3d0702358..f675763120718 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -2577,6 +2577,31 @@ NULL NULL NULL NULL NULL NULL +# maintains inner nullability +query ?T +select array_sort(column1), arrow_typeof(array_sort(column1)) +from values + (arrow_cast([], 'List(non-null Int32)')), + (arrow_cast(NULL, 'List(non-null Int32)')), + (arrow_cast([1, 3, 5, -5], 'List(non-null Int32)')) +; +---- +[] List(non-null Int32) +NULL List(non-null Int32) +[-5, 1, 3, 5] List(non-null Int32) + +query ?T +select column1, arrow_typeof(column1) +from values (array_sort(arrow_cast([1, 3, 5, -5], 'LargeList(non-null Int32)'))); +---- +[-5, 1, 3, 5] LargeList(non-null Int32) + +query ?T +select column1, arrow_typeof(column1) +from values (array_sort(arrow_cast([1, 3, 5, -5], 'FixedSizeList(4 x non-null Int32)'))); +---- +[-5, 1, 3, 5] List(non-null Int32) + query ? select array_sort([struct('foo', 3), struct('foo', 1), struct('bar', 1)]) ---- @@ -3231,6 +3256,99 @@ drop table array_repeat_table; statement ok drop table large_array_repeat_table; +# array_repeat: arrays with NULL counts +statement ok +create table array_repeat_null_count_table +as values +(1, 2), +(2, null), +(3, 1), +(4, -1), +(null, null); + +query I? +select column1, array_repeat(column1, column2) from array_repeat_null_count_table; +---- +1 [1, 1] +2 NULL +3 [3] +4 [] +NULL NULL + +statement ok +drop table array_repeat_null_count_table + +# array_repeat: nested arrays with NULL counts +statement ok +create table array_repeat_nested_null_count_table +as values +([[1, 2], [3, 4]], 2), +([[5, 6], [7, 8]], null), +([[null, null], [9, 10]], 1), +(null, 3), +([[11, 12]], -1); + +query ?? +select column1, array_repeat(column1, column2) from array_repeat_nested_null_count_table; +---- +[[1, 2], [3, 4]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]] +[[5, 6], [7, 8]] NULL +[[NULL, NULL], [9, 10]] [[[NULL, NULL], [9, 10]]] +NULL [NULL, NULL, NULL] +[[11, 12]] [] + +statement ok +drop table array_repeat_nested_null_count_table + +# array_repeat edge cases: empty arrays +query ??? +select array_repeat([], 3), array_repeat([], 0), array_repeat([], null); +---- +[[], [], []] [] NULL + +query ?? +select array_repeat(null::int, 0), array_repeat(null::int, null); +---- +[] NULL + +# array_repeat LargeList with NULL count +statement ok +create table array_repeat_large_list_null_table +as values +(arrow_cast([1, 2, 3], 'LargeList(Int64)'), 2), +(arrow_cast([4, 5], 'LargeList(Int64)'), null), +(arrow_cast(null, 'LargeList(Int64)'), 3); + +query ?? +select column1, array_repeat(column1, column2) from array_repeat_large_list_null_table; +---- +[1, 2, 3] [[1, 2, 3], [1, 2, 3]] +[4, 5] NULL +NULL [NULL, NULL, NULL] + +statement ok +drop table array_repeat_large_list_null_table + +# array_repeat edge cases: LargeList nested with NULL count +statement ok +create table array_repeat_large_nested_null_table +as values +(arrow_cast([[1, 2], [3, 4]], 'LargeList(List(Int64))'), 2), +(arrow_cast([[5, 6], [7, 8]], 'LargeList(List(Int64))'), null), +(arrow_cast([[null, null]], 'LargeList(List(Int64))'), 1), +(null, 3); + +query ?? +select column1, array_repeat(column1, column2) from array_repeat_large_nested_null_table; +---- +[[1, 2], [3, 4]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]] +[[5, 6], [7, 8]] NULL +[[NULL, NULL]] [[[NULL, NULL]]] +NULL [NULL, NULL, NULL] + +statement ok +drop table array_repeat_large_nested_null_table + ## array_concat (aliases: `array_cat`, `list_concat`, `list_cat`) # test with empty array @@ -4747,10 +4865,11 @@ select array_union(arrow_cast([], 'LargeList(Int64)'), arrow_cast([], 'LargeList [] # array_union scalar function #7 -query ? -select array_union([[null]], []); ----- -[[]] +# re-enable when https://github.com/apache/arrow-rs/issues/9227 is fixed +# query ? +# select array_union([[null]], []); +# ---- +# [[]] query error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'array_union' function: select array_union(arrow_cast([[null]], 'LargeList(List(Int64))'), arrow_cast([], 'LargeList(Int64)')); @@ -4770,12 +4889,12 @@ select array_union(arrow_cast([[null]], 'LargeList(List(Int64))'), arrow_cast([[ query ? select array_union(null, []); ---- -[] +NULL query ? select array_union(null, arrow_cast([], 'LargeList(Int64)')); ---- -[] +NULL # array_union scalar function #10 query ? @@ -4787,23 +4906,23 @@ NULL query ? select array_union([1, 1, 2, 2, 3, 3], null); ---- -[1, 2, 3] +NULL query ? select array_union(arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)'), null); ---- -[1, 2, 3] +NULL # array_union scalar function #12 query ? select array_union(null, [1, 1, 2, 2, 3, 3]); ---- -[1, 2, 3] +NULL query ? select array_union(null, arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)')); ---- -[1, 2, 3] +NULL # array_union scalar function #13 query ? @@ -4838,6 +4957,36 @@ NULL NULL NULL +query ? +select array_union(arrow_cast(null, 'List(Int64)'), [1, 2]); +---- +NULL + +query ? +select array_union([1, 2], arrow_cast(null, 'List(Int64)')); +---- +NULL + +query ? +select array_intersect(arrow_cast(null, 'List(Int64)'), [1, 2]); +---- +NULL + +query ? +select array_intersect([1, 2], arrow_cast(null, 'List(Int64)')); +---- +NULL + +query ? +select array_except(arrow_cast(null, 'List(Int64)'), [1, 2]); +---- +NULL + +query ? +select array_except([1, 2], arrow_cast(null, 'List(Int64)')); +---- +NULL + # list_to_string scalar function #4 (function alias `array_to_string`) query TTT select list_to_string(['h', 'e', 'l', 'l', 'o'], ','), list_to_string([1, 2, 3, 4, 5], '-'), list_to_string([1.0, 2.0, 3.0], '|'); @@ -6457,10 +6606,9 @@ physical_plan 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] -05)--------ProjectionExec: expr=[] -06)----------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) -07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -08)--------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] +05)--------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]), projection=[] +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] query I with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) @@ -6485,10 +6633,9 @@ physical_plan 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] -05)--------ProjectionExec: expr=[] -06)----------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) -07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -08)--------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] +05)--------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]), projection=[] +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] query I with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) @@ -6513,10 +6660,9 @@ physical_plan 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] -05)--------ProjectionExec: expr=[] -06)----------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) -07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -08)--------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] +05)--------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]), projection=[] +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] query I with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) @@ -6541,10 +6687,9 @@ physical_plan 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] -05)--------ProjectionExec: expr=[] -06)----------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) -07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -08)--------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] +05)--------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]), projection=[] +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] query I with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) @@ -6569,10 +6714,9 @@ physical_plan 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] -05)--------ProjectionExec: expr=[] -06)----------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) -07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -08)--------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] +05)--------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]), projection=[] +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] query I with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) @@ -6599,10 +6743,9 @@ physical_plan 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] -05)--------ProjectionExec: expr=[] -06)----------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IS NOT NULL OR NULL -07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -08)--------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] +05)--------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IS NOT NULL OR NULL, projection=[] +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] # any operator query ? @@ -6689,7 +6832,7 @@ from array_distinct_table_2D; ---- [[1, 2], [3, 4], [5, 6]] [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] -[NULL, [5, 6]] +[[5, 6], NULL] query ? select array_distinct(column1) @@ -6721,7 +6864,7 @@ from array_distinct_table_2D_fixed; ---- [[1, 2], [3, 4], [5, 6]] [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] -[NULL, [5, 6]] +[[5, 6], NULL] query ??? select array_intersect(column1, column2), @@ -6756,7 +6899,7 @@ select array_intersect(column1, column2), array_intersect(column5, column6) from array_intersect_table_1D_Boolean; ---- -[] [false, true] [false] +[] [true, false] [false] [false] [true] [true] query ??? @@ -6765,7 +6908,7 @@ select array_intersect(column1, column2), array_intersect(column5, column6) from large_array_intersect_table_1D_Boolean; ---- -[] [false, true] [false] +[] [true, false] [false] [false] [true] [true] query ??? @@ -6774,8 +6917,8 @@ select array_intersect(column1, column2), array_intersect(column5, column6) from array_intersect_table_1D_UTF8; ---- -[bc] [arrow, rust] [] -[] [arrow, datafusion, rust] [arrow, rust] +[bc] [rust, arrow] [] +[] [datafusion, rust, arrow] [rust, arrow] query ??? select array_intersect(column1, column2), @@ -6783,8 +6926,8 @@ select array_intersect(column1, column2), array_intersect(column5, column6) from large_array_intersect_table_1D_UTF8; ---- -[bc] [arrow, rust] [] -[] [arrow, datafusion, rust] [arrow, rust] +[bc] [rust, arrow] [] +[] [datafusion, rust, arrow] [rust, arrow] query ? select array_intersect(column1, column2) @@ -6888,27 +7031,27 @@ select array_intersect(arrow_cast([], 'LargeList(Int64)'), arrow_cast([], 'Large query ? select array_intersect([1, 1, 2, 2, 3, 3], null); ---- -[] +NULL query ? select array_intersect(arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)'), null); ---- -[] +NULL query ? select array_intersect(null, [1, 1, 2, 2, 3, 3]); ---- -[] +NULL query ? select array_intersect(null, arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)')); ---- -[] +NULL query ? select array_intersect([], null); ---- -[] +NULL query ? select array_intersect([[1,2,3]], [[]]); @@ -6923,17 +7066,17 @@ select array_intersect([[null]], [[]]); query ? select array_intersect(arrow_cast([], 'LargeList(Int64)'), null); ---- -[] +NULL query ? select array_intersect(null, []); ---- -[] +NULL query ? select array_intersect(null, arrow_cast([], 'LargeList(Int64)')); ---- -[] +NULL query ? select array_intersect(null, null); @@ -7189,12 +7332,12 @@ select generate_series('2021-01-01'::timestamp, '2021-01-01T15:00:00'::timestamp # Other timestamp types are coerced to nanosecond query ? -select generate_series(arrow_cast('2021-01-01'::timestamp, 'Timestamp(Second, None)'), '2021-01-01T15:00:00'::timestamp, INTERVAL '1' HOUR); +select generate_series(arrow_cast('2021-01-01'::timestamp, 'Timestamp(s)'), '2021-01-01T15:00:00'::timestamp, INTERVAL '1' HOUR); ---- [2021-01-01T00:00:00, 2021-01-01T01:00:00, 2021-01-01T02:00:00, 2021-01-01T03:00:00, 2021-01-01T04:00:00, 2021-01-01T05:00:00, 2021-01-01T06:00:00, 2021-01-01T07:00:00, 2021-01-01T08:00:00, 2021-01-01T09:00:00, 2021-01-01T10:00:00, 2021-01-01T11:00:00, 2021-01-01T12:00:00, 2021-01-01T13:00:00, 2021-01-01T14:00:00, 2021-01-01T15:00:00] query ? -select generate_series('2021-01-01'::timestamp, arrow_cast('2021-01-01T15:00:00'::timestamp, 'Timestamp(Microsecond, None)'), INTERVAL '1' HOUR); +select generate_series('2021-01-01'::timestamp, arrow_cast('2021-01-01T15:00:00'::timestamp, 'Timestamp(µs)'), INTERVAL '1' HOUR); ---- [2021-01-01T00:00:00, 2021-01-01T01:00:00, 2021-01-01T02:00:00, 2021-01-01T03:00:00, 2021-01-01T04:00:00, 2021-01-01T05:00:00, 2021-01-01T06:00:00, 2021-01-01T07:00:00, 2021-01-01T08:00:00, 2021-01-01T09:00:00, 2021-01-01T10:00:00, 2021-01-01T11:00:00, 2021-01-01T12:00:00, 2021-01-01T13:00:00, 2021-01-01T14:00:00, 2021-01-01T15:00:00] @@ -7476,7 +7619,7 @@ select array_except(column1, column2) from array_except_table; [2] [] NULL -[1, 2] +NULL NULL statement ok @@ -7497,7 +7640,7 @@ select array_except(column1, column2) from array_except_nested_list_table; ---- [[1, 2]] [[3]] -[[1, 2], [3]] +NULL NULL [] @@ -7536,7 +7679,7 @@ select array_except(column1, column2) from array_except_table_ut8; ---- [b, c] [a, bc] -[a, bc, def] +NULL NULL statement ok @@ -7558,7 +7701,7 @@ select array_except(column1, column2) from array_except_table_bool; [true] [true] [false] -[true, false] +NULL NULL statement ok @@ -7567,7 +7710,7 @@ drop table array_except_table_bool; query ? select array_except([], null); ---- -[] +NULL query ? select array_except([], []); diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt b/datafusion/sqllogictest/test_files/arrow_typeof.slt index ee1f204664a14..0c69e8591c3a4 100644 --- a/datafusion/sqllogictest/test_files/arrow_typeof.slt +++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt @@ -95,7 +95,7 @@ SELECT arrow_cast('1', 'Int16') query error SELECT arrow_cast('1') -query error Expect TypeSignatureClass::Native\(LogicalType\(Native\(String\), String\)\) but received NativeType::Int64, DataType: Int64 +query error DataFusion error: Error during planning: Function 'arrow_cast' requires TypeSignatureClass::Native\(LogicalType\(Native\(String\), String\)\), but received Int64 \(DataType: Int64\) SELECT arrow_cast('1', 43) query error DataFusion error: Execution error: arrow_cast requires its second argument to be a non\-empty constant string @@ -123,10 +123,10 @@ SELECT arrow_typeof(arrow_cast('foo', 'Utf8View')) as col_utf8_view, arrow_typeof(arrow_cast('foo', 'Binary')) as col_binary, arrow_typeof(arrow_cast('foo', 'LargeBinary')) as col_large_binary, - arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Second, None)')) as col_ts_s, - arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Millisecond, None)')) as col_ts_ms, - arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Microsecond, None)')) as col_ts_us, - arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Nanosecond, None)')) as col_ts_ns, + arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(s)')) as col_ts_s, + arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(ms)')) as col_ts_ms, + arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(µs)')) as col_ts_us, + arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(ns)')) as col_ts_ns, arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Second, Some("+08:00"))')) as col_tstz_s, arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Millisecond, Some("+08:00"))')) as col_tstz_ms, arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Microsecond, Some("+08:00"))')) as col_tstz_us, @@ -242,10 +242,10 @@ drop table foo statement ok create table foo as select - arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Second, None)') as col_ts_s, - arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Millisecond, None)') as col_ts_ms, - arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Microsecond, None)') as col_ts_us, - arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Nanosecond, None)') as col_ts_ns + arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(s)') as col_ts_s, + arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(ms)') as col_ts_ms, + arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(µs)') as col_ts_us, + arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(ns)') as col_ts_ns ; ## Ensure each column in the table has the expected type diff --git a/datafusion/sqllogictest/test_files/binary.slt b/datafusion/sqllogictest/test_files/binary.slt index 1077c32e46f35..c4a21deeff26b 100644 --- a/datafusion/sqllogictest/test_files/binary.slt +++ b/datafusion/sqllogictest/test_files/binary.slt @@ -311,3 +311,13 @@ Foo foo Foo foo NULL NULL NULL NULL Bar Bar Bar Bar FooBar fooBar FooBar fooBar + +# show helpful error msg when Binary type is used with string functions +query error DataFusion error: Error during planning: Function 'split_part' requires TypeSignatureClass::Native\(LogicalType\(Native\(String\), String\)\), but received Binary \(DataType: Binary\)\.\n\nHint: Binary types are not automatically coerced to String\. Use CAST\(column AS VARCHAR\) to convert Binary data to String\. +SELECT split_part(binary, '~', 2) FROM t WHERE binary IS NOT NULL LIMIT 1; + +# ensure the suggested CAST workaround works +query T +SELECT split_part(CAST(binary AS VARCHAR), 'o', 2) FROM t WHERE binary = X'466f6f'; +---- +(empty) diff --git a/datafusion/sqllogictest/test_files/case.slt b/datafusion/sqllogictest/test_files/case.slt index 074d216ac7524..3953878ceb666 100644 --- a/datafusion/sqllogictest/test_files/case.slt +++ b/datafusion/sqllogictest/test_files/case.slt @@ -383,9 +383,10 @@ SELECT column2, column3, column4 FROM t; ---- {foo: a, xxx: b} {xxx: c, foo: d} {xxx: e} -# coerce structs with different field orders, -# (note the *value*s are from column2 but the field name is 'xxx', as the coerced -# type takes the field name from the last argument (column3) +# coerce structs with different field orders +# With name-based struct coercion, matching fields by name: +# column2={foo:a, xxx:b} unified with column3={xxx:c, foo:d} +# Result uses the THEN branch's field order (when executed): {xxx: b, foo: a} query ? SELECT case @@ -394,9 +395,10 @@ SELECT end FROM t; ---- -{xxx: a, foo: b} +{xxx: b, foo: a} # coerce structs with different field orders +# When ELSE branch executes, uses its field order: {xxx: c, foo: d} query ? SELECT case @@ -407,8 +409,9 @@ FROM t; ---- {xxx: c, foo: d} -# coerce structs with subset of fields -query error Failed to coerce then +# coerce structs with subset of fields - field count mismatch causes type coercion failure +# column3 has 2 fields but column4 has only 1 field +query error DataFusion error: type_coercion\ncaused by\nError during planning: Failed to coerce then .* and else .* to common types in CASE WHEN expression SELECT case when column1 > 0 then column3 @@ -618,6 +621,59 @@ a b c +query I +SELECT CASE WHEN d != 0 THEN n / d ELSE NULL END FROM (VALUES (1, 1), (1, 0), (1, -1)) t(n,d) +---- +1 +NULL +-1 + +query I +SELECT CASE WHEN d > 0 THEN n / d ELSE NULL END FROM (VALUES (1, 1), (1, 0), (1, -1)) t(n,d) +---- +1 +NULL +NULL + +query I +SELECT CASE WHEN d < 0 THEN n / d ELSE NULL END FROM (VALUES (1, 1), (1, 0), (1, -1)) t(n,d) +---- +NULL +NULL +-1 + +# single WHEN, no ELSE (absent) +query I +SELECT CASE WHEN a > 0 THEN b END +FROM (VALUES (1, 10), (0, 20)) AS t(a, b); +---- +10 +NULL + +# single WHEN, explicit ELSE NULL +query I +SELECT CASE WHEN a > 0 THEN b ELSE NULL END +FROM (VALUES (1, 10), (0, 20)) AS t(a, b); +---- +10 +NULL + +# fallible THEN expression should only be evaluated on true rows +query I +SELECT CASE WHEN a > 0 THEN 10 / a END +FROM (VALUES (1), (0)) AS t(a); +---- +10 +NULL + +# all-false path returns typed NULLs +query I +SELECT CASE WHEN a < 0 THEN b END +FROM (VALUES (1, 10), (2, 20)) AS t(a, b); +---- +NULL +NULL + # EvalMethod::WithExpression using subset of all selected columns in case expression query III SELECT CASE a1 WHEN 1 THEN a1 WHEN 2 THEN a2 WHEN 3 THEN b END, b, c diff --git a/datafusion/sqllogictest/test_files/clickbench.slt b/datafusion/sqllogictest/test_files/clickbench.slt index 4c60a4365ee26..42b7cfafdaa63 100644 --- a/datafusion/sqllogictest/test_files/clickbench.slt +++ b/datafusion/sqllogictest/test_files/clickbench.slt @@ -26,10 +26,28 @@ # COPY (SELECT * FROM 'hits.parquet' LIMIT 10) TO 'clickbench_hits_10.parquet' (FORMAT PARQUET); statement ok -CREATE EXTERNAL TABLE hits +CREATE EXTERNAL TABLE hits_raw STORED AS PARQUET LOCATION '../core/tests/data/clickbench_hits_10.parquet'; +# ClickBench encodes EventDate as UInt16 days since epoch. +statement ok +CREATE VIEW hits AS +SELECT * EXCEPT ("EventDate"), + CAST(CAST("EventDate" AS INTEGER) AS DATE) AS "EventDate" +FROM hits_raw; + +# Verify EventDate transformation from UInt16 to DATE +query D +SELECT "EventDate" FROM hits LIMIT 1; +---- +2013-07-15 + +# Verify the raw value is still UInt16 in hits_raw +query I +SELECT "EventDate" FROM hits_raw LIMIT 1; +---- +15901 # queries.sql came from # https://github.com/ClickHouse/ClickBench/blob/8b9e3aa05ea18afa427f14909ddc678b8ef0d5e6/datafusion/queries.sql @@ -64,10 +82,10 @@ SELECT COUNT(DISTINCT "SearchPhrase") FROM hits; ---- 1 -query II +query DD SELECT MIN("EventDate"), MAX("EventDate") FROM hits; ---- -15901 15901 +2013-07-15 2013-07-15 query II SELECT "AdvEngineID", COUNT(*) FROM hits WHERE "AdvEngineID" <> 0 GROUP BY "AdvEngineID" ORDER BY COUNT(*) DESC; @@ -167,7 +185,8 @@ query TTTII SELECT "SearchPhrase", MIN("URL"), MIN("Title"), COUNT(*) AS c, COUNT(DISTINCT "UserID") FROM hits WHERE "Title" LIKE '%Google%' AND "URL" NOT LIKE '%.google.%' AND "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10; ---- -query IITIIIIIIIIIITTIIIIIIIIIITIIITIIIITTIIITIIIIIIIIIITIIIIITIIIIIITIIIIIIIIIITTTTIIIIIIIITITTITTTTTTTTTTIIII +query IITIIIIIIIIITTIIIIIIIIIITIIITIIIITTIIITIIIIIIIIIITIIIIITIIIIIITIIIIIIIIIITTTTIIIIIIIITITTITTTTTTTTTTIIIID + SELECT * FROM hits WHERE "URL" LIKE '%google%' ORDER BY "EventTime" LIMIT 10; ---- @@ -262,7 +281,7 @@ query IIITTI SELECT "TraficSourceID", "SearchEngineID", "AdvEngineID", CASE WHEN ("SearchEngineID" = 0 AND "AdvEngineID" = 0) THEN "Referer" ELSE '' END AS Src, "URL" AS Dst, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "IsRefresh" = 0 GROUP BY "TraficSourceID", "SearchEngineID", "AdvEngineID", Src, Dst ORDER BY PageViews DESC LIMIT 10 OFFSET 1000; ---- -query III +query IDI SELECT "URLHash", "EventDate", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "IsRefresh" = 0 AND "TraficSourceID" IN (-1, 6) AND "RefererHash" = 3594120000172545465 GROUP BY "URLHash", "EventDate" ORDER BY PageViews DESC LIMIT 10 OFFSET 100; ---- @@ -293,4 +312,7 @@ SELECT "BrowserCountry", COUNT(DISTINCT "SocialNetwork"), COUNT(DISTINCT "HitCo statement ok -drop table hits; +drop view hits; + +statement ok +drop table hits_raw; diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index 3dac92938772c..4fd77be045c13 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -42,6 +42,63 @@ physical_plan statement error DataFusion error: Error during planning: WITH query name "a" specified more than once WITH a AS (SELECT 1), a AS (SELECT 2) SELECT * FROM a; +statement ok +CREATE TABLE orders AS VALUES (1), (2); + +########## +## CTE Reference Resolution +########## + +# These tests exercise CTE reference resolution with and without identifier +# normalization. The session is configured with a strict catalog/schema provider +# (see `datafusion/sqllogictest/src/test_context.rs`) that only provides the +# `orders` table and panics on any unexpected table lookup. +# +# This makes it observable if DataFusion incorrectly treats a CTE reference as a +# catalog lookup. +# +# Refs: https://github.com/apache/datafusion/issues/18932 +# +# NOTE: This test relies on a strict catalog/schema provider registered in +# `datafusion/sqllogictest/src/test_context.rs` that provides only the `orders` +# table and panics on unexpected lookups. + +statement ok +set datafusion.sql_parser.enable_ident_normalization = true; + +query I +with barbaz as (select * from orders) select * from "barbaz"; +---- +1 +2 + +query I +with BarBaz as (select * from orders) select * from "barbaz"; +---- +1 +2 + +query I +with barbaz as (select * from orders) select * from barbaz; +---- +1 +2 + +statement ok +set datafusion.sql_parser.enable_ident_normalization = false; + +query I +with barbaz as (select * from orders) select * from "barbaz"; +---- +1 +2 + +query I +with barbaz as (select * from orders) select * from barbaz; +---- +1 +2 + # Test disabling recursive CTE statement ok set datafusion.execution.enable_recursive_ctes = false; @@ -996,7 +1053,7 @@ query TT explain WITH RECURSIVE numbers AS ( select 1 as n UNION ALL - select n + 1 FROM numbers WHERE N < 10 + select n + 1 FROM numbers WHERE n < 10 ) select * from numbers; ---- logical_plan @@ -1021,7 +1078,7 @@ query TT explain WITH RECURSIVE numbers AS ( select 1 as n UNION ALL - select n + 1 FROM numbers WHERE N < 10 + select n + 1 FROM numbers WHERE n < 10 ) select * from numbers; ---- logical_plan @@ -1160,5 +1217,5 @@ query error DataFusion error: This feature is not implemented: Recursive CTEs ar explain WITH RECURSIVE numbers AS ( select 1 as n UNION ALL - select n + 1 FROM numbers WHERE N < 10 + select n + 1 FROM numbers WHERE n < 10 ) select * from numbers; diff --git a/datafusion/sqllogictest/test_files/datetime/arith_date_time.slt b/datafusion/sqllogictest/test_files/datetime/arith_date_time.slt index bc796a51ff5a4..8e85c8f90580e 100644 --- a/datafusion/sqllogictest/test_files/datetime/arith_date_time.slt +++ b/datafusion/sqllogictest/test_files/datetime/arith_date_time.slt @@ -113,4 +113,3 @@ SELECT '2001-09-28'::date / '03:00'::time query error Invalid timestamp arithmetic operation SELECT '2001-09-28'::date % '03:00'::time - diff --git a/datafusion/sqllogictest/test_files/datetime/arith_timestamp_duration.slt b/datafusion/sqllogictest/test_files/datetime/arith_timestamp_duration.slt index 10381346f8359..aeeebe73db701 100644 --- a/datafusion/sqllogictest/test_files/datetime/arith_timestamp_duration.slt +++ b/datafusion/sqllogictest/test_files/datetime/arith_timestamp_duration.slt @@ -144,4 +144,4 @@ query error Invalid timestamp arithmetic operation SELECT '2001-09-28T01:00:00'::timestamp % arrow_cast(12345, 'Duration(Second)'); query error Invalid timestamp arithmetic operation -SELECT '2001-09-28T01:00:00'::timestamp / arrow_cast(12345, 'Duration(Second)'); \ No newline at end of file +SELECT '2001-09-28T01:00:00'::timestamp / arrow_cast(12345, 'Duration(Second)'); diff --git a/datafusion/sqllogictest/test_files/datetime/date_part.slt b/datafusion/sqllogictest/test_files/datetime/date_part.slt index bee8602d80bd2..79d6d8ac05098 100644 --- a/datafusion/sqllogictest/test_files/datetime/date_part.slt +++ b/datafusion/sqllogictest/test_files/datetime/date_part.slt @@ -19,7 +19,7 @@ # for the same function). -## Begin tests fo rdate_part with columns and timestamp's with timezones +## Begin tests for date_part with columns and timestamp's with timezones # Source data table has # timestamps with millisecond (very common timestamp precision) and nanosecond (maximum precision) timestamps @@ -40,30 +40,32 @@ with t as (values ) SELECT -- nanoseconds, with no, utc, and local timezone - arrow_cast(column1, 'Timestamp(Nanosecond, None)') as ts_nano_no_tz, + arrow_cast(column1, 'Timestamp(ns)') as ts_nano_no_tz, + arrow_cast(column1, 'Timestamp(Nanosecond, None)') as ts_nano_no_tz_old_format, arrow_cast(column1, 'Timestamp(Nanosecond, Some("UTC"))') as ts_nano_utc, arrow_cast(column1, 'Timestamp(Nanosecond, Some("America/New_York"))') as ts_nano_eastern, -- milliseconds, with no, utc, and local timezone - arrow_cast(column1, 'Timestamp(Millisecond, None)') as ts_milli_no_tz, + arrow_cast(column1, 'Timestamp(ms)') as ts_milli_no_tz, + arrow_cast(column1, 'Timestamp(Millisecond, None)') as ts_milli_no_tz_old_format, arrow_cast(column1, 'Timestamp(Millisecond, Some("UTC"))') as ts_milli_utc, arrow_cast(column1, 'Timestamp(Millisecond, Some("America/New_York"))') as ts_milli_eastern FROM t; -query PPPPPP +query PPPPPPPP SELECT * FROM source_ts; ---- -2020-01-01T00:00:00 2020-01-01T00:00:00Z 2019-12-31T19:00:00-05:00 2020-01-01T00:00:00 2020-01-01T00:00:00Z 2019-12-31T19:00:00-05:00 -2021-01-01T00:00:00 2021-01-01T00:00:00Z 2020-12-31T19:00:00-05:00 2021-01-01T00:00:00 2021-01-01T00:00:00Z 2020-12-31T19:00:00-05:00 -2020-09-01T00:00:00 2020-09-01T00:00:00Z 2020-08-31T20:00:00-04:00 2020-09-01T00:00:00 2020-09-01T00:00:00Z 2020-08-31T20:00:00-04:00 -2020-01-25T00:00:00 2020-01-25T00:00:00Z 2020-01-24T19:00:00-05:00 2020-01-25T00:00:00 2020-01-25T00:00:00Z 2020-01-24T19:00:00-05:00 -2020-01-24T00:00:00 2020-01-24T00:00:00Z 2020-01-23T19:00:00-05:00 2020-01-24T00:00:00 2020-01-24T00:00:00Z 2020-01-23T19:00:00-05:00 -2020-01-01T12:00:00 2020-01-01T12:00:00Z 2020-01-01T07:00:00-05:00 2020-01-01T12:00:00 2020-01-01T12:00:00Z 2020-01-01T07:00:00-05:00 -2020-01-01T00:30:00 2020-01-01T00:30:00Z 2019-12-31T19:30:00-05:00 2020-01-01T00:30:00 2020-01-01T00:30:00Z 2019-12-31T19:30:00-05:00 -2020-01-01T00:00:30 2020-01-01T00:00:30Z 2019-12-31T19:00:30-05:00 2020-01-01T00:00:30 2020-01-01T00:00:30Z 2019-12-31T19:00:30-05:00 -2020-01-01T00:00:00.123 2020-01-01T00:00:00.123Z 2019-12-31T19:00:00.123-05:00 2020-01-01T00:00:00.123 2020-01-01T00:00:00.123Z 2019-12-31T19:00:00.123-05:00 -2020-01-01T00:00:00.123456 2020-01-01T00:00:00.123456Z 2019-12-31T19:00:00.123456-05:00 2020-01-01T00:00:00.123 2020-01-01T00:00:00.123Z 2019-12-31T19:00:00.123-05:00 -2020-01-01T00:00:00.123456789 2020-01-01T00:00:00.123456789Z 2019-12-31T19:00:00.123456789-05:00 2020-01-01T00:00:00.123 2020-01-01T00:00:00.123Z 2019-12-31T19:00:00.123-05:00 +2020-01-01T00:00:00 2020-01-01T00:00:00 2020-01-01T00:00:00Z 2019-12-31T19:00:00-05:00 2020-01-01T00:00:00 2020-01-01T00:00:00 2020-01-01T00:00:00Z 2019-12-31T19:00:00-05:00 +2021-01-01T00:00:00 2021-01-01T00:00:00 2021-01-01T00:00:00Z 2020-12-31T19:00:00-05:00 2021-01-01T00:00:00 2021-01-01T00:00:00 2021-01-01T00:00:00Z 2020-12-31T19:00:00-05:00 +2020-09-01T00:00:00 2020-09-01T00:00:00 2020-09-01T00:00:00Z 2020-08-31T20:00:00-04:00 2020-09-01T00:00:00 2020-09-01T00:00:00 2020-09-01T00:00:00Z 2020-08-31T20:00:00-04:00 +2020-01-25T00:00:00 2020-01-25T00:00:00 2020-01-25T00:00:00Z 2020-01-24T19:00:00-05:00 2020-01-25T00:00:00 2020-01-25T00:00:00 2020-01-25T00:00:00Z 2020-01-24T19:00:00-05:00 +2020-01-24T00:00:00 2020-01-24T00:00:00 2020-01-24T00:00:00Z 2020-01-23T19:00:00-05:00 2020-01-24T00:00:00 2020-01-24T00:00:00 2020-01-24T00:00:00Z 2020-01-23T19:00:00-05:00 +2020-01-01T12:00:00 2020-01-01T12:00:00 2020-01-01T12:00:00Z 2020-01-01T07:00:00-05:00 2020-01-01T12:00:00 2020-01-01T12:00:00 2020-01-01T12:00:00Z 2020-01-01T07:00:00-05:00 +2020-01-01T00:30:00 2020-01-01T00:30:00 2020-01-01T00:30:00Z 2019-12-31T19:30:00-05:00 2020-01-01T00:30:00 2020-01-01T00:30:00 2020-01-01T00:30:00Z 2019-12-31T19:30:00-05:00 +2020-01-01T00:00:30 2020-01-01T00:00:30 2020-01-01T00:00:30Z 2019-12-31T19:00:30-05:00 2020-01-01T00:00:30 2020-01-01T00:00:30 2020-01-01T00:00:30Z 2019-12-31T19:00:30-05:00 +2020-01-01T00:00:00.123 2020-01-01T00:00:00.123 2020-01-01T00:00:00.123Z 2019-12-31T19:00:00.123-05:00 2020-01-01T00:00:00.123 2020-01-01T00:00:00.123 2020-01-01T00:00:00.123Z 2019-12-31T19:00:00.123-05:00 +2020-01-01T00:00:00.123456 2020-01-01T00:00:00.123456 2020-01-01T00:00:00.123456Z 2019-12-31T19:00:00.123456-05:00 2020-01-01T00:00:00.123 2020-01-01T00:00:00.123 2020-01-01T00:00:00.123Z 2019-12-31T19:00:00.123-05:00 +2020-01-01T00:00:00.123456789 2020-01-01T00:00:00.123456789 2020-01-01T00:00:00.123456789Z 2019-12-31T19:00:00.123456789-05:00 2020-01-01T00:00:00.123 2020-01-01T00:00:00.123 2020-01-01T00:00:00.123Z 2019-12-31T19:00:00.123-05:00 # date_part (year) with columns and explicit timestamp query IIIIII @@ -81,6 +83,23 @@ SELECT date_part('year', ts_nano_no_tz), date_part('year', ts_nano_utc), date_pa 2020 2020 2019 2020 2020 2019 2020 2020 2019 2020 2020 2019 +# date_part (isoyear) with columns and explicit timestamp +query IIIIII +SELECT date_part('isoyear', ts_nano_no_tz), date_part('isoyear', ts_nano_utc), date_part('isoyear', ts_nano_eastern), date_part('isoyear', ts_milli_no_tz), date_part('isoyear', ts_milli_utc), date_part('isoyear', ts_milli_eastern) FROM source_ts; +---- +2020 2020 2020 2020 2020 2020 +2020 2020 2020 2020 2020 2020 +2020 2020 2020 2020 2020 2020 +2020 2020 2020 2020 2020 2020 +2020 2020 2020 2020 2020 2020 +2020 2020 2020 2020 2020 2020 +2020 2020 2020 2020 2020 2020 +2020 2020 2020 2020 2020 2020 +2020 2020 2020 2020 2020 2020 +2020 2020 2020 2020 2020 2020 +2020 2020 2020 2020 2020 2020 + + # date_part (month) query IIIIII SELECT date_part('month', ts_nano_no_tz), date_part('month', ts_nano_utc), date_part('month', ts_nano_eastern), date_part('month', ts_milli_no_tz), date_part('month', ts_milli_utc), date_part('month', ts_milli_eastern) FROM source_ts; @@ -228,6 +247,26 @@ SELECT EXTRACT('year' FROM timestamp '2020-09-08T12:00:00+00:00') ---- 2020 +query I +SELECT date_part('ISOYEAR', CAST('2000-01-01' AS DATE)) +---- +1999 + +query I +SELECT EXTRACT(isoyear FROM timestamp '2020-09-08T12:00:00+00:00') +---- +2020 + +query I +SELECT EXTRACT("isoyear" FROM timestamp '2020-09-08T12:00:00+00:00') +---- +2020 + +query I +SELECT EXTRACT('isoyear' FROM timestamp '2020-09-08T12:00:00+00:00') +---- +2020 + query I SELECT date_part('QUARTER', CAST('2000-01-01' AS DATE)) ---- @@ -865,9 +904,15 @@ SELECT extract(month from arrow_cast('20 months', 'Interval(YearMonth)')) ---- 8 +query error DataFusion error: Arrow error: Compute error: YearISO does not support: Interval\(YearMonth\) +SELECT extract(isoyear from arrow_cast('10 years', 'Interval(YearMonth)')) + query error DataFusion error: Arrow error: Compute error: Year does not support: Interval\(DayTime\) SELECT extract(year from arrow_cast('10 days', 'Interval(DayTime)')) +query error DataFusion error: Arrow error: Compute error: YearISO does not support: Interval\(DayTime\) +SELECT extract(isoyear from arrow_cast('10 days', 'Interval(DayTime)')) + query error DataFusion error: Arrow error: Compute error: Month does not support: Interval\(DayTime\) SELECT extract(month from arrow_cast('10 days', 'Interval(DayTime)')) @@ -936,6 +981,57 @@ SELECT extract(second from arrow_cast(NULL, 'Interval(MonthDayNano)')) ---- NULL +# extract epoch from intervals +query R +SELECT extract(epoch from interval '15 minutes') +---- +900 + +query R +SELECT extract(epoch from interval '1 hour') +---- +3600 + +query R +SELECT extract(epoch from interval '1 day') +---- +86400 + +query R +SELECT extract(epoch from interval '1 month') +---- +2592000 + +query R +SELECT extract(epoch from arrow_cast('3 days', 'Interval(DayTime)')) +---- +259200 + +query R +SELECT extract(epoch from arrow_cast('100 milliseconds', 'Interval(MonthDayNano)')) +---- +0.1 + +query R +SELECT extract(epoch from arrow_cast('500 microseconds', 'Interval(MonthDayNano)')) +---- +0.0005 + +query R +SELECT extract(epoch from arrow_cast('2500 nanoseconds', 'Interval(MonthDayNano)')) +---- +0.0000025 + +query R +SELECT extract(epoch from arrow_cast('1 month 2 days 500 milliseconds', 'Interval(MonthDayNano)')) +---- +2764800.5 + +query R +SELECT extract(epoch from arrow_cast('2 months', 'Interval(YearMonth)')) +---- +5184000 + statement ok create table t (id int, i interval) as values (0, interval '5 months 1 day 10 nanoseconds'), @@ -1011,6 +1107,9 @@ SELECT extract(month from arrow_cast(864000, 'Duration(Second)')) query error DataFusion error: Arrow error: Compute error: Year does not support: Duration\(s\) SELECT extract(year from arrow_cast(864000, 'Duration(Second)')) +query error DataFusion error: Arrow error: Compute error: YearISO does not support: Duration\(s\) +SELECT extract(isoyear from arrow_cast(864000, 'Duration(Second)')) + query I SELECT extract(day from arrow_cast(NULL, 'Duration(Second)')) ---- @@ -1023,6 +1122,11 @@ SELECT (date_part('year', now()) = EXTRACT(year FROM now())) ---- true +query B +SELECT (date_part('isoyear', now()) = EXTRACT(isoyear FROM now())) +---- +true + query B SELECT (date_part('quarter', now()) = EXTRACT(quarter FROM now())) ---- @@ -1090,3 +1194,563 @@ query I SELECT EXTRACT('isodow' FROM to_timestamp('2020-09-08T12:00:00+00:00')) ---- 1 + +## Preimage tests + +statement ok +create table t1(c DATE) as VALUES (NULL), ('1990-01-01'), ('2024-01-01'), ('2030-01-01'); + +# Simple optimizations, col on LHS + +query D +select c from t1 where extract(year from c) = 2024; +---- +2024-01-01 + +query D +select c from t1 where extract(year from c) <> 2024; +---- +1990-01-01 +2030-01-01 + +query D +select c from t1 where extract(year from c) > 2024; +---- +2030-01-01 + +query D +select c from t1 where extract(year from c) < 2024; +---- +1990-01-01 + +query D +select c from t1 where extract(year from c) >= 2024; +---- +2024-01-01 +2030-01-01 + +query D +select c from t1 where extract(year from c) <= 2024; +---- +1990-01-01 +2024-01-01 + +query D +select c from t1 where extract(year from c) is not distinct from 2024 +---- +2024-01-01 + +query D +select c from t1 where extract(year from c) is distinct from 2024 +---- +NULL +1990-01-01 +2030-01-01 + +# IN list optimization +query D +select c from t1 where extract(year from c) in (1990, 2024); +---- +1990-01-01 +2024-01-01 + +# NOT IN list optimization (NULL does not satisfy NOT IN) +query D +select c from t1 where extract(year from c) not in (1990, 2024); +---- +2030-01-01 + +# Check that date_part is not in the explain statements + +query TT +explain select c from t1 where extract (year from c) = 2024 +---- +logical_plan +01)Filter: t1.c >= Date32("2024-01-01") AND t1.c < Date32("2025-01-01") +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: c@0 >= 2024-01-01 AND c@0 < 2025-01-01 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (year from c) <> 2024 +---- +logical_plan +01)Filter: t1.c < Date32("2024-01-01") OR t1.c >= Date32("2025-01-01") +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: c@0 < 2024-01-01 OR c@0 >= 2025-01-01 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (year from c) > 2024 +---- +logical_plan +01)Filter: t1.c >= Date32("2025-01-01") +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: c@0 >= 2025-01-01 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (year from c) < 2024 +---- +logical_plan +01)Filter: t1.c < Date32("2024-01-01") +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: c@0 < 2024-01-01 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (year from c) >= 2024 +---- +logical_plan +01)Filter: t1.c >= Date32("2024-01-01") +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: c@0 >= 2024-01-01 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (year from c) <= 2024 +---- +logical_plan +01)Filter: t1.c < Date32("2025-01-01") +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: c@0 < 2025-01-01 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (year from c) is not distinct from 2024 +---- +logical_plan +01)Filter: t1.c IS NOT NULL AND t1.c >= Date32("2024-01-01") AND t1.c < Date32("2025-01-01") +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: c@0 IS NOT NULL AND c@0 >= 2024-01-01 AND c@0 < 2025-01-01 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (year from c) is distinct from 2024 +---- +logical_plan +01)Filter: t1.c < Date32("2024-01-01") OR t1.c >= Date32("2025-01-01") OR t1.c IS NULL +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: c@0 < 2024-01-01 OR c@0 >= 2025-01-01 OR c@0 IS NULL +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (year from c) in (1990, 2024) +---- +logical_plan +01)Filter: t1.c >= Date32("1990-01-01") AND t1.c < Date32("1991-01-01") OR t1.c >= Date32("2024-01-01") AND t1.c < Date32("2025-01-01") +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: c@0 >= 1990-01-01 AND c@0 < 1991-01-01 OR c@0 >= 2024-01-01 AND c@0 < 2025-01-01 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +# Simple optimizations, column on RHS + +query D +select c from t1 where 2024 = extract(year from c); +---- +2024-01-01 + +query D +select c from t1 where 2024 <> extract(year from c); +---- +1990-01-01 +2030-01-01 + +query D +select c from t1 where 2024 < extract(year from c); +---- +2030-01-01 + +query D +select c from t1 where 2024 > extract(year from c); +---- +1990-01-01 + +query D +select c from t1 where 2024 <= extract(year from c); +---- +2024-01-01 +2030-01-01 + +query D +select c from t1 where 2024 >= extract(year from c); +---- +1990-01-01 +2024-01-01 + +query D +select c from t1 where 2024 is not distinct from extract(year from c); +---- +2024-01-01 + +query D +select c from t1 where 2024 is distinct from extract(year from c); +---- +NULL +1990-01-01 +2030-01-01 + +# Check explain statements for optimizations for other interval types + +query TT +explain select c from t1 where extract (quarter from c) = 2024 +---- +logical_plan +01)Filter: date_part(Utf8("QUARTER"), t1.c) = Int32(2024) +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: date_part(QUARTER, c@0) = 2024 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (month from c) = 2024 +---- +logical_plan +01)Filter: date_part(Utf8("MONTH"), t1.c) = Int32(2024) +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: date_part(MONTH, c@0) = 2024 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (week from c) = 2024 +---- +logical_plan +01)Filter: date_part(Utf8("WEEK"), t1.c) = Int32(2024) +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: date_part(WEEK, c@0) = 2024 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (day from c) = 2024 +---- +logical_plan +01)Filter: date_part(Utf8("DAY"), t1.c) = Int32(2024) +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: date_part(DAY, c@0) = 2024 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (hour from c) = 2024 +---- +logical_plan +01)Filter: date_part(Utf8("HOUR"), t1.c) = Int32(2024) +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: date_part(HOUR, c@0) = 2024 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (minute from c) = 2024 +---- +logical_plan +01)Filter: date_part(Utf8("MINUTE"), t1.c) = Int32(2024) +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: date_part(MINUTE, c@0) = 2024 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (second from c) = 2024 +---- +logical_plan +01)Filter: date_part(Utf8("SECOND"), t1.c) = Int32(2024) +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: date_part(SECOND, c@0) = 2024 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (millisecond from c) = 2024 +---- +logical_plan +01)Filter: date_part(Utf8("MILLISECOND"), t1.c) = Int32(2024) +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: date_part(MILLISECOND, c@0) = 2024 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (microsecond from c) = 2024 +---- +logical_plan +01)Filter: date_part(Utf8("MICROSECOND"), t1.c) = Int32(2024) +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: date_part(MICROSECOND, c@0) = 2024 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (nanosecond from c) = 2024 +---- +logical_plan +01)Filter: date_part(Utf8("NANOSECOND"), t1.c) = Int32(2024) +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: date_part(NANOSECOND, c@0) = 2024 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (dow from c) = 2024 +---- +logical_plan +01)Filter: date_part(Utf8("DOW"), t1.c) = Int32(2024) +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: date_part(DOW, c@0) = 2024 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (doy from c) = 2024 +---- +logical_plan +01)Filter: date_part(Utf8("DOY"), t1.c) = Int32(2024) +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: date_part(DOY, c@0) = 2024 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (epoch from c) = 2024 +---- +logical_plan +01)Filter: date_part(Utf8("EPOCH"), t1.c) = Float64(2024) +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: date_part(EPOCH, c@0) = 2024 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c from t1 where extract (isodow from c) = 2024 +---- +logical_plan +01)Filter: date_part(Utf8("ISODOW"), t1.c) = Int32(2024) +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: date_part(ISODOW, c@0) = 2024 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +# Simple optimize different datatypes + +statement ok +create table t2( + c1_date32 DATE, + c2_ts_sec timestamp, + c3_ts_mili timestamp, + c4_ts_micro timestamp, + c5_ts_nano timestamp +) as VALUES + (NULL, + NULL, + NULL, + NULL, + NULL), + ('1990-05-20', + '1990-05-20T00:00:10'::timestamp, + '1990-05-20T00:00:10.987'::timestamp, + '1990-05-20T00:00:10.987654'::timestamp, + '1990-05-20T00:00:10.987654321'::timestamp), + ('2024-01-01', + '2024-01-01T00:00:00'::timestamp, + '2024-01-01T00:00:00.123'::timestamp, + '2024-01-01T00:00:00.123456'::timestamp, + '2024-01-01T00:00:00.123456789'::timestamp), + ('2030-12-31', + '2030-12-31T23:59:59'::timestamp, + '2030-12-31T23:59:59.001'::timestamp, + '2030-12-31T23:59:59.001234'::timestamp, + '2030-12-31T23:59:59.001234567'::timestamp) +; + +query D +select c1_date32 from t2 where extract(year from c1_date32) = 2024; +---- +2024-01-01 + +query D +select c1_date32 from t2 where extract(year from c1_date32) <> 2024; +---- +1990-05-20 +2030-12-31 + +query P +select c2_ts_sec from t2 where extract(year from c2_ts_sec) > 2024; +---- +2030-12-31T23:59:59 + +query P +select c3_ts_mili from t2 where extract(year from c3_ts_mili) < 2024; +---- +1990-05-20T00:00:10.987 + +query P +select c4_ts_micro from t2 where extract(year from c4_ts_micro) >= 2024; +---- +2024-01-01T00:00:00.123456 +2030-12-31T23:59:59.001234 + +query P +select c5_ts_nano from t2 where extract(year from c5_ts_nano) <= 2024; +---- +1990-05-20T00:00:10.987654321 +2024-01-01T00:00:00.123456789 + +query D +select c1_date32 from t2 where extract(year from c1_date32) is not distinct from 2024 +---- +2024-01-01 + +query D +select c1_date32 from t2 where extract(year from c1_date32) is distinct from 2024 +---- +NULL +1990-05-20 +2030-12-31 + +# Check that date_part is not in the explain statements for other datatypes + +query TT +explain select c1_date32 from t2 where extract (year from c1_date32) = 2024 +---- +logical_plan +01)Filter: t2.c1_date32 >= Date32("2024-01-01") AND t2.c1_date32 < Date32("2025-01-01") +02)--TableScan: t2 projection=[c1_date32] +physical_plan +01)FilterExec: c1_date32@0 >= 2024-01-01 AND c1_date32@0 < 2025-01-01 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c1_date32 from t2 where extract (year from c1_date32) <> 2024 +---- +logical_plan +01)Filter: t2.c1_date32 < Date32("2024-01-01") OR t2.c1_date32 >= Date32("2025-01-01") +02)--TableScan: t2 projection=[c1_date32] +physical_plan +01)FilterExec: c1_date32@0 < 2024-01-01 OR c1_date32@0 >= 2025-01-01 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c2_ts_sec from t2 where extract (year from c2_ts_sec) > 2024 +---- +logical_plan +01)Filter: t2.c2_ts_sec >= TimestampNanosecond(1735689600000000000, None) +02)--TableScan: t2 projection=[c2_ts_sec] +physical_plan +01)FilterExec: c2_ts_sec@0 >= 1735689600000000000 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c3_ts_mili from t2 where extract (year from c3_ts_mili) < 2024 +---- +logical_plan +01)Filter: t2.c3_ts_mili < TimestampNanosecond(1704067200000000000, None) +02)--TableScan: t2 projection=[c3_ts_mili] +physical_plan +01)FilterExec: c3_ts_mili@0 < 1704067200000000000 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c4_ts_micro from t2 where extract (year from c4_ts_micro) >= 2024 +---- +logical_plan +01)Filter: t2.c4_ts_micro >= TimestampNanosecond(1704067200000000000, None) +02)--TableScan: t2 projection=[c4_ts_micro] +physical_plan +01)FilterExec: c4_ts_micro@0 >= 1704067200000000000 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c5_ts_nano from t2 where extract (year from c5_ts_nano) <= 2024 +---- +logical_plan +01)Filter: t2.c5_ts_nano < TimestampNanosecond(1735689600000000000, None) +02)--TableScan: t2 projection=[c5_ts_nano] +physical_plan +01)FilterExec: c5_ts_nano@0 < 1735689600000000000 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c1_date32 from t2 where extract (year from c1_date32) is not distinct from 2024 +---- +logical_plan +01)Filter: t2.c1_date32 IS NOT NULL AND t2.c1_date32 >= Date32("2024-01-01") AND t2.c1_date32 < Date32("2025-01-01") +02)--TableScan: t2 projection=[c1_date32] +physical_plan +01)FilterExec: c1_date32@0 IS NOT NULL AND c1_date32@0 >= 2024-01-01 AND c1_date32@0 < 2025-01-01 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +explain select c1_date32 from t2 where extract (year from c1_date32) is distinct from 2024 +---- +logical_plan +01)Filter: t2.c1_date32 < Date32("2024-01-01") OR t2.c1_date32 >= Date32("2025-01-01") OR t2.c1_date32 IS NULL +02)--TableScan: t2 projection=[c1_date32] +physical_plan +01)FilterExec: c1_date32@0 < 2024-01-01 OR c1_date32@0 >= 2025-01-01 OR c1_date32@0 IS NULL +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +# Preimage with timestamp with America/New_York timezone + +statement ok +SET datafusion.execution.time_zone = 'America/New_York'; + +statement ok +create table t3( + c1_ts_tz timestamptz +) as VALUES + (NULL), + ('2024-01-01T04:59:59Z'::timestamptz), -- local 2023-12-31 23:59:59 -05 + ('2024-01-01T05:00:00Z'::timestamptz), -- local 2024-01-01 00:00:00 -05 + ('2025-01-01T04:59:59Z'::timestamptz), -- local 2024-12-31 23:59:59 -05 + ('2025-01-01T05:00:00Z'::timestamptz) -- local 2025-01-01 00:00:00 -05 +; + +query P +select c1_ts_tz +from t3 +where extract(year from c1_ts_tz) = 2024 +order by c1_ts_tz +---- +2024-01-01T00:00:00-05:00 +2024-12-31T23:59:59-05:00 + +query TT +explain select c1_ts_tz from t3 where extract(year from c1_ts_tz) = 2024 +---- +logical_plan +01)Filter: t3.c1_ts_tz >= TimestampNanosecond(1704085200000000000, Some("America/New_York")) AND t3.c1_ts_tz < TimestampNanosecond(1735707600000000000, Some("America/New_York")) +02)--TableScan: t3 projection=[c1_ts_tz] +physical_plan +01)FilterExec: c1_ts_tz@0 >= 1704085200000000000 AND c1_ts_tz@0 < 1735707600000000000 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +statement ok +RESET datafusion.execution.time_zone; + +# Test non-Int32 rhs argument + +query D +select c from t1 where extract(year from c) = cast(2024 as bigint); +---- +2024-01-01 + +query TT +explain select c from t1 where extract (year from c) = cast(2024 as bigint) +---- +logical_plan +01)Filter: t1.c >= Date32("2024-01-01") AND t1.c < Date32("2025-01-01") +02)--TableScan: t1 projection=[c] +physical_plan +01)FilterExec: c@0 >= 2024-01-01 AND c@0 < 2025-01-01 +02)--DataSourceExec: partitions=1, partition_sizes=[1] diff --git a/datafusion/sqllogictest/test_files/datetime/timestamps.slt b/datafusion/sqllogictest/test_files/datetime/timestamps.slt index efa7a536c8ba4..8ed32940e826e 100644 --- a/datafusion/sqllogictest/test_files/datetime/timestamps.slt +++ b/datafusion/sqllogictest/test_files/datetime/timestamps.slt @@ -19,10 +19,10 @@ ## Common timestamp data # # ts_data: Int64 nanoseconds -# ts_data_nanos: Timestamp(Nanosecond, None) -# ts_data_micros: Timestamp(Microsecond, None) -# ts_data_millis: Timestamp(Millisecond, None) -# ts_data_secs: Timestamp(Second, None) +# ts_data_nanos: Timestamp(ns) +# ts_data_micros: Timestamp(µs) +# ts_data_millis: Timestamp(ms) +# ts_data_secs: Timestamp(s) ########## # Create timestamp tables with different precisions but the same logical values @@ -34,16 +34,16 @@ create table ts_data(ts bigint, value int) as values (1599565349190855123, 3); statement ok -create table ts_data_nanos as select arrow_cast(ts, 'Timestamp(Nanosecond, None)') as ts, value from ts_data; +create table ts_data_nanos as select arrow_cast(ts, 'Timestamp(ns)') as ts, value from ts_data; statement ok -create table ts_data_micros as select arrow_cast(ts / 1000, 'Timestamp(Microsecond, None)') as ts, value from ts_data; +create table ts_data_micros as select arrow_cast(ts / 1000, 'Timestamp(µs)') as ts, value from ts_data; statement ok -create table ts_data_millis as select arrow_cast(ts / 1000000, 'Timestamp(Millisecond, None)') as ts, value from ts_data; +create table ts_data_millis as select arrow_cast(ts / 1000000, 'Timestamp(ms)') as ts, value from ts_data; statement ok -create table ts_data_secs as select arrow_cast(ts / 1000000000, 'Timestamp(Second, None)') as ts, value from ts_data; +create table ts_data_secs as select arrow_cast(ts / 1000000000, 'Timestamp(s)') as ts, value from ts_data; statement ok create table ts_data_micros_kolkata as select arrow_cast(ts / 1000, 'Timestamp(Microsecond, Some("Asia/Kolkata"))') as ts, value from ts_data; @@ -1579,13 +1579,13 @@ second 2020-09-08T13:42:29 # test date trunc on different timestamp scalar types and ensure they are consistent query P rowsort -SELECT DATE_TRUNC('second', arrow_cast(TIMESTAMP '2023-08-03 14:38:50Z', 'Timestamp(Second, None)')) as ts +SELECT DATE_TRUNC('second', arrow_cast(TIMESTAMP '2023-08-03 14:38:50Z', 'Timestamp(s)')) as ts UNION ALL -SELECT DATE_TRUNC('second', arrow_cast(TIMESTAMP '2023-08-03 14:38:50Z', 'Timestamp(Nanosecond, None)')) as ts +SELECT DATE_TRUNC('second', arrow_cast(TIMESTAMP '2023-08-03 14:38:50Z', 'Timestamp(ns)')) as ts UNION ALL -SELECT DATE_TRUNC('day', arrow_cast(TIMESTAMP '2023-08-03 14:38:50Z', 'Timestamp(Microsecond, None)')) as ts +SELECT DATE_TRUNC('day', arrow_cast(TIMESTAMP '2023-08-03 14:38:50Z', 'Timestamp(µs)')) as ts UNION ALL -SELECT DATE_TRUNC('day', arrow_cast(TIMESTAMP '2023-08-03 14:38:50Z', 'Timestamp(Millisecond, None)')) as ts +SELECT DATE_TRUNC('day', arrow_cast(TIMESTAMP '2023-08-03 14:38:50Z', 'Timestamp(ms)')) as ts ---- 2023-08-03T00:00:00 2023-08-03T00:00:00 @@ -2706,7 +2706,7 @@ drop table ts_utf8_data ########## query B -select arrow_cast(now(), 'Date64') < arrow_cast('2022-02-02 02:02:02', 'Timestamp(Nanosecond, None)'); +select arrow_cast(now(), 'Date64') < arrow_cast('2022-02-02 02:02:02', 'Timestamp(ns)'); ---- false @@ -3064,7 +3064,7 @@ NULL query error DataFusion error: Error during planning: Function 'make_date' expects 3 arguments but received 1 select make_date(1); -query error Expect TypeSignatureClass::Native\(LogicalType\(Native\(Int32\), Int32\)\) but received NativeType::Interval\(MonthDayNano\), DataType: Interval\(MonthDayNano\) +query error DataFusion error: Error during planning: Function 'make_date' requires TypeSignatureClass::Native\(LogicalType\(Native\(Int32\), Int32\)\), but received Interval\(MonthDayNano\) \(DataType: Interval\(MonthDayNano\)\) select make_date(interval '1 day', '2001-05-21'::timestamp, '2001-05-21'::timestamp); ########## @@ -3337,7 +3337,7 @@ select make_time(22, '', 27); query error Cannot cast string '' to value of Int32 type select make_time(22, 1, ''); -query error Expect TypeSignatureClass::Native\(LogicalType\(Native\(Int32\), Int32\)\) but received NativeType::Float64, DataType: Float64 +query error DataFusion error: Error during planning: Function 'make_time' requires TypeSignatureClass::Native\(LogicalType\(Native\(Int32\), Int32\)\), but received Float64 \(DataType: Float64\) select make_time(arrow_cast(22, 'Float64'), 1, ''); ########## @@ -3640,7 +3640,7 @@ select to_char(arrow_cast(12344567890000, 'Time64(Nanosecond)'), '%H-%M-%S %f') 03-25-44 567890000 query T -select to_char(arrow_cast(TIMESTAMP '2023-08-03 14:38:50Z', 'Timestamp(Second, None)'), '%d-%m-%Y %H-%M-%S') +select to_char(arrow_cast(TIMESTAMP '2023-08-03 14:38:50Z', 'Timestamp(s)'), '%d-%m-%Y %H-%M-%S') ---- 03-08-2023 14-38-50 @@ -3732,7 +3732,7 @@ select to_unixtime(arrow_cast(to_timestamp('2023-01-14T01:01:30'), 'Timestamp(Se 1673638290 query I -select to_unixtime(arrow_cast(to_timestamp('2023-01-14T01:01:30'), 'Timestamp(Millisecond, None)')); +select to_unixtime(arrow_cast(to_timestamp('2023-01-14T01:01:30'), 'Timestamp(ms)')); ---- 1673658090 @@ -3952,7 +3952,7 @@ statement error select to_local_time('2024-04-01T00:00:20Z'::timestamp, 'some string'); # invalid argument data type -statement error DataFusion error: Error during planning: Internal error: Expect TypeSignatureClass::Timestamp but received NativeType::String, DataType: Utf8 +statement error DataFusion error: Error during planning: Function 'to_local_time' requires TypeSignatureClass::Timestamp, but received String \(DataType: Utf8\) select to_local_time('2024-04-01T00:00:20Z'); # invalid timezone @@ -4307,58 +4307,58 @@ SELECT CAST(CAST(one AS decimal(17,2)) AS timestamp(3)) AS a FROM (VALUES (1)) t 1970-01-01T00:00:00.001 query P -SELECT arrow_cast(CAST(1 AS decimal(17,2)), 'Timestamp(Nanosecond, None)') AS a UNION ALL -SELECT arrow_cast(CAST(one AS decimal(17,2)), 'Timestamp(Nanosecond, None)') AS a FROM (VALUES (1)) t(one); +SELECT arrow_cast(CAST(1 AS decimal(17,2)), 'Timestamp(ns)') AS a UNION ALL +SELECT arrow_cast(CAST(one AS decimal(17,2)), 'Timestamp(ns)') AS a FROM (VALUES (1)) t(one); ---- 1970-01-01T00:00:00.000000001 1970-01-01T00:00:00.000000001 query P -SELECT arrow_cast(CAST(1 AS decimal(17,2)), 'Timestamp(Microsecond, None)') AS a UNION ALL -SELECT arrow_cast(CAST(one AS decimal(17,2)), 'Timestamp(Microsecond, None)') AS a FROM (VALUES (1)) t(one); +SELECT arrow_cast(CAST(1 AS decimal(17,2)), 'Timestamp(µs)') AS a UNION ALL +SELECT arrow_cast(CAST(one AS decimal(17,2)), 'Timestamp(µs)') AS a FROM (VALUES (1)) t(one); ---- 1970-01-01T00:00:00.000001 1970-01-01T00:00:00.000001 query P -SELECT arrow_cast(CAST(1 AS decimal(17,2)), 'Timestamp(Millisecond, None)') AS a UNION ALL -SELECT arrow_cast(CAST(one AS decimal(17,2)), 'Timestamp(Millisecond, None)') AS a FROM (VALUES (1)) t(one); +SELECT arrow_cast(CAST(1 AS decimal(17,2)), 'Timestamp(ms)') AS a UNION ALL +SELECT arrow_cast(CAST(one AS decimal(17,2)), 'Timestamp(ms)') AS a FROM (VALUES (1)) t(one); ---- 1970-01-01T00:00:00.001 1970-01-01T00:00:00.001 query P -SELECT arrow_cast(CAST(1 AS decimal(17,2)), 'Timestamp(Second, None)') AS a UNION ALL -SELECT arrow_cast(CAST(one AS decimal(17,2)), 'Timestamp(Second, None)') AS a FROM (VALUES (1)) t(one); +SELECT arrow_cast(CAST(1 AS decimal(17,2)), 'Timestamp(s)') AS a UNION ALL +SELECT arrow_cast(CAST(one AS decimal(17,2)), 'Timestamp(s)') AS a FROM (VALUES (1)) t(one); ---- 1970-01-01T00:00:01 1970-01-01T00:00:01 query P -SELECT arrow_cast(CAST(1.123 AS decimal(17,3)), 'Timestamp(Nanosecond, None)') AS a UNION ALL -SELECT arrow_cast(CAST(one AS decimal(17,3)), 'Timestamp(Nanosecond, None)') AS a FROM (VALUES (1.123)) t(one); +SELECT arrow_cast(CAST(1.123 AS decimal(17,3)), 'Timestamp(ns)') AS a UNION ALL +SELECT arrow_cast(CAST(one AS decimal(17,3)), 'Timestamp(ns)') AS a FROM (VALUES (1.123)) t(one); ---- 1970-01-01T00:00:00.000000001 1970-01-01T00:00:00.000000001 query P -SELECT arrow_cast(CAST(1.123 AS decimal(17,3)), 'Timestamp(Microsecond, None)') AS a UNION ALL -SELECT arrow_cast(CAST(one AS decimal(17,3)), 'Timestamp(Microsecond, None)') AS a FROM (VALUES (1.123)) t(one); +SELECT arrow_cast(CAST(1.123 AS decimal(17,3)), 'Timestamp(µs)') AS a UNION ALL +SELECT arrow_cast(CAST(one AS decimal(17,3)), 'Timestamp(µs)') AS a FROM (VALUES (1.123)) t(one); ---- 1970-01-01T00:00:00.000001 1970-01-01T00:00:00.000001 query P -SELECT arrow_cast(CAST(1.123 AS decimal(17,3)), 'Timestamp(Millisecond, None)') AS a UNION ALL -SELECT arrow_cast(CAST(one AS decimal(17,3)), 'Timestamp(Millisecond, None)') AS a FROM (VALUES (1.123)) t(one); +SELECT arrow_cast(CAST(1.123 AS decimal(17,3)), 'Timestamp(ms)') AS a UNION ALL +SELECT arrow_cast(CAST(one AS decimal(17,3)), 'Timestamp(ms)') AS a FROM (VALUES (1.123)) t(one); ---- 1970-01-01T00:00:00.001 1970-01-01T00:00:00.001 query P -SELECT arrow_cast(CAST(1.123 AS decimal(17,3)), 'Timestamp(Second, None)') AS a UNION ALL -SELECT arrow_cast(CAST(one AS decimal(17,3)), 'Timestamp(Second, None)') AS a FROM (VALUES (1.123)) t(one); +SELECT arrow_cast(CAST(1.123 AS decimal(17,3)), 'Timestamp(s)') AS a UNION ALL +SELECT arrow_cast(CAST(one AS decimal(17,3)), 'Timestamp(s)') AS a FROM (VALUES (1.123)) t(one); ---- 1970-01-01T00:00:01 1970-01-01T00:00:01 @@ -4410,7 +4410,7 @@ FROM ts_data_micros_kolkata ## Casting between timestamp with and without timezone ########## -# Test casting from Timestamp(Nanosecond, Some("UTC")) to Timestamp(Nanosecond, None) +# Test casting from Timestamp(Nanosecond, Some("UTC")) to Timestamp(ns) # Verifies that the underlying nanosecond values are preserved when removing timezone # Verify input type @@ -4421,13 +4421,13 @@ Timestamp(ns, "UTC") # Verify output type after casting query T -SELECT arrow_typeof(arrow_cast(arrow_cast(1, 'Timestamp(Nanosecond, Some("UTC"))'), 'Timestamp(Nanosecond, None)')); +SELECT arrow_typeof(arrow_cast(arrow_cast(1, 'Timestamp(Nanosecond, Some("UTC"))'), 'Timestamp(ns)')); ---- Timestamp(ns) # Verify values are preserved when casting from timestamp with timezone to timestamp without timezone query P rowsort -SELECT arrow_cast(column1, 'Timestamp(Nanosecond, None)') +SELECT arrow_cast(column1, 'Timestamp(ns)') FROM (VALUES (arrow_cast(1, 'Timestamp(Nanosecond, Some("UTC"))')), (arrow_cast(2, 'Timestamp(Nanosecond, Some("UTC"))')), @@ -4442,18 +4442,18 @@ FROM (VALUES 1970-01-01T00:00:00.000000004 1970-01-01T00:00:00.000000005 -# Test casting from Timestamp(Nanosecond, None) to Timestamp(Nanosecond, Some("UTC")) +# Test casting from Timestamp(ns) to Timestamp(Nanosecond, Some("UTC")) # Verifies that the underlying nanosecond values are preserved when adding timezone # Verify input type query T -SELECT arrow_typeof(arrow_cast(1, 'Timestamp(Nanosecond, None)')); +SELECT arrow_typeof(arrow_cast(1, 'Timestamp(ns)')); ---- Timestamp(ns) # Verify output type after casting query T -SELECT arrow_typeof(arrow_cast(arrow_cast(1, 'Timestamp(Nanosecond, None)'), 'Timestamp(Nanosecond, Some("UTC"))')); +SELECT arrow_typeof(arrow_cast(arrow_cast(1, 'Timestamp(ns)'), 'Timestamp(Nanosecond, Some("UTC"))')); ---- Timestamp(ns, "UTC") @@ -4461,11 +4461,11 @@ Timestamp(ns, "UTC") query P rowsort SELECT arrow_cast(column1, 'Timestamp(Nanosecond, Some("UTC"))') FROM (VALUES - (arrow_cast(1, 'Timestamp(Nanosecond, None)')), - (arrow_cast(2, 'Timestamp(Nanosecond, None)')), - (arrow_cast(3, 'Timestamp(Nanosecond, None)')), - (arrow_cast(4, 'Timestamp(Nanosecond, None)')), - (arrow_cast(5, 'Timestamp(Nanosecond, None)')) + (arrow_cast(1, 'Timestamp(ns)')), + (arrow_cast(2, 'Timestamp(ns)')), + (arrow_cast(3, 'Timestamp(ns)')), + (arrow_cast(4, 'Timestamp(ns)')), + (arrow_cast(5, 'Timestamp(ns)')) ) t; ---- 1970-01-01T00:00:00.000000001Z @@ -5328,3 +5328,33 @@ drop table ts_data_secs statement ok drop table ts_data_micros_kolkata + +########## +## Test to_timestamp with scalar float inputs +########## + +statement ok +create table test_to_timestamp_scalar(id int, name varchar) as values + (1, 'foo'), + (2, 'bar'); + +query P +SELECT to_timestamp(123.5, name) FROM test_to_timestamp_scalar ORDER BY id; +---- +1970-01-01T00:02:03.500 +1970-01-01T00:02:03.500 + +query P +SELECT to_timestamp(456.789::float, name) FROM test_to_timestamp_scalar ORDER BY id; +---- +1970-01-01T00:07:36.789001464 +1970-01-01T00:07:36.789001464 + +query P +SELECT to_timestamp(arrow_cast(100.5, 'Float16'), name) FROM test_to_timestamp_scalar ORDER BY id; +---- +1970-01-01T00:01:40.500 +1970-01-01T00:01:40.500 + +statement ok +drop table test_to_timestamp_scalar diff --git a/datafusion/sqllogictest/test_files/delete.slt b/datafusion/sqllogictest/test_files/delete.slt index e86343b6bf5fb..b01eb6f5e9ec7 100644 --- a/datafusion/sqllogictest/test_files/delete.slt +++ b/datafusion/sqllogictest/test_files/delete.slt @@ -113,3 +113,30 @@ logical_plan 05)--------TableScan: t2 06)----TableScan: t1 physical_plan_error This feature is not implemented: Physical plan does not support logical expression InSubquery(InSubquery { expr: Column(Column { relation: Some(Bare { table: "t1" }), name: "a" }), subquery: , negated: false }) + + +# Delete with limit + +query TT +explain delete from t1 limit 10 +---- +logical_plan +01)Dml: op=[Delete] table=[t1] +02)--Limit: skip=0, fetch=10 +03)----TableScan: t1 +physical_plan +01)CooperativeExec +02)--DmlResultExec: rows_affected=0 + + +query TT +explain delete from t1 where a = 1 and b = '2' limit 10 +---- +logical_plan +01)Dml: op=[Delete] table=[t1] +02)--Limit: skip=0, fetch=10 +03)----Filter: CAST(t1.a AS Int64) = Int64(1) AND t1.b = CAST(Utf8("2") AS Utf8View) +04)------TableScan: t1 +physical_plan +01)CooperativeExec +02)--DmlResultExec: rows_affected=0 diff --git a/datafusion/sqllogictest/test_files/dictionary.slt b/datafusion/sqllogictest/test_files/dictionary.slt index b6098758a9e67..511061cf82f06 100644 --- a/datafusion/sqllogictest/test_files/dictionary.slt +++ b/datafusion/sqllogictest/test_files/dictionary.slt @@ -36,7 +36,7 @@ SELECT arrow_cast(column3, 'Utf8') as f2, arrow_cast(column4, 'Utf8') as f3, arrow_cast(column5, 'Float64') as f4, - arrow_cast(column6, 'Timestamp(Nanosecond, None)') as time + arrow_cast(column6, 'Timestamp(ns)') as time FROM ( VALUES -- equivalent to the following line protocol data @@ -111,7 +111,7 @@ SELECT arrow_cast(column1, 'Dictionary(Int32, Utf8)') as type, arrow_cast(column2, 'Dictionary(Int32, Utf8)') as tag_id, arrow_cast(column3, 'Float64') as f5, - arrow_cast(column4, 'Timestamp(Nanosecond, None)') as time + arrow_cast(column4, 'Timestamp(ns)') as time FROM ( VALUES -- equivalent to the following line protocol data diff --git a/datafusion/sqllogictest/test_files/dynamic_filter_pushdown_config.slt b/datafusion/sqllogictest/test_files/dynamic_filter_pushdown_config.slt index 3e403171e0718..275b0c9dd490f 100644 --- a/datafusion/sqllogictest/test_files/dynamic_filter_pushdown_config.slt +++ b/datafusion/sqllogictest/test_files/dynamic_filter_pushdown_config.slt @@ -92,6 +92,30 @@ physical_plan 01)SortExec: TopK(fetch=3), expr=[value@1 DESC], preserve_partitioning=[false] 02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/test_data.parquet]]}, projection=[id, value, name], file_type=parquet, predicate=DynamicFilter [ empty ] +statement ok +set datafusion.explain.analyze_level = summary; + +query TT +EXPLAIN ANALYZE SELECT id, value AS v, value + id as name FROM test_parquet where value > 3 ORDER BY v DESC LIMIT 3; +---- +Plan with Metrics +01)SortPreservingMergeExec: [v@1 DESC], fetch=3, metrics=[output_rows=3, ] +02)--SortExec: TopK(fetch=3), expr=[v@1 DESC], preserve_partitioning=[true], filter=[v@1 IS NULL OR v@1 > 800], metrics=[output_rows=3, ] +03)----ProjectionExec: expr=[id@0 as id, value@1 as v, value@1 + id@0 as name], metrics=[output_rows=10, ] +04)------FilterExec: value@1 > 3, metrics=[output_rows=10, , selectivity=100% (10/10)] +05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, metrics=[output_rows=10, ] +06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/test_data.parquet]]}, projection=[id, value], file_type=parquet, predicate=value@1 > 3 AND DynamicFilter [ value@1 IS NULL OR value@1 > 800 ], pruning_predicate=value_null_count@1 != row_count@2 AND value_max@0 > 3 AND (value_null_count@1 > 0 OR value_null_count@1 != row_count@2 AND value_max@0 > 800), required_guarantees=[], metrics=[output_rows=10, , files_ranges_pruned_statistics=1 total → 1 matched, row_groups_pruned_statistics=1 total → 1 matched -> 1 fully matched, row_groups_pruned_bloom_filter=1 total → 1 matched, page_index_pages_pruned=1 total → 1 matched, limit_pruned_row_groups=0 total → 0 matched, bytes_scanned=210, metadata_load_time=, scan_efficiency_ratio=18% (210/1.16 K)] + +statement ok +set datafusion.explain.analyze_level = dev; + +query III +SELECT id, value AS v, value + id as name FROM test_parquet where value > 3 ORDER BY v DESC LIMIT 3; +---- +10 1000 1010 +9 900 909 +8 800 808 + # Disable TopK dynamic filter pushdown statement ok SET datafusion.optimizer.enable_topk_dynamic_filter_pushdown = false; @@ -106,6 +130,13 @@ physical_plan 01)SortExec: TopK(fetch=3), expr=[value@1 DESC], preserve_partitioning=[false] 02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/test_data.parquet]]}, projection=[id, value, name], file_type=parquet +query IIT +SELECT id, value AS v, name FROM (SELECT * FROM test_parquet UNION ALL SELECT * FROM test_parquet) ORDER BY v DESC LIMIT 3; +---- +10 1000 j +10 1000 j +9 900 i + # Re-enable for next tests statement ok SET datafusion.optimizer.enable_topk_dynamic_filter_pushdown = true; @@ -156,6 +187,197 @@ physical_plan statement ok SET datafusion.optimizer.enable_join_dynamic_filter_pushdown = true; +# Test 2b: Dynamic filter pushdown for non-inner join types +# LEFT JOIN: optimizer swaps to physical Right join (build=right_parquet, probe=left_parquet). +# Dynamic filter is NOT pushed because Right join needs all probe rows in output. +query TT +EXPLAIN SELECT l.*, r.info +FROM left_parquet l +LEFT JOIN right_parquet r ON l.id = r.id; +---- +logical_plan +01)Projection: l.id, l.data, r.info +02)--Left Join: l.id = r.id +03)----SubqueryAlias: l +04)------TableScan: left_parquet projection=[id, data] +05)----SubqueryAlias: r +06)------TableScan: right_parquet projection=[id, info] +physical_plan +01)ProjectionExec: expr=[id@1 as id, data@2 as data, info@0 as info] +02)--HashJoinExec: mode=CollectLeft, join_type=Right, on=[(id@0, id@0)], projection=[info@1, id@2, data@3] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_right.parquet]]}, projection=[id, info], file_type=parquet +04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_left.parquet]]}, projection=[id, data], file_type=parquet + +# LEFT JOIN correctness: all left rows appear, unmatched right rows produce NULLs +query ITT +SELECT l.id, l.data, r.info +FROM left_parquet l +LEFT JOIN right_parquet r ON l.id = r.id +ORDER BY l.id; +---- +1 left1 right1 +2 left2 NULL +3 left3 right3 +4 left4 NULL +5 left5 right5 + +# RIGHT JOIN: optimizer swaps to physical Left join (build=right_parquet, probe=left_parquet). +# No self-generated dynamic filter (only Inner joins get that), but parent filters +# on the preserved (build) side can still push down. +query TT +EXPLAIN SELECT l.*, r.info +FROM left_parquet l +RIGHT JOIN right_parquet r ON l.id = r.id; +---- +logical_plan +01)Projection: l.id, l.data, r.info +02)--Right Join: l.id = r.id +03)----SubqueryAlias: l +04)------TableScan: left_parquet projection=[id, data] +05)----SubqueryAlias: r +06)------TableScan: right_parquet projection=[id, info] +physical_plan +01)ProjectionExec: expr=[id@1 as id, data@2 as data, info@0 as info] +02)--HashJoinExec: mode=CollectLeft, join_type=Left, on=[(id@0, id@0)], projection=[info@1, id@2, data@3] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_right.parquet]]}, projection=[id, info], file_type=parquet +04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_left.parquet]]}, projection=[id, data], file_type=parquet + +# RIGHT JOIN correctness: all right rows appear, unmatched left rows produce NULLs +query ITT +SELECT l.id, l.data, r.info +FROM left_parquet l +RIGHT JOIN right_parquet r ON l.id = r.id +ORDER BY r.id; +---- +1 left1 right1 +3 left3 right3 +5 left5 right5 + +# FULL JOIN: dynamic filter should NOT be pushed (both sides must preserve all rows) +query TT +EXPLAIN SELECT l.id, r.id as rid, l.data, r.info +FROM left_parquet l +FULL JOIN right_parquet r ON l.id = r.id; +---- +logical_plan +01)Projection: l.id, r.id AS rid, l.data, r.info +02)--Full Join: l.id = r.id +03)----SubqueryAlias: l +04)------TableScan: left_parquet projection=[id, data] +05)----SubqueryAlias: r +06)------TableScan: right_parquet projection=[id, info] +physical_plan +01)ProjectionExec: expr=[id@2 as id, id@0 as rid, data@3 as data, info@1 as info] +02)--HashJoinExec: mode=CollectLeft, join_type=Full, on=[(id@0, id@0)] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_right.parquet]]}, projection=[id, info], file_type=parquet +04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_left.parquet]]}, projection=[id, data], file_type=parquet + +# LEFT SEMI JOIN: optimizer swaps to RightSemi (build=right_parquet, probe=left_parquet). +# No self-generated dynamic filter (only Inner joins), but parent filters on +# the preserved (probe) side can push down. +query TT +EXPLAIN SELECT l.* +FROM left_parquet l +WHERE l.id IN (SELECT r.id FROM right_parquet r); +---- +logical_plan +01)LeftSemi Join: l.id = __correlated_sq_1.id +02)--SubqueryAlias: l +03)----TableScan: left_parquet projection=[id, data] +04)--SubqueryAlias: __correlated_sq_1 +05)----SubqueryAlias: r +06)------TableScan: right_parquet projection=[id] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(id@0, id@0)] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_right.parquet]]}, projection=[id], file_type=parquet +03)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_left.parquet]]}, projection=[id, data], file_type=parquet + +# LEFT ANTI JOIN: no self-generated dynamic filter, but parent filters can push +# to the preserved (left/build) side. +query TT +EXPLAIN SELECT l.* +FROM left_parquet l +WHERE l.id NOT IN (SELECT r.id FROM right_parquet r); +---- +logical_plan +01)LeftAnti Join: l.id = __correlated_sq_1.id +02)--SubqueryAlias: l +03)----TableScan: left_parquet projection=[id, data] +04)--SubqueryAlias: __correlated_sq_1 +05)----SubqueryAlias: r +06)------TableScan: right_parquet projection=[id] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=LeftAnti, on=[(id@0, id@0)] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_left.parquet]]}, projection=[id, data], file_type=parquet +03)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_right.parquet]]}, projection=[id], file_type=parquet + +# Test 2c: Parent dynamic filter (from TopK) pushed through semi/anti joins +# Sort on the join key (id) so the TopK dynamic filter pushes to BOTH sides. + +# SEMI JOIN with TopK parent: TopK generates a dynamic filter on `id` (join key) +# that pushes through the RightSemi join to both the build and probe sides. +query TT +EXPLAIN SELECT l.* +FROM left_parquet l +WHERE l.id IN (SELECT r.id FROM right_parquet r) +ORDER BY l.id LIMIT 2; +---- +logical_plan +01)Sort: l.id ASC NULLS LAST, fetch=2 +02)--LeftSemi Join: l.id = __correlated_sq_1.id +03)----SubqueryAlias: l +04)------TableScan: left_parquet projection=[id, data] +05)----SubqueryAlias: __correlated_sq_1 +06)------SubqueryAlias: r +07)--------TableScan: right_parquet projection=[id] +physical_plan +01)SortExec: TopK(fetch=2), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(id@0, id@0)] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_right.parquet]]}, projection=[id], file_type=parquet, predicate=DynamicFilter [ empty ] +04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_left.parquet]]}, projection=[id, data], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Correctness check +query IT +SELECT l.* +FROM left_parquet l +WHERE l.id IN (SELECT r.id FROM right_parquet r) +ORDER BY l.id LIMIT 2; +---- +1 left1 +3 left3 + +# ANTI JOIN with TopK parent: TopK generates a dynamic filter on `id` (join key) +# that pushes through the LeftAnti join to both the preserved and non-preserved sides. +query TT +EXPLAIN SELECT l.* +FROM left_parquet l +WHERE l.id NOT IN (SELECT r.id FROM right_parquet r) +ORDER BY l.id LIMIT 2; +---- +logical_plan +01)Sort: l.id ASC NULLS LAST, fetch=2 +02)--LeftAnti Join: l.id = __correlated_sq_1.id +03)----SubqueryAlias: l +04)------TableScan: left_parquet projection=[id, data] +05)----SubqueryAlias: __correlated_sq_1 +06)------SubqueryAlias: r +07)--------TableScan: right_parquet projection=[id] +physical_plan +01)SortExec: TopK(fetch=2), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--HashJoinExec: mode=CollectLeft, join_type=LeftAnti, on=[(id@0, id@0)] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_left.parquet]]}, projection=[id, data], file_type=parquet, predicate=DynamicFilter [ empty ] +04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_right.parquet]]}, projection=[id], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Correctness check +query IT +SELECT l.* +FROM left_parquet l +WHERE l.id NOT IN (SELECT r.id FROM right_parquet r) +ORDER BY l.id LIMIT 2; +---- +2 left2 +4 left4 + # Test 3: Test independent control # Disable TopK, keep Join enabled @@ -257,6 +479,25 @@ physical_plan 04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/agg_data.parquet]]}, projection=[score], file_type=parquet, predicate=category@0 = alpha AND DynamicFilter [ empty ], pruning_predicate=category_null_count@2 != row_count@3 AND category_min@0 <= alpha AND alpha <= category_max@1, required_guarantees=[category in (alpha)] +# Test 4b: COUNT + MAX — DynamicFilter should NOT appear here in mixed aggregates + +query TT +EXPLAIN SELECT COUNT(*), MAX(score) FROM agg_parquet WHERE category = 'alpha'; +---- +logical_plan +01)Projection: count(Int64(1)) AS count(*), max(agg_parquet.score) +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1)), max(agg_parquet.score)]] +03)----Projection: agg_parquet.score +04)------Filter: agg_parquet.category = Utf8View("alpha") +05)--------TableScan: agg_parquet projection=[category, score], partial_filters=[agg_parquet.category = Utf8View("alpha")] +physical_plan +01)ProjectionExec: expr=[count(Int64(1))@0 as count(*), max(agg_parquet.score)@1 as max(agg_parquet.score)] +02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1)), max(agg_parquet.score)] +03)----CoalescePartitionsExec +04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1)), max(agg_parquet.score)] +05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/agg_data.parquet]]}, projection=[score], file_type=parquet, predicate=category@0 = alpha, pruning_predicate=category_null_count@2 != row_count@3 AND category_min@0 <= alpha AND alpha <= category_max@1, required_guarantees=[category in (alpha)] + # Disable aggregate dynamic filters only statement ok SET datafusion.optimizer.enable_aggregate_dynamic_filter_pushdown = false; @@ -388,6 +629,97 @@ physical_plan 03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_right.parquet]]}, projection=[id, info], file_type=parquet 04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/dynamic_filter_pushdown_config/join_left.parquet]]}, projection=[id, data], file_type=parquet, predicate=DynamicFilter [ empty ] +# Test 6: Regression test for issue #20213 - dynamic filter applied to wrong table +# when subquery join has same column names on both sides. +# +# The bug: when an outer join pushes a DynamicFilter for column "k" through an +# inner join where both sides have a column named "k", the name-based routing +# incorrectly pushed the filter to BOTH sides instead of only the correct one. +# This caused wrong results (0 rows instead of expected). + +# Create tables with same column names (k, v) on both sides +statement ok +CREATE TABLE issue_20213_t1(k INT, v INT) AS +SELECT i as k, i as v FROM generate_series(1, 1000) t(i); + +statement ok +CREATE TABLE issue_20213_t2(k INT, v INT) AS +SELECT i + 100 as k, i as v FROM generate_series(1, 100) t(i); + +# Use small row groups to make statistics-based pruning more likely to manifest the bug +statement ok +SET datafusion.execution.parquet.max_row_group_size = 10; + +query I +COPY issue_20213_t1 TO 'test_files/scratch/dynamic_filter_pushdown_config/issue_20213_t1.parquet' STORED AS PARQUET; +---- +1000 + +query I +COPY issue_20213_t2 TO 'test_files/scratch/dynamic_filter_pushdown_config/issue_20213_t2.parquet' STORED AS PARQUET; +---- +100 + +# Reset row group size +statement ok +SET datafusion.execution.parquet.max_row_group_size = 1000000; + +statement ok +CREATE EXTERNAL TABLE t1_20213(k INT, v INT) +STORED AS PARQUET +LOCATION 'test_files/scratch/dynamic_filter_pushdown_config/issue_20213_t1.parquet'; + +statement ok +CREATE EXTERNAL TABLE t2_20213(k INT, v INT) +STORED AS PARQUET +LOCATION 'test_files/scratch/dynamic_filter_pushdown_config/issue_20213_t2.parquet'; + +# The query from issue #20213: subquery joins t1 and t2 on v, then outer +# join uses t2's k column. The dynamic filter on k from the outer join +# must only apply to t2 (k range 101-200), NOT to t1 (k range 1-1000). +query I +SELECT count(*) FROM ( + SELECT t2_20213.k as k, t1_20213.k as k2 + FROM t1_20213 + JOIN t2_20213 ON t1_20213.v = t2_20213.v +) a +JOIN t2_20213 b ON a.k = b.k +WHERE b.v < 10; +---- +9 + +# Also verify with SELECT * to catch row-level correctness +query IIII rowsort +SELECT * FROM ( + SELECT t2_20213.k as k, t1_20213.k as k2 + FROM t1_20213 + JOIN t2_20213 ON t1_20213.v = t2_20213.v +) a +JOIN t2_20213 b ON a.k = b.k +WHERE b.v < 10; +---- +101 1 101 1 +102 2 102 2 +103 3 103 3 +104 4 104 4 +105 5 105 5 +106 6 106 6 +107 7 107 7 +108 8 108 8 +109 9 109 9 + +statement ok +DROP TABLE issue_20213_t1; + +statement ok +DROP TABLE issue_20213_t2; + +statement ok +DROP TABLE t1_20213; + +statement ok +DROP TABLE t2_20213; + # Cleanup statement ok diff --git a/datafusion/sqllogictest/test_files/encoding.slt b/datafusion/sqllogictest/test_files/encoding.slt index ef91eade01e5b..b04d5061825b4 100644 --- a/datafusion/sqllogictest/test_files/encoding.slt +++ b/datafusion/sqllogictest/test_files/encoding.slt @@ -20,21 +20,41 @@ SELECT encode(arrow_cast('tom', 'Utf8View'),'base64'); ---- dG9t +query T +SELECT encode(arrow_cast('tommy', 'Utf8View'),'base64pad'); +---- +dG9tbXk= + query T SELECT arrow_cast(decode(arrow_cast('dG9t', 'Utf8View'),'base64'), 'Utf8'); ---- tom +query T +SELECT arrow_cast(decode(arrow_cast('dG9tbXk=', 'Utf8View'),'base64pad'), 'Utf8'); +---- +tommy + query T SELECT encode(arrow_cast('tom', 'BinaryView'),'base64'); ---- dG9t +query T +SELECT encode(arrow_cast('tommy', 'BinaryView'),'base64pad'); +---- +dG9tbXk= + query T SELECT arrow_cast(decode(arrow_cast('dG9t', 'BinaryView'),'base64'), 'Utf8'); ---- tom +query T +SELECT arrow_cast(decode(arrow_cast('dG9tbXk=', 'BinaryView'),'base64pad'), 'Utf8'); +---- +tommy + # test for hex digest query T select encode(digest('hello', 'sha256'), 'hex'); @@ -55,16 +75,16 @@ CREATE TABLE test( ; # errors -query error DataFusion error: Error during planning: Internal error: Expect TypeSignatureClass::Binary but received NativeType::Int64, DataType: Int64 +query error DataFusion error: Error during planning: Function 'encode' requires TypeSignatureClass::Binary, but received Int64 \(DataType: Int64\) select encode(12, 'hex'); -query error DataFusion error: Error during planning: Internal error: Expect TypeSignatureClass::Binary but received NativeType::Int64, DataType: Int64 +query error DataFusion error: Error during planning: Function 'decode' requires TypeSignatureClass::Binary, but received Int64 \(DataType: Int64\) select decode(12, 'hex'); -query error DataFusion error: Error during planning: There is no built\-in encoding named 'non_encoding', currently supported encodings are: base64, hex +query error DataFusion error: Error during planning: There is no built\-in encoding named 'non_encoding', currently supported encodings are: base64, base64pad, hex select encode('', 'non_encoding'); -query error DataFusion error: Error during planning: There is no built\-in encoding named 'non_encoding', currently supported encodings are: base64, hex +query error DataFusion error: Error during planning: There is no built\-in encoding named 'non_encoding', currently supported encodings are: base64, base64pad, hex select decode('', 'non_encoding'); query error DataFusion error: Execution error: Encoding must be a non-null string @@ -73,7 +93,7 @@ select decode('', null) from test; query error DataFusion error: This feature is not implemented: Encoding must be a scalar; array specified encoding is not yet supported select decode('', hex_field) from test; -query error DataFusion error: Error during planning: Internal error: Expect TypeSignatureClass::Integer but received NativeType::String, DataType: Utf8View +query error DataFusion error: Error during planning: Function 'to_hex' requires TypeSignatureClass::Integer, but received String \(DataType: Utf8View\) select to_hex(hex_field) from test; query error DataFusion error: Execution error: Failed to decode value using base64 @@ -124,11 +144,21 @@ select encode(bin_field, 'base64') FROM test WHERE num = 3; ---- j1DT9g6uNw3b+FyGIZxVEIo1AWU +query T +select encode(bin_field, 'base64pad') FROM test WHERE num = 3; +---- +j1DT9g6uNw3b+FyGIZxVEIo1AWU= + query B select decode(encode(bin_field, 'base64'), 'base64') = X'8f50d3f60eae370ddbf85c86219c55108a350165' FROM test WHERE num = 3; ---- true +query B +select decode(encode(bin_field, 'base64pad'), 'base64pad') = X'8f50d3f60eae370ddbf85c86219c55108a350165' FROM test WHERE num = 3; +---- +true + statement ok drop table test @@ -144,18 +174,20 @@ FROM VALUES ('Raphael', 'R'), (NULL, 'R'); -query TTTT +query TTTTTT SELECT encode(column1_utf8view, 'base64') AS column1_base64, + encode(column1_utf8view, 'base64pad') AS column1_base64pad, encode(column1_utf8view, 'hex') AS column1_hex, encode(column2_utf8view, 'base64') AS column2_base64, + encode(column2_utf8view, 'base64pad') AS column2_base64pad, encode(column2_utf8view, 'hex') AS column2_hex FROM test_utf8view; ---- -QW5kcmV3 416e64726577 WA 58 -WGlhbmdwZW5n 5869616e6770656e67 WGlhbmdwZW5n 5869616e6770656e67 -UmFwaGFlbA 5261706861656c Ug 52 -NULL NULL Ug 52 +QW5kcmV3 QW5kcmV3 416e64726577 WA WA== 58 +WGlhbmdwZW5n WGlhbmdwZW5n 5869616e6770656e67 WGlhbmdwZW5n WGlhbmdwZW5n 5869616e6770656e67 +UmFwaGFlbA UmFwaGFlbA== 5261706861656c Ug Ug== 52 +NULL NULL NULL Ug Ug== 52 query TTTTTT SELECT @@ -172,6 +204,22 @@ WGlhbmdwZW5n WGlhbmdwZW5n WGlhbmdwZW5n WGlhbmdwZW5n WGlhbmdwZW5n WGlhbmdwZW5n UmFwaGFlbA UmFwaGFlbA UmFwaGFlbA UmFwaGFlbA UmFwaGFlbA UmFwaGFlbA NULL NULL NULL NULL NULL NULL + +query TTTTTT +SELECT + encode(arrow_cast(column1_utf8view, 'Utf8'), 'base64pad'), + encode(arrow_cast(column1_utf8view, 'LargeUtf8'), 'base64pad'), + encode(arrow_cast(column1_utf8view, 'Utf8View'), 'base64pad'), + encode(arrow_cast(column1_utf8view, 'Binary'), 'base64pad'), + encode(arrow_cast(column1_utf8view, 'LargeBinary'), 'base64pad'), + encode(arrow_cast(column1_utf8view, 'BinaryView'), 'base64pad') +FROM test_utf8view; +---- +QW5kcmV3 QW5kcmV3 QW5kcmV3 QW5kcmV3 QW5kcmV3 QW5kcmV3 +WGlhbmdwZW5n WGlhbmdwZW5n WGlhbmdwZW5n WGlhbmdwZW5n WGlhbmdwZW5n WGlhbmdwZW5n +UmFwaGFlbA== UmFwaGFlbA== UmFwaGFlbA== UmFwaGFlbA== UmFwaGFlbA== UmFwaGFlbA== +NULL NULL NULL NULL NULL NULL + statement ok drop table test_utf8view @@ -180,26 +228,31 @@ statement ok CREATE TABLE test_fsb AS SELECT arrow_cast(X'0123456789ABCDEF', 'FixedSizeBinary(8)') as fsb_col; -query ?? +query ??? SELECT decode(encode(arrow_cast(X'0123456789abcdef', 'FixedSizeBinary(8)'), 'base64'), 'base64'), + decode(encode(arrow_cast(X'0123456789abcdef', 'FixedSizeBinary(8)'), 'base64pad'), 'base64pad'), decode(encode(arrow_cast(X'0123456789abcdef', 'FixedSizeBinary(8)'), 'hex'), 'hex'); ---- -0123456789abcdef 0123456789abcdef +0123456789abcdef 0123456789abcdef 0123456789abcdef -query ?? +query ??? SELECT decode(encode(column1, 'base64'), 'base64'), + decode(encode(column1, 'base64pad'), 'base64pad'), decode(encode(column1, 'hex'), 'hex') FROM values (arrow_cast(X'0123456789abcdef', 'FixedSizeBinary(8)')), (arrow_cast(X'ffffffffffffffff', 'FixedSizeBinary(8)')); ---- -0123456789abcdef 0123456789abcdef -ffffffffffffffff ffffffffffffffff +0123456789abcdef 0123456789abcdef 0123456789abcdef +ffffffffffffffff ffffffffffffffff ffffffffffffffff query error DataFusion error: Execution error: Failed to decode value using base64 select decode('invalid', 'base64'); +query error DataFusion error: Execution error: Failed to decode value using base64pad +select decode('invalid', 'base64pad'); + query error DataFusion error: Execution error: Failed to decode value using hex select decode('invalid', 'hex'); diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 3cedb648951cf..c5907d497500e 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -176,6 +176,7 @@ initial_logical_plan logical_plan after resolve_grouping_function SAME TEXT AS ABOVE logical_plan after type_coercion SAME TEXT AS ABOVE analyzed_logical_plan SAME TEXT AS ABOVE +logical_plan after rewrite_set_comparison SAME TEXT AS ABOVE logical_plan after optimize_unions SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE @@ -196,7 +197,10 @@ logical_plan after push_down_filter SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE +logical_plan after extract_leaf_expressions SAME TEXT AS ABOVE +logical_plan after push_down_leaf_projections SAME TEXT AS ABOVE logical_plan after optimize_projections TableScan: simple_explain_test projection=[a, b, c] +logical_plan after rewrite_set_comparison SAME TEXT AS ABOVE logical_plan after optimize_unions SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE @@ -217,6 +221,8 @@ logical_plan after push_down_filter SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE +logical_plan after extract_leaf_expressions SAME TEXT AS ABOVE +logical_plan after push_down_leaf_projections SAME TEXT AS ABOVE logical_plan after optimize_projections SAME TEXT AS ABOVE logical_plan TableScan: simple_explain_test projection=[a, b, c] initial_physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], file_type=csv, has_header=true @@ -297,8 +303,8 @@ initial_physical_plan 01)GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]: ScanBytes=Exact(32)),(Col[1]: ScanBytes=Inexact(24)),(Col[2]: ScanBytes=Exact(32)),(Col[3]: ScanBytes=Exact(32)),(Col[4]: ScanBytes=Exact(32)),(Col[5]: ScanBytes=Exact(64)),(Col[6]: ScanBytes=Exact(32)),(Col[7]: ScanBytes=Exact(64)),(Col[8]: ScanBytes=Inexact(88)),(Col[9]: ScanBytes=Inexact(49)),(Col[10]: ScanBytes=Exact(64))]] 02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]: ScanBytes=Exact(32)),(Col[1]: ScanBytes=Inexact(24)),(Col[2]: ScanBytes=Exact(32)),(Col[3]: ScanBytes=Exact(32)),(Col[4]: ScanBytes=Exact(32)),(Col[5]: ScanBytes=Exact(64)),(Col[6]: ScanBytes=Exact(32)),(Col[7]: ScanBytes=Exact(64)),(Col[8]: ScanBytes=Inexact(88)),(Col[9]: ScanBytes=Inexact(49)),(Col[10]: ScanBytes=Exact(64))]] initial_physical_plan_with_schema -01)GlobalLimitExec: skip=0, fetch=10, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:BinaryView;N, string_col:BinaryView;N, timestamp_col:Timestamp(Nanosecond, None);N] -02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:BinaryView;N, string_col:BinaryView;N, timestamp_col:Timestamp(Nanosecond, None);N] +01)GlobalLimitExec: skip=0, fetch=10, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:BinaryView;N, string_col:BinaryView;N, timestamp_col:Timestamp(ns);N] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:BinaryView;N, string_col:BinaryView;N, timestamp_col:Timestamp(ns);N] physical_plan after OutputRequirements 01)OutputRequirementExec: order_by=[], dist_by=Unspecified, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]: ScanBytes=Exact(32)),(Col[1]: ScanBytes=Inexact(24)),(Col[2]: ScanBytes=Exact(32)),(Col[3]: ScanBytes=Exact(32)),(Col[4]: ScanBytes=Exact(32)),(Col[5]: ScanBytes=Exact(64)),(Col[6]: ScanBytes=Exact(32)),(Col[7]: ScanBytes=Exact(64)),(Col[8]: ScanBytes=Inexact(88)),(Col[9]: ScanBytes=Inexact(49)),(Col[10]: ScanBytes=Exact(64))]] 02)--GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]: ScanBytes=Exact(32)),(Col[1]: ScanBytes=Inexact(24)),(Col[2]: ScanBytes=Exact(32)),(Col[3]: ScanBytes=Exact(32)),(Col[4]: ScanBytes=Exact(32)),(Col[5]: ScanBytes=Exact(64)),(Col[6]: ScanBytes=Exact(32)),(Col[7]: ScanBytes=Exact(64)),(Col[8]: ScanBytes=Inexact(88)),(Col[9]: ScanBytes=Inexact(49)),(Col[10]: ScanBytes=Exact(64))]] @@ -324,7 +330,7 @@ physical_plan after EnsureCooperative SAME TEXT AS ABOVE physical_plan after FilterPushdown(Post) SAME TEXT AS ABOVE physical_plan after SanityCheckPlan SAME TEXT AS ABOVE physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]: ScanBytes=Exact(32)),(Col[1]: ScanBytes=Inexact(24)),(Col[2]: ScanBytes=Exact(32)),(Col[3]: ScanBytes=Exact(32)),(Col[4]: ScanBytes=Exact(32)),(Col[5]: ScanBytes=Exact(64)),(Col[6]: ScanBytes=Exact(32)),(Col[7]: ScanBytes=Exact(64)),(Col[8]: ScanBytes=Inexact(88)),(Col[9]: ScanBytes=Inexact(49)),(Col[10]: ScanBytes=Exact(64))]] -physical_plan_with_schema DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:BinaryView;N, string_col:BinaryView;N, timestamp_col:Timestamp(Nanosecond, None);N] +physical_plan_with_schema DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:BinaryView;N, string_col:BinaryView;N, timestamp_col:Timestamp(ns);N] statement ok @@ -341,8 +347,8 @@ initial_physical_plan_with_stats 01)GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]: ScanBytes=Exact(32)),(Col[1]: ScanBytes=Inexact(24)),(Col[2]: ScanBytes=Exact(32)),(Col[3]: ScanBytes=Exact(32)),(Col[4]: ScanBytes=Exact(32)),(Col[5]: ScanBytes=Exact(64)),(Col[6]: ScanBytes=Exact(32)),(Col[7]: ScanBytes=Exact(64)),(Col[8]: ScanBytes=Inexact(88)),(Col[9]: ScanBytes=Inexact(49)),(Col[10]: ScanBytes=Exact(64))]] 02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]: ScanBytes=Exact(32)),(Col[1]: ScanBytes=Inexact(24)),(Col[2]: ScanBytes=Exact(32)),(Col[3]: ScanBytes=Exact(32)),(Col[4]: ScanBytes=Exact(32)),(Col[5]: ScanBytes=Exact(64)),(Col[6]: ScanBytes=Exact(32)),(Col[7]: ScanBytes=Exact(64)),(Col[8]: ScanBytes=Inexact(88)),(Col[9]: ScanBytes=Inexact(49)),(Col[10]: ScanBytes=Exact(64))]] initial_physical_plan_with_schema -01)GlobalLimitExec: skip=0, fetch=10, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:BinaryView;N, string_col:BinaryView;N, timestamp_col:Timestamp(Nanosecond, None);N] -02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:BinaryView;N, string_col:BinaryView;N, timestamp_col:Timestamp(Nanosecond, None);N] +01)GlobalLimitExec: skip=0, fetch=10, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:BinaryView;N, string_col:BinaryView;N, timestamp_col:Timestamp(ns);N] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:BinaryView;N, string_col:BinaryView;N, timestamp_col:Timestamp(ns);N] physical_plan after OutputRequirements 01)OutputRequirementExec: order_by=[], dist_by=Unspecified 02)--GlobalLimitExec: skip=0, fetch=10 @@ -369,7 +375,7 @@ physical_plan after FilterPushdown(Post) SAME TEXT AS ABOVE physical_plan after SanityCheckPlan SAME TEXT AS ABOVE physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet physical_plan_with_stats DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]: ScanBytes=Exact(32)),(Col[1]: ScanBytes=Inexact(24)),(Col[2]: ScanBytes=Exact(32)),(Col[3]: ScanBytes=Exact(32)),(Col[4]: ScanBytes=Exact(32)),(Col[5]: ScanBytes=Exact(64)),(Col[6]: ScanBytes=Exact(32)),(Col[7]: ScanBytes=Exact(64)),(Col[8]: ScanBytes=Inexact(88)),(Col[9]: ScanBytes=Inexact(49)),(Col[10]: ScanBytes=Exact(64))]] -physical_plan_with_schema DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:BinaryView;N, string_col:BinaryView;N, timestamp_col:Timestamp(Nanosecond, None);N] +physical_plan_with_schema DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:BinaryView;N, string_col:BinaryView;N, timestamp_col:Timestamp(ns);N] statement ok @@ -535,6 +541,7 @@ initial_logical_plan logical_plan after resolve_grouping_function SAME TEXT AS ABOVE logical_plan after type_coercion SAME TEXT AS ABOVE analyzed_logical_plan SAME TEXT AS ABOVE +logical_plan after rewrite_set_comparison SAME TEXT AS ABOVE logical_plan after optimize_unions SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE @@ -555,7 +562,10 @@ logical_plan after push_down_filter SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE +logical_plan after extract_leaf_expressions SAME TEXT AS ABOVE +logical_plan after push_down_leaf_projections SAME TEXT AS ABOVE logical_plan after optimize_projections TableScan: simple_explain_test projection=[a, b, c] +logical_plan after rewrite_set_comparison SAME TEXT AS ABOVE logical_plan after optimize_unions SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE @@ -576,6 +586,8 @@ logical_plan after push_down_filter SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE +logical_plan after extract_leaf_expressions SAME TEXT AS ABOVE +logical_plan after push_down_leaf_projections SAME TEXT AS ABOVE logical_plan after optimize_projections SAME TEXT AS ABOVE logical_plan TableScan: simple_explain_test projection=[a, b, c] initial_physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], file_type=csv, has_header=true diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index cec9b63675a66..c737efca4a6d0 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -60,7 +60,7 @@ SELECT isnan(NULL), iszero(NULL) ---- -NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 NULL NULL NULL # test_array_cast_invalid_timezone_will_panic statement error Parser error: Invalid timezone "Foo": failed to parse timezone @@ -432,6 +432,16 @@ SELECT chr(CAST(0 AS int)) statement error DataFusion error: Execution error: invalid Unicode scalar value: 9223372036854775807 SELECT chr(CAST(9223372036854775807 AS bigint)) +statement error DataFusion error: Execution error: invalid Unicode scalar value: 1114112 +SELECT chr(CAST(1114112 AS bigint)) + +statement error DataFusion error: Execution error: invalid Unicode scalar value: -1 +SELECT chr(CAST(-1 AS bigint)) + +# surrogate code point (invalid scalar value) +statement error DataFusion error: Execution error: invalid Unicode scalar value: 55297 +SELECT chr(CAST(55297 AS bigint)) + query T SELECT concat('a','b','c') ---- @@ -494,6 +504,25 @@ abc statement ok drop table foo +# concat_ws with a Utf8View column as separator +statement ok +create table test_concat_ws_sep (sep varchar, val1 varchar, val2 varchar) as values (',', 'foo', 'bar'), ('|', 'a', 'b'); + +query T +SELECT concat_ws(arrow_cast(sep, 'Utf8View'), val1, val2) FROM test_concat_ws_sep ORDER BY val1 +---- +a|b +foo,bar + +query T +SELECT concat_ws(arrow_cast(sep, 'LargeUtf8'), val1, val2) FROM test_concat_ws_sep ORDER BY val1 +---- +a|b +foo,bar + +statement ok +drop table test_concat_ws_sep + query T SELECT initcap('') ---- @@ -589,7 +618,7 @@ select repeat('-1.2', arrow_cast(3, 'Int32')); ---- -1.2-1.2-1.2 -query error DataFusion error: Error during planning: Internal error: Expect TypeSignatureClass::Native\(LogicalType\(Native\(Int64\), Int64\)\) but received NativeType::Float64, DataType: Float64 +query error DataFusion error: Error during planning: Function 'repeat' requires TypeSignatureClass::Native\(LogicalType\(Native\(Int64\), Int64\)\), but received Float64 \(DataType: Float64\) select repeat('-1.2', 3.2); query T @@ -715,6 +744,27 @@ SELECT to_hex(CAST(NULL AS int)) ---- NULL +query T +SELECT to_hex(0) +---- +0 + +# negative values (two's complement encoding) +query T +SELECT to_hex(-1) +---- +ffffffffffffffff + +query T +SELECT to_hex(CAST(-1 AS INT)) +---- +ffffffffffffffff + +query T +SELECT to_hex(CAST(255 AS TINYINT UNSIGNED)) +---- +ff + query T SELECT trim(' tom ') ---- diff --git a/datafusion/sqllogictest/test_files/extract_tz.slt b/datafusion/sqllogictest/test_files/extract_tz.slt new file mode 100644 index 0000000000000..e0dc37e6965d8 --- /dev/null +++ b/datafusion/sqllogictest/test_files/extract_tz.slt @@ -0,0 +1,175 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Tests for timezone-aware extract SQL statement support. +# Test with different timezone +statement ok +SET datafusion.execution.time_zone = '-03:00'; + +query I +SELECT EXTRACT(HOUR FROM TIMESTAMP '2025-11-18 10:00:00'); +---- +7 + +query II +SELECT EXTRACT(MINUTE FROM TIMESTAMP '2023-10-30 10:45:30'), + EXTRACT(SECOND FROM TIMESTAMP '2023-10-30 10:45:30'); +---- +45 30 + +query III +SELECT EXTRACT(YEAR FROM DATE '2023-10-30'), + EXTRACT(MONTH FROM DATE '2023-10-30'), + EXTRACT(DAY FROM DATE '2023-10-30'); +---- +2023 10 30 + +query I +SELECT EXTRACT(HOUR FROM CAST(NULL AS TIMESTAMP)); +---- +NULL + +statement ok +SET datafusion.execution.time_zone = '+04:00'; + +query I +SELECT EXTRACT(HOUR FROM TIMESTAMP '2023-10-30 02:00:00'); +---- +6 + +query III +SELECT EXTRACT(HOUR FROM TIMESTAMP '2023-10-30 18:20:59'), + EXTRACT(MINUTE FROM TIMESTAMP '2023-10-30 18:20:59'), + EXTRACT(SECOND FROM TIMESTAMP '2023-10-30 18:20:59'); +---- +22 20 59 + +query II +SELECT EXTRACT(DOW FROM DATE '2025-11-01'), + EXTRACT(DOY FROM DATE '2026-12-31'); +---- +6 365 + +statement ok +SET datafusion.execution.time_zone = '+00:00'; + +query I +SELECT EXTRACT(HOUR FROM TIMESTAMP '2025-10-30 10:45:30+02:00'); +---- +8 + +query I +SELECT EXTRACT(HOUR FROM TIMESTAMP '2025-10-30 10:45:30-05:00'); +---- +15 + +query II +SELECT EXTRACT(YEAR FROM TIMESTAMP '2026-11-30 10:45:30Z'), + EXTRACT(MONTH FROM TIMESTAMP '2023-10-30 10:45:30Z'); +---- +2026 10 + +query III +SELECT EXTRACT(HOUR FROM TIMESTAMP '2023-10-30 18:20:59+04:00'), + EXTRACT(MINUTE FROM TIMESTAMP '2023-10-30 18:20:59+04:00'), + EXTRACT(SECOND FROM TIMESTAMP '2023-10-30 18:20:59+04:00'); +---- +14 20 59 + +query II +SELECT EXTRACT(HOUR FROM TIMESTAMP '2025-10-30 10:25:30+02:30'), + EXTRACT(MINUTE FROM TIMESTAMP '2023-10-30 18:20:59-04:30'); +---- +7 50 + +query III +SELECT EXTRACT(HOUR FROM TIMESTAMP '2023-10-30 18:20:59-08:00'), + EXTRACT(DAY FROM TIMESTAMP '2023-10-30 18:20:59-07:00'), + EXTRACT(DAY FROM TIMESTAMP '2023-10-30 07:20:59+12:00'); +---- +2 31 29 + +query IIIIII +SELECT EXTRACT(YEAR FROM TIMESTAMP '2023-12-31 18:20:59-08:45'), + EXTRACT(MONTH FROM TIMESTAMP '2023-12-31 18:20:59-08:45'), + EXTRACT(DAY FROM TIMESTAMP '2023-12-31 18:20:59-08:45'), + EXTRACT(HOUR FROM TIMESTAMP '2023-12-31 18:20:59-08:45'), + EXTRACT(MINUTE FROM TIMESTAMP '2023-12-31 18:20:59-08:45'), + EXTRACT(SECOND FROM TIMESTAMP '2023-12-31 18:20:59-08:45'); +---- +2024 1 1 3 5 59 + +query IIIIII +SELECT EXTRACT(YEAR FROM TIMESTAMP '2024-01-01 03:05:59+08:45'), + EXTRACT(MONTH FROM TIMESTAMP '2024-01-01 03:05:59+08:45'), + EXTRACT(DAY FROM TIMESTAMP '2024-01-01 03:05:59+08:45'), + EXTRACT(HOUR FROM TIMESTAMP '2024-01-01 03:05:59+08:45'), + EXTRACT(MINUTE FROM TIMESTAMP '2024-01-01 03:05:59+08:45'), + EXTRACT(SECOND FROM TIMESTAMP '2024-01-01 03:05:59+08:45'); +---- +2023 12 31 18 20 59 + +statement ok +SET datafusion.execution.time_zone = 'Asia/Kolkata'; + +query IIII +SELECT EXTRACT(HOUR FROM TIMESTAMP '2025-11-22 15:30:45'), +EXTRACT(MINUTE FROM TIMESTAMP '2025-11-22 15:30:45'), +EXTRACT(DOW FROM TIMESTAMP '2025-11-22 00:00:00'), +EXTRACT(SECOND FROM TIMESTAMP '2024-01-01 03:05:59'); +---- +21 0 6 59 + +query I +SELECT EXTRACT(HOUR FROM TIMESTAMP '2025-01-15 10:00:00'); +---- +15 + +statement ok +SET datafusion.execution.time_zone = 'America/New_York'; + +query IIII +SELECT +EXTRACT(HOUR FROM TIMESTAMP '2025-11-22 15:30:45'), +EXTRACT(MINUTE FROM TIMESTAMP '2025-11-22 15:30:45'), +EXTRACT(DOW FROM TIMESTAMP '2025-11-22 00:00:00'), +EXTRACT(SECOND FROM TIMESTAMP '2024-01-01 03:05:59'); +---- +10 30 5 59 + +query I +SELECT EXTRACT(HOUR FROM TIMESTAMP '2025-01-15 10:00:00'); +---- +5 + +statement ok +SET datafusion.execution.time_zone = '-03:30'; + +query II +SELECT EXTRACT(MINUTE FROM TIMESTAMP '2023-10-30 10:45:30'), +EXTRACT(SECOND FROM TIMESTAMP '2023-10-30 10:45:30'); +---- +15 30 + +statement ok +SET datafusion.execution.time_zone = 'America/St_Johns'; + +query II +SELECT EXTRACT(MINUTE FROM TIMESTAMP '2023-10-30 10:45:30'), +EXTRACT(SECOND FROM TIMESTAMP '2023-10-30 10:45:30'); +---- +15 30 diff --git a/datafusion/sqllogictest/test_files/floor_preimage.slt b/datafusion/sqllogictest/test_files/floor_preimage.slt new file mode 100644 index 0000000000000..93302b3d7a2f6 --- /dev/null +++ b/datafusion/sqllogictest/test_files/floor_preimage.slt @@ -0,0 +1,308 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## Floor Preimage Tests +## +## Tests for floor function preimage optimization: +## floor(col) = N transforms to col >= N AND col < N + 1 +## +## Uses representative types only (Float64, Int32, Decimal128). +## Unit tests cover all type variants. +########## + +# Setup: Single table with representative types +statement ok +CREATE TABLE test_data ( + id INT, + float_val DOUBLE, + int_val INT, + decimal_val DECIMAL(10,2) +) AS VALUES + (1, 5.3, 100, 100.00), + (2, 5.7, 101, 100.50), + (3, 6.0, 102, 101.00), + (4, 6.5, -5, 101.99), + (5, 7.0, 0, 102.00), + (6, NULL, NULL, NULL); + +########## +## Data Correctness Tests +########## + +# Float64: floor(x) = 5 matches values in [5.0, 6.0) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) = arrow_cast(5, 'Float64'); +---- +1 +2 + +# Int32: floor(x) = 100 matches values in [100, 101) +query I rowsort +SELECT id FROM test_data WHERE floor(int_val) = 100; +---- +1 + +# Decimal128: floor(x) = 100 matches values in [100.00, 101.00) +query I rowsort +SELECT id FROM test_data WHERE floor(decimal_val) = arrow_cast(100, 'Decimal128(10,2)'); +---- +1 +2 + +# Negative value: floor(x) = -5 matches values in [-5, -4) +query I rowsort +SELECT id FROM test_data WHERE floor(int_val) = -5; +---- +4 + +# Zero value: floor(x) = 0 matches values in [0, 1) +query I rowsort +SELECT id FROM test_data WHERE floor(int_val) = 0; +---- +5 + +# Column on RHS (same result as LHS) +query I rowsort +SELECT id FROM test_data WHERE arrow_cast(5, 'Float64') = floor(float_val); +---- +1 +2 + +# IS NOT DISTINCT FROM (excludes NULLs) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) IS NOT DISTINCT FROM arrow_cast(5, 'Float64'); +---- +1 +2 + +# IS DISTINCT FROM (includes NULLs) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) IS DISTINCT FROM arrow_cast(5, 'Float64'); +---- +3 +4 +5 +6 + +# Non-integer literal (empty result - floor returns integers) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) = arrow_cast(5.5, 'Float64'); +---- + +# IN list: floor(x) IN (5, 7) matches [5.0, 6.0) and [7.0, 8.0) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) IN (arrow_cast(5, 'Float64'), arrow_cast(7, 'Float64')); +---- +1 +2 +5 + +# NOT IN list: floor(x) NOT IN (5, 7) excludes matching ranges and NULLs +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) NOT IN (arrow_cast(5, 'Float64'), arrow_cast(7, 'Float64')); +---- +3 +4 + +########## +## EXPLAIN Tests - Plan Optimization +########## + +statement ok +set datafusion.explain.logical_plan_only = true; + +# 1. Basic: Float64 - floor(col) = N transforms to col >= N AND col < N+1 +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) = arrow_cast(5, 'Float64'); +---- +logical_plan +01)Filter: test_data.float_val >= Float64(5) AND test_data.float_val < Float64(6) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# 2. Basic: Int32 - transformed (coerced to Float64) +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(int_val) = 100; +---- +logical_plan +01)Projection: test_data.id, test_data.float_val, test_data.int_val, test_data.decimal_val +02)--Filter: __common_expr_3 >= Float64(100) AND __common_expr_3 < Float64(101) +03)----Projection: CAST(test_data.int_val AS Float64) AS __common_expr_3, test_data.id, test_data.float_val, test_data.int_val, test_data.decimal_val +04)------TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# 3. Basic: Decimal128 - same transformation +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(decimal_val) = arrow_cast(100, 'Decimal128(10,2)'); +---- +logical_plan +01)Filter: test_data.decimal_val >= Decimal128(Some(10000),10,2) AND test_data.decimal_val < Decimal128(Some(10100),10,2) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# 4. Column on RHS - same transformation +query TT +EXPLAIN SELECT * FROM test_data WHERE arrow_cast(5, 'Float64') = floor(float_val); +---- +logical_plan +01)Filter: test_data.float_val >= Float64(5) AND test_data.float_val < Float64(6) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# 5. IS NOT DISTINCT FROM - adds IS NOT NULL +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) IS NOT DISTINCT FROM arrow_cast(5, 'Float64'); +---- +logical_plan +01)Filter: test_data.float_val IS NOT NULL AND test_data.float_val >= Float64(5) AND test_data.float_val < Float64(6) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# 6. IS DISTINCT FROM - includes NULL check +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) IS DISTINCT FROM arrow_cast(5, 'Float64'); +---- +logical_plan +01)Filter: test_data.float_val < Float64(5) OR test_data.float_val >= Float64(6) OR test_data.float_val IS NULL +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# 7. Non-optimizable: non-integer literal (original predicate preserved) +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) = arrow_cast(5.5, 'Float64'); +---- +logical_plan +01)Filter: floor(test_data.float_val) = Float64(5.5) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# 8. Non-optimizable: extreme float literal (2^53) where n+1 loses precision, so preimage returns None +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) = 9007199254740992; +---- +logical_plan +01)Filter: floor(test_data.float_val) = Float64(9007199254740992) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# 9. IN list: each list item is rewritten with preimage and OR-ed together +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) IN (arrow_cast(5, 'Float64'), arrow_cast(7, 'Float64')); +---- +logical_plan +01)Filter: test_data.float_val >= Float64(5) AND test_data.float_val < Float64(6) OR test_data.float_val >= Float64(7) AND test_data.float_val < Float64(8) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# Data correctness: floor(col) = 2^53 returns no rows (no value in test_data has floor exactly 2^53) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) = 9007199254740992; +---- + +########## +## Other Comparison Operators +## +## The preimage framework automatically handles all comparison operators: +## floor(x) <> N -> x < N OR x >= N+1 +## floor(x) > N -> x >= N+1 +## floor(x) < N -> x < N +## floor(x) >= N -> x >= N +## floor(x) <= N -> x < N+1 +########## + +# Data correctness tests for other operators + +# Not equals: floor(x) <> 5 matches values outside [5.0, 6.0) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) <> arrow_cast(5, 'Float64'); +---- +3 +4 +5 + +# Greater than: floor(x) > 5 matches values in [6.0, inf) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) > arrow_cast(5, 'Float64'); +---- +3 +4 +5 + +# Less than: floor(x) < 6 matches values in (-inf, 6.0) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) < arrow_cast(6, 'Float64'); +---- +1 +2 + +# Greater than or equal: floor(x) >= 5 matches values in [5.0, inf) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) >= arrow_cast(5, 'Float64'); +---- +1 +2 +3 +4 +5 + +# Less than or equal: floor(x) <= 5 matches values in (-inf, 6.0) +query I rowsort +SELECT id FROM test_data WHERE floor(float_val) <= arrow_cast(5, 'Float64'); +---- +1 +2 + +# EXPLAIN tests showing optimized transformations + +# Not equals: floor(x) <> 5 -> x < 5 OR x >= 6 +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) <> arrow_cast(5, 'Float64'); +---- +logical_plan +01)Filter: test_data.float_val < Float64(5) OR test_data.float_val >= Float64(6) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# Greater than: floor(x) > 5 -> x >= 6 +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) > arrow_cast(5, 'Float64'); +---- +logical_plan +01)Filter: test_data.float_val >= Float64(6) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# Less than: floor(x) < 6 -> x < 6 +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) < arrow_cast(6, 'Float64'); +---- +logical_plan +01)Filter: test_data.float_val < Float64(6) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# Greater than or equal: floor(x) >= 5 -> x >= 5 +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) >= arrow_cast(5, 'Float64'); +---- +logical_plan +01)Filter: test_data.float_val >= Float64(5) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +# Less than or equal: floor(x) <= 5 -> x < 6 +query TT +EXPLAIN SELECT * FROM test_data WHERE floor(float_val) <= arrow_cast(5, 'Float64'); +---- +logical_plan +01)Filter: test_data.float_val < Float64(6) +02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val] + +########## +## Cleanup +########## + +statement ok +DROP TABLE test_data; diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 6c87d618c7278..35a32897d03f5 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -239,6 +239,11 @@ SELECT translate('12345', '143', NULL) ---- NULL +query T +SELECT translate(arrow_cast('12345', 'LargeUtf8'), '143', 'ax') +---- +a2x5 + statement ok CREATE TABLE test( c1 VARCHAR diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index cd1ed2bc0caca..294841552a66d 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -4329,9 +4329,9 @@ physical_plan 01)SortPreservingMergeExec: [months@0 DESC], fetch=5 02)--SortExec: TopK(fetch=5), expr=[months@0 DESC], preserve_partitioning=[true] 03)----ProjectionExec: expr=[date_part(Utf8("MONTH"),csv_with_timestamps.ts)@0 as months] -04)------AggregateExec: mode=FinalPartitioned, gby=[date_part(Utf8("MONTH"),csv_with_timestamps.ts)@0 as date_part(Utf8("MONTH"),csv_with_timestamps.ts)], aggr=[] +04)------AggregateExec: mode=FinalPartitioned, gby=[date_part(Utf8("MONTH"),csv_with_timestamps.ts)@0 as date_part(Utf8("MONTH"),csv_with_timestamps.ts)], aggr=[], lim=[5] 05)--------RepartitionExec: partitioning=Hash([date_part(Utf8("MONTH"),csv_with_timestamps.ts)@0], 8), input_partitions=8 -06)----------AggregateExec: mode=Partial, gby=[date_part(MONTH, ts@0) as date_part(Utf8("MONTH"),csv_with_timestamps.ts)], aggr=[] +06)----------AggregateExec: mode=Partial, gby=[date_part(MONTH, ts@0) as date_part(Utf8("MONTH"),csv_with_timestamps.ts)], aggr=[], lim=[5] 07)------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true 08)--------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/timestamps.csv]]}, projection=[ts], output_ordering=[ts@0 DESC], file_type=csv, has_header=false @@ -5478,7 +5478,7 @@ create table source as values ; statement ok -create view t as select column1 as a, arrow_cast(column2, 'Timestamp(Nanosecond, None)') as b from source; +create view t as select column1 as a, arrow_cast(column2, 'Timestamp(ns)') as b from source; query IPI select a, b, count(*) from t group by a, b order by a, b; diff --git a/datafusion/sqllogictest/test_files/grouping_set_repartition.slt b/datafusion/sqllogictest/test_files/grouping_set_repartition.slt new file mode 100644 index 0000000000000..16ab90651c8b3 --- /dev/null +++ b/datafusion/sqllogictest/test_files/grouping_set_repartition.slt @@ -0,0 +1,246 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +# Tests for ROLLUP/CUBE/GROUPING SETS with multiple partitions +# +# This tests the fix for https://github.com/apache/datafusion/issues/19849 +# where ROLLUP queries produced incorrect results with multiple partitions +# because subset partitioning satisfaction was incorrectly applied. +# +# The bug manifests when: +# 1. UNION ALL of subqueries each with hash-partitioned aggregates +# 2. Outer ROLLUP groups by more columns than inner hash partitioning +# 3. InterleaveExec preserves the inner hash partitioning +# 4. Optimizer incorrectly uses subset satisfaction, skipping necessary repartition +# +# The fix ensures that when hash partitioning includes __grouping_id, +# subset satisfaction is disabled and proper RepartitionExec is inserted. +########## + +########## +# SETUP: Create partitioned parquet files to simulate distributed data +########## + +statement ok +set datafusion.execution.target_partitions = 4; + +statement ok +set datafusion.optimizer.repartition_aggregations = true; + +# Create partition 1 +statement ok +COPY (SELECT column1 as channel, column2 as brand, column3 as amount FROM (VALUES + ('store', 'nike', 100), + ('store', 'nike', 200), + ('store', 'adidas', 150) +)) +TO 'test_files/scratch/grouping_set_repartition/part=1/data.parquet' +STORED AS PARQUET; + +# Create partition 2 +statement ok +COPY (SELECT column1 as channel, column2 as brand, column3 as amount FROM (VALUES + ('store', 'adidas', 250), + ('web', 'nike', 300), + ('web', 'nike', 400) +)) +TO 'test_files/scratch/grouping_set_repartition/part=2/data.parquet' +STORED AS PARQUET; + +# Create partition 3 +statement ok +COPY (SELECT column1 as channel, column2 as brand, column3 as amount FROM (VALUES + ('web', 'adidas', 350), + ('web', 'adidas', 450), + ('catalog', 'nike', 500) +)) +TO 'test_files/scratch/grouping_set_repartition/part=3/data.parquet' +STORED AS PARQUET; + +# Create partition 4 +statement ok +COPY (SELECT column1 as channel, column2 as brand, column3 as amount FROM (VALUES + ('catalog', 'nike', 600), + ('catalog', 'adidas', 550), + ('catalog', 'adidas', 650) +)) +TO 'test_files/scratch/grouping_set_repartition/part=4/data.parquet' +STORED AS PARQUET; + +# Create external table pointing to the partitioned data +statement ok +CREATE EXTERNAL TABLE sales (channel VARCHAR, brand VARCHAR, amount INT) +STORED AS PARQUET +PARTITIONED BY (part INT) +LOCATION 'test_files/scratch/grouping_set_repartition/'; + +########## +# TEST 1: UNION ALL + ROLLUP pattern (similar to TPC-DS q14) +# This query pattern triggers the subset satisfaction bug because: +# - Each UNION ALL branch has hash partitioning on (brand) +# - The outer ROLLUP requires hash partitioning on (channel, brand, __grouping_id) +# - Without the fix, subset satisfaction incorrectly skips repartition +# +# Verify the physical plan includes RepartitionExec with __grouping_id +########## + +query TT +EXPLAIN SELECT channel, brand, SUM(total) as grand_total +FROM ( + SELECT 'store' as channel, brand, SUM(amount) as total + FROM sales WHERE channel = 'store' + GROUP BY brand + UNION ALL + SELECT 'web' as channel, brand, SUM(amount) as total + FROM sales WHERE channel = 'web' + GROUP BY brand + UNION ALL + SELECT 'catalog' as channel, brand, SUM(amount) as total + FROM sales WHERE channel = 'catalog' + GROUP BY brand +) sub +GROUP BY ROLLUP(channel, brand) +ORDER BY channel NULLS FIRST, brand NULLS FIRST; +---- +logical_plan +01)Sort: sub.channel ASC NULLS FIRST, sub.brand ASC NULLS FIRST +02)--Projection: sub.channel, sub.brand, sum(sub.total) AS grand_total +03)----Aggregate: groupBy=[[ROLLUP (sub.channel, sub.brand)]], aggr=[[sum(sub.total)]] +04)------SubqueryAlias: sub +05)--------Union +06)----------Projection: Utf8("store") AS channel, sales.brand, sum(sales.amount) AS total +07)------------Aggregate: groupBy=[[sales.brand]], aggr=[[sum(CAST(sales.amount AS Int64))]] +08)--------------Projection: sales.brand, sales.amount +09)----------------Filter: sales.channel = Utf8View("store") +10)------------------TableScan: sales projection=[channel, brand, amount], partial_filters=[sales.channel = Utf8View("store")] +11)----------Projection: Utf8("web") AS channel, sales.brand, sum(sales.amount) AS total +12)------------Aggregate: groupBy=[[sales.brand]], aggr=[[sum(CAST(sales.amount AS Int64))]] +13)--------------Projection: sales.brand, sales.amount +14)----------------Filter: sales.channel = Utf8View("web") +15)------------------TableScan: sales projection=[channel, brand, amount], partial_filters=[sales.channel = Utf8View("web")] +16)----------Projection: Utf8("catalog") AS channel, sales.brand, sum(sales.amount) AS total +17)------------Aggregate: groupBy=[[sales.brand]], aggr=[[sum(CAST(sales.amount AS Int64))]] +18)--------------Projection: sales.brand, sales.amount +19)----------------Filter: sales.channel = Utf8View("catalog") +20)------------------TableScan: sales projection=[channel, brand, amount], partial_filters=[sales.channel = Utf8View("catalog")] +physical_plan +01)SortPreservingMergeExec: [channel@0 ASC, brand@1 ASC] +02)--SortExec: expr=[channel@0 ASC, brand@1 ASC], preserve_partitioning=[true] +03)----ProjectionExec: expr=[channel@0 as channel, brand@1 as brand, sum(sub.total)@3 as grand_total] +04)------AggregateExec: mode=FinalPartitioned, gby=[channel@0 as channel, brand@1 as brand, __grouping_id@2 as __grouping_id], aggr=[sum(sub.total)] +05)--------RepartitionExec: partitioning=Hash([channel@0, brand@1, __grouping_id@2], 4), input_partitions=4 +06)----------AggregateExec: mode=Partial, gby=[(NULL as channel, NULL as brand), (channel@0 as channel, NULL as brand), (channel@0 as channel, brand@1 as brand)], aggr=[sum(sub.total)] +07)------------InterleaveExec +08)--------------ProjectionExec: expr=[store as channel, brand@0 as brand, sum(sales.amount)@1 as total] +09)----------------AggregateExec: mode=FinalPartitioned, gby=[brand@0 as brand], aggr=[sum(sales.amount)] +10)------------------RepartitionExec: partitioning=Hash([brand@0], 4), input_partitions=4 +11)--------------------AggregateExec: mode=Partial, gby=[brand@0 as brand], aggr=[sum(sales.amount)] +12)----------------------FilterExec: channel@0 = store, projection=[brand@1, amount@2] +13)------------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=1/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=2/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=3/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=4/data.parquet]]}, projection=[channel, brand, amount], file_type=parquet, predicate=channel@0 = store, pruning_predicate=channel_null_count@2 != row_count@3 AND channel_min@0 <= store AND store <= channel_max@1, required_guarantees=[channel in (store)] +14)--------------ProjectionExec: expr=[web as channel, brand@0 as brand, sum(sales.amount)@1 as total] +15)----------------AggregateExec: mode=FinalPartitioned, gby=[brand@0 as brand], aggr=[sum(sales.amount)] +16)------------------RepartitionExec: partitioning=Hash([brand@0], 4), input_partitions=4 +17)--------------------AggregateExec: mode=Partial, gby=[brand@0 as brand], aggr=[sum(sales.amount)] +18)----------------------FilterExec: channel@0 = web, projection=[brand@1, amount@2] +19)------------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=1/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=2/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=3/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=4/data.parquet]]}, projection=[channel, brand, amount], file_type=parquet, predicate=channel@0 = web, pruning_predicate=channel_null_count@2 != row_count@3 AND channel_min@0 <= web AND web <= channel_max@1, required_guarantees=[channel in (web)] +20)--------------ProjectionExec: expr=[catalog as channel, brand@0 as brand, sum(sales.amount)@1 as total] +21)----------------AggregateExec: mode=FinalPartitioned, gby=[brand@0 as brand], aggr=[sum(sales.amount)] +22)------------------RepartitionExec: partitioning=Hash([brand@0], 4), input_partitions=4 +23)--------------------AggregateExec: mode=Partial, gby=[brand@0 as brand], aggr=[sum(sales.amount)] +24)----------------------FilterExec: channel@0 = catalog, projection=[brand@1, amount@2] +25)------------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=1/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=2/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=3/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=4/data.parquet]]}, projection=[channel, brand, amount], file_type=parquet, predicate=channel@0 = catalog, pruning_predicate=channel_null_count@2 != row_count@3 AND channel_min@0 <= catalog AND catalog <= channel_max@1, required_guarantees=[channel in (catalog)] + +query TTI rowsort +SELECT channel, brand, SUM(total) as grand_total +FROM ( + SELECT 'store' as channel, brand, SUM(amount) as total + FROM sales WHERE channel = 'store' + GROUP BY brand + UNION ALL + SELECT 'web' as channel, brand, SUM(amount) as total + FROM sales WHERE channel = 'web' + GROUP BY brand + UNION ALL + SELECT 'catalog' as channel, brand, SUM(amount) as total + FROM sales WHERE channel = 'catalog' + GROUP BY brand +) sub +GROUP BY ROLLUP(channel, brand) +ORDER BY channel NULLS FIRST, brand NULLS FIRST; +---- +NULL NULL 4500 +catalog NULL 2300 +catalog adidas 1200 +catalog nike 1100 +store NULL 700 +store adidas 400 +store nike 300 +web NULL 1500 +web adidas 800 +web nike 700 + +########## +# TEST 2: Simple ROLLUP (baseline test) +########## + +query TTI rowsort +SELECT channel, brand, SUM(amount) as total +FROM sales +GROUP BY ROLLUP(channel, brand) +ORDER BY channel NULLS FIRST, brand NULLS FIRST; +---- +NULL NULL 4500 +catalog NULL 2300 +catalog adidas 1200 +catalog nike 1100 +store NULL 700 +store adidas 400 +store nike 300 +web NULL 1500 +web adidas 800 +web nike 700 + +########## +# TEST 3: Verify CUBE also works correctly +########## + +query TTI rowsort +SELECT channel, brand, SUM(amount) as total +FROM sales +GROUP BY CUBE(channel, brand) +ORDER BY channel NULLS FIRST, brand NULLS FIRST; +---- +NULL NULL 4500 +NULL adidas 2400 +NULL nike 2100 +catalog NULL 2300 +catalog adidas 1200 +catalog nike 1100 +store NULL 700 +store adidas 400 +store nike 300 +web NULL 1500 +web adidas 800 +web nike 700 + +########## +# CLEANUP +########## + +statement ok +DROP TABLE sales; diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 2039ee93df837..b61ceecb24fc0 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -297,6 +297,7 @@ datafusion.optimizer.enable_aggregate_dynamic_filter_pushdown true datafusion.optimizer.enable_distinct_aggregation_soft_limit true datafusion.optimizer.enable_dynamic_filter_pushdown true datafusion.optimizer.enable_join_dynamic_filter_pushdown true +datafusion.optimizer.enable_leaf_expression_pushdown true datafusion.optimizer.enable_piecewise_merge_join false datafusion.optimizer.enable_round_robin_repartition true datafusion.optimizer.enable_sort_pushdown true @@ -395,7 +396,7 @@ datafusion.execution.parquet.skip_arrow_metadata false (writing) Skip encoding t datafusion.execution.parquet.skip_metadata true (reading) If true, the parquet reader skip the optional embedded metadata that may be in the file Schema. This setting can help avoid schema conflicts when querying multiple parquet files with schemas containing compatible types but different metadata datafusion.execution.parquet.statistics_enabled page (writing) Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting datafusion.execution.parquet.statistics_truncate_length 64 (writing) Sets statistics truncate length. If NULL, uses default parquet writer setting -datafusion.execution.parquet.write_batch_size 1024 (writing) Sets write_batch_size in bytes +datafusion.execution.parquet.write_batch_size 1024 (writing) Sets write_batch_size in rows datafusion.execution.parquet.writer_version 1.0 (writing) Sets parquet writer version valid values are "1.0" and "2.0" datafusion.execution.perfect_hash_join_min_key_density 0.15 The minimum required density of join keys on the build side to consider a perfect hash join (see `HashJoinExec` for more details). Density is calculated as: `(number of rows) / (max_key - min_key + 1)`. A perfect hash join may be used if the actual key density > this value. Currently only supports cases where build_side.num_rows() < u32::MAX. Support for build_side.num_rows() >= u32::MAX will be added in the future. datafusion.execution.perfect_hash_join_small_build_threshold 1024 A perfect hash join (see `HashJoinExec` for more details) will be considered if the range of keys (max - min) on the build side is < this threshold. This provides a fast path for joins with very small key ranges, bypassing the density check. Currently only supports cases where build_side.num_rows() < u32::MAX. Support for build_side.num_rows() >= u32::MAX will be added in the future. @@ -434,6 +435,7 @@ datafusion.optimizer.enable_aggregate_dynamic_filter_pushdown true When set to t datafusion.optimizer.enable_distinct_aggregation_soft_limit true When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. datafusion.optimizer.enable_dynamic_filter_pushdown true When set to true attempts to push down dynamic filters generated by operators (TopK, Join & Aggregate) into the file scan phase. For example, for a query such as `SELECT * FROM t ORDER BY timestamp DESC LIMIT 10`, the optimizer will attempt to push down the current top 10 timestamps that the TopK operator references into the file scans. This means that if we already have 10 timestamps in the year 2025 any files that only have timestamps in the year 2024 can be skipped / pruned at various stages in the scan. The config will suppress `enable_join_dynamic_filter_pushdown`, `enable_topk_dynamic_filter_pushdown` & `enable_aggregate_dynamic_filter_pushdown` So if you disable `enable_topk_dynamic_filter_pushdown`, then enable `enable_dynamic_filter_pushdown`, the `enable_topk_dynamic_filter_pushdown` will be overridden. datafusion.optimizer.enable_join_dynamic_filter_pushdown true When set to true, the optimizer will attempt to push down Join dynamic filters into the file scan phase. +datafusion.optimizer.enable_leaf_expression_pushdown true When set to true, the optimizer will extract leaf expressions (such as `get_field`) from filter/sort/join nodes into projections closer to the leaf table scans, and push those projections down towards the leaf nodes. datafusion.optimizer.enable_piecewise_merge_join false When set to true, piecewise merge join is enabled. PiecewiseMergeJoin is currently experimental. Physical planner will opt for PiecewiseMergeJoin when there is only one range filter. datafusion.optimizer.enable_round_robin_repartition true When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores datafusion.optimizer.enable_sort_pushdown true Enable sort pushdown optimization. When enabled, attempts to push sort requirements down to data sources that can natively handle them (e.g., by reversing file/row group read order). Returns **inexact ordering**: Sort operator is kept for correctness, but optimized input enables early termination for TopK queries (ORDER BY ... LIMIT N), providing significant speedup. Memory: No additional overhead (only changes read order). Future: Will add option to detect perfectly sorted data and eliminate Sort completely. Default: true @@ -799,9 +801,9 @@ select * from information_schema.routines where routine_name = 'date_trunc' OR r ---- datafusion public date_trunc datafusion public date_trunc FUNCTION true Date SCALAR Truncates a timestamp or time value to a specified precision. date_trunc(precision, expression) datafusion public date_trunc datafusion public date_trunc FUNCTION true String SCALAR Truncates a timestamp or time value to a specified precision. date_trunc(precision, expression) -datafusion public date_trunc datafusion public date_trunc FUNCTION true Time(Nanosecond) SCALAR Truncates a timestamp or time value to a specified precision. date_trunc(precision, expression) -datafusion public date_trunc datafusion public date_trunc FUNCTION true Timestamp(Nanosecond, None) SCALAR Truncates a timestamp or time value to a specified precision. date_trunc(precision, expression) -datafusion public date_trunc datafusion public date_trunc FUNCTION true Timestamp(Nanosecond, Some("+TZ")) SCALAR Truncates a timestamp or time value to a specified precision. date_trunc(precision, expression) +datafusion public date_trunc datafusion public date_trunc FUNCTION true Time(ns) SCALAR Truncates a timestamp or time value to a specified precision. date_trunc(precision, expression) +datafusion public date_trunc datafusion public date_trunc FUNCTION true Timestamp(ns) SCALAR Truncates a timestamp or time value to a specified precision. date_trunc(precision, expression) +datafusion public date_trunc datafusion public date_trunc FUNCTION true Timestamp(ns, "+TZ") SCALAR Truncates a timestamp or time value to a specified precision. date_trunc(precision, expression) datafusion public rank datafusion public rank FUNCTION true NULL WINDOW Returns the rank of the current row within its partition, allowing gaps between ranks. This function provides a ranking similar to `row_number`, but skips ranks for identical values. rank() datafusion public string_agg datafusion public string_agg FUNCTION true String AGGREGATE Concatenates the values of string expressions and places separator values between them. If ordering is required, strings are concatenated in the specified order. This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the first argument expression. string_agg([DISTINCT] expression, delimiter [ORDER BY expression]) @@ -821,14 +823,14 @@ datafusion public date_trunc 1 IN precision String NULL false 1 datafusion public date_trunc 2 IN expression String NULL false 1 datafusion public date_trunc 1 OUT NULL String NULL false 1 datafusion public date_trunc 1 IN precision String NULL false 2 -datafusion public date_trunc 2 IN expression Time(Nanosecond) NULL false 2 -datafusion public date_trunc 1 OUT NULL Time(Nanosecond) NULL false 2 +datafusion public date_trunc 2 IN expression Time(ns) NULL false 2 +datafusion public date_trunc 1 OUT NULL Time(ns) NULL false 2 datafusion public date_trunc 1 IN precision String NULL false 3 -datafusion public date_trunc 2 IN expression Timestamp(Nanosecond, None) NULL false 3 -datafusion public date_trunc 1 OUT NULL Timestamp(Nanosecond, None) NULL false 3 +datafusion public date_trunc 2 IN expression Timestamp(ns) NULL false 3 +datafusion public date_trunc 1 OUT NULL Timestamp(ns) NULL false 3 datafusion public date_trunc 1 IN precision String NULL false 4 -datafusion public date_trunc 2 IN expression Timestamp(Nanosecond, Some("+TZ")) NULL false 4 -datafusion public date_trunc 1 OUT NULL Timestamp(Nanosecond, Some("+TZ")) NULL false 4 +datafusion public date_trunc 2 IN expression Timestamp(ns, "+TZ") NULL false 4 +datafusion public date_trunc 1 OUT NULL Timestamp(ns, "+TZ") NULL false 4 datafusion public string_agg 2 IN delimiter Null NULL false 0 datafusion public string_agg 1 IN expression String NULL false 0 datafusion public string_agg 1 OUT NULL String NULL false 0 @@ -856,9 +858,9 @@ show functions like 'date_trunc'; ---- date_trunc Date [precision, expression] [String, Date] SCALAR Truncates a timestamp or time value to a specified precision. date_trunc(precision, expression) date_trunc String [precision, expression] [String, String] SCALAR Truncates a timestamp or time value to a specified precision. date_trunc(precision, expression) -date_trunc Time(Nanosecond) [precision, expression] [String, Time(Nanosecond)] SCALAR Truncates a timestamp or time value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Nanosecond, None) [precision, expression] [String, Timestamp(Nanosecond, None)] SCALAR Truncates a timestamp or time value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Nanosecond, Some("+TZ")) [precision, expression] [String, Timestamp(Nanosecond, Some("+TZ"))] SCALAR Truncates a timestamp or time value to a specified precision. date_trunc(precision, expression) +date_trunc Time(ns) [precision, expression] [String, Time(ns)] SCALAR Truncates a timestamp or time value to a specified precision. date_trunc(precision, expression) +date_trunc Timestamp(ns) [precision, expression] [String, Timestamp(ns)] SCALAR Truncates a timestamp or time value to a specified precision. date_trunc(precision, expression) +date_trunc Timestamp(ns, "+TZ") [precision, expression] [String, Timestamp(ns, "+TZ")] SCALAR Truncates a timestamp or time value to a specified precision. date_trunc(precision, expression) statement ok show functions diff --git a/datafusion/sqllogictest/test_files/insert.slt b/datafusion/sqllogictest/test_files/insert.slt index 8ef2596f18e33..e7b9e77dfef58 100644 --- a/datafusion/sqllogictest/test_files/insert.slt +++ b/datafusion/sqllogictest/test_files/insert.slt @@ -165,7 +165,7 @@ ORDER BY c1 ---- logical_plan 01)Dml: op=[Insert Into] table=[table_without_values] -02)--Projection: a1 AS a1, a2 AS a2 +02)--Projection: a1, a2 03)----Sort: aggregate_test_100.c1 ASC NULLS LAST 04)------Projection: sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS a1, count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS a2, aggregate_test_100.c1 05)--------WindowAggr: windowExpr=[[sum(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] diff --git a/datafusion/sqllogictest/test_files/join.slt.part b/datafusion/sqllogictest/test_files/join.slt.part index 5d111374ac8cf..c0a838c97d552 100644 --- a/datafusion/sqllogictest/test_files/join.slt.part +++ b/datafusion/sqllogictest/test_files/join.slt.part @@ -973,19 +973,19 @@ ON e.emp_id = d.emp_id WHERE ((dept_name != 'Engineering' AND e.name = 'Alice') OR (name != 'Alice' AND e.name = 'Carol')); ---- logical_plan -01)Filter: d.dept_name != Utf8View("Engineering") AND e.name = Utf8View("Alice") OR e.name != Utf8View("Alice") AND e.name = Utf8View("Carol") +01)Filter: d.dept_name != Utf8View("Engineering") AND e.name = Utf8View("Alice") OR e.name = Utf8View("Carol") 02)--Projection: e.emp_id, e.name, d.dept_name 03)----Left Join: e.emp_id = d.emp_id 04)------SubqueryAlias: e -05)--------Filter: employees.name = Utf8View("Alice") OR employees.name != Utf8View("Alice") AND employees.name = Utf8View("Carol") +05)--------Filter: employees.name = Utf8View("Alice") OR employees.name = Utf8View("Carol") 06)----------TableScan: employees projection=[emp_id, name] 07)------SubqueryAlias: d 08)--------TableScan: department projection=[emp_id, dept_name] physical_plan -01)FilterExec: dept_name@2 != Engineering AND name@1 = Alice OR name@1 != Alice AND name@1 = Carol +01)FilterExec: dept_name@2 != Engineering AND name@1 = Alice OR name@1 = Carol 02)--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 03)----HashJoinExec: mode=CollectLeft, join_type=Left, on=[(emp_id@0, emp_id@0)], projection=[emp_id@0, name@1, dept_name@3] -04)------FilterExec: name@1 = Alice OR name@1 != Alice AND name@1 = Carol +04)------FilterExec: name@1 = Alice OR name@1 = Carol 05)--------DataSourceExec: partitions=1, partition_sizes=[1] 06)------DataSourceExec: partitions=1, partition_sizes=[1] diff --git a/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt b/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt index c16b3528aa7a5..59f3d8285af49 100644 --- a/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt +++ b/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt @@ -55,7 +55,7 @@ logical_plan 07)--------TableScan: annotated_data projection=[a, c] physical_plan 01)SortPreservingMergeExec: [a@0 ASC NULLS LAST], fetch=5 -02)--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(c@0, c@1)], projection=[a@1] +02)--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(c@0, c@1)], projection=[a@1], fetch=5 03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c], file_type=csv, has_header=true 04)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, maintains_sort_order=true 05)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_ordering=[a@0 ASC NULLS LAST], file_type=csv, has_header=true @@ -88,18 +88,22 @@ logical_plan 02)--Projection: t2.a AS a2, t2.b 03)----RightSemi Join: t1.d = t2.d, t1.c = t2.c 04)------SubqueryAlias: t1 -05)--------TableScan: annotated_data projection=[c, d] -06)------SubqueryAlias: t2 -07)--------Filter: annotated_data.d = Int32(3) -08)----------TableScan: annotated_data projection=[a, b, c, d], partial_filters=[annotated_data.d = Int32(3)] +05)--------Filter: annotated_data.d = Int32(3) +06)----------TableScan: annotated_data projection=[c, d], partial_filters=[annotated_data.d = Int32(3)] +07)------SubqueryAlias: t2 +08)--------Filter: annotated_data.d = Int32(3) +09)----------TableScan: annotated_data projection=[a, b, c, d], partial_filters=[annotated_data.d = Int32(3)] physical_plan 01)SortPreservingMergeExec: [a2@0 ASC NULLS LAST, b@1 ASC NULLS LAST], fetch=10 02)--ProjectionExec: expr=[a@0 as a2, b@1 as b] -03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(d@1, d@3), (c@0, c@2)], projection=[a@0, b@1] -04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], file_type=csv, has_header=true -05)------FilterExec: d@3 = 3 -06)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, maintains_sort_order=true -07)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], file_type=csv, has_header=true +03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(d@1, d@3), (c@0, c@2)], projection=[a@0, b@1], fetch=10 +04)------CoalescePartitionsExec +05)--------FilterExec: d@1 = 3 +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], file_type=csv, has_header=true +08)------FilterExec: d@3 = 3 +09)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, maintains_sort_order=true +10)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], file_type=csv, has_header=true # preserve_right_semi_join query II nosort diff --git a/datafusion/sqllogictest/test_files/join_limit_pushdown.slt b/datafusion/sqllogictest/test_files/join_limit_pushdown.slt new file mode 100644 index 0000000000000..6bb23c1b4c243 --- /dev/null +++ b/datafusion/sqllogictest/test_files/join_limit_pushdown.slt @@ -0,0 +1,269 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Tests for limit pushdown into joins + +# need to use a single partition for deterministic results +statement ok +set datafusion.execution.target_partitions = 1; + +statement ok +set datafusion.explain.logical_plan_only = false; + +statement ok +set datafusion.optimizer.prefer_hash_join = true; + +# Create test tables +statement ok +CREATE TABLE t1 (a INT, b VARCHAR) AS VALUES + (1, 'one'), + (2, 'two'), + (3, 'three'), + (4, 'four'), + (5, 'five'); + +statement ok +CREATE TABLE t2 (x INT, y VARCHAR) AS VALUES + (1, 'alpha'), + (2, 'beta'), + (3, 'gamma'), + (6, 'delta'), + (7, 'epsilon'); + +query TT +EXPLAIN SELECT t1.a, t2.x FROM t1 INNER JOIN t2 ON t1.a = t2.x LIMIT 2; +---- +logical_plan +01)Limit: skip=0, fetch=2 +02)--Inner Join: t1.a = t2.x +03)----TableScan: t1 projection=[a] +04)----TableScan: t2 projection=[x] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, x@0)], fetch=2 +02)--DataSourceExec: partitions=1, partition_sizes=[1] +03)--DataSourceExec: partitions=1, partition_sizes=[1] + +query II +SELECT t1.a, t2.x FROM t1 INNER JOIN t2 ON t1.a = t2.x LIMIT 2; +---- +1 1 +2 2 + +# Right join is converted to Left join with projection - fetch pushdown is supported +query TT +EXPLAIN SELECT t1.a, t2.x FROM t1 RIGHT JOIN t2 ON t1.a = t2.x LIMIT 3; +---- +logical_plan +01)Limit: skip=0, fetch=3 +02)--Right Join: t1.a = t2.x +03)----TableScan: t1 projection=[a] +04)----Limit: skip=0, fetch=3 +05)------TableScan: t2 projection=[x], fetch=3 +physical_plan +01)ProjectionExec: expr=[a@1 as a, x@0 as x] +02)--HashJoinExec: mode=CollectLeft, join_type=Left, on=[(x@0, a@0)], fetch=3 +03)----DataSourceExec: partitions=1, partition_sizes=[1], fetch=3 +04)----DataSourceExec: partitions=1, partition_sizes=[1] + +query II +SELECT t1.a, t2.x FROM t1 RIGHT JOIN t2 ON t1.a = t2.x LIMIT 3; +---- +1 1 +2 2 +3 3 + +# Left join supports fetch pushdown +query TT +EXPLAIN SELECT t1.a, t2.x FROM t1 LEFT JOIN t2 ON t1.a = t2.x LIMIT 3; +---- +logical_plan +01)Limit: skip=0, fetch=3 +02)--Left Join: t1.a = t2.x +03)----Limit: skip=0, fetch=3 +04)------TableScan: t1 projection=[a], fetch=3 +05)----TableScan: t2 projection=[x] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=Left, on=[(a@0, x@0)], fetch=3 +02)--DataSourceExec: partitions=1, partition_sizes=[1], fetch=3 +03)--DataSourceExec: partitions=1, partition_sizes=[1] + +query II +SELECT t1.a, t2.x FROM t1 LEFT JOIN t2 ON t1.a = t2.x LIMIT 3; +---- +1 1 +2 2 +3 3 + + +# Full join supports fetch pushdown +query TT +EXPLAIN SELECT t1.a, t2.x FROM t1 FULL OUTER JOIN t2 ON t1.a = t2.x LIMIT 4; +---- +logical_plan +01)Limit: skip=0, fetch=4 +02)--Full Join: t1.a = t2.x +03)----TableScan: t1 projection=[a] +04)----TableScan: t2 projection=[x] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=Full, on=[(a@0, x@0)], fetch=4 +02)--DataSourceExec: partitions=1, partition_sizes=[1] +03)--DataSourceExec: partitions=1, partition_sizes=[1] + +# Note: FULL OUTER JOIN order is not deterministic, so we just check count +query I +SELECT COUNT(*) FROM (SELECT t1.a, t2.x FROM t1 FULL OUTER JOIN t2 ON t1.a = t2.x LIMIT 4); +---- +4 + +# EXISTS becomes left semi join - fetch pushdown is supported +query TT +EXPLAIN SELECT t2.x FROM t2 WHERE EXISTS (SELECT 1 FROM t1 WHERE t1.a = t2.x) LIMIT 2; +---- +logical_plan +01)Limit: skip=0, fetch=2 +02)--LeftSemi Join: t2.x = __correlated_sq_1.a +03)----TableScan: t2 projection=[x] +04)----SubqueryAlias: __correlated_sq_1 +05)------TableScan: t1 projection=[a] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=LeftSemi, on=[(x@0, a@0)], fetch=2 +02)--DataSourceExec: partitions=1, partition_sizes=[1] +03)--DataSourceExec: partitions=1, partition_sizes=[1] + +query I +SELECT t2.x FROM t2 WHERE EXISTS (SELECT 1 FROM t1 WHERE t1.a = t2.x) LIMIT 2; +---- +1 +2 + +# NOT EXISTS becomes LeftAnti - fetch pushdown is supported +query TT +EXPLAIN SELECT t2.x FROM t2 WHERE NOT EXISTS (SELECT 1 FROM t1 WHERE t1.a = t2.x) LIMIT 1; +---- +logical_plan +01)Limit: skip=0, fetch=1 +02)--LeftAnti Join: t2.x = __correlated_sq_1.a +03)----TableScan: t2 projection=[x] +04)----SubqueryAlias: __correlated_sq_1 +05)------TableScan: t1 projection=[a] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=LeftAnti, on=[(x@0, a@0)], fetch=1 +02)--DataSourceExec: partitions=1, partition_sizes=[1] +03)--DataSourceExec: partitions=1, partition_sizes=[1] + +query I +SELECT t2.x FROM t2 WHERE NOT EXISTS (SELECT 1 FROM t1 WHERE t1.a = t2.x) LIMIT 1; +---- +6 + +# Inner join should push +query TT +EXPLAIN SELECT t1.a, t2.x FROM t1 INNER JOIN t2 ON t1.a = t2.x LIMIT 1 OFFSET 1; +---- +logical_plan +01)Limit: skip=1, fetch=1 +02)--Inner Join: t1.a = t2.x +03)----TableScan: t1 projection=[a] +04)----TableScan: t2 projection=[x] +physical_plan +01)GlobalLimitExec: skip=1, fetch=1 +02)--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, x@0)], fetch=2 +03)----DataSourceExec: partitions=1, partition_sizes=[1] +04)----DataSourceExec: partitions=1, partition_sizes=[1] + +query II +SELECT t1.a, t2.x FROM t1 INNER JOIN t2 ON t1.a = t2.x LIMIT 1 OFFSET 1; +---- +2 2 + +query TT +EXPLAIN SELECT t1.a, t2.x FROM t1 INNER JOIN t2 ON t1.a = t2.x LIMIT 0; +---- +logical_plan EmptyRelation: rows=0 +physical_plan EmptyExec + +query II +SELECT t1.a, t2.x FROM t1 INNER JOIN t2 ON t1.a = t2.x LIMIT 0; +---- + +statement ok +CREATE TABLE t3 (p INT, q VARCHAR) AS VALUES + (1, 'foo'), + (2, 'bar'), + (3, 'baz'); + +query TT +EXPLAIN SELECT t1.a, t2.x, t3.p +FROM t1 +INNER JOIN t2 ON t1.a = t2.x +INNER JOIN t3 ON t2.x = t3.p +LIMIT 2; +---- +logical_plan +01)Limit: skip=0, fetch=2 +02)--Inner Join: t2.x = t3.p +03)----Inner Join: t1.a = t2.x +04)------TableScan: t1 projection=[a] +05)------TableScan: t2 projection=[x] +06)----TableScan: t3 projection=[p] +physical_plan +01)ProjectionExec: expr=[a@1 as a, x@2 as x, p@0 as p] +02)--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(p@0, x@1)], fetch=2 +03)----DataSourceExec: partitions=1, partition_sizes=[1] +04)----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, x@0)] +05)------DataSourceExec: partitions=1, partition_sizes=[1] +06)------DataSourceExec: partitions=1, partition_sizes=[1] + +query III +SELECT t1.a, t2.x, t3.p +FROM t1 +INNER JOIN t2 ON t1.a = t2.x +INNER JOIN t3 ON t2.x = t3.p +LIMIT 2; +---- +1 1 1 +2 2 2 + +# Try larger limit +query TT +EXPLAIN SELECT t1.a, t2.x FROM t1 INNER JOIN t2 ON t1.a = t2.x LIMIT 100; +---- +logical_plan +01)Limit: skip=0, fetch=100 +02)--Inner Join: t1.a = t2.x +03)----TableScan: t1 projection=[a] +04)----TableScan: t2 projection=[x] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, x@0)], fetch=100 +02)--DataSourceExec: partitions=1, partition_sizes=[1] +03)--DataSourceExec: partitions=1, partition_sizes=[1] + +query II +SELECT t1.a, t2.x FROM t1 INNER JOIN t2 ON t1.a = t2.x LIMIT 100; +---- +1 1 +2 2 +3 3 + +statement ok +DROP TABLE t1; + +statement ok +DROP TABLE t2; + +statement ok +DROP TABLE t3; diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 38037ede21db2..2fb544a638d61 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -57,15 +57,15 @@ statement ok CREATE TABLE join_t3(s3 struct) AS VALUES (NULL), - (struct(1)), - (struct(2)); + ({id: 1}), + ({id: 2}); statement ok CREATE TABLE join_t4(s4 struct) AS VALUES (NULL), - (struct(2)), - (struct(3)); + ({id: 2}), + ({id: 3}); # Left semi anti join @@ -146,10 +146,10 @@ AS VALUES statement ok CREATE TABLE test_timestamps_table as SELECT - arrow_cast(ts::timestamp::bigint, 'Timestamp(Nanosecond, None)') as nanos, - arrow_cast(ts::timestamp::bigint / 1000, 'Timestamp(Microsecond, None)') as micros, - arrow_cast(ts::timestamp::bigint / 1000000, 'Timestamp(Millisecond, None)') as millis, - arrow_cast(ts::timestamp::bigint / 1000000000, 'Timestamp(Second, None)') as secs, + arrow_cast(ts::timestamp::bigint, 'Timestamp(ns)') as nanos, + arrow_cast(ts::timestamp::bigint / 1000, 'Timestamp(µs)') as micros, + arrow_cast(ts::timestamp::bigint / 1000000, 'Timestamp(ms)') as millis, + arrow_cast(ts::timestamp::bigint / 1000000000, 'Timestamp(s)') as secs, names FROM test_timestamps_table_source; @@ -2085,7 +2085,7 @@ SELECT join_t1.t1_id, join_t2.t2_id FROM (select t1_id from join_t1 where join_t1.t1_id > 22) as join_t1 RIGHT JOIN (select t2_id from join_t2 where join_t2.t2_id > 11) as join_t2 ON join_t1.t1_id < join_t2.t2_id -ORDER BY 1, 2 +ORDER BY 1, 2 ---- 33 44 33 55 @@ -3516,7 +3516,6 @@ AS VALUES query IT SELECT t1_id, t1_name FROM join_test_left WHERE t1_id NOT IN (SELECT t2_id FROM join_test_right) ORDER BY t1_id; ---- -NULL e #### # join_partitioned_test @@ -3955,7 +3954,7 @@ query TT explain select t1_id, t1_name, i from join_t1 t1 cross join lateral (select * from unnest(generate_series(1, t1_int))) as series(i); ---- logical_plan -01)Cross Join: +01)Cross Join: 02)--SubqueryAlias: t1 03)----TableScan: join_t1 projection=[t1_id, t1_name] 04)--SubqueryAlias: series @@ -4162,10 +4161,9 @@ logical_plan 03)----TableScan: t0 projection=[c1, c2] 04)----TableScan: t1 projection=[c1, c2, c3] physical_plan -01)GlobalLimitExec: skip=0, fetch=2 -02)--HashJoinExec: mode=CollectLeft, join_type=Full, on=[(c1@0, c1@0)] -03)----DataSourceExec: partitions=1, partition_sizes=[2] -04)----DataSourceExec: partitions=1, partition_sizes=[2] +01)HashJoinExec: mode=CollectLeft, join_type=Full, on=[(c1@0, c1@0)], fetch=2 +02)--DataSourceExec: partitions=1, partition_sizes=[2] +03)--DataSourceExec: partitions=1, partition_sizes=[2] ## Test join.on.is_empty() && join.filter.is_some() -> single filter now a PWMJ query TT @@ -4192,10 +4190,9 @@ logical_plan 03)----TableScan: t0 projection=[c1, c2] 04)----TableScan: t1 projection=[c1, c2, c3] physical_plan -01)GlobalLimitExec: skip=0, fetch=2 -02)--HashJoinExec: mode=CollectLeft, join_type=Full, on=[(c1@0, c1@0)], filter=c2@0 >= c2@1 -03)----DataSourceExec: partitions=1, partition_sizes=[2] -04)----DataSourceExec: partitions=1, partition_sizes=[2] +01)HashJoinExec: mode=CollectLeft, join_type=Full, on=[(c1@0, c1@0)], filter=c2@0 >= c2@1, fetch=2 +02)--DataSourceExec: partitions=1, partition_sizes=[2] +03)--DataSourceExec: partitions=1, partition_sizes=[2] ## Add more test cases for join limit pushdown statement ok @@ -4246,6 +4243,7 @@ select * from t1 LEFT JOIN t2 ON t1.a = t2.b LIMIT 2; 1 1 # can only push down to t1 (preserved side) +# limit pushdown supported for left join - both to join and probe side query TT explain select * from t1 LEFT JOIN t2 ON t1.a = t2.b LIMIT 2; ---- @@ -4256,10 +4254,9 @@ logical_plan 04)------TableScan: t1 projection=[a], fetch=2 05)----TableScan: t2 projection=[b] physical_plan -01)GlobalLimitExec: skip=0, fetch=2 -02)--HashJoinExec: mode=CollectLeft, join_type=Left, on=[(a@0, b@0)] -03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/joins/t1.csv]]}, projection=[a], limit=2, file_type=csv, has_header=true -04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/joins/t2.csv]]}, projection=[b], file_type=csv, has_header=true +01)HashJoinExec: mode=CollectLeft, join_type=Left, on=[(a@0, b@0)], fetch=2 +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/joins/t1.csv]]}, projection=[a], limit=2, file_type=csv, has_header=true +03)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/joins/t2.csv]]}, projection=[b], file_type=csv, has_header=true ###### ## RIGHT JOIN w/ LIMIT @@ -4290,10 +4287,9 @@ logical_plan 04)----Limit: skip=0, fetch=2 05)------TableScan: t2 projection=[b], fetch=2 physical_plan -01)GlobalLimitExec: skip=0, fetch=2 -02)--HashJoinExec: mode=CollectLeft, join_type=Right, on=[(a@0, b@0)] -03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/joins/t1.csv]]}, projection=[a], file_type=csv, has_header=true -04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/joins/t2.csv]]}, projection=[b], limit=2, file_type=csv, has_header=true +01)HashJoinExec: mode=CollectLeft, join_type=Right, on=[(a@0, b@0)], fetch=2 +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/joins/t1.csv]]}, projection=[a], file_type=csv, has_header=true +03)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/joins/t2.csv]]}, projection=[b], limit=2, file_type=csv, has_header=true ###### ## FULL JOIN w/ LIMIT @@ -4317,7 +4313,7 @@ select * from t1 FULL JOIN t2 ON t1.a = t2.b LIMIT 2; 4 4 -# can't push limit for full outer join +# full outer join supports fetch pushdown query TT explain select * from t1 FULL JOIN t2 ON t1.a = t2.b LIMIT 2; ---- @@ -4327,10 +4323,9 @@ logical_plan 03)----TableScan: t1 projection=[a] 04)----TableScan: t2 projection=[b] physical_plan -01)GlobalLimitExec: skip=0, fetch=2 -02)--HashJoinExec: mode=CollectLeft, join_type=Full, on=[(a@0, b@0)] -03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/joins/t1.csv]]}, projection=[a], file_type=csv, has_header=true -04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/joins/t2.csv]]}, projection=[b], file_type=csv, has_header=true +01)HashJoinExec: mode=CollectLeft, join_type=Full, on=[(a@0, b@0)], fetch=2 +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/joins/t1.csv]]}, projection=[a], file_type=csv, has_header=true +03)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/joins/t2.csv]]}, projection=[b], file_type=csv, has_header=true statement ok drop table t1; @@ -4368,10 +4363,9 @@ logical_plan physical_plan 01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)] 02)--AggregateExec: mode=Single, gby=[], aggr=[count(Int64(1))] -03)----ProjectionExec: expr=[] -04)------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(binary_col@0, binary_col@0)] -05)--------DataSourceExec: partitions=1, partition_sizes=[1] -06)--------DataSourceExec: partitions=1, partition_sizes=[1] +03)----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(binary_col@0, binary_col@0)], projection=[] +04)------DataSourceExec: partitions=1, partition_sizes=[1] +05)------DataSourceExec: partitions=1, partition_sizes=[1] # Test hash join sort push down # Issue: https://github.com/apache/datafusion/issues/13559 @@ -4533,7 +4527,7 @@ query TT explain SELECT * FROM person a NATURAL JOIN lineitem b; ---- logical_plan -01)Cross Join: +01)Cross Join: 02)--SubqueryAlias: a 03)----TableScan: person projection=[id, age, state] 04)--SubqueryAlias: b @@ -4579,7 +4573,7 @@ query TT explain SELECT j1_string, j2_string FROM j1, LATERAL (SELECT * FROM j2 WHERE j1_id < j2_id) AS j2; ---- logical_plan -01)Cross Join: +01)Cross Join: 02)--TableScan: j1 projection=[j1_string] 03)--SubqueryAlias: j2 04)----Projection: j2.j2_string @@ -4592,7 +4586,7 @@ query TT explain SELECT * FROM j1 JOIN (j2 JOIN j3 ON(j2_id = j3_id - 2)) ON(j1_id = j2_id), LATERAL (SELECT * FROM j3 WHERE j3_string = j2_string) as j4 ---- logical_plan -01)Cross Join: +01)Cross Join: 02)--Inner Join: CAST(j2.j2_id AS Int64) = CAST(j3.j3_id AS Int64) - Int64(2) 03)----Inner Join: j1.j1_id = j2.j2_id 04)------TableScan: j1 projection=[j1_string, j1_id] @@ -4608,11 +4602,11 @@ query TT explain SELECT * FROM j1, LATERAL (SELECT * FROM j1, LATERAL (SELECT * FROM j2 WHERE j1_id = j2_id) as j2) as j2; ---- logical_plan -01)Cross Join: +01)Cross Join: 02)--TableScan: j1 projection=[j1_string, j1_id] 03)--SubqueryAlias: j2 04)----Subquery: -05)------Cross Join: +05)------Cross Join: 06)--------TableScan: j1 projection=[j1_string, j1_id] 07)--------SubqueryAlias: j2 08)----------Subquery: @@ -4624,7 +4618,7 @@ query TT explain SELECT j1_string, j2_string FROM j1 LEFT JOIN LATERAL (SELECT * FROM j2 WHERE j1_id < j2_id) AS j2 ON(true); ---- logical_plan -01)Left Join: +01)Left Join: 02)--TableScan: j1 projection=[j1_string] 03)--SubqueryAlias: j2 04)----Projection: j2.j2_string @@ -4637,9 +4631,9 @@ query TT explain SELECT * FROM j1, (j2 LEFT JOIN LATERAL (SELECT * FROM j3 WHERE j1_id + j2_id = j3_id) AS j3 ON(true)); ---- logical_plan -01)Cross Join: +01)Cross Join: 02)--TableScan: j1 projection=[j1_string, j1_id] -03)--Left Join: +03)--Left Join: 04)----TableScan: j2 projection=[j2_string, j2_id] 05)----SubqueryAlias: j3 06)------Subquery: @@ -4651,7 +4645,7 @@ query TT explain SELECT * FROM j1, LATERAL (SELECT 1) AS j2; ---- logical_plan -01)Cross Join: +01)Cross Join: 02)--TableScan: j1 projection=[j1_string, j1_id] 03)--SubqueryAlias: j2 04)----Projection: Int64(1) @@ -4993,7 +4987,7 @@ FULL JOIN t2 ON k1 = k2 # LEFT MARK JOIN query TT -EXPLAIN +EXPLAIN SELECT * FROM t2 WHERE k2 > 0 @@ -5050,9 +5044,10 @@ WHERE k1 < 0 ---- physical_plan 01)HashJoinExec: mode=CollectLeft, join_type=RightAnti, on=[(k2@0, k1@0)] -02)--DataSourceExec: partitions=1, partition_sizes=[0] -03)--FilterExec: k1@0 < 0 -04)----DataSourceExec: partitions=1, partition_sizes=[10000] +02)--FilterExec: k2@0 < 0 +03)----DataSourceExec: partitions=1, partition_sizes=[0] +04)--FilterExec: k1@0 < 0 +05)----DataSourceExec: partitions=1, partition_sizes=[10000] query II SELECT * @@ -5067,14 +5062,14 @@ CREATE OR REPLACE TABLE t1(b INT, c INT, d INT); statement ok INSERT INTO t1 VALUES - (10, 5, 3), - ( 1, 7, 8), - ( 2, 9, 7), - ( 3, 8,10), - ( 5, 6, 6), - ( 0, 4, 9), - ( 4, 8, 7), - (100,6, 5); + (10, 5, 3), + ( 1, 7, 8), + ( 2, 9, 7), + ( 3, 8,10), + ( 5, 6, 6), + ( 0, 4, 9), + ( 4, 8, 7), + (100,6, 5); query I rowsort SELECT c @@ -5198,3 +5193,100 @@ DROP TABLE t1_c; statement ok DROP TABLE t2_c; + +# Reproducer of https://github.com/apache/datafusion/issues/19067 +statement count 0 +set datafusion.explain.physical_plan_only = true; + +# Setup Left Table with FixedSizeBinary(4) +statement count 0 +CREATE TABLE issue_19067_left AS +SELECT + column1 as id, + arrow_cast(decode(column2, 'hex'), 'FixedSizeBinary(4)') as join_key +FROM (VALUES + (1, 'AAAAAAAA'), + (2, 'BBBBBBBB'), + (3, 'CCCCCCCC') +); + +# Setup Right Table with FixedSizeBinary(4) +statement count 0 +CREATE TABLE issue_19067_right AS +SELECT + arrow_cast(decode(column1, 'hex'), 'FixedSizeBinary(4)') as join_key, + column2 as value +FROM (VALUES + ('AAAAAAAA', 1000), + ('BBBBBBBB', 2000) +); + +# Perform Left Join. Third row should contain NULL in `right_key`. +query I??I +SELECT + l.id, + l.join_key as left_key, + r.join_key as right_key, + r.value +FROM issue_19067_left l +LEFT JOIN issue_19067_right r ON l.join_key = r.join_key +ORDER BY l.id; +---- +1 aaaaaaaa aaaaaaaa 1000 +2 bbbbbbbb bbbbbbbb 2000 +3 cccccccc NULL NULL + +# Ensure usage of HashJoinExec +query TT +EXPLAIN +SELECT + l.id, + l.join_key as left_key, + r.join_key as right_key, + r.value +FROM issue_19067_left l +LEFT JOIN issue_19067_right r ON l.join_key = r.join_key +ORDER BY l.id; +---- +physical_plan +01)SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--ProjectionExec: expr=[id@2 as id, join_key@3 as left_key, join_key@0 as right_key, value@1 as value] +03)----HashJoinExec: mode=CollectLeft, join_type=Right, on=[(join_key@0, join_key@1)] +04)------DataSourceExec: partitions=1, partition_sizes=[1] +05)------DataSourceExec: partitions=1, partition_sizes=[1] + +statement count 0 +set datafusion.explain.physical_plan_only = false; + +statement count 0 +DROP TABLE issue_19067_left; + +statement count 0 +DROP TABLE issue_19067_right; + +# Test that empty projections pushed into joins produce correct row counts at runtime. +# When count(1) is used over a RIGHT/FULL JOIN, the optimizer embeds an empty projection +# (projection=[]) into the HashJoinExec. This validates that the runtime batch construction +# handles zero-column output correctly, preserving the correct number of rows. + +statement ok +CREATE TABLE empty_proj_left AS VALUES (1, 'a'), (2, 'b'), (3, 'c'); + +statement ok +CREATE TABLE empty_proj_right AS VALUES (1, 'x'), (2, 'y'), (4, 'z'); + +query I +SELECT count(1) FROM empty_proj_left RIGHT JOIN empty_proj_right ON empty_proj_left.column1 = empty_proj_right.column1; +---- +3 + +query I +SELECT count(1) FROM empty_proj_left FULL JOIN empty_proj_right ON empty_proj_left.column1 = empty_proj_right.column1; +---- +4 + +statement count 0 +DROP TABLE empty_proj_left; + +statement count 0 +DROP TABLE empty_proj_right; diff --git a/datafusion/sqllogictest/test_files/json.slt b/datafusion/sqllogictest/test_files/json.slt index b46b8c49d6623..60bec4213db02 100644 --- a/datafusion/sqllogictest/test_files/json.slt +++ b/datafusion/sqllogictest/test_files/json.slt @@ -146,3 +146,31 @@ EXPLAIN SELECT id FROM json_partitioned_test WHERE part = 2 ---- logical_plan TableScan: json_partitioned_test projection=[id], full_filters=[json_partitioned_test.part = Int32(2)] physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/partitioned_table_json/part=2/data.json]]}, projection=[id], file_type=json + +########## +## JSON Array Format Tests +########## + +# Test reading JSON array format file with newline_delimited=false +statement ok +CREATE EXTERNAL TABLE json_array_test +STORED AS JSON +OPTIONS ('format.newline_delimited' 'false') +LOCATION '../core/tests/data/json_array.json'; + +query IT rowsort +SELECT a, b FROM json_array_test +---- +1 hello +2 world +3 test + +statement ok +DROP TABLE json_array_test; + +# Test that reading JSON array format WITHOUT newline_delimited option fails +# (default is newline_delimited=true which can't parse array format correctly) +statement error Not valid JSON +CREATE EXTERNAL TABLE json_array_as_ndjson +STORED AS JSON +LOCATION '../core/tests/data/json_array.json'; diff --git a/datafusion/sqllogictest/test_files/limit.slt b/datafusion/sqllogictest/test_files/limit.slt index 524304546d569..ec8363f51acfa 100644 --- a/datafusion/sqllogictest/test_files/limit.slt +++ b/datafusion/sqllogictest/test_files/limit.slt @@ -706,8 +706,8 @@ ON t1.b = t2.b ORDER BY t1.b desc, c desc, c2 desc OFFSET 3 LIMIT 2; ---- -3 99 82 -3 99 79 +3 98 79 +3 97 96 statement ok drop table ordered_table; diff --git a/datafusion/sqllogictest/test_files/limit_pruning.slt b/datafusion/sqllogictest/test_files/limit_pruning.slt new file mode 100644 index 0000000000000..72672b707d4f5 --- /dev/null +++ b/datafusion/sqllogictest/test_files/limit_pruning.slt @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +set datafusion.execution.parquet.pushdown_filters = true; + + +statement ok +CREATE TABLE tracking_data AS VALUES +-- ***** Row Group 0 ***** + ('Anow Vole', 7), + ('Brown Bear', 133), + ('Gray Wolf', 82), +-- ***** Row Group 1 ***** + ('Lynx', 71), + ('Red Fox', 40), + ('Alpine Bat', 6), +-- ***** Row Group 2 ***** + ('Nlpine Ibex', 101), + ('Nlpine Goat', 76), + ('Nlpine Sheep', 83), +-- ***** Row Group 3 ***** + ('Europ. Mole', 4), + ('Polecat', 16), + ('Alpine Ibex', 97); + +statement ok +COPY (SELECT column1 as species, column2 as s FROM tracking_data) +TO 'test_files/scratch/limit_pruning/data.parquet' +STORED AS PARQUET +OPTIONS ( + 'format.max_row_group_size' '3' +); + +statement ok +drop table tracking_data; + +statement ok +CREATE EXTERNAL TABLE tracking_data +STORED AS PARQUET +LOCATION 'test_files/scratch/limit_pruning/data.parquet'; + + +statement ok +set datafusion.explain.analyze_level = summary; + +# row_groups_pruned_statistics=4 total → 3 matched -> 1 fully matched +# limit_pruned_row_groups=2 total → 0 matched +query TT +explain analyze select * from tracking_data where species > 'M' AND s >= 50 limit 3; +---- +Plan with Metrics DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/limit_pruning/data.parquet]]}, projection=[species, s], limit=3, file_type=parquet, predicate=species@0 > M AND s@1 >= 50, pruning_predicate=species_null_count@1 != row_count@2 AND species_max@0 > M AND s_null_count@4 != row_count@2 AND s_max@3 >= 50, required_guarantees=[], metrics=[output_rows=3, elapsed_compute=, output_bytes=, files_ranges_pruned_statistics=1 total → 1 matched, row_groups_pruned_statistics=4 total → 3 matched -> 1 fully matched, row_groups_pruned_bloom_filter=3 total → 3 matched, page_index_pages_pruned=2 total → 2 matched, limit_pruned_row_groups=2 total → 0 matched, bytes_scanned=, metadata_load_time=, scan_efficiency_ratio= (171/2.35 K)] + +# limit_pruned_row_groups=0 total → 0 matched +# because of order by, scan needs to preserve sort, so limit pruning is disabled +query TT +explain analyze select * from tracking_data where species > 'M' AND s >= 50 order by species limit 3; +---- +Plan with Metrics +01)SortExec: TopK(fetch=3), expr=[species@0 ASC NULLS LAST], preserve_partitioning=[false], filter=[species@0 < Nlpine Sheep], metrics=[output_rows=3, elapsed_compute=, output_bytes=] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/limit_pruning/data.parquet]]}, projection=[species, s], file_type=parquet, predicate=species@0 > M AND s@1 >= 50 AND DynamicFilter [ species@0 < Nlpine Sheep ], pruning_predicate=species_null_count@1 != row_count@2 AND species_max@0 > M AND s_null_count@4 != row_count@2 AND s_max@3 >= 50 AND species_null_count@1 != row_count@2 AND species_min@5 < Nlpine Sheep, required_guarantees=[], metrics=[output_rows=3, elapsed_compute=, output_bytes=, files_ranges_pruned_statistics=1 total → 1 matched, row_groups_pruned_statistics=4 total → 3 matched -> 1 fully matched, row_groups_pruned_bloom_filter=3 total → 3 matched, page_index_pages_pruned=6 total → 6 matched, limit_pruned_row_groups=0 total → 0 matched, bytes_scanned=, metadata_load_time=, scan_efficiency_ratio= (521/2.35 K)] + +statement ok +drop table tracking_data; + +statement ok +reset datafusion.explain.analyze_level; diff --git a/datafusion/sqllogictest/test_files/limit_single_row_batches.slt b/datafusion/sqllogictest/test_files/limit_single_row_batches.slt new file mode 100644 index 0000000000000..9f626816e2146 --- /dev/null +++ b/datafusion/sqllogictest/test_files/limit_single_row_batches.slt @@ -0,0 +1,22 @@ + +# minimize batch size to 1 in order to trigger different code paths +statement ok +set datafusion.execution.batch_size = '1'; + +# ---- +# tests with target partition set to 1 +# ---- +statement ok +set datafusion.execution.target_partitions = '1'; + + +statement ok +CREATE TABLE filter_limit (i INT) as values (1), (2); + +query I +SELECT COUNT(*) FROM (SELECT i FROM filter_limit WHERE i <> 0 LIMIT 1); +---- +1 + +statement ok +DROP TABLE filter_limit; diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index 71a969c751591..2227466fdf254 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -111,12 +111,44 @@ SELECT isnan(1.0::FLOAT), isnan('NaN'::FLOAT), isnan(-'NaN'::FLOAT), isnan(NULL: ---- false true true NULL +# isnan: non-float numeric inputs are never NaN +query BBBB +SELECT isnan(1::INT), isnan(0::INT), isnan(NULL::INT), isnan(123::BIGINT) +---- +false false NULL false + +query BBBB +SELECT isnan(1::INT UNSIGNED), isnan(0::INT UNSIGNED), isnan(NULL::INT UNSIGNED), isnan(255::TINYINT UNSIGNED) +---- +false false NULL false + +query BBBB +SELECT isnan(1::DECIMAL(10,2)), isnan(0::DECIMAL(10,2)), isnan(NULL::DECIMAL(10,2)), isnan(-1::DECIMAL(10,2)) +---- +false false NULL false + # iszero query BBBB SELECT iszero(1.0), iszero(0.0), iszero(-0.0), iszero(NULL) ---- false true true NULL +# iszero: integers / unsigned / decimals +query BBBB +SELECT iszero(1::INT), iszero(0::INT), iszero(NULL::INT), iszero(-1::INT) +---- +false true NULL false + +query BBBB +SELECT iszero(1::INT UNSIGNED), iszero(0::INT UNSIGNED), iszero(NULL::INT UNSIGNED), iszero(255::TINYINT UNSIGNED) +---- +false true NULL false + +query BBBB +SELECT iszero(1::DECIMAL(10,2)), iszero(0::DECIMAL(10,2)), iszero(NULL::DECIMAL(10,2)), iszero(-1::DECIMAL(10,2)) +---- +false true NULL false + # abs: empty argument statement error SELECT abs(); diff --git a/datafusion/sqllogictest/test_files/null_aware_anti_join.slt b/datafusion/sqllogictest/test_files/null_aware_anti_join.slt new file mode 100644 index 0000000000000..5907a85a9b923 --- /dev/null +++ b/datafusion/sqllogictest/test_files/null_aware_anti_join.slt @@ -0,0 +1,453 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +############# +## Null-Aware Anti Join Tests +## Tests for automatic null-aware semantics in NOT IN subqueries +############# + +statement ok +CREATE TABLE outer_table(id INT, value TEXT) AS VALUES +(1, 'a'), +(2, 'b'), +(3, 'c'), +(4, 'd'), +(NULL, 'e'); + +statement ok +CREATE TABLE inner_table_no_null(id INT, value TEXT) AS VALUES +(2, 'x'), +(4, 'y'); + +statement ok +CREATE TABLE inner_table_with_null(id INT, value TEXT) AS VALUES +(2, 'x'), +(NULL, 'y'); + +############# +## Test 1: NOT IN with no NULLs - should behave like regular anti join +############# + +query IT rowsort +SELECT * FROM outer_table WHERE id NOT IN (SELECT id FROM inner_table_no_null); +---- +1 a +3 c + +# Verify the plan uses LeftAnti join +query TT +EXPLAIN SELECT * FROM outer_table WHERE id NOT IN (SELECT id FROM inner_table_no_null); +---- +logical_plan +01)LeftAnti Join: outer_table.id = __correlated_sq_1.id +02)--TableScan: outer_table projection=[id, value] +03)--SubqueryAlias: __correlated_sq_1 +04)----TableScan: inner_table_no_null projection=[id] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=LeftAnti, on=[(id@0, id@0)] +02)--DataSourceExec: partitions=1, partition_sizes=[1] +03)--DataSourceExec: partitions=1, partition_sizes=[1] + +############# +## Test 2: NOT IN with NULL in subquery - should return 0 rows (null-aware semantics) +############# + +query IT rowsort +SELECT * FROM outer_table WHERE id NOT IN (SELECT id FROM inner_table_with_null); +---- + +# Verify the result is empty even though there are rows in outer_table +# that don't match the non-NULL value (2) in the subquery. +# This is correct null-aware behavior: if subquery contains NULL, result is unknown. + +############# +## Test 3: NOT IN with NULL in outer table but not in subquery +## NULL rows from outer should not appear in output +############# + +query IT rowsort +SELECT * FROM outer_table WHERE id NOT IN (SELECT id FROM inner_table_no_null) AND id IS NOT NULL; +---- +1 a +3 c + +############# +## Test 4: Test with all NULL subquery +############# + +statement ok +CREATE TABLE all_null_table(id INT) AS VALUES (NULL), (NULL); + +query IT rowsort +SELECT * FROM outer_table WHERE id NOT IN (SELECT id FROM all_null_table); +---- + +############# +## Test 5: Test with empty subquery - should return all rows +############# + +statement ok +CREATE TABLE empty_table(id INT, value TEXT); + +query IT rowsort +SELECT * FROM outer_table WHERE id NOT IN (SELECT id FROM empty_table); +---- +1 a +2 b +3 c +4 d +NULL e + +############# +## Test 6: NOT IN with complex expression +############# + +query IT rowsort +SELECT * FROM outer_table WHERE id + 1 NOT IN (SELECT id FROM inner_table_no_null); +---- +2 b +4 d + +############# +## Test 7: NOT IN with complex expression and NULL in subquery +############# + +query IT rowsort +SELECT * FROM outer_table WHERE id + 1 NOT IN (SELECT id FROM inner_table_with_null); +---- + +############# +## Test 8: Multiple NOT IN conditions (AND) +############# + +statement ok +CREATE TABLE inner_table2(id INT) AS VALUES (1), (3); + +query IT rowsort +SELECT * FROM outer_table +WHERE id NOT IN (SELECT id FROM inner_table_no_null) + AND id NOT IN (SELECT id FROM inner_table2); +---- + +############# +## Test 9: Multiple NOT IN conditions (OR) +############# + +# KNOWN LIMITATION: Mark joins used for OR conditions don't support null-aware semantics. +# The NULL row is incorrectly returned here. According to SQL semantics: +# - NULL NOT IN (2, 4) = UNKNOWN +# - NULL NOT IN (1, 3) = UNKNOWN +# - UNKNOWN OR UNKNOWN = UNKNOWN (should be filtered out) +# But mark joins treat NULL keys as non-matching (FALSE), so: +# - NULL mark column = FALSE +# - NOT FALSE OR NOT FALSE = TRUE OR TRUE = TRUE (incorrectly included) +# TODO: Implement null-aware support for mark joins to fix this + +query IT rowsort +SELECT * FROM outer_table +WHERE id NOT IN (SELECT id FROM inner_table_no_null) + OR id NOT IN (SELECT id FROM inner_table2); +---- +1 a +2 b +3 c +4 d +NULL e + +############# +## Test 10: NOT IN with WHERE clause in subquery +############# + +query IT rowsort +SELECT * FROM outer_table +WHERE id NOT IN (SELECT id FROM inner_table_with_null WHERE value = 'x'); +---- +1 a +3 c +4 d + +# Note: The NULL row from inner_table_with_null is filtered out by WHERE clause, +# so this behaves like regular anti join (not null-aware) + +############# +## Test 11: Verify NULL-aware flag is set for LeftAnti joins +############# + +# Check that the physical plan shows null-aware anti join +# Note: The exact format may vary, but we should see LeftAnti join type +query TT +EXPLAIN SELECT * FROM outer_table WHERE id NOT IN (SELECT id FROM inner_table_with_null); +---- +logical_plan +01)LeftAnti Join: outer_table.id = __correlated_sq_1.id +02)--TableScan: outer_table projection=[id, value] +03)--SubqueryAlias: __correlated_sq_1 +04)----TableScan: inner_table_with_null projection=[id] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=LeftAnti, on=[(id@0, id@0)] +02)--DataSourceExec: partitions=1, partition_sizes=[1] +03)--DataSourceExec: partitions=1, partition_sizes=[1] + +############# +## Test 12: Correlated NOT IN subquery with NULL +############# + +statement ok +CREATE TABLE orders(order_id INT, customer_id INT) AS VALUES +(1, 100), +(2, 200), +(3, 300); + +statement ok +CREATE TABLE payments(payment_id INT, order_id INT) AS VALUES +(1, 1), +(2, NULL); + +# Find orders that don't have payments +# Should return empty because there's a NULL in payments.order_id +query I rowsort +SELECT order_id FROM orders +WHERE order_id NOT IN (SELECT order_id FROM payments); +---- + +############# +## Test 13: NOT IN with DISTINCT in subquery +############# + +statement ok +CREATE TABLE duplicates_with_null(id INT) AS VALUES +(2), +(2), +(NULL), +(NULL); + +query IT rowsort +SELECT * FROM outer_table +WHERE id NOT IN (SELECT DISTINCT id FROM duplicates_with_null); +---- + +############# +## Test 14: NOT EXISTS vs NOT IN - Demonstrating the difference +############# + +# NOT EXISTS should NOT use null-aware semantics +# It uses two-valued logic (TRUE/FALSE), not three-valued logic (TRUE/FALSE/UNKNOWN) + +# Setup tables for comparison +statement ok +CREATE TABLE customers(id INT, name TEXT) AS VALUES +(1, 'Alice'), +(2, 'Bob'), +(3, 'Charlie'), +(NULL, 'Dave'); + +statement ok +CREATE TABLE banned(id INT) AS VALUES +(2), +(NULL); + +# Test 14a: NOT IN with NULL in subquery - Returns EMPTY (null-aware) +query IT rowsort +SELECT * FROM customers WHERE id NOT IN (SELECT id FROM banned); +---- + +# Test 14b: NOT EXISTS with NULL in subquery - Returns rows (NOT null-aware) +# This should return (1, 'Alice'), (3, 'Charlie'), (NULL, 'Dave') +# Because NOT EXISTS uses two-valued logic: NULL = NULL is FALSE, so no match found +query IT rowsort +SELECT * FROM customers c +WHERE NOT EXISTS (SELECT 1 FROM banned b WHERE c.id = b.id); +---- +1 Alice +3 Charlie +NULL Dave + +# Test 14c: Verify with EXPLAIN that NOT EXISTS doesn't use null-aware +query TT +EXPLAIN SELECT * FROM customers c +WHERE NOT EXISTS (SELECT 1 FROM banned b WHERE c.id = b.id); +---- +logical_plan +01)LeftAnti Join: c.id = __correlated_sq_1.id +02)--SubqueryAlias: c +03)----TableScan: customers projection=[id, name] +04)--SubqueryAlias: __correlated_sq_1 +05)----SubqueryAlias: b +06)------TableScan: banned projection=[id] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=RightAnti, on=[(id@0, id@0)] +02)--DataSourceExec: partitions=1, partition_sizes=[1] +03)--DataSourceExec: partitions=1, partition_sizes=[1] + +############# +## Test 15: NOT EXISTS - No NULLs +############# + +statement ok +CREATE TABLE active_customers(id INT) AS VALUES (1), (3); + +# Should return only Bob (id=2) and Dave (id=NULL) +query IT rowsort +SELECT * FROM customers c +WHERE NOT EXISTS (SELECT 1 FROM active_customers a WHERE c.id = a.id); +---- +2 Bob +NULL Dave + +############# +## Test 16: NOT EXISTS - Correlated subquery +############# + +statement ok +CREATE TABLE orders_test(order_id INT, customer_id INT) AS VALUES +(1, 100), +(2, 200), +(3, NULL); + +statement ok +CREATE TABLE customers_test(customer_id INT, name TEXT) AS VALUES +(100, 'Alice'), +(200, 'Bob'), +(300, 'Charlie'), +(NULL, 'Unknown'); + +# Find customers with no orders +# Should return Charlie (300) and Unknown (NULL) +query IT rowsort +SELECT * FROM customers_test c +WHERE NOT EXISTS ( + SELECT 1 FROM orders_test o WHERE o.customer_id = c.customer_id +); +---- +300 Charlie +NULL Unknown + +############# +## Test 17: NOT EXISTS with all NULL subquery +############# + +statement ok +CREATE TABLE all_null_banned(id INT) AS VALUES (NULL), (NULL); + +# NOT EXISTS should return all rows because NULL = NULL is FALSE (no matches) +query IT rowsort +SELECT * FROM customers c +WHERE NOT EXISTS (SELECT 1 FROM all_null_banned b WHERE c.id = b.id); +---- +1 Alice +2 Bob +3 Charlie +NULL Dave + +# Compare with NOT IN which returns empty +query IT rowsort +SELECT * FROM customers WHERE id NOT IN (SELECT id FROM all_null_banned); +---- + +############# +## Test 18: Nested NOT EXISTS and NOT IN +############# + +# NOT EXISTS outside, NOT IN inside - should work correctly +query IT rowsort +SELECT * FROM customers c +WHERE NOT EXISTS ( + SELECT 1 FROM banned b + WHERE c.id = b.id + AND b.id NOT IN (SELECT id FROM active_customers) +); +---- +1 Alice +3 Charlie +NULL Dave + +############# +## Test from GitHub issue #10583 +## Tests NOT IN with NULL in subquery result - should return empty result +############# + +statement ok +CREATE TABLE test_table(c1 INT, c2 INT) AS VALUES +(1, 1), +(2, 2), +(3, 3), +(4, NULL), +(NULL, 0); + +# When subquery contains NULL, NOT IN should return empty result +# because NULL NOT IN (values including NULL) is UNKNOWN for all rows +query II rowsort +SELECT * FROM test_table WHERE (c1 NOT IN (SELECT c2 FROM test_table)) = true; +---- + +# NOTE: The correlated subquery version from issue #10583: +# SELECT * FROM test_table t1 WHERE c1 NOT IN (SELECT c2 FROM test_table t2 WHERE t1.c1 = t2.c1) +# is not yet supported because it creates a multi-column join (correlation + NOT IN condition). +# This is a known limitation - currently only supports single column null-aware anti joins. +# This will be addressed in next Phase (multi-column support). + +############# +## Cleanup +############# + +statement ok +DROP TABLE test_table; + +statement ok +DROP TABLE outer_table; + +statement ok +DROP TABLE inner_table_no_null; + +statement ok +DROP TABLE inner_table_with_null; + +statement ok +DROP TABLE all_null_table; + +statement ok +DROP TABLE empty_table; + +statement ok +DROP TABLE inner_table2; + +statement ok +DROP TABLE orders; + +statement ok +DROP TABLE payments; + +statement ok +DROP TABLE duplicates_with_null; + +statement ok +DROP TABLE customers; + +statement ok +DROP TABLE banned; + +statement ok +DROP TABLE active_customers; + +statement ok +DROP TABLE orders_test; + +statement ok +DROP TABLE customers_test; + +statement ok +DROP TABLE all_null_banned; diff --git a/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt b/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt index aa94e2e2f2c04..e2473ee328e51 100644 --- a/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt +++ b/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt @@ -674,3 +674,66 @@ logical_plan physical_plan 01)SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] 02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/array_data/data.parquet]]}, projection=[id, tags], file_type=parquet, predicate=id@0 > 1 AND array_has(tags@1, rust), pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1, required_guarantees=[] + +### +# Test filter pushdown through UNION with mixed support +# This tests the case where one child supports filter pushdown (parquet) and one doesn't (memory table) +### + +# enable filter pushdown +statement ok +set datafusion.execution.parquet.pushdown_filters = true; + +statement ok +set datafusion.optimizer.max_passes = 0; + +# Create memory table with matching schema (a: VARCHAR, b: BIGINT) +statement ok +CREATE TABLE t_union_mem(a VARCHAR, b BIGINT) AS VALUES ('qux', 4), ('quux', 5); + +# Create parquet table with matching schema +statement ok +CREATE EXTERNAL TABLE t_union_parquet(a VARCHAR, b BIGINT) STORED AS PARQUET +LOCATION 'test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet'; + +# Query results combining memory table and Parquet with filter +query I rowsort +SELECT b FROM ( + SELECT a, b FROM t_union_mem + UNION ALL + SELECT a, b FROM t_union_parquet +) WHERE b > 2; +---- +3 +4 +5 +50 + +# Explain the union query - filter should be pushed to parquet but not memory table +query TT +EXPLAIN SELECT b FROM ( + SELECT a, b FROM t_union_mem + UNION ALL + SELECT a, b FROM t_union_parquet +) WHERE b > 2; +---- +logical_plan +01)Projection: b +02)--Filter: b > Int64(2) +03)----Union +04)------Projection: t_union_mem.a, t_union_mem.b +05)--------TableScan: t_union_mem +06)------Projection: t_union_parquet.a, t_union_parquet.b +07)--------TableScan: t_union_parquet +physical_plan +01)UnionExec +02)--FilterExec: b@0 > 2 +03)----DataSourceExec: partitions=1, partition_sizes=[1] +04)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet]]}, projection=[b], file_type=parquet, predicate=b@1 > 2, pruning_predicate=b_null_count@1 != row_count@2 AND b_max@0 > 2, required_guarantees=[] + +# Clean up union test tables +statement ok +DROP TABLE t_union_mem; + +statement ok +DROP TABLE t_union_parquet; diff --git a/datafusion/sqllogictest/test_files/parquet_sorted_statistics.slt b/datafusion/sqllogictest/test_files/parquet_sorted_statistics.slt index 5a559bdb94835..fd3a40ca17079 100644 --- a/datafusion/sqllogictest/test_files/parquet_sorted_statistics.slt +++ b/datafusion/sqllogictest/test_files/parquet_sorted_statistics.slt @@ -274,4 +274,4 @@ logical_plan 02)--TableScan: test_table projection=[constant_col] physical_plan 01)SortPreservingMergeExec: [constant_col@0 ASC NULLS LAST] -02)--DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=A/0.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=B/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=C/2.parquet]]}, projection=[constant_col], file_type=parquet +02)--DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=A/0.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=B/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=C/2.parquet]]}, projection=[constant_col], output_ordering=[constant_col@0 ASC NULLS LAST], file_type=parquet diff --git a/datafusion/sqllogictest/test_files/preserve_file_partitioning.slt b/datafusion/sqllogictest/test_files/preserve_file_partitioning.slt index 34c5fd97b51f3..297094fab16e7 100644 --- a/datafusion/sqllogictest/test_files/preserve_file_partitioning.slt +++ b/datafusion/sqllogictest/test_files/preserve_file_partitioning.slt @@ -101,6 +101,29 @@ STORED AS PARQUET; ---- 4 +# Create hive-partitioned dimension table (3 partitions matching fact_table) +# For testing Partitioned joins with matching partition counts +query I +COPY (SELECT 'dev' as env, 'log' as service) +TO 'test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=A/data.parquet' +STORED AS PARQUET; +---- +1 + +query I +COPY (SELECT 'prod' as env, 'log' as service) +TO 'test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=B/data.parquet' +STORED AS PARQUET; +---- +1 + +query I +COPY (SELECT 'prod' as env, 'log' as service) +TO 'test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=C/data.parquet' +STORED AS PARQUET; +---- +1 + # Create high-cardinality fact table (5 partitions > 3 target_partitions) # For testing partition merging with consistent hashing query I @@ -173,6 +196,13 @@ CREATE EXTERNAL TABLE dimension_table (d_dkey STRING, env STRING, service STRING STORED AS PARQUET LOCATION 'test_files/scratch/preserve_file_partitioning/dimension/'; +# Hive-partitioned dimension table (3 partitions matching fact_table for Partitioned join tests) +statement ok +CREATE EXTERNAL TABLE dimension_table_partitioned (env STRING, service STRING) +STORED AS PARQUET +PARTITIONED BY (d_dkey STRING) +LOCATION 'test_files/scratch/preserve_file_partitioning/dimension_partitioned/'; + # 'High'-cardinality fact table (5 partitions > 3 target_partitions) statement ok CREATE EXTERNAL TABLE high_cardinality_table (timestamp TIMESTAMP, value DOUBLE) @@ -579,6 +609,101 @@ C 1 300 D 1 400 E 1 500 +########## +# TEST 11: Partitioned Join with Matching Partition Counts - Without Optimization +# fact_table (3 partitions) joins dimension_table_partitioned (3 partitions) +# Shows RepartitionExec added when preserve_file_partitions is disabled +########## + +statement ok +set datafusion.optimizer.preserve_file_partitions = 0; + +# Force Partitioned join mode (not CollectLeft) +statement ok +set datafusion.optimizer.hash_join_single_partition_threshold = 0; + +statement ok +set datafusion.optimizer.hash_join_single_partition_threshold_rows = 0; + +query TT +EXPLAIN SELECT f.f_dkey, d.env, sum(f.value) +FROM fact_table f +INNER JOIN dimension_table_partitioned d ON f.f_dkey = d.d_dkey +GROUP BY f.f_dkey, d.env; +---- +logical_plan +01)Aggregate: groupBy=[[f.f_dkey, d.env]], aggr=[[sum(f.value)]] +02)--Projection: f.value, f.f_dkey, d.env +03)----Inner Join: f.f_dkey = d.d_dkey +04)------SubqueryAlias: f +05)--------TableScan: fact_table projection=[value, f_dkey] +06)------SubqueryAlias: d +07)--------TableScan: dimension_table_partitioned projection=[env, d_dkey] +physical_plan +01)AggregateExec: mode=FinalPartitioned, gby=[f_dkey@0 as f_dkey, env@1 as env], aggr=[sum(f.value)] +02)--RepartitionExec: partitioning=Hash([f_dkey@0, env@1], 3), input_partitions=3 +03)----AggregateExec: mode=Partial, gby=[f_dkey@1 as f_dkey, env@2 as env], aggr=[sum(f.value)] +04)------ProjectionExec: expr=[value@1 as value, f_dkey@2 as f_dkey, env@0 as env] +05)--------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(d_dkey@1, f_dkey@1)], projection=[env@0, value@2, f_dkey@3] +06)----------RepartitionExec: partitioning=Hash([d_dkey@1], 3), input_partitions=3 +07)------------DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=A/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=B/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=C/data.parquet]]}, projection=[env, d_dkey], file_type=parquet +08)----------RepartitionExec: partitioning=Hash([f_dkey@1], 3), input_partitions=3 +09)------------DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/fact/f_dkey=A/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/fact/f_dkey=B/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/fact/f_dkey=C/data.parquet]]}, projection=[value, f_dkey], file_type=parquet, predicate=DynamicFilter [ empty ] + +query TTR rowsort +SELECT f.f_dkey, d.env, sum(f.value) +FROM fact_table f +INNER JOIN dimension_table_partitioned d ON f.f_dkey = d.d_dkey +GROUP BY f.f_dkey, d.env; +---- +A dev 772.4 +B prod 614.4 +C prod 2017.6 + +########## +# TEST 12: Partitioned Join with Matching Partition Counts - With Optimization +# Both tables have 3 partitions matching target_partitions=3 +# No RepartitionExec needed for join - partitions already satisfy the requirement +# Dynamic filter pushdown is disabled in this mode because preserve_file_partitions +# reports Hash partitioning for Hive-style file groups, which are not hash-routed. +########## + +statement ok +set datafusion.optimizer.preserve_file_partitions = 1; + +query TT +EXPLAIN SELECT f.f_dkey, d.env, sum(f.value) +FROM fact_table f +INNER JOIN dimension_table_partitioned d ON f.f_dkey = d.d_dkey +GROUP BY f.f_dkey, d.env; +---- +logical_plan +01)Aggregate: groupBy=[[f.f_dkey, d.env]], aggr=[[sum(f.value)]] +02)--Projection: f.value, f.f_dkey, d.env +03)----Inner Join: f.f_dkey = d.d_dkey +04)------SubqueryAlias: f +05)--------TableScan: fact_table projection=[value, f_dkey] +06)------SubqueryAlias: d +07)--------TableScan: dimension_table_partitioned projection=[env, d_dkey] +physical_plan +01)AggregateExec: mode=FinalPartitioned, gby=[f_dkey@0 as f_dkey, env@1 as env], aggr=[sum(f.value)] +02)--RepartitionExec: partitioning=Hash([f_dkey@0, env@1], 3), input_partitions=3 +03)----AggregateExec: mode=Partial, gby=[f_dkey@1 as f_dkey, env@2 as env], aggr=[sum(f.value)] +04)------ProjectionExec: expr=[value@1 as value, f_dkey@2 as f_dkey, env@0 as env] +05)--------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(d_dkey@1, f_dkey@1)], projection=[env@0, value@2, f_dkey@3] +06)----------DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=A/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=B/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/dimension_partitioned/d_dkey=C/data.parquet]]}, projection=[env, d_dkey], file_type=parquet +07)----------DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/fact/f_dkey=A/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/fact/f_dkey=B/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/preserve_file_partitioning/fact/f_dkey=C/data.parquet]]}, projection=[value, f_dkey], file_type=parquet + +query TTR rowsort +SELECT f.f_dkey, d.env, sum(f.value) +FROM fact_table f +INNER JOIN dimension_table_partitioned d ON f.f_dkey = d.d_dkey +GROUP BY f.f_dkey, d.env; +---- +A dev 772.4 +B prod 614.4 +C prod 2017.6 + ########## # CLEANUP ########## @@ -592,5 +717,8 @@ DROP TABLE fact_table_ordered; statement ok DROP TABLE dimension_table; +statement ok +DROP TABLE dimension_table_partitioned; + statement ok DROP TABLE high_cardinality_table; diff --git a/datafusion/sqllogictest/test_files/projection.slt b/datafusion/sqllogictest/test_files/projection.slt index 5a4411233424a..e18114bc51ca8 100644 --- a/datafusion/sqllogictest/test_files/projection.slt +++ b/datafusion/sqllogictest/test_files/projection.slt @@ -167,12 +167,12 @@ set datafusion.explain.logical_plan_only = false # project cast dictionary query T -SELECT - CASE +SELECT + CASE WHEN cpu_load_short.host IS NULL THEN '' ELSE cpu_load_short.host END AS host -FROM +FROM cpu_load_short; ---- host1 @@ -275,7 +275,6 @@ logical_plan 02)--Filter: t1.a > Int64(1) 03)----TableScan: t1 projection=[a], partial_filters=[t1.a > Int64(1)] physical_plan -01)ProjectionExec: expr=[] -02)--FilterExec: a@0 > 1 -03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection/17513.parquet]]}, projection=[a], file_type=parquet, predicate=a@0 > 1, pruning_predicate=a_null_count@1 != row_count@2 AND a_max@0 > 1, required_guarantees=[] +01)FilterExec: a@0 > 1, projection=[] +02)--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection/17513.parquet]]}, projection=[a], file_type=parquet, predicate=a@0 > 1, pruning_predicate=a_null_count@1 != row_count@2 AND a_max@0 > 1, required_guarantees=[] diff --git a/datafusion/sqllogictest/test_files/projection_pushdown.slt b/datafusion/sqllogictest/test_files/projection_pushdown.slt new file mode 100644 index 0000000000000..c25b80a0d7f20 --- /dev/null +++ b/datafusion/sqllogictest/test_files/projection_pushdown.slt @@ -0,0 +1,1951 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +# Tests for projection pushdown behavior with get_field expressions +# +# This file tests the ExtractTrivialProjections optimizer rule and +# physical projection pushdown for: +# - get_field expressions (struct field access like s['foo']) +# - Pushdown through Filter, Sort, and TopK operators +# - Multi-partition scenarios with SortPreservingMergeExec +########## + +##################### +# Section 1: Setup - Single Partition Tests +##################### + +# Set target_partitions = 1 for deterministic plan output +statement ok +SET datafusion.execution.target_partitions = 1; + +# Create parquet file with struct column containing value and label fields +statement ok +COPY ( + SELECT + column1 as id, + column2 as s + FROM VALUES + (1, {value: 100, label: 'alpha'}), + (2, {value: 200, label: 'beta'}), + (3, {value: 150, label: 'gamma'}), + (4, {value: 300, label: 'delta'}), + (5, {value: 250, label: 'epsilon'}) +) TO 'test_files/scratch/projection_pushdown/simple.parquet' +STORED AS PARQUET; + +# Create table for simple struct tests +statement ok +CREATE EXTERNAL TABLE simple_struct STORED AS PARQUET +LOCATION 'test_files/scratch/projection_pushdown/simple.parquet'; + +# Create parquet file with nested struct column +statement ok +COPY ( + SELECT + column1 as id, + column2 as nested + FROM VALUES + (1, {outer: {inner: 10, name: 'one'}, extra: 'x'}), + (2, {outer: {inner: 20, name: 'two'}, extra: 'y'}), + (3, {outer: {inner: 30, name: 'three'}, extra: 'z'}) +) TO 'test_files/scratch/projection_pushdown/nested.parquet' +STORED AS PARQUET; + +# Create table for nested struct tests +statement ok +CREATE EXTERNAL TABLE nested_struct STORED AS PARQUET +LOCATION 'test_files/scratch/projection_pushdown/nested.parquet'; + +# Create parquet file with nullable struct column +statement ok +COPY ( + SELECT + column1 as id, + column2 as s + FROM VALUES + (1, {value: 100, label: 'alpha'}), + (2, NULL), + (3, {value: 150, label: 'gamma'}), + (4, NULL), + (5, {value: 250, label: 'epsilon'}) +) TO 'test_files/scratch/projection_pushdown/nullable.parquet' +STORED AS PARQUET; + +# Create table for nullable struct tests +statement ok +CREATE EXTERNAL TABLE nullable_struct STORED AS PARQUET +LOCATION 'test_files/scratch/projection_pushdown/nullable.parquet'; + + +##################### +# Section 2: Basic get_field Pushdown (Projection above scan) +##################### + +### +# Test 2.1: Simple s['value'] +### + +query TT +EXPLAIN SELECT id, s['value'] FROM simple_struct; +---- +logical_plan +01)Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")) +02)--TableScan: simple_struct projection=[id, s] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) as simple_struct.s[value]], file_type=parquet + +# Verify correctness +query II +SELECT id, s['value'] FROM simple_struct ORDER BY id; +---- +1 100 +2 200 +3 150 +4 300 +5 250 + +query TT +EXPLAIN SELECT s['label'] FROM simple_struct; +---- +logical_plan +01)Projection: get_field(simple_struct.s, Utf8("label")) +02)--TableScan: simple_struct projection=[s] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, label) as simple_struct.s[label]], file_type=parquet + +# Verify correctness +query T +SELECT s['label'] FROM simple_struct ORDER BY s['label']; +---- +alpha +beta +delta +epsilon +gamma + +### +# Test 2.2: Multiple get_field expressions +### + +query TT +EXPLAIN SELECT id, s['value'], s['label'] FROM simple_struct; +---- +logical_plan +01)Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")), get_field(simple_struct.s, Utf8("label")) +02)--TableScan: simple_struct projection=[id, s] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) as simple_struct.s[value], get_field(s@1, label) as simple_struct.s[label]], file_type=parquet + +# Verify correctness +query IIT +SELECT id, s['value'], s['label'] FROM simple_struct ORDER BY id; +---- +1 100 alpha +2 200 beta +3 150 gamma +4 300 delta +5 250 epsilon + +### +# Test 2.3: Nested s['outer']['inner'] +### + +query TT +EXPLAIN SELECT id, nested['outer']['inner'] FROM nested_struct; +---- +logical_plan +01)Projection: nested_struct.id, get_field(nested_struct.nested, Utf8("outer"), Utf8("inner")) +02)--TableScan: nested_struct projection=[id, nested] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/nested.parquet]]}, projection=[id, get_field(nested@1, outer, inner) as nested_struct.nested[outer][inner]], file_type=parquet + +# Verify correctness +query II +SELECT id, nested['outer']['inner'] FROM nested_struct ORDER BY id; +---- +1 10 +2 20 +3 30 + +### +# Test 2.4: s['value'] + 1 +### + +query TT +EXPLAIN SELECT id, s['value'] + 1 FROM simple_struct; +---- +logical_plan +01)Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")) + Int64(1) +02)--TableScan: simple_struct projection=[id, s] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) + 1 as simple_struct.s[value] + Int64(1)], file_type=parquet + +# Verify correctness +query II +SELECT id, s['value'] + 1 FROM simple_struct ORDER BY id; +---- +1 101 +2 201 +3 151 +4 301 +5 251 + +### +# Test 2.5: s['label'] || '_suffix' +### + +query TT +EXPLAIN SELECT id, s['label'] || '_suffix' FROM simple_struct; +---- +logical_plan +01)Projection: simple_struct.id, get_field(simple_struct.s, Utf8("label")) || Utf8("_suffix") +02)--TableScan: simple_struct projection=[id, s] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, label) || _suffix as simple_struct.s[label] || Utf8("_suffix")], file_type=parquet + +# Verify correctness +query IT +SELECT id, s['label'] || '_suffix' FROM simple_struct ORDER BY id; +---- +1 alpha_suffix +2 beta_suffix +3 gamma_suffix +4 delta_suffix +5 epsilon_suffix + + +##################### +# Section 3: Projection Pushdown Through FilterExec +##################### + +### +# Test 3.1: Simple get_field through Filter +### + +query TT +EXPLAIN SELECT id, s['value'] FROM simple_struct WHERE id > 2; +---- +logical_plan +01)Projection: simple_struct.id, __datafusion_extracted_1 AS simple_struct.s[value] +02)--Filter: simple_struct.id > Int64(2) +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +04)------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(2)] +physical_plan +01)ProjectionExec: expr=[id@1 as id, __datafusion_extracted_1@0 as simple_struct.s[value]] +02)--FilterExec: id@1 > 2 +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet, predicate=id@0 > 2, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 2, required_guarantees=[] + +# Verify correctness +query II +SELECT id, s['value'] FROM simple_struct WHERE id > 2 ORDER BY id; +---- +3 150 +4 300 +5 250 + +### +# Test 3.2: s['value'] + 1 through Filter +### + +query TT +EXPLAIN SELECT id, s['value'] + 1 FROM simple_struct WHERE id > 2; +---- +logical_plan +01)Projection: simple_struct.id, __datafusion_extracted_1 + Int64(1) AS simple_struct.s[value] + Int64(1) +02)--Filter: simple_struct.id > Int64(2) +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +04)------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(2)] +physical_plan +01)ProjectionExec: expr=[id@1 as id, __datafusion_extracted_1@0 + 1 as simple_struct.s[value] + Int64(1)] +02)--FilterExec: id@1 > 2 +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet, predicate=id@0 > 2, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 2, required_guarantees=[] + +# Verify correctness +query II +SELECT id, s['value'] + 1 FROM simple_struct WHERE id > 2 ORDER BY id; +---- +3 151 +4 301 +5 251 + +### +# Test 3.3: Filter on get_field expression +### + +query TT +EXPLAIN SELECT id, s['label'] FROM simple_struct WHERE s['value'] > 150; +---- +logical_plan +01)Projection: simple_struct.id, __datafusion_extracted_2 AS simple_struct.s[label] +02)--Filter: __datafusion_extracted_1 > Int64(150) +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id, get_field(simple_struct.s, Utf8("label")) AS __datafusion_extracted_2 +04)------TableScan: simple_struct projection=[id, s], partial_filters=[get_field(simple_struct.s, Utf8("value")) > Int64(150)] +physical_plan +01)ProjectionExec: expr=[id@0 as id, __datafusion_extracted_2@1 as simple_struct.s[label]] +02)--FilterExec: __datafusion_extracted_1@0 > 150, projection=[id@1, __datafusion_extracted_2@2] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id, get_field(s@1, label) as __datafusion_extracted_2], file_type=parquet + +# Verify correctness +query IT +SELECT id, s['label'] FROM simple_struct WHERE s['value'] > 150 ORDER BY id; +---- +2 beta +4 delta +5 epsilon + + +##################### +# Section 4: Projection Pushdown Through SortExec (no LIMIT) +##################### + +### +# Test 4.1: Simple get_field through Sort +### + +query TT +EXPLAIN SELECT id, s['value'] FROM simple_struct ORDER BY id; +---- +logical_plan +01)Sort: simple_struct.id ASC NULLS LAST +02)--Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")) +03)----TableScan: simple_struct projection=[id, s] +physical_plan +01)SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) as simple_struct.s[value]], file_type=parquet + +# Verify correctness +query II +SELECT id, s['value'] FROM simple_struct ORDER BY id; +---- +1 100 +2 200 +3 150 +4 300 +5 250 + +### +# Test 4.2: s['value'] + 1 through Sort - split projection +### + +query TT +EXPLAIN SELECT id, s['value'] + 1 FROM simple_struct ORDER BY id; +---- +logical_plan +01)Sort: simple_struct.id ASC NULLS LAST +02)--Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")) + Int64(1) +03)----TableScan: simple_struct projection=[id, s] +physical_plan +01)SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) + 1 as simple_struct.s[value] + Int64(1)], file_type=parquet + +# Verify correctness +query II +SELECT id, s['value'] + 1 FROM simple_struct ORDER BY id; +---- +1 101 +2 201 +3 151 +4 301 +5 251 + +### +# Test 4.3: Sort by get_field expression +### + +query TT +EXPLAIN SELECT id, s['value'] FROM simple_struct ORDER BY s['value']; +---- +logical_plan +01)Sort: simple_struct.s[value] ASC NULLS LAST +02)--Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")) +03)----TableScan: simple_struct projection=[id, s] +physical_plan +01)SortExec: expr=[simple_struct.s[value]@1 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) as simple_struct.s[value]], file_type=parquet + +# Verify correctness +query II +SELECT id, s['value'] FROM simple_struct ORDER BY s['value']; +---- +1 100 +3 150 +2 200 +5 250 +4 300 + +### +# Test 4.4: Projection with duplicate column through Sort +# The projection expands the number of columns from 3 to 4 by introducing `col_b_dup` +### + +statement ok +COPY ( + SELECT + column1 as col_a, + column2 as col_b, + column3 as col_c + FROM VALUES + (1, 2, 3), + (4, 5, 6), + (7, 8, 9) +) TO 'test_files/scratch/projection_pushdown/three_cols.parquet' +STORED AS PARQUET; + +statement ok +CREATE EXTERNAL TABLE three_cols STORED AS PARQUET +LOCATION 'test_files/scratch/projection_pushdown/three_cols.parquet'; + +query TT +EXPLAIN SELECT col_a, col_b, col_c, col_b as col_b_dup FROM three_cols ORDER BY col_a; +---- +logical_plan +01)Sort: three_cols.col_a ASC NULLS LAST +02)--Projection: three_cols.col_a, three_cols.col_b, three_cols.col_c, three_cols.col_b AS col_b_dup +03)----TableScan: three_cols projection=[col_a, col_b, col_c] +physical_plan +01)SortExec: expr=[col_a@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/three_cols.parquet]]}, projection=[col_a, col_b, col_c, col_b@1 as col_b_dup], file_type=parquet + +# Verify correctness +query IIII +SELECT col_a, col_b, col_c, col_b as col_b_dup FROM three_cols ORDER BY col_a DESC; +---- +7 8 9 8 +4 5 6 5 +1 2 3 2 + +statement ok +DROP TABLE three_cols; + + +##################### +# Section 5: Projection Pushdown Through TopK (ORDER BY + LIMIT) +##################### + +### +# Test 5.1: Simple get_field through TopK +### + +query TT +EXPLAIN SELECT id, s['value'] FROM simple_struct ORDER BY id LIMIT 3; +---- +logical_plan +01)Sort: simple_struct.id ASC NULLS LAST, fetch=3 +02)--Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")) +03)----TableScan: simple_struct projection=[id, s] +physical_plan +01)SortExec: TopK(fetch=3), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) as simple_struct.s[value]], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query II +SELECT id, s['value'] FROM simple_struct ORDER BY id LIMIT 3; +---- +1 100 +2 200 +3 150 + +### +# Test 5.2: s['value'] + 1 through TopK +### + +query TT +EXPLAIN SELECT id, s['value'] + 1 FROM simple_struct ORDER BY id LIMIT 3; +---- +logical_plan +01)Sort: simple_struct.id ASC NULLS LAST, fetch=3 +02)--Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")) + Int64(1) +03)----TableScan: simple_struct projection=[id, s] +physical_plan +01)SortExec: TopK(fetch=3), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) + 1 as simple_struct.s[value] + Int64(1)], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query II +SELECT id, s['value'] + 1 FROM simple_struct ORDER BY id LIMIT 3; +---- +1 101 +2 201 +3 151 + +### +# Test 5.3: Multiple get_field through TopK +### + +query TT +EXPLAIN SELECT id, s['value'], s['label'] FROM simple_struct ORDER BY id LIMIT 3; +---- +logical_plan +01)Sort: simple_struct.id ASC NULLS LAST, fetch=3 +02)--Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")), get_field(simple_struct.s, Utf8("label")) +03)----TableScan: simple_struct projection=[id, s] +physical_plan +01)SortExec: TopK(fetch=3), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) as simple_struct.s[value], get_field(s@1, label) as simple_struct.s[label]], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query IIT +SELECT id, s['value'], s['label'] FROM simple_struct ORDER BY id LIMIT 3; +---- +1 100 alpha +2 200 beta +3 150 gamma + +### +# Test 5.4: Nested get_field through TopK +### + +query TT +EXPLAIN SELECT id, nested['outer']['inner'] FROM nested_struct ORDER BY id LIMIT 2; +---- +logical_plan +01)Sort: nested_struct.id ASC NULLS LAST, fetch=2 +02)--Projection: nested_struct.id, get_field(nested_struct.nested, Utf8("outer"), Utf8("inner")) +03)----TableScan: nested_struct projection=[id, nested] +physical_plan +01)SortExec: TopK(fetch=2), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/nested.parquet]]}, projection=[id, get_field(nested@1, outer, inner) as nested_struct.nested[outer][inner]], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query II +SELECT id, nested['outer']['inner'] FROM nested_struct ORDER BY id LIMIT 2; +---- +1 10 +2 20 + +### +# Test 5.5: String concat through TopK +### + +query TT +EXPLAIN SELECT id, s['label'] || '_suffix' FROM simple_struct ORDER BY id LIMIT 3; +---- +logical_plan +01)Sort: simple_struct.id ASC NULLS LAST, fetch=3 +02)--Projection: simple_struct.id, get_field(simple_struct.s, Utf8("label")) || Utf8("_suffix") +03)----TableScan: simple_struct projection=[id, s] +physical_plan +01)SortExec: TopK(fetch=3), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, label) || _suffix as simple_struct.s[label] || Utf8("_suffix")], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query IT +SELECT id, s['label'] || '_suffix' FROM simple_struct ORDER BY id LIMIT 3; +---- +1 alpha_suffix +2 beta_suffix +3 gamma_suffix + + +##################### +# Section 6: Combined Operators (Filter + Sort/TopK) +##################### + +### +# Test 6.1: Filter + Sort + get_field +### + +query TT +EXPLAIN SELECT id, s['value'] FROM simple_struct WHERE id > 1 ORDER BY s['value']; +---- +logical_plan +01)Sort: simple_struct.s[value] ASC NULLS LAST +02)--Projection: simple_struct.id, __datafusion_extracted_1 AS simple_struct.s[value] +03)----Filter: simple_struct.id > Int64(1) +04)------Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +05)--------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(1)] +physical_plan +01)SortExec: expr=[simple_struct.s[value]@1 ASC NULLS LAST], preserve_partitioning=[false] +02)--ProjectionExec: expr=[id@1 as id, __datafusion_extracted_1@0 as simple_struct.s[value]] +03)----FilterExec: id@1 > 1 +04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet, predicate=id@0 > 1, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1, required_guarantees=[] + +# Verify correctness +query II +SELECT id, s['value'] FROM simple_struct WHERE id > 1 ORDER BY s['value']; +---- +3 150 +2 200 +5 250 +4 300 + +### +# Test 6.2: Filter + TopK + get_field +### + +query TT +EXPLAIN SELECT id, s['value'] FROM simple_struct WHERE id > 1 ORDER BY s['value'] LIMIT 2; +---- +logical_plan +01)Sort: simple_struct.s[value] ASC NULLS LAST, fetch=2 +02)--Projection: simple_struct.id, __datafusion_extracted_1 AS simple_struct.s[value] +03)----Filter: simple_struct.id > Int64(1) +04)------Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +05)--------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(1)] +physical_plan +01)SortExec: TopK(fetch=2), expr=[simple_struct.s[value]@1 ASC NULLS LAST], preserve_partitioning=[false] +02)--ProjectionExec: expr=[id@1 as id, __datafusion_extracted_1@0 as simple_struct.s[value]] +03)----FilterExec: id@1 > 1 +04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet, predicate=id@0 > 1, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1, required_guarantees=[] + +# Verify correctness +query II +SELECT id, s['value'] FROM simple_struct WHERE id > 1 ORDER BY s['value'] LIMIT 2; +---- +3 150 +2 200 + +### +# Test 6.3: Filter + TopK + get_field with arithmetic +### + +query TT +EXPLAIN SELECT id, s['value'] + 1 FROM simple_struct WHERE id > 1 ORDER BY id LIMIT 2; +---- +logical_plan +01)Sort: simple_struct.id ASC NULLS LAST, fetch=2 +02)--Projection: simple_struct.id, __datafusion_extracted_1 + Int64(1) AS simple_struct.s[value] + Int64(1) +03)----Filter: simple_struct.id > Int64(1) +04)------Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +05)--------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(1)] +physical_plan +01)SortExec: TopK(fetch=2), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--ProjectionExec: expr=[id@1 as id, __datafusion_extracted_1@0 + 1 as simple_struct.s[value] + Int64(1)] +03)----FilterExec: id@1 > 1 +04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet, predicate=id@0 > 1 AND DynamicFilter [ empty ], pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1, required_guarantees=[] + +# Verify correctness +query II +SELECT id, s['value'] + 1 FROM simple_struct WHERE id > 1 ORDER BY id LIMIT 2; +---- +2 201 +3 151 + + +##################### +# Section 7: Multi-Partition Tests +##################### + +# Set target_partitions = 4 for parallel execution +statement ok +SET datafusion.execution.target_partitions = 4; + +# Create 5 parquet files (more than partitions) for parallel tests +statement ok +COPY (SELECT 1 as id, {value: 100, label: 'alpha'} as s) +TO 'test_files/scratch/projection_pushdown/multi/part1.parquet' +STORED AS PARQUET; + +statement ok +COPY (SELECT 2 as id, {value: 200, label: 'beta'} as s) +TO 'test_files/scratch/projection_pushdown/multi/part2.parquet' +STORED AS PARQUET; + +statement ok +COPY (SELECT 3 as id, {value: 150, label: 'gamma'} as s) +TO 'test_files/scratch/projection_pushdown/multi/part3.parquet' +STORED AS PARQUET; + +statement ok +COPY (SELECT 4 as id, {value: 300, label: 'delta'} as s) +TO 'test_files/scratch/projection_pushdown/multi/part4.parquet' +STORED AS PARQUET; + +statement ok +COPY (SELECT 5 as id, {value: 250, label: 'epsilon'} as s) +TO 'test_files/scratch/projection_pushdown/multi/part5.parquet' +STORED AS PARQUET; + +# Create table from multiple parquet files +statement ok +CREATE EXTERNAL TABLE multi_struct STORED AS PARQUET +LOCATION 'test_files/scratch/projection_pushdown/multi/'; + +### +# Test 7.1: Multi-partition Sort with get_field +### + +query TT +EXPLAIN SELECT id, s['value'] FROM multi_struct ORDER BY id; +---- +logical_plan +01)Sort: multi_struct.id ASC NULLS LAST +02)--Projection: multi_struct.id, get_field(multi_struct.s, Utf8("value")) +03)----TableScan: multi_struct projection=[id, s] +physical_plan +01)SortPreservingMergeExec: [id@0 ASC NULLS LAST] +02)--SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part1.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part2.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part3.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part4.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part5.parquet]]}, projection=[id, get_field(s@1, value) as multi_struct.s[value]], file_type=parquet + +# Verify correctness +query II +SELECT id, s['value'] FROM multi_struct ORDER BY id; +---- +1 100 +2 200 +3 150 +4 300 +5 250 + +### +# Test 7.2: Multi-partition TopK with get_field +### + +query TT +EXPLAIN SELECT id, s['value'] FROM multi_struct ORDER BY id LIMIT 3; +---- +logical_plan +01)Sort: multi_struct.id ASC NULLS LAST, fetch=3 +02)--Projection: multi_struct.id, get_field(multi_struct.s, Utf8("value")) +03)----TableScan: multi_struct projection=[id, s] +physical_plan +01)SortPreservingMergeExec: [id@0 ASC NULLS LAST], fetch=3 +02)--SortExec: TopK(fetch=3), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part1.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part2.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part3.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part4.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part5.parquet]]}, projection=[id, get_field(s@1, value) as multi_struct.s[value]], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query II +SELECT id, s['value'] FROM multi_struct ORDER BY id LIMIT 3; +---- +1 100 +2 200 +3 150 + +### +# Test 7.3: Multi-partition TopK with arithmetic (non-trivial stays above merge) +### + +query TT +EXPLAIN SELECT id, s['value'] + 1 FROM multi_struct ORDER BY id LIMIT 3; +---- +logical_plan +01)Sort: multi_struct.id ASC NULLS LAST, fetch=3 +02)--Projection: multi_struct.id, get_field(multi_struct.s, Utf8("value")) + Int64(1) +03)----TableScan: multi_struct projection=[id, s] +physical_plan +01)SortPreservingMergeExec: [id@0 ASC NULLS LAST], fetch=3 +02)--SortExec: TopK(fetch=3), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part1.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part2.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part3.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part4.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part5.parquet]]}, projection=[id, get_field(s@1, value) + 1 as multi_struct.s[value] + Int64(1)], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query II +SELECT id, s['value'] + 1 FROM multi_struct ORDER BY id LIMIT 3; +---- +1 101 +2 201 +3 151 + +### +# Test 7.4: Multi-partition Filter with get_field +### + +query TT +EXPLAIN SELECT id, s['value'] FROM multi_struct WHERE id > 2 ORDER BY id; +---- +logical_plan +01)Sort: multi_struct.id ASC NULLS LAST +02)--Projection: multi_struct.id, __datafusion_extracted_1 AS multi_struct.s[value] +03)----Filter: multi_struct.id > Int64(2) +04)------Projection: get_field(multi_struct.s, Utf8("value")) AS __datafusion_extracted_1, multi_struct.id +05)--------TableScan: multi_struct projection=[id, s], partial_filters=[multi_struct.id > Int64(2)] +physical_plan +01)SortPreservingMergeExec: [id@0 ASC NULLS LAST] +02)--SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[id@1 as id, __datafusion_extracted_1@0 as multi_struct.s[value]] +04)------FilterExec: id@1 > 2 +05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=3 +06)----------DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part1.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part2.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part3.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part4.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part5.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet, predicate=id@0 > 2, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 2, required_guarantees=[] + +# Verify correctness +query II +SELECT id, s['value'] FROM multi_struct WHERE id > 2 ORDER BY id; +---- +3 150 +4 300 +5 250 + +### +# Test 7.5: Aggregation with get_field (CoalescePartitions) +### + +query TT +EXPLAIN SELECT s['label'], SUM(s['value']) FROM multi_struct GROUP BY s['label']; +---- +logical_plan +01)Projection: __datafusion_extracted_1 AS multi_struct.s[label], sum(__datafusion_extracted_2) AS sum(multi_struct.s[value]) +02)--Aggregate: groupBy=[[__datafusion_extracted_1]], aggr=[[sum(__datafusion_extracted_2)]] +03)----Projection: get_field(multi_struct.s, Utf8("label")) AS __datafusion_extracted_1, get_field(multi_struct.s, Utf8("value")) AS __datafusion_extracted_2 +04)------TableScan: multi_struct projection=[s] +physical_plan +01)ProjectionExec: expr=[__datafusion_extracted_1@0 as multi_struct.s[label], sum(__datafusion_extracted_2)@1 as sum(multi_struct.s[value])] +02)--AggregateExec: mode=FinalPartitioned, gby=[__datafusion_extracted_1@0 as __datafusion_extracted_1], aggr=[sum(__datafusion_extracted_2)] +03)----RepartitionExec: partitioning=Hash([__datafusion_extracted_1@0], 4), input_partitions=3 +04)------AggregateExec: mode=Partial, gby=[__datafusion_extracted_1@0 as __datafusion_extracted_1], aggr=[sum(__datafusion_extracted_2)] +05)--------DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part1.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part2.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part3.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part4.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part5.parquet]]}, projection=[get_field(s@1, label) as __datafusion_extracted_1, get_field(s@1, value) as __datafusion_extracted_2], file_type=parquet + +# Verify correctness +query TI +SELECT s['label'], SUM(s['value']) FROM multi_struct GROUP BY s['label'] ORDER BY s['label']; +---- +alpha 100 +beta 200 +delta 300 +epsilon 250 +gamma 150 + + +##################### +# Section 8: Edge Cases +##################### + +# Reset to single partition for edge case tests +statement ok +SET datafusion.execution.target_partitions = 1; + +### +# Test 8.1: get_field on nullable struct column +### + +query TT +EXPLAIN SELECT id, s['value'] FROM nullable_struct; +---- +logical_plan +01)Projection: nullable_struct.id, get_field(nullable_struct.s, Utf8("value")) +02)--TableScan: nullable_struct projection=[id, s] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/nullable.parquet]]}, projection=[id, get_field(s@1, value) as nullable_struct.s[value]], file_type=parquet + +# Verify correctness (NULL struct returns NULL field) +query II +SELECT id, s['value'] FROM nullable_struct ORDER BY id; +---- +1 100 +2 NULL +3 150 +4 NULL +5 250 + +### +# Test 8.2: get_field returning NULL values +### + +query TT +EXPLAIN SELECT id, s['label'] FROM nullable_struct WHERE s['value'] IS NOT NULL; +---- +logical_plan +01)Projection: nullable_struct.id, __datafusion_extracted_2 AS nullable_struct.s[label] +02)--Filter: __datafusion_extracted_1 IS NOT NULL +03)----Projection: get_field(nullable_struct.s, Utf8("value")) AS __datafusion_extracted_1, nullable_struct.id, get_field(nullable_struct.s, Utf8("label")) AS __datafusion_extracted_2 +04)------TableScan: nullable_struct projection=[id, s], partial_filters=[get_field(nullable_struct.s, Utf8("value")) IS NOT NULL] +physical_plan +01)ProjectionExec: expr=[id@0 as id, __datafusion_extracted_2@1 as nullable_struct.s[label]] +02)--FilterExec: __datafusion_extracted_1@0 IS NOT NULL, projection=[id@1, __datafusion_extracted_2@2] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/nullable.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id, get_field(s@1, label) as __datafusion_extracted_2], file_type=parquet + +# Verify correctness +query IT +SELECT id, s['label'] FROM nullable_struct WHERE s['value'] IS NOT NULL ORDER BY id; +---- +1 alpha +3 gamma +5 epsilon + +### +# Test 8.3: Mixed trivial and non-trivial in same projection +### + +query TT +EXPLAIN SELECT id, s['value'], s['value'] + 10, s['label'] FROM simple_struct ORDER BY id LIMIT 3; +---- +logical_plan +01)Sort: simple_struct.id ASC NULLS LAST, fetch=3 +02)--Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")), get_field(simple_struct.s, Utf8("value")) + Int64(10), get_field(simple_struct.s, Utf8("label")) +03)----TableScan: simple_struct projection=[id, s] +physical_plan +01)SortExec: TopK(fetch=3), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) as simple_struct.s[value], get_field(s@1, value) + 10 as simple_struct.s[value] + Int64(10), get_field(s@1, label) as simple_struct.s[label]], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query IIIT +SELECT id, s['value'], s['value'] + 10, s['label'] FROM simple_struct ORDER BY id LIMIT 3; +---- +1 100 110 alpha +2 200 210 beta +3 150 160 gamma + +### +# Test 8.4: Literal projection through TopK +### + +query TT +EXPLAIN SELECT id, 42 as constant FROM simple_struct ORDER BY id LIMIT 3; +---- +logical_plan +01)Sort: simple_struct.id ASC NULLS LAST, fetch=3 +02)--Projection: simple_struct.id, Int64(42) AS constant +03)----TableScan: simple_struct projection=[id] +physical_plan +01)SortExec: TopK(fetch=3), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, 42 as constant], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query II +SELECT id, 42 as constant FROM simple_struct ORDER BY id LIMIT 3; +---- +1 42 +2 42 +3 42 + +### +# Test 8.5: Simple column through TopK (baseline comparison) +### + +query TT +EXPLAIN SELECT id FROM simple_struct ORDER BY id LIMIT 3; +---- +logical_plan +01)Sort: simple_struct.id ASC NULLS LAST, fetch=3 +02)--TableScan: simple_struct projection=[id] +physical_plan +01)SortExec: TopK(fetch=3), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query I +SELECT id FROM simple_struct ORDER BY id LIMIT 3; +---- +1 +2 +3 + + +##################### +# Section 9: Coverage Tests - Edge Cases for Uncovered Code Paths +##################### + +### +# Test 9.1: TopK with computed projection +### + +query TT +EXPLAIN SELECT id, id + 100 as computed FROM simple_struct ORDER BY id LIMIT 3; +---- +logical_plan +01)Sort: simple_struct.id ASC NULLS LAST, fetch=3 +02)--Projection: simple_struct.id, simple_struct.id + Int64(100) AS computed +03)----TableScan: simple_struct projection=[id] +physical_plan +01)SortExec: TopK(fetch=3), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, id@0 + 100 as computed], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query II +SELECT id, id + 100 as computed FROM simple_struct ORDER BY id LIMIT 3; +---- +1 101 +2 102 +3 103 + +### +# Test 9.2: Duplicate get_field expressions (same expression referenced twice) +# Common subexpression elimination happens in the logical plan, and the physical +# plan extracts the shared get_field for efficient computation +### + +query TT +EXPLAIN SELECT (id + s['value']) * (id + s['value']) as id_and_value FROM simple_struct WHERE id > 2; +---- +logical_plan +01)Projection: __common_expr_1 * __common_expr_1 AS id_and_value +02)--Projection: simple_struct.id + __datafusion_extracted_2 AS __common_expr_1 +03)----Filter: simple_struct.id > Int64(2) +04)------Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_2, simple_struct.id +05)--------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(2)] +physical_plan +01)ProjectionExec: expr=[__common_expr_1@0 * __common_expr_1@0 as id_and_value] +02)--ProjectionExec: expr=[id@1 + __datafusion_extracted_2@0 as __common_expr_1] +03)----FilterExec: id@1 > 2 +04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_2, id], file_type=parquet, predicate=id@0 > 2, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 2, required_guarantees=[] + + +query TT +EXPLAIN SELECT s['value'] + s['value'] as doubled FROM simple_struct WHERE id > 2; +---- +logical_plan +01)Projection: __datafusion_extracted_1 + __datafusion_extracted_1 AS doubled +02)--Filter: simple_struct.id > Int64(2) +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +04)------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(2)] +physical_plan +01)ProjectionExec: expr=[__datafusion_extracted_1@0 + __datafusion_extracted_1@0 as doubled] +02)--FilterExec: id@1 > 2, projection=[__datafusion_extracted_1@0] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet, predicate=id@0 > 2, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 2, required_guarantees=[] + +# Verify correctness +query I +SELECT s['value'] + s['value'] as doubled FROM simple_struct WHERE id > 2 ORDER BY doubled; +---- +300 +500 +600 + +### +# Test 9.3: Projection with only get_field expressions through Filter +### + +query TT +EXPLAIN SELECT s['value'], s['label'] FROM simple_struct WHERE id > 2; +---- +logical_plan +01)Projection: __datafusion_extracted_1 AS simple_struct.s[value], __datafusion_extracted_2 AS simple_struct.s[label] +02)--Filter: simple_struct.id > Int64(2) +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, get_field(simple_struct.s, Utf8("label")) AS __datafusion_extracted_2, simple_struct.id +04)------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(2)] +physical_plan +01)ProjectionExec: expr=[__datafusion_extracted_1@0 as simple_struct.s[value], __datafusion_extracted_2@1 as simple_struct.s[label]] +02)--FilterExec: id@2 > 2, projection=[__datafusion_extracted_1@0, __datafusion_extracted_2@1] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, get_field(s@1, label) as __datafusion_extracted_2, id], file_type=parquet, predicate=id@0 > 2, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 2, required_guarantees=[] + +# Verify correctness +query IT +SELECT s['value'], s['label'] FROM simple_struct WHERE id > 2 ORDER BY s['value']; +---- +150 gamma +250 epsilon +300 delta + +### +# Test 9.4: Mixed column reference with get_field in expression through TopK +# Tests column remapping in finalize_outer_exprs when outer expr references both extracted and original columns +### + +query TT +EXPLAIN SELECT id, s['value'] + id as combined FROM simple_struct ORDER BY id LIMIT 3; +---- +logical_plan +01)Sort: simple_struct.id ASC NULLS LAST, fetch=3 +02)--Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")) + simple_struct.id AS combined +03)----TableScan: simple_struct projection=[id, s] +physical_plan +01)SortExec: TopK(fetch=3), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) + id@0 as combined], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query II +SELECT id, s['value'] + id as combined FROM simple_struct ORDER BY id LIMIT 3; +---- +1 101 +2 202 +3 153 + +### +# Test 9.5: Multiple get_field from same struct in expression through Filter +# Tests extraction when base struct is shared across multiple get_field calls +### + +query TT +EXPLAIN SELECT s['value'] * 2 + length(s['label']) as score FROM simple_struct WHERE id > 1; +---- +logical_plan +01)Projection: __datafusion_extracted_1 * Int64(2) + CAST(character_length(__datafusion_extracted_2) AS Int64) AS score +02)--Filter: simple_struct.id > Int64(1) +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, get_field(simple_struct.s, Utf8("label")) AS __datafusion_extracted_2, simple_struct.id +04)------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(1)] +physical_plan +01)ProjectionExec: expr=[__datafusion_extracted_1@0 * 2 + CAST(character_length(__datafusion_extracted_2@1) AS Int64) as score] +02)--FilterExec: id@2 > 1, projection=[__datafusion_extracted_1@0, __datafusion_extracted_2@1] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, get_field(s@1, label) as __datafusion_extracted_2, id], file_type=parquet, predicate=id@0 > 1, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1, required_guarantees=[] + +# Verify correctness +query I +SELECT s['value'] * 2 + length(s['label']) as score FROM simple_struct WHERE id > 1 ORDER BY score; +---- +305 +404 +507 +605 + + +##################### +# Section 10: Literal with get_field Expressions +##################### + +### +# Test 10.1: Literal constant + get_field in same projection +# Tests projection with both trivial (literal) and get_field expressions +### + +query TT +EXPLAIN SELECT id, 42 as answer, s['label'] FROM simple_struct ORDER BY id LIMIT 2; +---- +logical_plan +01)Sort: simple_struct.id ASC NULLS LAST, fetch=2 +02)--Projection: simple_struct.id, Int64(42) AS answer, get_field(simple_struct.s, Utf8("label")) +03)----TableScan: simple_struct projection=[id, s] +physical_plan +01)SortExec: TopK(fetch=2), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, 42 as answer, get_field(s@1, label) as simple_struct.s[label]], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query IIT +SELECT id, 42 as answer, s['label'] FROM simple_struct ORDER BY id LIMIT 2; +---- +1 42 alpha +2 42 beta + +### +# Test 10.2: Multiple non-trivial get_field expressions together +# Tests arithmetic on one field and string concat on another in same projection +### + +query TT +EXPLAIN SELECT id, s['value'] + 100, s['label'] || '_test' FROM simple_struct ORDER BY id LIMIT 2; +---- +logical_plan +01)Sort: simple_struct.id ASC NULLS LAST, fetch=2 +02)--Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")) + Int64(100), get_field(simple_struct.s, Utf8("label")) || Utf8("_test") +03)----TableScan: simple_struct projection=[id, s] +physical_plan +01)SortExec: TopK(fetch=2), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) + 100 as simple_struct.s[value] + Int64(100), get_field(s@1, label) || _test as simple_struct.s[label] || Utf8("_test")], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query IIT +SELECT id, s['value'] + 100, s['label'] || '_test' FROM simple_struct ORDER BY id LIMIT 2; +---- +1 200 alpha_test +2 300 beta_test + +##################### +# Section 11: FilterExec Projection Pushdown - Handling Predicate Column Requirements +##################### + +query TT +EXPLAIN SELECT id, s['value'] FROM simple_struct WHERE id > 1; +---- +logical_plan +01)Projection: simple_struct.id, __datafusion_extracted_1 AS simple_struct.s[value] +02)--Filter: simple_struct.id > Int64(1) +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +04)------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(1)] +physical_plan +01)ProjectionExec: expr=[id@1 as id, __datafusion_extracted_1@0 as simple_struct.s[value]] +02)--FilterExec: id@1 > 1 +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet, predicate=id@0 > 1, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1, required_guarantees=[] + +# Verify correctness +query II +SELECT id, s['value'] FROM simple_struct WHERE id > 1 ORDER BY id LIMIT 2; +---- +2 200 +3 150 + +query TT +EXPLAIN SELECT s['value'] FROM simple_struct WHERE id > 1 AND (id < 4 OR id = 5); +---- +logical_plan +01)Projection: __datafusion_extracted_1 AS simple_struct.s[value] +02)--Filter: simple_struct.id > Int64(1) AND (simple_struct.id < Int64(4) OR simple_struct.id = Int64(5)) +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +04)------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(1), simple_struct.id < Int64(4) OR simple_struct.id = Int64(5)] +physical_plan +01)ProjectionExec: expr=[__datafusion_extracted_1@0 as simple_struct.s[value]] +02)--FilterExec: id@1 > 1 AND (id@1 < 4 OR id@1 = 5), projection=[__datafusion_extracted_1@0] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet, predicate=id@0 > 1 AND (id@0 < 4 OR id@0 = 5), pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1 AND (id_null_count@1 != row_count@2 AND id_min@3 < 4 OR id_null_count@1 != row_count@2 AND id_min@3 <= 5 AND 5 <= id_max@0), required_guarantees=[] + +# Verify correctness - should return rows where (id > 1) AND ((id < 4) OR (id = 5)) +# That's: id=2,3 (1 1 AND (id < 4 OR id = 5) ORDER BY s['value']; +---- +150 +200 +250 + +query TT +EXPLAIN SELECT s['value'] FROM simple_struct WHERE id > 1 AND id < 5; +---- +logical_plan +01)Projection: __datafusion_extracted_1 AS simple_struct.s[value] +02)--Filter: simple_struct.id > Int64(1) AND simple_struct.id < Int64(5) +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +04)------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(1), simple_struct.id < Int64(5)] +physical_plan +01)ProjectionExec: expr=[__datafusion_extracted_1@0 as simple_struct.s[value]] +02)--FilterExec: id@1 > 1 AND id@1 < 5, projection=[__datafusion_extracted_1@0] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet, predicate=id@0 > 1 AND id@0 < 5, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1 AND id_null_count@1 != row_count@2 AND id_min@3 < 5, required_guarantees=[] + +# Verify correctness - should return rows where 1 < id < 5 (id=2,3,4) +query I +SELECT s['value'] FROM simple_struct WHERE id > 1 AND id < 5 ORDER BY s['value']; +---- +150 +200 +300 + +query TT +EXPLAIN SELECT s['value'], s['label'], id FROM simple_struct WHERE id > 1; +---- +logical_plan +01)Projection: __datafusion_extracted_1 AS simple_struct.s[value], __datafusion_extracted_2 AS simple_struct.s[label], simple_struct.id +02)--Filter: simple_struct.id > Int64(1) +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, get_field(simple_struct.s, Utf8("label")) AS __datafusion_extracted_2, simple_struct.id +04)------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(1)] +physical_plan +01)ProjectionExec: expr=[__datafusion_extracted_1@0 as simple_struct.s[value], __datafusion_extracted_2@1 as simple_struct.s[label], id@2 as id] +02)--FilterExec: id@2 > 1 +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, get_field(s@1, label) as __datafusion_extracted_2, id], file_type=parquet, predicate=id@0 > 1, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1, required_guarantees=[] + +# Verify correctness - note that id is now at index 2 in the augmented projection +query ITI +SELECT s['value'], s['label'], id FROM simple_struct WHERE id > 1 ORDER BY id LIMIT 3; +---- +200 beta 2 +150 gamma 3 +300 delta 4 + +query TT +EXPLAIN SELECT s['value'] FROM simple_struct WHERE length(s['label']) > 4; +---- +logical_plan +01)Projection: __datafusion_extracted_2 AS simple_struct.s[value] +02)--Filter: character_length(__datafusion_extracted_1) > Int32(4) +03)----Projection: get_field(simple_struct.s, Utf8("label")) AS __datafusion_extracted_1, get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_2 +04)------TableScan: simple_struct projection=[s], partial_filters=[character_length(get_field(simple_struct.s, Utf8("label"))) > Int32(4)] +physical_plan +01)ProjectionExec: expr=[__datafusion_extracted_2@0 as simple_struct.s[value]] +02)--FilterExec: character_length(__datafusion_extracted_1@0) > 4, projection=[__datafusion_extracted_2@1] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, label) as __datafusion_extracted_1, get_field(s@1, value) as __datafusion_extracted_2], file_type=parquet + +# Verify correctness - filter on rows where label length > 4 (all have length 5, except 'one' has 3) +# Wait, from the data: alpha(5), beta(4), gamma(5), delta(5), epsilon(7) +# So: alpha, gamma, delta, epsilon (not beta which has 4 characters) +query I +SELECT s['value'] FROM simple_struct WHERE length(s['label']) > 4 ORDER BY s['value']; +---- +100 +150 +250 +300 + +##################### +# Section 11a: ProjectionExec on top of a SortExec with missing Sort Columns +##################### + +### +# Test 11a.1: Sort by dropped column +# Selects only id, drops s entirely, but sorts by s['value'] +### + +query TT +EXPLAIN SELECT id FROM simple_struct ORDER BY s['value']; +---- +logical_plan +01)Projection: simple_struct.id +02)--Sort: __datafusion_extracted_1 ASC NULLS LAST +03)----Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1 +04)------TableScan: simple_struct projection=[id, s] +physical_plan +01)ProjectionExec: expr=[id@0 as id] +02)--SortExec: expr=[__datafusion_extracted_1@1 ASC NULLS LAST], preserve_partitioning=[false] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) as __datafusion_extracted_1], file_type=parquet + +# Verify correctness +query I +SELECT id FROM simple_struct ORDER BY s['value']; +---- +1 +3 +2 +5 +4 + +### +# Test 11a.2: Multiple sort columns with partial selection +# Selects only id and s['value'], but sorts by id and s['label'] +# One sort column (s['label']) is not selected but needed for ordering +### + +query TT +EXPLAIN SELECT id, s['value'] FROM simple_struct ORDER BY id, s['label']; +---- +logical_plan +01)Projection: simple_struct.id, simple_struct.s[value] +02)--Sort: simple_struct.id ASC NULLS LAST, __datafusion_extracted_1 ASC NULLS LAST +03)----Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")), get_field(simple_struct.s, Utf8("label")) AS __datafusion_extracted_1 +04)------TableScan: simple_struct projection=[id, s] +physical_plan +01)ProjectionExec: expr=[id@0 as id, simple_struct.s[value]@1 as simple_struct.s[value]] +02)--SortExec: expr=[id@0 ASC NULLS LAST, __datafusion_extracted_1@2 ASC NULLS LAST], preserve_partitioning=[false] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) as simple_struct.s[value], get_field(s@1, label) as __datafusion_extracted_1], file_type=parquet + +# Verify correctness +query II +SELECT id, s['value'] FROM simple_struct ORDER BY id, s['label']; +---- +1 100 +2 200 +3 150 +4 300 +5 250 + + +### +# Test 11a.3: TopK with dropped sort column +# Same as test 11a.1 but with LIMIT +### + +query TT +EXPLAIN SELECT id FROM simple_struct ORDER BY s['value'] LIMIT 2; +---- +logical_plan +01)Projection: simple_struct.id +02)--Sort: __datafusion_extracted_1 ASC NULLS LAST, fetch=2 +03)----Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1 +04)------TableScan: simple_struct projection=[id, s] +physical_plan +01)ProjectionExec: expr=[id@0 as id] +02)--SortExec: TopK(fetch=2), expr=[__datafusion_extracted_1@1 ASC NULLS LAST], preserve_partitioning=[false] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) as __datafusion_extracted_1], file_type=parquet + +# Verify correctness +query I +SELECT id FROM simple_struct ORDER BY s['value'] LIMIT 2; +---- +1 +3 + +### +# Test 11a.4: Sort by derived expression with dropped column +# Projects only id, sorts by s['value'] * 2 (derived expression) +# Sort column is computed but requires base columns not in projection +### + +query TT +EXPLAIN SELECT id FROM simple_struct ORDER BY s['value'] * 2; +---- +logical_plan +01)Projection: simple_struct.id +02)--Sort: __datafusion_extracted_1 * Int64(2) ASC NULLS LAST +03)----Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1 +04)------TableScan: simple_struct projection=[id, s] +physical_plan +01)ProjectionExec: expr=[id@0 as id] +02)--SortExec: expr=[__datafusion_extracted_1@1 * 2 ASC NULLS LAST], preserve_partitioning=[false] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) as __datafusion_extracted_1], file_type=parquet + +# Verify correctness +query I +SELECT id FROM simple_struct ORDER BY s['value'] * 2; +---- +1 +3 +2 +5 +4 + +### +# Test 11a.5: All sort columns selected +# All columns needed for sorting are included in projection +### + +query TT +EXPLAIN SELECT id, s['value'] FROM simple_struct ORDER BY id, s['value']; +---- +logical_plan +01)Sort: simple_struct.id ASC NULLS LAST, simple_struct.s[value] ASC NULLS LAST +02)--Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")) +03)----TableScan: simple_struct projection=[id, s] +physical_plan +01)SortExec: expr=[id@0 ASC NULLS LAST, simple_struct.s[value]@1 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) as simple_struct.s[value]], file_type=parquet + +# Verify correctness +query II +SELECT id, s['value'] FROM simple_struct ORDER BY id, s['value']; +---- +1 100 +2 200 +3 150 +4 300 +5 250 + +##################### +# Section 12: Join Tests - get_field Extraction from Join Nodes +##################### + +# Create a second table for join tests +statement ok +COPY ( + SELECT + column1 as id, + column2 as s + FROM VALUES + (1, {role: 'admin', level: 10}), + (2, {role: 'user', level: 5}), + (3, {role: 'guest', level: 1}), + (4, {role: 'admin', level: 8}), + (5, {role: 'user', level: 3}) +) TO 'test_files/scratch/projection_pushdown/join_right.parquet' +STORED AS PARQUET; + +statement ok +CREATE EXTERNAL TABLE join_right STORED AS PARQUET +LOCATION 'test_files/scratch/projection_pushdown/join_right.parquet'; + +### +# Test 12.1: Join with get_field in equijoin condition +# Tests extraction from join ON clause - get_field on each side routed appropriately +### + +query TT +EXPLAIN SELECT simple_struct.id, join_right.id +FROM simple_struct +INNER JOIN join_right ON simple_struct.s['value'] = join_right.s['level'] * 10; +---- +logical_plan +01)Projection: simple_struct.id, join_right.id +02)--Inner Join: __datafusion_extracted_1 = __datafusion_extracted_2 * Int64(10) +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +04)------TableScan: simple_struct projection=[id, s] +05)----Projection: get_field(join_right.s, Utf8("level")) AS __datafusion_extracted_2, join_right.id +06)------TableScan: join_right projection=[id, s] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(__datafusion_extracted_1@0, __datafusion_extracted_2 * Int64(10)@2)], projection=[id@1, id@3] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet +03)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/join_right.parquet]]}, projection=[get_field(s@1, level) as __datafusion_extracted_2, id, get_field(s@1, level) * 10 as __datafusion_extracted_2 * Int64(10)], file_type=parquet + +# Verify correctness - value = level * 10 +# simple_struct: (1,100), (2,200), (3,150), (4,300), (5,250) +# join_right: (1,10), (2,5), (3,1), (4,8), (5,3) +# Matches: simple_struct.value=100 matches join_right.level*10=100 (level=10, id=1) +query II +SELECT simple_struct.id, join_right.id +FROM simple_struct +INNER JOIN join_right ON simple_struct.s['value'] = join_right.s['level'] * 10 +ORDER BY simple_struct.id; +---- +1 1 + +### +# Test 12.2: Join with get_field in non-equi filter +# Tests extraction from join filter expression - left side only +### + +query TT +EXPLAIN SELECT simple_struct.id, join_right.id +FROM simple_struct +INNER JOIN join_right ON simple_struct.id = join_right.id +WHERE simple_struct.s['value'] > 150; +---- +logical_plan +01)Inner Join: simple_struct.id = join_right.id +02)--Projection: simple_struct.id +03)----Filter: __datafusion_extracted_1 > Int64(150) +04)------Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +05)--------TableScan: simple_struct projection=[id, s], partial_filters=[get_field(simple_struct.s, Utf8("value")) > Int64(150)] +06)--TableScan: join_right projection=[id] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(id@0, id@0)] +02)--FilterExec: __datafusion_extracted_1@0 > 150, projection=[id@1] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet +04)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/join_right.parquet]]}, projection=[id], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness - id matches and value > 150 +query II +SELECT simple_struct.id, join_right.id +FROM simple_struct +INNER JOIN join_right ON simple_struct.id = join_right.id +WHERE simple_struct.s['value'] > 150 +ORDER BY simple_struct.id; +---- +2 2 +4 4 +5 5 + +### +# Test 12.3: Join with get_field from both sides in filter +# Tests extraction routing to both left and right inputs +### + +query TT +EXPLAIN SELECT simple_struct.id, join_right.id +FROM simple_struct +INNER JOIN join_right ON simple_struct.id = join_right.id +WHERE simple_struct.s['value'] > 100 AND join_right.s['level'] > 3; +---- +logical_plan +01)Inner Join: simple_struct.id = join_right.id +02)--Projection: simple_struct.id +03)----Filter: __datafusion_extracted_1 > Int64(100) +04)------Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +05)--------TableScan: simple_struct projection=[id, s], partial_filters=[get_field(simple_struct.s, Utf8("value")) > Int64(100)] +06)--Projection: join_right.id +07)----Filter: __datafusion_extracted_2 > Int64(3) +08)------Projection: get_field(join_right.s, Utf8("level")) AS __datafusion_extracted_2, join_right.id +09)--------TableScan: join_right projection=[id, s], partial_filters=[get_field(join_right.s, Utf8("level")) > Int64(3)] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(id@0, id@0)] +02)--FilterExec: __datafusion_extracted_1@0 > 100, projection=[id@1] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet +04)--FilterExec: __datafusion_extracted_2@0 > 3, projection=[id@1] +05)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/join_right.parquet]]}, projection=[get_field(s@1, level) as __datafusion_extracted_2, id], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness - id matches, value > 100, and level > 3 +# Matching ids where value > 100: 2(200), 3(150), 4(300), 5(250) +# Of those, level > 3: 2(5), 4(8), 5(3) -> only 2 and 4 +query II +SELECT simple_struct.id, join_right.id +FROM simple_struct +INNER JOIN join_right ON simple_struct.id = join_right.id +WHERE simple_struct.s['value'] > 100 AND join_right.s['level'] > 3 +ORDER BY simple_struct.id; +---- +2 2 +4 4 + +### +# Test 12.4: Join with get_field in SELECT projection +# Tests that get_field in output columns pushes down through the join +### + +query TT +EXPLAIN SELECT simple_struct.id, simple_struct.s['label'], join_right.s['role'] +FROM simple_struct +INNER JOIN join_right ON simple_struct.id = join_right.id; +---- +logical_plan +01)Projection: simple_struct.id, __datafusion_extracted_1 AS simple_struct.s[label], __datafusion_extracted_2 AS join_right.s[role] +02)--Inner Join: simple_struct.id = join_right.id +03)----Projection: get_field(simple_struct.s, Utf8("label")) AS __datafusion_extracted_1, simple_struct.id +04)------TableScan: simple_struct projection=[id, s] +05)----Projection: get_field(join_right.s, Utf8("role")) AS __datafusion_extracted_2, join_right.id +06)------TableScan: join_right projection=[id, s] +physical_plan +01)ProjectionExec: expr=[id@1 as id, __datafusion_extracted_1@0 as simple_struct.s[label], __datafusion_extracted_2@2 as join_right.s[role]] +02)--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(id@1, id@1)], projection=[__datafusion_extracted_1@0, id@1, __datafusion_extracted_2@2] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, label) as __datafusion_extracted_1, id], file_type=parquet +04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/join_right.parquet]]}, projection=[get_field(s@1, role) as __datafusion_extracted_2, id], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query ITT +SELECT simple_struct.id, simple_struct.s['label'], join_right.s['role'] +FROM simple_struct +INNER JOIN join_right ON simple_struct.id = join_right.id +ORDER BY simple_struct.id; +---- +1 alpha admin +2 beta user +3 gamma guest +4 delta admin +5 epsilon user + +### +# Test 12.5: Join without get_field (baseline - no extraction needed) +# Verifies no unnecessary projections are added when there's nothing to extract +### + +query TT +EXPLAIN SELECT simple_struct.id, join_right.id +FROM simple_struct +INNER JOIN join_right ON simple_struct.id = join_right.id; +---- +logical_plan +01)Inner Join: simple_struct.id = join_right.id +02)--TableScan: simple_struct projection=[id] +03)--TableScan: join_right projection=[id] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(id@0, id@0)] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id], file_type=parquet +03)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/join_right.parquet]]}, projection=[id], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness +query II +SELECT simple_struct.id, join_right.id +FROM simple_struct +INNER JOIN join_right ON simple_struct.id = join_right.id +ORDER BY simple_struct.id; +---- +1 1 +2 2 +3 3 +4 4 +5 5 + +### +# Test 12.6: Left Join with get_field extraction +# Tests extraction works correctly with outer joins +### + +query TT +EXPLAIN SELECT simple_struct.id, simple_struct.s['value'], join_right.s['level'] +FROM simple_struct +LEFT JOIN join_right ON simple_struct.id = join_right.id AND join_right.s['level'] > 5; +---- +logical_plan +01)Projection: simple_struct.id, __datafusion_extracted_2 AS simple_struct.s[value], __datafusion_extracted_3 AS join_right.s[level] +02)--Left Join: simple_struct.id = join_right.id +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_2, simple_struct.id +04)------TableScan: simple_struct projection=[id, s] +05)----Projection: join_right.id, __datafusion_extracted_3 +06)------Filter: __datafusion_extracted_1 > Int64(5) +07)--------Projection: get_field(join_right.s, Utf8("level")) AS __datafusion_extracted_1, join_right.id, get_field(join_right.s, Utf8("level")) AS __datafusion_extracted_3 +08)----------TableScan: join_right projection=[id, s], partial_filters=[get_field(join_right.s, Utf8("level")) > Int64(5)] +physical_plan +01)ProjectionExec: expr=[id@1 as id, __datafusion_extracted_2@0 as simple_struct.s[value], __datafusion_extracted_3@2 as join_right.s[level]] +02)--HashJoinExec: mode=CollectLeft, join_type=Left, on=[(id@1, id@0)], projection=[__datafusion_extracted_2@0, id@1, __datafusion_extracted_3@3] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_2, id], file_type=parquet +04)----FilterExec: __datafusion_extracted_1@0 > 5, projection=[id@1, __datafusion_extracted_3@2] +05)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/join_right.parquet]]}, projection=[get_field(s@1, level) as __datafusion_extracted_1, id, get_field(s@1, level) as __datafusion_extracted_3], file_type=parquet + +# Verify correctness - left join with level > 5 condition +# Only join_right rows with level > 5 are matched: id=1 (level=10), id=4 (level=8) +query III +SELECT simple_struct.id, simple_struct.s['value'], join_right.s['level'] +FROM simple_struct +LEFT JOIN join_right ON simple_struct.id = join_right.id AND join_right.s['level'] > 5 +ORDER BY simple_struct.id; +---- +1 100 10 +2 200 NULL +3 150 NULL +4 300 8 +5 250 NULL + +##################### +# Section 13: RepartitionExec tests +##################### + +# Set target partitions to 32 -> this forces a RepartitionExec +statement ok +SET datafusion.execution.target_partitions = 32; + +query TT +EXPLAIN SELECT s['value'] FROM simple_struct WHERE id > 2; +---- +logical_plan +01)Projection: __datafusion_extracted_1 AS simple_struct.s[value] +02)--Filter: simple_struct.id > Int64(2) +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +04)------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(2)] +physical_plan +01)ProjectionExec: expr=[__datafusion_extracted_1@0 as simple_struct.s[value]] +02)--FilterExec: id@1 > 2, projection=[__datafusion_extracted_1@0] +03)----RepartitionExec: partitioning=RoundRobinBatch(32), input_partitions=1 +04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet, predicate=id@0 > 2, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 2, required_guarantees=[] + +##################### +# Section 14: SubqueryAlias tests +##################### + +# Reset target partitions +statement ok +SET datafusion.execution.target_partitions = 1; + +# get_field pushdown through subquery alias with filter +query TT +EXPLAIN SELECT t.s['value'] FROM (SELECT * FROM simple_struct) t WHERE t.id > 2; +---- +logical_plan +01)Projection: __datafusion_extracted_1 AS t.s[value] +02)--SubqueryAlias: t +03)----Projection: __datafusion_extracted_1 +04)------Filter: simple_struct.id > Int64(2) +05)--------Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +06)----------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(2)] +physical_plan +01)ProjectionExec: expr=[__datafusion_extracted_1@0 as t.s[value]] +02)--FilterExec: id@1 > 2, projection=[__datafusion_extracted_1@0] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet, predicate=id@0 > 2, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 2, required_guarantees=[] + +# Verify correctness +query I +SELECT t.s['value'] FROM (SELECT * FROM simple_struct) t WHERE t.id > 2 ORDER BY t.id; +---- +150 +300 +250 + +# Multiple get_field through subquery alias with sort +query TT +EXPLAIN SELECT t.s['value'], t.s['label'] FROM (SELECT * FROM simple_struct) t ORDER BY t.s['value']; +---- +logical_plan +01)Sort: t.s[value] ASC NULLS LAST +02)--Projection: __datafusion_extracted_1 AS t.s[value], __datafusion_extracted_2 AS t.s[label] +03)----SubqueryAlias: t +04)------Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, get_field(simple_struct.s, Utf8("label")) AS __datafusion_extracted_2 +05)--------TableScan: simple_struct projection=[s] +physical_plan +01)SortExec: expr=[t.s[value]@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as t.s[value], get_field(s@1, label) as t.s[label]], file_type=parquet + +# Verify correctness +query IT +SELECT t.s['value'], t.s['label'] FROM (SELECT * FROM simple_struct) t ORDER BY t.s['value']; +---- +100 alpha +150 gamma +200 beta +250 epsilon +300 delta + +# Nested subquery aliases +query TT +EXPLAIN SELECT u.s['value'] FROM (SELECT * FROM (SELECT * FROM simple_struct) t) u WHERE u.id > 2; +---- +logical_plan +01)Projection: __datafusion_extracted_1 AS u.s[value] +02)--SubqueryAlias: u +03)----SubqueryAlias: t +04)------Projection: __datafusion_extracted_1 +05)--------Filter: simple_struct.id > Int64(2) +06)----------Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +07)------------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(2)] +physical_plan +01)ProjectionExec: expr=[__datafusion_extracted_1@0 as u.s[value]] +02)--FilterExec: id@1 > 2, projection=[__datafusion_extracted_1@0] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet, predicate=id@0 > 2, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 2, required_guarantees=[] + +# Verify correctness +query I +SELECT u.s['value'] FROM (SELECT * FROM (SELECT * FROM simple_struct) t) u WHERE u.id > 2 ORDER BY u.id; +---- +150 +300 +250 + +# get_field in filter through subquery alias +query TT +EXPLAIN SELECT t.id FROM (SELECT * FROM simple_struct) t WHERE t.s['value'] > 200; +---- +logical_plan +01)SubqueryAlias: t +02)--Projection: simple_struct.id +03)----Filter: __datafusion_extracted_1 > Int64(200) +04)------Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +05)--------TableScan: simple_struct projection=[id, s], partial_filters=[get_field(simple_struct.s, Utf8("value")) > Int64(200)] +physical_plan +01)FilterExec: __datafusion_extracted_1@0 > 200, projection=[id@1] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet + +# Verify correctness +query I +SELECT t.id FROM (SELECT * FROM simple_struct) t WHERE t.s['value'] > 200 ORDER BY t.id; +---- +4 +5 + +##################### +# Section 15: UNION ALL tests +##################### + +# get_field on UNION ALL result +query TT +EXPLAIN SELECT s['value'] FROM ( + SELECT s FROM simple_struct WHERE id <= 3 + UNION ALL + SELECT s FROM simple_struct WHERE id > 3 +) t; +---- +logical_plan +01)Projection: __datafusion_extracted_1 AS t.s[value] +02)--SubqueryAlias: t +03)----Union +04)------Projection: __datafusion_extracted_1 +05)--------Filter: simple_struct.id <= Int64(3) +06)----------Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +07)------------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id <= Int64(3)] +08)------Projection: __datafusion_extracted_1 +09)--------Filter: simple_struct.id > Int64(3) +10)----------Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +11)------------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(3)] +physical_plan +01)ProjectionExec: expr=[__datafusion_extracted_1@0 as t.s[value]] +02)--UnionExec +03)----FilterExec: id@1 <= 3, projection=[__datafusion_extracted_1@0] +04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet, predicate=id@0 <= 3, pruning_predicate=id_null_count@1 != row_count@2 AND id_min@0 <= 3, required_guarantees=[] +05)----FilterExec: id@1 > 3, projection=[__datafusion_extracted_1@0] +06)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet, predicate=id@0 > 3, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 3, required_guarantees=[] + +# Verify correctness +query I +SELECT s['value'] FROM ( + SELECT s FROM simple_struct WHERE id <= 3 + UNION ALL + SELECT s FROM simple_struct WHERE id > 3 +) t ORDER BY s['value']; +---- +100 +150 +200 +250 +300 + +# Multiple get_field on UNION ALL with ORDER BY +query TT +EXPLAIN SELECT s['value'], s['label'] FROM ( + SELECT s FROM simple_struct WHERE id <= 3 + UNION ALL + SELECT s FROM simple_struct WHERE id > 3 +) t ORDER BY s['value']; +---- +logical_plan +01)Sort: t.s[value] ASC NULLS LAST +02)--Projection: __datafusion_extracted_1 AS t.s[value], __datafusion_extracted_2 AS t.s[label] +03)----SubqueryAlias: t +04)------Union +05)--------Projection: __datafusion_extracted_1, __datafusion_extracted_2 +06)----------Filter: simple_struct.id <= Int64(3) +07)------------Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, get_field(simple_struct.s, Utf8("label")) AS __datafusion_extracted_2, simple_struct.id +08)--------------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id <= Int64(3)] +09)--------Projection: __datafusion_extracted_1, __datafusion_extracted_2 +10)----------Filter: simple_struct.id > Int64(3) +11)------------Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, get_field(simple_struct.s, Utf8("label")) AS __datafusion_extracted_2, simple_struct.id +12)--------------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(3)] +physical_plan +01)SortPreservingMergeExec: [t.s[value]@0 ASC NULLS LAST] +02)--SortExec: expr=[t.s[value]@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[__datafusion_extracted_1@0 as t.s[value], __datafusion_extracted_2@1 as t.s[label]] +04)------UnionExec +05)--------FilterExec: id@2 <= 3, projection=[__datafusion_extracted_1@0, __datafusion_extracted_2@1] +06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, get_field(s@1, label) as __datafusion_extracted_2, id], file_type=parquet, predicate=id@0 <= 3, pruning_predicate=id_null_count@1 != row_count@2 AND id_min@0 <= 3, required_guarantees=[] +07)--------FilterExec: id@2 > 3, projection=[__datafusion_extracted_1@0, __datafusion_extracted_2@1] +08)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, get_field(s@1, label) as __datafusion_extracted_2, id], file_type=parquet, predicate=id@0 > 3, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 3, required_guarantees=[] + +# Verify correctness +query IT +SELECT s['value'], s['label'] FROM ( + SELECT s FROM simple_struct WHERE id <= 3 + UNION ALL + SELECT s FROM simple_struct WHERE id > 3 +) t ORDER BY s['value']; +---- +100 alpha +150 gamma +200 beta +250 epsilon +300 delta + +##################### +# Section 16: Aggregate / Join edge-case tests +# Translated from unit tests in extract_leaf_expressions.rs +##################### + +### +# Test 16.1: Projection with get_field above Aggregate +# Aggregate blocks pushdown, so the get_field stays in the top projection. +# (mirrors test_projection_with_leaf_expr_above_aggregate) +### + +query TT +EXPLAIN SELECT s['label'] IS NOT NULL AS has_label, COUNT(1) +FROM simple_struct GROUP BY s; +---- +logical_plan +01)Projection: get_field(simple_struct.s, Utf8("label")) IS NOT NULL AS has_label, count(Int64(1)) +02)--Aggregate: groupBy=[[simple_struct.s]], aggr=[[count(Int64(1))]] +03)----TableScan: simple_struct projection=[s] +physical_plan +01)ProjectionExec: expr=[get_field(s@0, label) IS NOT NULL as has_label, count(Int64(1))@1 as count(Int64(1))] +02)--AggregateExec: mode=Single, gby=[s@0 as s], aggr=[count(Int64(1))] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[s], file_type=parquet + +# Verify correctness - all labels are non-null +query BI +SELECT s['label'] IS NOT NULL AS has_label, COUNT(1) +FROM simple_struct GROUP BY s ORDER BY COUNT(1); +---- +true 1 +true 1 +true 1 +true 1 +true 1 + +### +# Test 16.2: Join with get_field filter on qualified right side +# The get_field on join_right.s['role'] must be routed to the right input only. +# (mirrors test_extract_from_join_qualified_right_side) +### + +query TT +EXPLAIN +SELECT s.s['value'], j.s['role'] +FROM join_right j +INNER JOIN simple_struct s ON s.id = j.id +WHERE s.s['value'] > j.s['level']; +---- +logical_plan +01)Projection: __datafusion_extracted_3 AS s.s[value], __datafusion_extracted_4 AS j.s[role] +02)--Inner Join: j.id = s.id Filter: __datafusion_extracted_1 > __datafusion_extracted_2 +03)----SubqueryAlias: j +04)------Projection: get_field(join_right.s, Utf8("level")) AS __datafusion_extracted_2, get_field(join_right.s, Utf8("role")) AS __datafusion_extracted_4, join_right.id +05)--------TableScan: join_right projection=[id, s] +06)----SubqueryAlias: s +07)------Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_3, simple_struct.id +08)--------TableScan: simple_struct projection=[id, s] +physical_plan +01)ProjectionExec: expr=[__datafusion_extracted_3@1 as s.s[value], __datafusion_extracted_4@0 as j.s[role]] +02)--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(id@2, id@2)], filter=__datafusion_extracted_1@1 > __datafusion_extracted_2@0, projection=[__datafusion_extracted_4@1, __datafusion_extracted_3@4] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/join_right.parquet]]}, projection=[get_field(s@1, level) as __datafusion_extracted_2, get_field(s@1, role) as __datafusion_extracted_4, id], file_type=parquet +04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, get_field(s@1, value) as __datafusion_extracted_3, id], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness - only admin roles match (ids 1 and 4) +query II +SELECT simple_struct.id, join_right.id +FROM simple_struct +INNER JOIN join_right + ON simple_struct.id = join_right.id + AND join_right.s['role'] = 'admin' +ORDER BY simple_struct.id; +---- +1 1 +4 4 + +### +# Test 16.3: Join with cross-input get_field comparison in WHERE +# get_field from each side is extracted and routed to its respective input independently. +# (mirrors test_extract_from_join_cross_input_expression) +### + +query TT +EXPLAIN SELECT simple_struct.id, join_right.id +FROM simple_struct +INNER JOIN join_right ON simple_struct.id = join_right.id +WHERE simple_struct.s['value'] > join_right.s['level']; +---- +logical_plan +01)Projection: simple_struct.id, join_right.id +02)--Inner Join: simple_struct.id = join_right.id Filter: __datafusion_extracted_1 > __datafusion_extracted_2 +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __datafusion_extracted_1, simple_struct.id +04)------TableScan: simple_struct projection=[id, s] +05)----Projection: get_field(join_right.s, Utf8("level")) AS __datafusion_extracted_2, join_right.id +06)------TableScan: join_right projection=[id, s] +physical_plan +01)HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(id@1, id@1)], filter=__datafusion_extracted_1@0 > __datafusion_extracted_2@1, projection=[id@1, id@3] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __datafusion_extracted_1, id], file_type=parquet +03)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/join_right.parquet]]}, projection=[get_field(s@1, level) as __datafusion_extracted_2, id], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify correctness - all rows match since value >> level for all ids +# simple_struct: (1,100), (2,200), (3,150), (4,300), (5,250) +# join_right: (1,10), (2,5), (3,1), (4,8), (5,3) +query II +SELECT simple_struct.id, join_right.id +FROM simple_struct +INNER JOIN join_right ON simple_struct.id = join_right.id +WHERE simple_struct.s['value'] > join_right.s['level'] +ORDER BY simple_struct.id; +---- +1 1 +2 2 +3 3 +4 4 +5 5 diff --git a/datafusion/sqllogictest/test_files/push_down_filter.slt b/datafusion/sqllogictest/test_files/push_down_filter.slt index 4353f805c848b..edafcfaa543f2 100644 --- a/datafusion/sqllogictest/test_files/push_down_filter.slt +++ b/datafusion/sqllogictest/test_files/push_down_filter.slt @@ -116,11 +116,12 @@ explain select * from (select column1, unnest(column2) as o from d) where o['a'] ---- physical_plan 01)ProjectionExec: expr=[column1@0 as column1, __unnest_placeholder(d.column2,depth=1)@1 as o] -02)--FilterExec: get_field(__unnest_placeholder(d.column2,depth=1)@1, a) = 1 +02)--FilterExec: __datafusion_extracted_1@0 = 1, projection=[column1@1, __unnest_placeholder(d.column2,depth=1)@2] 03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -04)------UnnestExec -05)--------ProjectionExec: expr=[column1@0 as column1, column2@1 as __unnest_placeholder(d.column2)] -06)----------DataSourceExec: partitions=1, partition_sizes=[1] +04)------ProjectionExec: expr=[get_field(__unnest_placeholder(d.column2,depth=1)@1, a) as __datafusion_extracted_1, column1@0 as column1, __unnest_placeholder(d.column2,depth=1)@1 as __unnest_placeholder(d.column2,depth=1)] +05)--------UnnestExec +06)----------ProjectionExec: expr=[column1@0 as column1, column2@1 as __unnest_placeholder(d.column2)] +07)------------DataSourceExec: partitions=1, partition_sizes=[1] statement ok drop table d; @@ -488,3 +489,257 @@ physical_plan statement ok drop table agg_dyn_test; + +statement ok +drop table t1; + +statement ok +drop table t2; + + + +# check LEFT/RIGHT joins with filter pushdown to both relations (when possible) + +statement ok +create table t1(k int, v int); + +statement ok +create table t2(k int, v int); + +statement ok +insert into t1 values + (1, 10), + (2, 20), + (3, 30), + (null, 40), + (50, null), + (null, null); + +statement ok +insert into t2 values + (1, 11), + (2, 21), + (2, 22), + (null, 41), + (51, null), + (null, null); + +statement ok +set datafusion.explain.physical_plan_only = false; + +statement ok +set datafusion.explain.logical_plan_only = true; + + +# left join + filter on join key -> pushed +query TT +explain select * from t1 left join t2 on t1.k = t2.k where t1.k > 1; +---- +logical_plan +01)Left Join: t1.k = t2.k +02)--Filter: t1.k > Int32(1) +03)----TableScan: t1 projection=[k, v] +04)--Filter: t2.k > Int32(1) +05)----TableScan: t2 projection=[k, v] + +query IIII rowsort +select * from t1 left join t2 on t1.k = t2.k where t1.k > 1; +---- +2 20 2 21 +2 20 2 22 +3 30 NULL NULL +50 NULL NULL NULL + +# left join + filter on another column -> not pushed +query TT +explain select * from t1 left join t2 on t1.k = t2.k where t1.v > 1; +---- +logical_plan +01)Left Join: t1.k = t2.k +02)--Filter: t1.v > Int32(1) +03)----TableScan: t1 projection=[k, v] +04)--TableScan: t2 projection=[k, v] + +query IIII rowsort +select * from t1 left join t2 on t1.k = t2.k where t1.v > 1; +---- +1 10 1 11 +2 20 2 21 +2 20 2 22 +3 30 NULL NULL +NULL 40 NULL NULL + +# left join + or + filter on another column -> not pushed +query TT +explain select * from t1 left join t2 on t1.k = t2.k where t1.k > 3 or t1.v > 20; +---- +logical_plan +01)Left Join: t1.k = t2.k +02)--Filter: t1.k > Int32(3) OR t1.v > Int32(20) +03)----TableScan: t1 projection=[k, v] +04)--TableScan: t2 projection=[k, v] + +query IIII rowsort +select * from t1 left join t2 on t1.k = t2.k where t1.k > 3 or t1.v > 20; +---- +3 30 NULL NULL +50 NULL NULL NULL +NULL 40 NULL NULL + + +# right join + filter on join key -> pushed +query TT +explain select * from t1 right join t2 on t1.k = t2.k where t1.k > 1; +---- +logical_plan +01)Inner Join: t1.k = t2.k +02)--Filter: t1.k > Int32(1) +03)----TableScan: t1 projection=[k, v] +04)--Filter: t2.k > Int32(1) +05)----TableScan: t2 projection=[k, v] + +query IIII rowsort +select * from t1 right join t2 on t1.k = t2.k where t1.k > 1; +---- +2 20 2 21 +2 20 2 22 + +# right join + filter on another column -> not pushed +query TT +explain select * from t1 right join t2 on t1.k = t2.k where t1.v > 1; +---- +logical_plan +01)Inner Join: t1.k = t2.k +02)--Filter: t1.v > Int32(1) +03)----TableScan: t1 projection=[k, v] +04)--TableScan: t2 projection=[k, v] + +query IIII rowsort +select * from t1 right join t2 on t1.k = t2.k where t1.v > 1; +---- +1 10 1 11 +2 20 2 21 +2 20 2 22 + +# right join + or + filter on another column -> not pushed +query TT +explain select * from t1 right join t2 on t1.k = t2.k where t1.k > 3 or t1.v > 20; +---- +logical_plan +01)Inner Join: t1.k = t2.k +02)--Filter: t1.k > Int32(3) OR t1.v > Int32(20) +03)----TableScan: t1 projection=[k, v] +04)--TableScan: t2 projection=[k, v] + +query IIII rowsort +select * from t1 right join t2 on t1.k = t2.k where t1.k > 3 or t1.v > 20; +---- + + +# left anti join + filter on join key -> pushed +query TT +explain select * from t1 left anti join t2 on t1.k = t2.k where t1.k > 1; +---- +logical_plan +01)LeftAnti Join: t1.k = t2.k +02)--Filter: t1.k > Int32(1) +03)----TableScan: t1 projection=[k, v] +04)--Filter: t2.k > Int32(1) +05)----TableScan: t2 projection=[k] + +query II rowsort +select * from t1 left anti join t2 on t1.k = t2.k where t1.k > 1; +---- +3 30 +50 NULL + +# left anti join + filter on another column -> not pushed +query TT +explain select * from t1 left anti join t2 on t1.k = t2.k where t1.v > 1; +---- +logical_plan +01)LeftAnti Join: t1.k = t2.k +02)--Filter: t1.v > Int32(1) +03)----TableScan: t1 projection=[k, v] +04)--TableScan: t2 projection=[k] + +query II rowsort +select * from t1 left anti join t2 on t1.k = t2.k where t1.v > 1; +---- +3 30 +NULL 40 + +# left anti join + or + filter on another column -> not pushed +query TT +explain select * from t1 left anti join t2 on t1.k = t2.k where t1.k > 3 or t1.v > 20; +---- +logical_plan +01)LeftAnti Join: t1.k = t2.k +02)--Filter: t1.k > Int32(3) OR t1.v > Int32(20) +03)----TableScan: t1 projection=[k, v] +04)--TableScan: t2 projection=[k] + +query II rowsort +select * from t1 left anti join t2 on t1.k = t2.k where t1.k > 3 or t1.v > 20; +---- +3 30 +50 NULL +NULL 40 + + +# right anti join + filter on join key -> pushed +query TT +explain select * from t1 right anti join t2 on t1.k = t2.k where t2.k > 1; +---- +logical_plan +01)RightAnti Join: t1.k = t2.k +02)--Filter: t1.k > Int32(1) +03)----TableScan: t1 projection=[k] +04)--Filter: t2.k > Int32(1) +05)----TableScan: t2 projection=[k, v] + +query II rowsort +select * from t1 right anti join t2 on t1.k = t2.k where t2.k > 1; +---- +51 NULL + +# right anti join + filter on another column -> not pushed +query TT +explain select * from t1 right anti join t2 on t1.k = t2.k where t2.v > 1; +---- +logical_plan +01)RightAnti Join: t1.k = t2.k +02)--TableScan: t1 projection=[k] +03)--Filter: t2.v > Int32(1) +04)----TableScan: t2 projection=[k, v] + +query II rowsort +select * from t1 right anti join t2 on t1.k = t2.k where t2.v > 1; +---- +NULL 41 + +# right anti join + or + filter on another column -> not pushed +query TT +explain select * from t1 right anti join t2 on t1.k = t2.k where t2.k > 3 or t2.v > 20; +---- +logical_plan +01)RightAnti Join: t1.k = t2.k +02)--TableScan: t1 projection=[k] +03)--Filter: t2.k > Int32(3) OR t2.v > Int32(20) +04)----TableScan: t2 projection=[k, v] + +query II rowsort +select * from t1 right anti join t2 on t1.k = t2.k where t2.k > 3 or t2.v > 20; +---- +51 NULL +NULL 41 + + +statement ok +set datafusion.explain.logical_plan_only = false; + +statement ok +drop table t1; + +statement ok +drop table t2; diff --git a/datafusion/sqllogictest/test_files/regexp/regexp_like.slt b/datafusion/sqllogictest/test_files/regexp/regexp_like.slt index 6f2d5a873c1b6..2b304c8de1a3c 100644 --- a/datafusion/sqllogictest/test_files/regexp/regexp_like.slt +++ b/datafusion/sqllogictest/test_files/regexp/regexp_like.slt @@ -334,5 +334,10 @@ true true false false false false +query TT +select * from regexp_test where regexp_like('f', regexp_replace((('v\r') like ('f_*sP6H1*')), '339629555', '-1459539013')); +---- + + statement ok drop table if exists dict_table; diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 7be7de5a4def8..7a4a81b5faa64 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -765,11 +765,11 @@ select nanvl(null, 64); ---- NULL -# nanvl scalar nulls #1 +# nanvl scalar nulls #1 - x is not NaN, so return x even if y is NULL query R rowsort select nanvl(2, null); ---- -NULL +2 # nanvl scalar nulls #2 query R rowsort @@ -1165,7 +1165,7 @@ from small_floats; ---- 0.447 0.4 0.447 0.707 0.7 0.707 -0.837 0.8 0.837 +0.836 0.8 0.836 1 1 1 ## bitwise and @@ -1311,6 +1311,14 @@ select a << b, c << d, e << f from signed_integers; 33554432 123 10485760 NULL NULL NULL +## bitwise operations should reject non-integer types + +query error DataFusion error: Error during planning: Cannot infer common type for bitwise operation Float32 & Float32 +select arrow_cast(1, 'Float32') & arrow_cast(2, 'Float32'); + +query error DataFusion error: Error during planning: Cannot infer common type for bitwise operation Date32 & Date32 +select arrow_cast(1, 'Date32') & arrow_cast(2, 'Date32'); + statement ok drop table unsigned_integers; @@ -1993,10 +2001,10 @@ query TT EXPLAIN SELECT letter, letter = LEFT(letter2, 1) FROM simple_string; ---- logical_plan -01)Projection: simple_string.letter, simple_string.letter = CAST(left(simple_string.letter2, Int64(1)) AS Utf8View) +01)Projection: simple_string.letter, simple_string.letter = left(simple_string.letter2, Int64(1)) 02)--TableScan: simple_string projection=[letter, letter2] physical_plan -01)ProjectionExec: expr=[letter@0 as letter, letter@0 = CAST(left(letter2@1, 1) AS Utf8View) as simple_string.letter = left(simple_string.letter2,Int64(1))] +01)ProjectionExec: expr=[letter@0 as letter, letter@0 = left(letter2@1, 1) as simple_string.letter = left(simple_string.letter2,Int64(1))] 02)--DataSourceExec: partitions=1, partition_sizes=[1] query TB @@ -2010,8 +2018,8 @@ D false # test string_temporal_coercion query BBBBBBBBBB select - arrow_cast(to_timestamp('2020-01-01 01:01:11.1234567890Z'), 'Timestamp(Second, None)') == '2020-01-01T01:01:11', - arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Second, None)') == arrow_cast('2020-01-02T01:01:11', 'LargeUtf8'), + arrow_cast(to_timestamp('2020-01-01 01:01:11.1234567890Z'), 'Timestamp(s)') == '2020-01-01T01:01:11', + arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(s)') == arrow_cast('2020-01-02T01:01:11', 'LargeUtf8'), arrow_cast(to_timestamp('2020-01-03 01:01:11.1234567890Z'), 'Time32(Second)') == '01:01:11', arrow_cast(to_timestamp('2020-01-04 01:01:11.1234567890Z'), 'Time32(Second)') == arrow_cast('01:01:11', 'LargeUtf8'), arrow_cast(to_timestamp('2020-01-05 01:01:11.1234567890Z'), 'Time64(Microsecond)') == '01:01:11.123456', @@ -2069,7 +2077,7 @@ select position('' in '') ---- 1 -query error DataFusion error: Error during planning: Internal error: Expect TypeSignatureClass::Native\(LogicalType\(Native\(String\), String\)\) but received NativeType::Int64, DataType: Int64 +query error DataFusion error: Error during planning: Function 'strpos' requires TypeSignatureClass::Native\(LogicalType\(Native\(String\), String\)\), but received Int64 \(DataType: Int64\) select position(1 in 1) query I diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index 490df4b72d17b..d49ccb9fe979f 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -1926,3 +1926,12 @@ select "current_time" is not null from t_with_current_time; true false true + +# https://github.com/apache/datafusion/issues/20215 +statement count 0 +CREATE TABLE t0; + +query I +SELECT COUNT(*) FROM t0 AS tt0 WHERE (4==(3/0)); +---- +0 diff --git a/datafusion/sqllogictest/test_files/simplify_expr.slt b/datafusion/sqllogictest/test_files/simplify_expr.slt index d8c25ab25e8ea..99fc9900ef619 100644 --- a/datafusion/sqllogictest/test_files/simplify_expr.slt +++ b/datafusion/sqllogictest/test_files/simplify_expr.slt @@ -113,3 +113,21 @@ logical_plan physical_plan 01)ProjectionExec: expr=[[{x:100}] as a] 02)--PlaceholderRowExec + +# Simplify expr = L1 AND expr != L2 to expr = L1 when L1 != L2 +query TT +EXPLAIN SELECT + v = 1 AND v != 0 as opt1, + v = 2 AND v != 2 as noopt1, + v != 3 AND v = 4 as opt2, + v != 5 AND v = 5 as noopt2 +FROM (VALUES (0), (1), (2)) t(v) +---- +logical_plan +01)Projection: t.v = Int64(1) AS opt1, t.v = Int64(2) AND t.v != Int64(2) AS noopt1, t.v = Int64(4) AS opt2, t.v != Int64(5) AND t.v = Int64(5) AS noopt2 +02)--SubqueryAlias: t +03)----Projection: column1 AS v +04)------Values: (Int64(0)), (Int64(1)), (Int64(2)) +physical_plan +01)ProjectionExec: expr=[column1@0 = 1 as opt1, column1@0 = 2 AND column1@0 != 2 as noopt1, column1@0 = 4 as opt2, column1@0 != 5 AND column1@0 = 5 as noopt2] +02)--DataSourceExec: partitions=1, partition_sizes=[1] diff --git a/datafusion/sqllogictest/test_files/sort_pushdown.slt b/datafusion/sqllogictest/test_files/sort_pushdown.slt index 58d9915a24be2..99f26b66d458b 100644 --- a/datafusion/sqllogictest/test_files/sort_pushdown.slt +++ b/datafusion/sqllogictest/test_files/sort_pushdown.slt @@ -851,7 +851,749 @@ LIMIT 3; 5 4 2 -3 +# Test 3.7: Aggregate ORDER BY expression should keep SortExec +# Source pattern declared on parquet scan: [x ASC, y ASC]. +# Requested pattern in ORDER BY: [x ASC, CAST(y AS BIGINT) % 2 ASC]. +# Example for x=1 input y order 1,2,3 gives bucket order 1,0,1, which does not +# match requested bucket ASC order. SortExec is required above AggregateExec. +statement ok +SET datafusion.execution.target_partitions = 1; + +statement ok +CREATE TABLE agg_expr_data(x INT, y INT, v INT) AS VALUES +(1, 1, 10), +(1, 2, 20), +(1, 3, 30), +(2, 1, 40), +(2, 2, 50), +(2, 3, 60); + +query I +COPY (SELECT * FROM agg_expr_data ORDER BY x, y) +TO 'test_files/scratch/sort_pushdown/agg_expr_sorted.parquet'; +---- +6 + +statement ok +CREATE EXTERNAL TABLE agg_expr_parquet(x INT, y INT, v INT) +STORED AS PARQUET +LOCATION 'test_files/scratch/sort_pushdown/agg_expr_sorted.parquet' +WITH ORDER (x ASC, y ASC); + +query TT +EXPLAIN SELECT + x, + CAST(y AS BIGINT) % 2, + SUM(v) +FROM agg_expr_parquet +GROUP BY x, CAST(y AS BIGINT) % 2 +ORDER BY x, CAST(y AS BIGINT) % 2; +---- +logical_plan +01)Sort: agg_expr_parquet.x ASC NULLS LAST, agg_expr_parquet.y % Int64(2) ASC NULLS LAST +02)--Aggregate: groupBy=[[agg_expr_parquet.x, CAST(agg_expr_parquet.y AS Int64) % Int64(2)]], aggr=[[sum(CAST(agg_expr_parquet.v AS Int64))]] +03)----TableScan: agg_expr_parquet projection=[x, y, v] +physical_plan +01)SortExec: expr=[x@0 ASC NULLS LAST, agg_expr_parquet.y % Int64(2)@1 ASC NULLS LAST], preserve_partitioning=[false] +02)--AggregateExec: mode=Single, gby=[x@0 as x, CAST(y@1 AS Int64) % 2 as agg_expr_parquet.y % Int64(2)], aggr=[sum(agg_expr_parquet.v)], ordering_mode=PartiallySorted([0]) +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/agg_expr_sorted.parquet]]}, projection=[x, y, v], output_ordering=[x@0 ASC NULLS LAST, y@1 ASC NULLS LAST], file_type=parquet + +# Expected output pattern from ORDER BY [x, bucket]: +# rows grouped by x, and within each x bucket appears as 0 then 1. +query III +SELECT + x, + CAST(y AS BIGINT) % 2, + SUM(v) +FROM agg_expr_parquet +GROUP BY x, CAST(y AS BIGINT) % 2 +ORDER BY x, CAST(y AS BIGINT) % 2; +---- +1 0 20 +1 1 40 +2 0 50 +2 1 100 + +# Test 3.8: Aggregate ORDER BY monotonic expression can push down (no SortExec) +query TT +EXPLAIN SELECT + x, + CAST(y AS BIGINT), + SUM(v) +FROM agg_expr_parquet +GROUP BY x, CAST(y AS BIGINT) +ORDER BY x, CAST(y AS BIGINT); +---- +logical_plan +01)Sort: agg_expr_parquet.x ASC NULLS LAST, agg_expr_parquet.y ASC NULLS LAST +02)--Aggregate: groupBy=[[agg_expr_parquet.x, CAST(agg_expr_parquet.y AS Int64)]], aggr=[[sum(CAST(agg_expr_parquet.v AS Int64))]] +03)----TableScan: agg_expr_parquet projection=[x, y, v] +physical_plan +01)AggregateExec: mode=Single, gby=[x@0 as x, CAST(y@1 AS Int64) as agg_expr_parquet.y], aggr=[sum(agg_expr_parquet.v)], ordering_mode=Sorted +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/agg_expr_sorted.parquet]]}, projection=[x, y, v], output_ordering=[x@0 ASC NULLS LAST, y@1 ASC NULLS LAST], file_type=parquet + +query III +SELECT + x, + CAST(y AS BIGINT), + SUM(v) +FROM agg_expr_parquet +GROUP BY x, CAST(y AS BIGINT) +ORDER BY x, CAST(y AS BIGINT); +---- +1 1 10 +1 2 20 +1 3 30 +2 1 40 +2 2 50 +2 3 60 + +# Test 3.9: Aggregate ORDER BY aggregate output should keep SortExec +query TT +EXPLAIN SELECT x, SUM(v) +FROM agg_expr_parquet +GROUP BY x +ORDER BY SUM(v); +---- +logical_plan +01)Sort: sum(agg_expr_parquet.v) ASC NULLS LAST +02)--Aggregate: groupBy=[[agg_expr_parquet.x]], aggr=[[sum(CAST(agg_expr_parquet.v AS Int64))]] +03)----TableScan: agg_expr_parquet projection=[x, v] +physical_plan +01)SortExec: expr=[sum(agg_expr_parquet.v)@1 ASC NULLS LAST], preserve_partitioning=[false] +02)--AggregateExec: mode=Single, gby=[x@0 as x], aggr=[sum(agg_expr_parquet.v)], ordering_mode=Sorted +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/agg_expr_sorted.parquet]]}, projection=[x, v], output_ordering=[x@0 ASC NULLS LAST], file_type=parquet + +query II +SELECT x, SUM(v) +FROM agg_expr_parquet +GROUP BY x +ORDER BY SUM(v); +---- +1 60 +2 150 + +# Test 3.10: Aggregate with non-preserved input order should keep SortExec +# v is not part of the order by +query TT +EXPLAIN SELECT v, SUM(y) +FROM agg_expr_parquet +GROUP BY v +ORDER BY v; +---- +logical_plan +01)Sort: agg_expr_parquet.v ASC NULLS LAST +02)--Aggregate: groupBy=[[agg_expr_parquet.v]], aggr=[[sum(CAST(agg_expr_parquet.y AS Int64))]] +03)----TableScan: agg_expr_parquet projection=[y, v] +physical_plan +01)SortExec: expr=[v@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--AggregateExec: mode=Single, gby=[v@1 as v], aggr=[sum(agg_expr_parquet.y)] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/agg_expr_sorted.parquet]]}, projection=[y, v], file_type=parquet + +query II +SELECT v, SUM(y) +FROM agg_expr_parquet +GROUP BY v +ORDER BY v; +---- +10 1 +20 2 +30 3 +40 1 +50 2 +60 3 + +# Test 3.11: Aggregate ORDER BY non-column expression (unsatisfied) keeps SortExec +# (though note in theory DataFusion could figure out that data sorted by x will also be sorted by x+1) +query TT +EXPLAIN SELECT x, SUM(v) +FROM agg_expr_parquet +GROUP BY x +ORDER BY x + 1 DESC; +---- +logical_plan +01)Sort: CAST(agg_expr_parquet.x AS Int64) + Int64(1) DESC NULLS FIRST +02)--Aggregate: groupBy=[[agg_expr_parquet.x]], aggr=[[sum(CAST(agg_expr_parquet.v AS Int64))]] +03)----TableScan: agg_expr_parquet projection=[x, v] +physical_plan +01)SortExec: expr=[CAST(x@0 AS Int64) + 1 DESC], preserve_partitioning=[false] +02)--AggregateExec: mode=Single, gby=[x@0 as x], aggr=[sum(agg_expr_parquet.v)], ordering_mode=Sorted +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/agg_expr_sorted.parquet]]}, projection=[x, v], output_ordering=[x@0 ASC NULLS LAST], file_type=parquet + +query II +SELECT x, SUM(v) +FROM agg_expr_parquet +GROUP BY x +ORDER BY x + 1 DESC; +---- +2 150 +1 60 + +# Test 3.12: Aggregate ORDER BY non-column expression (unsatisfied) keeps SortExec +# (though note in theory DataFusion could figure out that data sorted by x will also be sorted by x+1) +query TT +EXPLAIN SELECT x, SUM(v) +FROM agg_expr_parquet +GROUP BY x +ORDER BY 2 * x ASC; +---- +logical_plan +01)Sort: Int64(2) * CAST(agg_expr_parquet.x AS Int64) ASC NULLS LAST +02)--Aggregate: groupBy=[[agg_expr_parquet.x]], aggr=[[sum(CAST(agg_expr_parquet.v AS Int64))]] +03)----TableScan: agg_expr_parquet projection=[x, v] +physical_plan +01)SortExec: expr=[2 * CAST(x@0 AS Int64) ASC NULLS LAST], preserve_partitioning=[false] +02)--AggregateExec: mode=Single, gby=[x@0 as x], aggr=[sum(agg_expr_parquet.v)], ordering_mode=Sorted +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/agg_expr_sorted.parquet]]}, projection=[x, v], output_ordering=[x@0 ASC NULLS LAST], file_type=parquet + +query II +SELECT x, SUM(v) +FROM agg_expr_parquet +GROUP BY x +ORDER BY 2 * x ASC; +---- +1 60 +2 150 + +# Test 4: Reversed filesystem order with inferred ordering +# Create 3 parquet files with non-overlapping id ranges, named so filesystem +# order is OPPOSITE to data order. Each file is internally sorted by id ASC. +# Force target_partitions=1 so all files end up in one file group, which is +# where the inter-file ordering bug manifests. +# Without inter-file validation, the optimizer would incorrectly trust the +# inferred ordering and remove SortExec. + +# Save current target_partitions and set to 1 to force single file group +statement ok +SET datafusion.execution.target_partitions = 1; + +statement ok +CREATE TABLE reversed_high(id INT, value INT) AS VALUES (7, 700), (8, 800), (9, 900); + +statement ok +CREATE TABLE reversed_mid(id INT, value INT) AS VALUES (4, 400), (5, 500), (6, 600); + +statement ok +CREATE TABLE reversed_low(id INT, value INT) AS VALUES (1, 100), (2, 200), (3, 300); + +query I +COPY (SELECT * FROM reversed_high ORDER BY id ASC) +TO 'test_files/scratch/sort_pushdown/reversed/a_high.parquet'; +---- +3 + +query I +COPY (SELECT * FROM reversed_mid ORDER BY id ASC) +TO 'test_files/scratch/sort_pushdown/reversed/b_mid.parquet'; +---- +3 + +query I +COPY (SELECT * FROM reversed_low ORDER BY id ASC) +TO 'test_files/scratch/sort_pushdown/reversed/c_low.parquet'; +---- +3 + +# External table with NO "WITH ORDER" — relies on inferred ordering from parquet metadata +statement ok +CREATE EXTERNAL TABLE reversed_parquet(id INT, value INT) +STORED AS PARQUET +LOCATION 'test_files/scratch/sort_pushdown/reversed/'; + +# Test 4.1: SortExec must be present because files are not in inter-file order +query TT +EXPLAIN SELECT * FROM reversed_parquet ORDER BY id ASC; +---- +logical_plan +01)Sort: reversed_parquet.id ASC NULLS LAST +02)--TableScan: reversed_parquet projection=[id, value] +physical_plan +01)SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/reversed/a_high.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/reversed/b_mid.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/reversed/c_low.parquet]]}, projection=[id, value], file_type=parquet + +# Test 4.2: Results must be correct +query II +SELECT * FROM reversed_parquet ORDER BY id ASC; +---- +1 100 +2 200 +3 300 +4 400 +5 500 +6 600 +7 700 +8 800 +9 900 + +# Test 5: Overlapping files with inferred ordering +# Create files with overlapping id ranges + +statement ok +CREATE TABLE overlap_x(id INT, value INT) AS VALUES (1, 100), (3, 300), (5, 500); + +statement ok +CREATE TABLE overlap_y(id INT, value INT) AS VALUES (2, 200), (4, 400), (6, 600); + +query I +COPY (SELECT * FROM overlap_x ORDER BY id ASC) +TO 'test_files/scratch/sort_pushdown/overlap/file_x.parquet'; +---- +3 + +query I +COPY (SELECT * FROM overlap_y ORDER BY id ASC) +TO 'test_files/scratch/sort_pushdown/overlap/file_y.parquet'; +---- +3 + +statement ok +CREATE EXTERNAL TABLE overlap_parquet(id INT, value INT) +STORED AS PARQUET +LOCATION 'test_files/scratch/sort_pushdown/overlap/'; + +# Test 5.1: SortExec must be present because files have overlapping ranges +query TT +EXPLAIN SELECT * FROM overlap_parquet ORDER BY id ASC; +---- +logical_plan +01)Sort: overlap_parquet.id ASC NULLS LAST +02)--TableScan: overlap_parquet projection=[id, value] +physical_plan +01)SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/overlap/file_x.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/overlap/file_y.parquet]]}, projection=[id, value], file_type=parquet + +# Test 5.2: Results must be correct +query II +SELECT * FROM overlap_parquet ORDER BY id ASC; +---- +1 100 +2 200 +3 300 +4 400 +5 500 +6 600 + +# Test 6: WITH ORDER + reversed filesystem order +# Same file setup as Test 4 but explicitly declaring ordering via WITH ORDER. +# Even with WITH ORDER, the optimizer should detect that inter-file order is wrong +# and keep SortExec. + +statement ok +CREATE EXTERNAL TABLE reversed_with_order_parquet(id INT, value INT) +STORED AS PARQUET +LOCATION 'test_files/scratch/sort_pushdown/reversed/' +WITH ORDER (id ASC); + +# Test 6.1: SortExec must be present despite WITH ORDER +query TT +EXPLAIN SELECT * FROM reversed_with_order_parquet ORDER BY id ASC; +---- +logical_plan +01)Sort: reversed_with_order_parquet.id ASC NULLS LAST +02)--TableScan: reversed_with_order_parquet projection=[id, value] +physical_plan +01)SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/reversed/a_high.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/reversed/b_mid.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/reversed/c_low.parquet]]}, projection=[id, value], file_type=parquet + +# Test 6.2: Results must be correct +query II +SELECT * FROM reversed_with_order_parquet ORDER BY id ASC; +---- +1 100 +2 200 +3 300 +4 400 +5 500 +6 600 +7 700 +8 800 +9 900 + +# Test 7: Correctly ordered multi-file single group (positive case) +# Files are in CORRECT inter-file order within a single group. +# The validation should PASS and SortExec should be eliminated. + +statement ok +CREATE TABLE correct_low(id INT, value INT) AS VALUES (1, 100), (2, 200), (3, 300); + +statement ok +CREATE TABLE correct_mid(id INT, value INT) AS VALUES (4, 400), (5, 500), (6, 600); + +statement ok +CREATE TABLE correct_high(id INT, value INT) AS VALUES (7, 700), (8, 800), (9, 900); + +query I +COPY (SELECT * FROM correct_low ORDER BY id ASC) +TO 'test_files/scratch/sort_pushdown/correct/a_low.parquet'; +---- +3 + +query I +COPY (SELECT * FROM correct_mid ORDER BY id ASC) +TO 'test_files/scratch/sort_pushdown/correct/b_mid.parquet'; +---- +3 + +query I +COPY (SELECT * FROM correct_high ORDER BY id ASC) +TO 'test_files/scratch/sort_pushdown/correct/c_high.parquet'; +---- +3 + +statement ok +CREATE EXTERNAL TABLE correct_parquet(id INT, value INT) +STORED AS PARQUET +LOCATION 'test_files/scratch/sort_pushdown/correct/' +WITH ORDER (id ASC); + +# Test 7.1: SortExec should be ELIMINATED — files are in correct inter-file order +query TT +EXPLAIN SELECT * FROM correct_parquet ORDER BY id ASC; +---- +logical_plan +01)Sort: correct_parquet.id ASC NULLS LAST +02)--TableScan: correct_parquet projection=[id, value] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/correct/a_low.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/correct/b_mid.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/correct/c_high.parquet]]}, projection=[id, value], output_ordering=[id@0 ASC NULLS LAST], file_type=parquet + +# Test 7.2: Results must be correct +query II +SELECT * FROM correct_parquet ORDER BY id ASC; +---- +1 100 +2 200 +3 300 +4 400 +5 500 +6 600 +7 700 +8 800 +9 900 + +# Test 7.3: DESC query on correctly ordered ASC files should still use SortExec +# Note: reverse_row_groups=true reverses the file list in the plan display +query TT +EXPLAIN SELECT * FROM correct_parquet ORDER BY id DESC; +---- +logical_plan +01)Sort: correct_parquet.id DESC NULLS FIRST +02)--TableScan: correct_parquet projection=[id, value] +physical_plan +01)SortExec: expr=[id@0 DESC], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/correct/c_high.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/correct/b_mid.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/correct/a_low.parquet]]}, projection=[id, value], file_type=parquet, reverse_row_groups=true + +query II +SELECT * FROM correct_parquet ORDER BY id DESC; +---- +9 900 +8 800 +7 700 +6 600 +5 500 +4 400 +3 300 +2 200 +1 100 + +# Test 8: DESC ordering with files in wrong inter-file DESC order +# Create files internally sorted by id DESC, but named so filesystem order +# is WRONG for DESC ordering (low values first in filesystem order). + +statement ok +CREATE TABLE desc_low(id INT, value INT) AS VALUES (3, 300), (2, 200), (1, 100); + +statement ok +CREATE TABLE desc_high(id INT, value INT) AS VALUES (9, 900), (8, 800), (7, 700); + +query I +COPY (SELECT * FROM desc_low ORDER BY id DESC) +TO 'test_files/scratch/sort_pushdown/desc_reversed/a_low.parquet'; +---- +3 + +query I +COPY (SELECT * FROM desc_high ORDER BY id DESC) +TO 'test_files/scratch/sort_pushdown/desc_reversed/b_high.parquet'; +---- +3 + +statement ok +CREATE EXTERNAL TABLE desc_reversed_parquet(id INT, value INT) +STORED AS PARQUET +LOCATION 'test_files/scratch/sort_pushdown/desc_reversed/' +WITH ORDER (id DESC); + +# Test 8.1: SortExec must be present — files are in wrong inter-file DESC order +# (a_low has 1-3, b_high has 7-9; for DESC, b_high should come first) +query TT +EXPLAIN SELECT * FROM desc_reversed_parquet ORDER BY id DESC; +---- +logical_plan +01)Sort: desc_reversed_parquet.id DESC NULLS FIRST +02)--TableScan: desc_reversed_parquet projection=[id, value] +physical_plan +01)SortExec: expr=[id@0 DESC], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/desc_reversed/a_low.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/desc_reversed/b_high.parquet]]}, projection=[id, value], file_type=parquet + +# Test 8.2: Results must be correct +query II +SELECT * FROM desc_reversed_parquet ORDER BY id DESC; +---- +9 900 +8 800 +7 700 +3 300 +2 200 +1 100 + +# Test 9: Multi-column sort key validation +# Files have (category, id) ordering. Files share a boundary value on category='B' +# so column-level min/max statistics overlap on the primary key column. +# The validation conservatively rejects this because column-level stats can't +# precisely represent row-level boundaries for multi-column keys. + +statement ok +CREATE TABLE multi_col_a(category VARCHAR, id INT, value INT) AS VALUES +('A', 1, 10), ('A', 2, 20), ('B', 1, 30); + +statement ok +CREATE TABLE multi_col_b(category VARCHAR, id INT, value INT) AS VALUES +('B', 2, 40), ('C', 1, 50), ('C', 2, 60); + +query I +COPY (SELECT * FROM multi_col_a ORDER BY category ASC, id ASC) +TO 'test_files/scratch/sort_pushdown/multi_col/a_first.parquet'; +---- +3 + +query I +COPY (SELECT * FROM multi_col_b ORDER BY category ASC, id ASC) +TO 'test_files/scratch/sort_pushdown/multi_col/b_second.parquet'; +---- +3 + +statement ok +CREATE EXTERNAL TABLE multi_col_parquet(category VARCHAR, id INT, value INT) +STORED AS PARQUET +LOCATION 'test_files/scratch/sort_pushdown/multi_col/' +WITH ORDER (category ASC, id ASC); + +# Test 9.1: SortExec is present — validation conservatively rejects because +# column-level stats overlap on category='B' across both files +query TT +EXPLAIN SELECT * FROM multi_col_parquet ORDER BY category ASC, id ASC; +---- +logical_plan +01)Sort: multi_col_parquet.category ASC NULLS LAST, multi_col_parquet.id ASC NULLS LAST +02)--TableScan: multi_col_parquet projection=[category, id, value] +physical_plan +01)SortExec: expr=[category@0 ASC NULLS LAST, id@1 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/multi_col/a_first.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/multi_col/b_second.parquet]]}, projection=[category, id, value], file_type=parquet + +# Test 9.2: Results must be correct +query TII +SELECT * FROM multi_col_parquet ORDER BY category ASC, id ASC; +---- +A 1 10 +A 2 20 +B 1 30 +B 2 40 +C 1 50 +C 2 60 + +# Test 9.3: Multi-column sort with non-overlapping primary key across files +# When files don't overlap on the primary column, validation succeeds. + +statement ok +CREATE TABLE multi_col_x(category VARCHAR, id INT, value INT) AS VALUES +('A', 1, 10), ('A', 2, 20); + +statement ok +CREATE TABLE multi_col_y(category VARCHAR, id INT, value INT) AS VALUES +('B', 1, 30), ('B', 2, 40); + +query I +COPY (SELECT * FROM multi_col_x ORDER BY category ASC, id ASC) +TO 'test_files/scratch/sort_pushdown/multi_col_clean/x_first.parquet'; +---- +2 + +query I +COPY (SELECT * FROM multi_col_y ORDER BY category ASC, id ASC) +TO 'test_files/scratch/sort_pushdown/multi_col_clean/y_second.parquet'; +---- +2 + +statement ok +CREATE EXTERNAL TABLE multi_col_clean_parquet(category VARCHAR, id INT, value INT) +STORED AS PARQUET +LOCATION 'test_files/scratch/sort_pushdown/multi_col_clean/' +WITH ORDER (category ASC, id ASC); + +# Test 9.3a: SortExec should be eliminated — non-overlapping primary column +query TT +EXPLAIN SELECT * FROM multi_col_clean_parquet ORDER BY category ASC, id ASC; +---- +logical_plan +01)Sort: multi_col_clean_parquet.category ASC NULLS LAST, multi_col_clean_parquet.id ASC NULLS LAST +02)--TableScan: multi_col_clean_parquet projection=[category, id, value] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/multi_col_clean/x_first.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/multi_col_clean/y_second.parquet]]}, projection=[category, id, value], output_ordering=[category@0 ASC NULLS LAST, id@1 ASC NULLS LAST], file_type=parquet + +# Test 9.3b: Results must be correct +query TII +SELECT * FROM multi_col_clean_parquet ORDER BY category ASC, id ASC; +---- +A 1 10 +A 2 20 +B 1 30 +B 2 40 + +# Test 10: Correctly ordered files WITH ORDER (positive counterpart to Test 6) +# Files in correct_parquet are in correct ASC order — WITH ORDER should pass validation +# and SortExec should be eliminated. + +statement ok +CREATE EXTERNAL TABLE correct_with_order_parquet(id INT, value INT) +STORED AS PARQUET +LOCATION 'test_files/scratch/sort_pushdown/correct/' +WITH ORDER (id ASC); + +# Test 10.1: SortExec should be ELIMINATED — files are in correct order +query TT +EXPLAIN SELECT * FROM correct_with_order_parquet ORDER BY id ASC; +---- +logical_plan +01)Sort: correct_with_order_parquet.id ASC NULLS LAST +02)--TableScan: correct_with_order_parquet projection=[id, value] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/correct/a_low.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/correct/b_mid.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/correct/c_high.parquet]]}, projection=[id, value], output_ordering=[id@0 ASC NULLS LAST], file_type=parquet + +# Test 10.2: Results must be correct +query II +SELECT * FROM correct_with_order_parquet ORDER BY id ASC; +---- +1 100 +2 200 +3 300 +4 400 +5 500 +6 600 +7 700 +8 800 +9 900 + +# Test 11: Multiple file groups (target_partitions > 1) — each group has one file +# When files are spread across separate partitions (one file per group), each +# partition is trivially sorted and SortPreservingMergeExec handles the merge. + +# Restore higher target_partitions so files go into separate groups +statement ok +SET datafusion.execution.target_partitions = 4; + +statement ok +CREATE EXTERNAL TABLE multi_partition_parquet(id INT, value INT) +STORED AS PARQUET +LOCATION 'test_files/scratch/sort_pushdown/reversed/' +WITH ORDER (id ASC); + +# Test 11.1: With separate partitions, each file is trivially sorted. +# SortPreservingMergeExec merges, no SortExec needed per-partition. +query TT +EXPLAIN SELECT * FROM multi_partition_parquet ORDER BY id ASC; +---- +logical_plan +01)Sort: multi_partition_parquet.id ASC NULLS LAST +02)--TableScan: multi_partition_parquet projection=[id, value] +physical_plan +01)SortPreservingMergeExec: [id@0 ASC NULLS LAST] +02)--DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/reversed/a_high.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/reversed/b_mid.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/reversed/c_low.parquet]]}, projection=[id, value], output_ordering=[id@0 ASC NULLS LAST], file_type=parquet + +# Test 11.2: Results must be correct +query II +SELECT * FROM multi_partition_parquet ORDER BY id ASC; +---- +1 100 +2 200 +3 300 +4 400 +5 500 +6 600 +7 700 +8 800 +9 900 + +# Restore target_partitions to 1 for remaining cleanup +statement ok +SET datafusion.execution.target_partitions = 2; + # Cleanup +statement ok +DROP TABLE reversed_high; + +statement ok +DROP TABLE reversed_mid; + +statement ok +DROP TABLE reversed_low; + +statement ok +DROP TABLE reversed_parquet; + +statement ok +DROP TABLE overlap_x; + +statement ok +DROP TABLE overlap_y; + +statement ok +DROP TABLE overlap_parquet; + +statement ok +DROP TABLE reversed_with_order_parquet; + +statement ok +DROP TABLE correct_low; + +statement ok +DROP TABLE correct_mid; + +statement ok +DROP TABLE correct_high; + +statement ok +DROP TABLE correct_parquet; + +statement ok +DROP TABLE desc_low; + +statement ok +DROP TABLE desc_high; + +statement ok +DROP TABLE desc_reversed_parquet; + +statement ok +DROP TABLE multi_col_a; + +statement ok +DROP TABLE multi_col_b; + +statement ok +DROP TABLE multi_col_parquet; + +statement ok +DROP TABLE multi_col_x; + +statement ok +DROP TABLE multi_col_y; + +statement ok +DROP TABLE multi_col_clean_parquet; + +statement ok +DROP TABLE correct_with_order_parquet; + +statement ok +DROP TABLE multi_partition_parquet; + statement ok DROP TABLE timestamp_data; @@ -882,5 +1624,11 @@ DROP TABLE signed_data; statement ok DROP TABLE signed_parquet; +statement ok +DROP TABLE agg_expr_data; + +statement ok +DROP TABLE agg_expr_parquet; + statement ok SET datafusion.optimizer.enable_sort_pushdown = true; diff --git a/datafusion/sqllogictest/test_files/spark/README.md b/datafusion/sqllogictest/test_files/spark/README.md index cffd28009889d..e61001c6e42e5 100644 --- a/datafusion/sqllogictest/test_files/spark/README.md +++ b/datafusion/sqllogictest/test_files/spark/README.md @@ -39,6 +39,18 @@ When testing Spark functions: - Test cases should only contain `SELECT` statements with the function being tested - Add explicit casts to input values to ensure the correct data type is used (e.g., `0::INT`) - Explicit casting is necessary because DataFusion and Spark do not infer data types in the same way +- If the Spark built-in function under test behaves differently in ANSI SQL mode, please wrap your test cases like this example: + +```sql +statement ok +set datafusion.execution.enable_ansi_mode = true; + +# Functions under test +select abs((-128)::TINYINT) + +statement ok +set datafusion.execution.enable_ansi_mode = false; +``` ### Finding Test Cases diff --git a/datafusion/sqllogictest/test_files/spark/array/slice.slt b/datafusion/sqllogictest/test_files/spark/array/slice.slt new file mode 100644 index 0000000000000..21f321033bcb2 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/array/slice.slt @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query ? +SELECT slice([], 2, 2); +---- +[] + +query ? +SELECT slice([1, 2, 3, 4], 2, 2); +---- +[2, 3] + +query ? +SELECT slice([1, 2, 3, 4], 1, 100); +---- +[1, 2, 3, 4] + +query ? +SELECT slice([1, 2, 3, 4], -2, 2); +---- +[3, 4] + +query ? +SELECT slice([1, 2, 3, 4], 100, 2); +---- +[] + +query ? +SELECT slice([1, 2, 3, 4], -200, 2); +---- +[] + +query error DataFusion error: Execution error: Length must be non-negative, but got -2 +SELECT slice([1, 2, 3, 4], 2, -2); + +query error DataFusion error: Execution error: Length must be non-negative, but got -2 +SELECT slice([1, 2, 3, 4], -2, -2); + +query error DataFusion error: Execution error: Start index must not be zero +SELECT slice([1, 2, 3, 4], 0, -2); + +query ? +SELECT slice([NULL, NULL, NULL, NULL, NULL], 2, 2); +---- +[NULL, NULL] + +query ? +SELECT slice(arrow_cast(NULL, 'FixedSizeList(1, Int64)'), 2, 2); +---- +NULL + +query ? +SELECT slice([1, 2, 3, 4], NULL, 2); +---- +NULL + +query ? +SELECT slice([1, 2, 3, 4], 2, NULL); +---- +NULL + + +query ? +SELECT slice(column1, column2, column3) +FROM VALUES +([1, 2, 3, 4], 2, 2), +([1, 2, 3, 4], 1, 100), +([1, 2, 3, 4], -2, 2), +([], 2, 2), +([1, 2, 3, 4], 100, 2), +([1, 2, 3, 4], -200, 2), +([NULL, NULL, NULL, NULL, NULL], 2, 2), +(arrow_cast(NULL, 'FixedSizeList(1, Int64)'), 2, 2), +([1, 2, 3, 4], NULL, 2), +([1, 2, 3, 4], 2, NULL); +---- +[2, 3] +[1, 2, 3, 4] +[3, 4] +[] +[] +[] +[NULL, NULL] +NULL +NULL +NULL diff --git a/datafusion/sqllogictest/test_files/spark/bitmap/bitmap_bit_position.slt b/datafusion/sqllogictest/test_files/spark/bitmap/bitmap_bit_position.slt new file mode 100644 index 0000000000000..4af3193a5db31 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/bitmap/bitmap_bit_position.slt @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +query I +SELECT bitmap_bit_position(arrow_cast(1, 'Int8')); +---- +0 + +query I +SELECT bitmap_bit_position(arrow_cast(3, 'Int8')); +---- +2 + +query I +SELECT bitmap_bit_position(arrow_cast(7, 'Int8')); +---- +6 + +query I +SELECT bitmap_bit_position(arrow_cast(15, 'Int8')); +---- +14 + +query I +SELECT bitmap_bit_position(arrow_cast(-1, 'Int8')); +---- +1 + +query I +SELECT bitmap_bit_position(arrow_cast(256, 'Int16')); +---- +255 + +query I +SELECT bitmap_bit_position(arrow_cast(1024, 'Int16')); +---- +1023 + +query I +SELECT bitmap_bit_position(arrow_cast(-32768, 'Int16')); +---- +0 + +query I +SELECT bitmap_bit_position(arrow_cast(16384, 'Int16')); +---- +16383 + +query I +SELECT bitmap_bit_position(arrow_cast(-1, 'Int16')); +---- +1 + +query I +SELECT bitmap_bit_position(arrow_cast(65536, 'Int32')); +---- +32767 + +query I +SELECT bitmap_bit_position(arrow_cast(1048576, 'Int32')); +---- +32767 + +query I +SELECT bitmap_bit_position(arrow_cast(-2147483648, 'Int32')); +---- +0 + +query I +SELECT bitmap_bit_position(arrow_cast(1073741824, 'Int32')); +---- +32767 + +query I +SELECT bitmap_bit_position(arrow_cast(-1, 'Int32')); +---- +1 + +query I +SELECT bitmap_bit_position(arrow_cast(4294967296, 'Int64')); +---- +32767 + +query I +SELECT bitmap_bit_position(arrow_cast(-1, 'Int64')); +---- +1 + +query I +SELECT bitmap_bit_position(arrow_cast(-9223372036854775808, 'Int64')); +---- +0 + +query I +SELECT bitmap_bit_position(arrow_cast(9223372036854775807, 'Int64')); +---- +32766 diff --git a/datafusion/sqllogictest/test_files/spark/collection/size.slt b/datafusion/sqllogictest/test_files/spark/collection/size.slt new file mode 100644 index 0000000000000..106760eebfe42 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/collection/size.slt @@ -0,0 +1,131 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT size(array(1, 2, 3)); +## PySpark 3.5.5 Result: {'size(array(1, 2, 3))': 3} + +# Basic array +query I +SELECT size(make_array(1, 2, 3)); +---- +3 + +# Nested array +query I +SELECT size(make_array(make_array(1, 2), make_array(3, 4, 5))); +---- +2 + +# LargeList tests +query I +SELECT size(arrow_cast(make_array(1, 2, 3), 'LargeList(Int32)')); +---- +3 + +query I +SELECT size(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')); +---- +5 + +# FixedSizeList tests +query I +SELECT size(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int32)')); +---- +3 + +query I +SELECT size(arrow_cast(make_array(1, 2, 3, 4), 'FixedSizeList(4, Int32)')); +---- +4 + +# Map size tests +query I +SELECT size(map(make_array('a', 'b', 'c'), make_array(1, 2, 3))); +---- +3 + +query I +SELECT size(map(make_array('a'), make_array(1))); +---- +1 + +# Empty array +query I +SELECT size(arrow_cast(make_array(), 'List(Int32)')); +---- +0 + + +# Array with NULL elements (size counts elements including NULLs) +query I +SELECT size(make_array(1, NULL, 3)); +---- +3 + +# NULL array returns -1 (Spark behavior) +query I +SELECT size(NULL::int[]); +---- +-1 + + +# Empty map +query I +SELECT size(map(arrow_cast(make_array(), 'List(Utf8)'), arrow_cast(make_array(), 'List(Int32)'))); +---- +0 + +# String array +query I +SELECT size(make_array('hello', 'world')); +---- +2 + +# Boolean array +query I +SELECT size(make_array(true, false, true)); +---- +3 + +# Float array +query I +SELECT size(make_array(1.5, 2.5, 3.5, 4.5)); +---- +4 + +# Array column tests (with NULL values) +query I +SELECT size(column1) FROM VALUES ([1]), ([1,2]), ([]), (NULL); +---- +1 +2 +0 +-1 + +# Map column tests (with NULL values) +query I +SELECT size(column1) FROM VALUES (map(['a'], [1])), (map(['a','b'], [1,2])), (NULL); +---- +1 +2 +-1 diff --git a/datafusion/sqllogictest/test_files/spark/datetime/add_months.slt b/datafusion/sqllogictest/test_files/spark/datetime/add_months.slt index cae9b21dd4766..55a493ffefe26 100644 --- a/datafusion/sqllogictest/test_files/spark/datetime/add_months.slt +++ b/datafusion/sqllogictest/test_files/spark/datetime/add_months.slt @@ -15,13 +15,45 @@ # specific language governing permissions and limitations # under the License. -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT add_months('2016-08-31', 1); -## PySpark 3.5.5 Result: {'add_months(2016-08-31, 1)': datetime.date(2016, 9, 30), 'typeof(add_months(2016-08-31, 1))': 'date', 'typeof(2016-08-31)': 'string', 'typeof(1)': 'int'} -#query -#SELECT add_months('2016-08-31'::string, 1::int); +query D +SELECT add_months('2016-07-30'::date, 1::int); +---- +2016-08-30 + +query D +SELECT add_months('2016-07-30'::date, 0::int); +---- +2016-07-30 + +query D +SELECT add_months('2016-07-30'::date, 10000::int); +---- +2849-11-30 + +# Test integer overflow +# TODO: Enable with next arrow upgrade (>=58.0.0) +# query D +# SELECT add_months('2016-07-30'::date, 2147483647::int); +# ---- +# NULL + +query D +SELECT add_months('2016-07-30'::date, -5::int); +---- +2016-02-29 + +# Test with NULL values +query D +SELECT add_months(NULL::date, 1::int); +---- +NULL + +query D +SELECT add_months('2016-07-30'::date, NULL::int); +---- +NULL + +query D +SELECT add_months(NULL::date, NULL::int); +---- +NULL diff --git a/datafusion/sqllogictest/test_files/spark/datetime/date_add.slt b/datafusion/sqllogictest/test_files/spark/datetime/date_add.slt index a2ac7cf2edb11..cb407a6453696 100644 --- a/datafusion/sqllogictest/test_files/spark/datetime/date_add.slt +++ b/datafusion/sqllogictest/test_files/spark/datetime/date_add.slt @@ -41,7 +41,7 @@ SELECT date_add('2016-07-30'::date, arrow_cast(1, 'Int8')); 2016-07-31 query D -SELECT date_sub('2016-07-30'::date, 0::int); +SELECT date_add('2016-07-30'::date, 0::int); ---- 2016-07-30 @@ -51,20 +51,15 @@ SELECT date_add('2016-07-30'::date, 2147483647::int)::int; -2147466637 query I -SELECT date_sub('1969-01-01'::date, 2147483647::int)::int; +SELECT date_add('1969-01-01'::date, 2147483647::int)::int; ---- -2147483284 +2147483282 query D SELECT date_add('2016-07-30'::date, 100000::int); ---- 2290-05-15 -query D -SELECT date_sub('2016-07-30'::date, 100000::int); ----- -1742-10-15 - # Test with negative day values (should subtract days) query D SELECT date_add('2016-07-30'::date, -5::int); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/date_diff.slt b/datafusion/sqllogictest/test_files/spark/datetime/date_diff.slt index c5871ab41e183..b0952d6a43510 100644 --- a/datafusion/sqllogictest/test_files/spark/datetime/date_diff.slt +++ b/datafusion/sqllogictest/test_files/spark/datetime/date_diff.slt @@ -15,18 +15,138 @@ # specific language governing permissions and limitations # under the License. -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT date_diff('2009-07-30', '2009-07-31'); -## PySpark 3.5.5 Result: {'date_diff(2009-07-30, 2009-07-31)': -1, 'typeof(date_diff(2009-07-30, 2009-07-31))': 'int', 'typeof(2009-07-30)': 'string', 'typeof(2009-07-31)': 'string'} -#query -#SELECT date_diff('2009-07-30'::string, '2009-07-31'::string); - -## Original Query: SELECT date_diff('2009-07-31', '2009-07-30'); -## PySpark 3.5.5 Result: {'date_diff(2009-07-31, 2009-07-30)': 1, 'typeof(date_diff(2009-07-31, 2009-07-30))': 'int', 'typeof(2009-07-31)': 'string', 'typeof(2009-07-30)': 'string'} -#query -#SELECT date_diff('2009-07-31'::string, '2009-07-30'::string); +# date input +query I +SELECT date_diff('2009-07-30'::date, '2009-07-31'::date); +---- +-1 + +query I +SELECT date_diff('2009-07-31'::date, '2009-07-30'::date); +---- +1 + +query I +SELECT date_diff('2009-07-31'::string, '2009-07-30'::date); +---- +1 + +query I +SELECT date_diff('2009-07-31'::timestamp, '2009-07-30'::date); +---- +1 + +# Date64 input +query I +SELECT date_diff(arrow_cast('2009-07-31', 'Date64'), arrow_cast('2009-07-30', 'Date64')); +---- +1 + +query I +SELECT date_diff(arrow_cast('2009-07-30', 'Date64'), arrow_cast('2009-07-31', 'Date64')); +---- +-1 + +# Mixed Date32 and Date64 input +query I +SELECT date_diff('2009-07-31'::date, arrow_cast('2009-07-30', 'Date64')); +---- +1 + +query I +SELECT date_diff(arrow_cast('2009-07-31', 'Date64'), '2009-07-30'::date); +---- +1 + + +# Same date returns 0 +query I +SELECT date_diff('2009-07-30'::date, '2009-07-30'::date); +---- +0 + +# Large difference +query I +SELECT date_diff('2020-01-01'::date, '1970-01-01'::date); +---- +18262 + +# timestamp input +query I +SELECT date_diff('2009-07-30 12:34:56'::timestamp, '2009-07-31 23:45:01'::timestamp); +---- +-1 + +query I +SELECT date_diff('2009-07-31 23:45:01'::timestamp, '2009-07-30 12:34:56'::timestamp); +---- +1 + +query I +SELECT date_diff('2009-07-31 23:45:01'::string, '2009-07-30 12:34:56'::timestamp); +---- +1 + +# string input +query I +SELECT date_diff('2009-07-30', '2009-07-31'); +---- +-1 + +query I +SELECT date_diff('2009-07-31', '2009-07-30'); +---- +1 + +# NULL handling +query I +SELECT date_diff(NULL::date, '2009-07-30'::date); +---- +NULL + +query I +SELECT date_diff('2009-07-31'::date, NULL::date); +---- +NULL + +query I +SELECT date_diff(NULL::date, NULL::date); +---- +NULL + +query I +SELECT date_diff(column1, column2) +FROM VALUES +('2009-07-30'::date, '2009-07-31'::date), +('2009-07-31'::date, '2009-07-30'::date), +(NULL::date, '2009-07-30'::date), +('2009-07-31'::date, NULL::date), +(NULL::date, NULL::date); +---- +-1 +1 +NULL +NULL +NULL + + +# Alias datediff +query I +SELECT datediff('2009-07-30'::date, '2009-07-31'::date); +---- +-1 + +query I +SELECT datediff(column1, column2) +FROM VALUES +('2009-07-30'::date, '2009-07-31'::date), +('2009-07-31'::date, '2009-07-30'::date), +(NULL::date, '2009-07-30'::date), +('2009-07-31'::date, NULL::date), +(NULL::date, NULL::date); +---- +-1 +1 +NULL +NULL +NULL diff --git a/datafusion/sqllogictest/test_files/spark/datetime/date_part.slt b/datafusion/sqllogictest/test_files/spark/datetime/date_part.slt index cd3271cdc7df8..48216bd551692 100644 --- a/datafusion/sqllogictest/test_files/spark/datetime/date_part.slt +++ b/datafusion/sqllogictest/test_files/spark/datetime/date_part.slt @@ -15,48 +15,262 @@ # specific language governing permissions and limitations # under the License. -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT date_part('MINUTE', INTERVAL '123 23:55:59.002001' DAY TO SECOND); -## PySpark 3.5.5 Result: {"date_part(MINUTE, INTERVAL '123 23:55:59.002001' DAY TO SECOND)": 55, "typeof(date_part(MINUTE, INTERVAL '123 23:55:59.002001' DAY TO SECOND))": 'tinyint', 'typeof(MINUTE)': 'string', "typeof(INTERVAL '123 23:55:59.002001' DAY TO SECOND)": 'interval day to second'} -#query -#SELECT date_part('MINUTE'::string, INTERVAL '123 23:55:59.002001' DAY TO SECOND::interval day to second); - -## Original Query: SELECT date_part('MONTH', INTERVAL '2021-11' YEAR TO MONTH); -## PySpark 3.5.5 Result: {"date_part(MONTH, INTERVAL '2021-11' YEAR TO MONTH)": 11, "typeof(date_part(MONTH, INTERVAL '2021-11' YEAR TO MONTH))": 'tinyint', 'typeof(MONTH)': 'string', "typeof(INTERVAL '2021-11' YEAR TO MONTH)": 'interval year to month'} -#query -#SELECT date_part('MONTH'::string, INTERVAL '2021-11' YEAR TO MONTH::interval year to month); - -## Original Query: SELECT date_part('SECONDS', timestamp'2019-10-01 00:00:01.000001'); -## PySpark 3.5.5 Result: {"date_part(SECONDS, TIMESTAMP '2019-10-01 00:00:01.000001')": Decimal('1.000001'), "typeof(date_part(SECONDS, TIMESTAMP '2019-10-01 00:00:01.000001'))": 'decimal(8,6)', 'typeof(SECONDS)': 'string', "typeof(TIMESTAMP '2019-10-01 00:00:01.000001')": 'timestamp'} -#query -#SELECT date_part('SECONDS'::string, TIMESTAMP '2019-10-01 00:00:01.000001'::timestamp); - -## Original Query: SELECT date_part('YEAR', TIMESTAMP '2019-08-12 01:00:00.123456'); -## PySpark 3.5.5 Result: {"date_part(YEAR, TIMESTAMP '2019-08-12 01:00:00.123456')": 2019, "typeof(date_part(YEAR, TIMESTAMP '2019-08-12 01:00:00.123456'))": 'int', 'typeof(YEAR)': 'string', "typeof(TIMESTAMP '2019-08-12 01:00:00.123456')": 'timestamp'} -#query -#SELECT date_part('YEAR'::string, TIMESTAMP '2019-08-12 01:00:00.123456'::timestamp); - -## Original Query: SELECT date_part('days', interval 5 days 3 hours 7 minutes); -## PySpark 3.5.5 Result: {"date_part(days, INTERVAL '5 03:07' DAY TO MINUTE)": 5, "typeof(date_part(days, INTERVAL '5 03:07' DAY TO MINUTE))": 'int', 'typeof(days)': 'string', "typeof(INTERVAL '5 03:07' DAY TO MINUTE)": 'interval day to minute'} -#query -#SELECT date_part('days'::string, INTERVAL '5 03:07' DAY TO MINUTE::interval day to minute); - -## Original Query: SELECT date_part('doy', DATE'2019-08-12'); -## PySpark 3.5.5 Result: {"date_part(doy, DATE '2019-08-12')": 224, "typeof(date_part(doy, DATE '2019-08-12'))": 'int', 'typeof(doy)': 'string', "typeof(DATE '2019-08-12')": 'date'} -#query -#SELECT date_part('doy'::string, DATE '2019-08-12'::date); - -## Original Query: SELECT date_part('seconds', interval 5 hours 30 seconds 1 milliseconds 1 microseconds); -## PySpark 3.5.5 Result: {"date_part(seconds, INTERVAL '05:00:30.001001' HOUR TO SECOND)": Decimal('30.001001'), "typeof(date_part(seconds, INTERVAL '05:00:30.001001' HOUR TO SECOND))": 'decimal(8,6)', 'typeof(seconds)': 'string', "typeof(INTERVAL '05:00:30.001001' HOUR TO SECOND)": 'interval hour to second'} -#query -#SELECT date_part('seconds'::string, INTERVAL '05:00:30.001001' HOUR TO SECOND::interval hour to second); - -## Original Query: SELECT date_part('week', timestamp'2019-08-12 01:00:00.123456'); -## PySpark 3.5.5 Result: {"date_part(week, TIMESTAMP '2019-08-12 01:00:00.123456')": 33, "typeof(date_part(week, TIMESTAMP '2019-08-12 01:00:00.123456'))": 'int', 'typeof(week)': 'string', "typeof(TIMESTAMP '2019-08-12 01:00:00.123456')": 'timestamp'} -#query -#SELECT date_part('week'::string, TIMESTAMP '2019-08-12 01:00:00.123456'::timestamp); +# YEAR +query I +SELECT date_part('YEAR'::string, '2000-01-01'::date); +---- +2000 + +query I +SELECT date_part('YEARS'::string, '2000-01-01'::date); +---- +2000 + +query I +SELECT date_part('Y'::string, '2000-01-01'::date); +---- +2000 + +query I +SELECT date_part('YR'::string, '2000-01-01'::date); +---- +2000 + +query I +SELECT date_part('YRS'::string, '2000-01-01'::date); +---- +2000 + +# YEAROFWEEK +query I +SELECT date_part('YEAROFWEEK'::string, '2000-01-01'::date); +---- +1999 + +# QUARTER +query I +SELECT date_part('QUARTER'::string, '2000-01-01'::date); +---- +1 + +query I +SELECT date_part('QTR'::string, '2000-01-01'::date); +---- +1 + +# MONTH +query I +SELECT date_part('MONTH'::string, '2000-01-01'::date); +---- +1 + +query I +SELECT date_part('MON'::string, '2000-01-01'::date); +---- +1 + +query I +SELECT date_part('MONS'::string, '2000-01-01'::date); +---- +1 + +query I +SELECT date_part('MONTHS'::string, '2000-01-01'::date); +---- +1 + +# WEEK +query I +SELECT date_part('WEEK'::string, '2000-01-01'::date); +---- +52 + +query I +SELECT date_part('WEEKS'::string, '2000-01-01'::date); +---- +52 + +query I +SELECT date_part('W'::string, '2000-01-01'::date); +---- +52 + +# DAYS +query I +SELECT date_part('DAY'::string, '2000-01-01'::date); +---- +1 + +query I +SELECT date_part('D'::string, '2000-01-01'::date); +---- +1 + +query I +SELECT date_part('DAYS'::string, '2000-01-01'::date); +---- +1 + +# DAYOFWEEK +query I +SELECT date_part('DAYOFWEEK'::string, '2000-01-01'::date); +---- +7 + +query I +SELECT date_part('DOW'::string, '2000-01-01'::date); +---- +7 + +# DAYOFWEEK_ISO +query I +SELECT date_part('DAYOFWEEK_ISO'::string, '2000-01-01'::date); +---- +6 + +query I +SELECT date_part('DOW_ISO'::string, '2000-01-01'::date); +---- +6 + +# DOY +query I +SELECT date_part('DOY'::string, '2000-01-01'::date); +---- +1 + +# HOUR +query I +SELECT date_part('HOUR'::string, '2000-01-01 12:30:45'::timestamp); +---- +12 + +query I +SELECT date_part('H'::string, '2000-01-01 12:30:45'::timestamp); +---- +12 + +query I +SELECT date_part('HOURS'::string, '2000-01-01 12:30:45'::timestamp); +---- +12 + +query I +SELECT date_part('HR'::string, '2000-01-01 12:30:45'::timestamp); +---- +12 + +query I +SELECT date_part('HRS'::string, '2000-01-01 12:30:45'::timestamp); +---- +12 + +# MINUTE +query I +SELECT date_part('MINUTE'::string, '2000-01-01 12:30:45'::timestamp); +---- +30 + +query I +SELECT date_part('M'::string, '2000-01-01 12:30:45'::timestamp); +---- +30 + +query I +SELECT date_part('MIN'::string, '2000-01-01 12:30:45'::timestamp); +---- +30 + +query I +SELECT date_part('MINS'::string, '2000-01-01 12:30:45'::timestamp); +---- +30 + +query I +SELECT date_part('MINUTES'::string, '2000-01-01 12:30:45'::timestamp); +---- +30 + +# SECOND +query I +SELECT date_part('SECOND'::string, '2000-01-01 12:30:45'::timestamp); +---- +45 + +query I +SELECT date_part('S'::string, '2000-01-01 12:30:45'::timestamp); +---- +45 + +query I +SELECT date_part('SEC'::string, '2000-01-01 12:30:45'::timestamp); +---- +45 + +query I +SELECT date_part('SECONDS'::string, '2000-01-01 12:30:45'::timestamp); +---- +45 + +query I +SELECT date_part('SECS'::string, '2000-01-01 12:30:45'::timestamp); +---- +45 + +# NULL input +query I +SELECT date_part('year'::string, NULL::timestamp); +---- +NULL + +query error Internal error: First argument of `DATE_PART` must be non-null scalar Utf8 +SELECT date_part(NULL::string, '2000-01-01'::date); + +# Invalid part +query error DataFusion error: Execution error: Date part 'test' not supported +SELECT date_part('test'::string, '2000-01-01'::date); + +query I +SELECT date_part('year', column1) +FROM VALUES +('2022-03-15'::date), +('1999-12-31'::date), +('2000-01-01'::date), +(NULL::date); +---- +2022 +1999 +2000 +NULL + +query I +SELECT date_part('minutes', column1) +FROM VALUES +('2022-03-15 12:30:45'::timestamp), +('1999-12-31 12:32:45'::timestamp), +('2000-01-01 12:00:45'::timestamp), +(NULL::timestamp); +---- +30 +32 +0 +NULL + +# alias datepart +query I +SELECT datepart('YEAR'::string, '2000-01-01'::date); +---- +2000 + +query I +SELECT datepart('year', column1) +FROM VALUES +('2022-03-15'::date), +('1999-12-31'::date), +('2000-01-01'::date), +(NULL::date); +---- +2022 +1999 +2000 +NULL diff --git a/datafusion/sqllogictest/test_files/spark/datetime/date_sub.slt b/datafusion/sqllogictest/test_files/spark/datetime/date_sub.slt index cb5e77c3b4f1e..bf36ebd867d19 100644 --- a/datafusion/sqllogictest/test_files/spark/datetime/date_sub.slt +++ b/datafusion/sqllogictest/test_files/spark/datetime/date_sub.slt @@ -45,6 +45,16 @@ SELECT date_sub('2016-07-30'::date, 0::int); ---- 2016-07-30 +query I +SELECT date_sub('1969-01-01'::date, 2147483647::int)::int; +---- +2147483284 + +query D +SELECT date_sub('2016-07-30'::date, 100000::int); +---- +1742-10-15 + # Test with negative day values (should add days) query D SELECT date_sub('2016-07-30'::date, -1::int); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/date_trunc.slt b/datafusion/sqllogictest/test_files/spark/datetime/date_trunc.slt index 8a15254e6795e..7fc1583bb9310 100644 --- a/datafusion/sqllogictest/test_files/spark/datetime/date_trunc.slt +++ b/datafusion/sqllogictest/test_files/spark/datetime/date_trunc.slt @@ -15,33 +15,150 @@ # specific language governing permissions and limitations # under the License. -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT date_trunc('DD', '2015-03-05T09:32:05.359'); -## PySpark 3.5.5 Result: {'date_trunc(DD, 2015-03-05T09:32:05.359)': datetime.datetime(2015, 3, 5, 0, 0), 'typeof(date_trunc(DD, 2015-03-05T09:32:05.359))': 'timestamp', 'typeof(DD)': 'string', 'typeof(2015-03-05T09:32:05.359)': 'string'} -#query -#SELECT date_trunc('DD'::string, '2015-03-05T09:32:05.359'::string); - -## Original Query: SELECT date_trunc('HOUR', '2015-03-05T09:32:05.359'); -## PySpark 3.5.5 Result: {'date_trunc(HOUR, 2015-03-05T09:32:05.359)': datetime.datetime(2015, 3, 5, 9, 0), 'typeof(date_trunc(HOUR, 2015-03-05T09:32:05.359))': 'timestamp', 'typeof(HOUR)': 'string', 'typeof(2015-03-05T09:32:05.359)': 'string'} -#query -#SELECT date_trunc('HOUR'::string, '2015-03-05T09:32:05.359'::string); - -## Original Query: SELECT date_trunc('MILLISECOND', '2015-03-05T09:32:05.123456'); -## PySpark 3.5.5 Result: {'date_trunc(MILLISECOND, 2015-03-05T09:32:05.123456)': datetime.datetime(2015, 3, 5, 9, 32, 5, 123000), 'typeof(date_trunc(MILLISECOND, 2015-03-05T09:32:05.123456))': 'timestamp', 'typeof(MILLISECOND)': 'string', 'typeof(2015-03-05T09:32:05.123456)': 'string'} -#query -#SELECT date_trunc('MILLISECOND'::string, '2015-03-05T09:32:05.123456'::string); - -## Original Query: SELECT date_trunc('MM', '2015-03-05T09:32:05.359'); -## PySpark 3.5.5 Result: {'date_trunc(MM, 2015-03-05T09:32:05.359)': datetime.datetime(2015, 3, 1, 0, 0), 'typeof(date_trunc(MM, 2015-03-05T09:32:05.359))': 'timestamp', 'typeof(MM)': 'string', 'typeof(2015-03-05T09:32:05.359)': 'string'} -#query -#SELECT date_trunc('MM'::string, '2015-03-05T09:32:05.359'::string); - -## Original Query: SELECT date_trunc('YEAR', '2015-03-05T09:32:05.359'); -## PySpark 3.5.5 Result: {'date_trunc(YEAR, 2015-03-05T09:32:05.359)': datetime.datetime(2015, 1, 1, 0, 0), 'typeof(date_trunc(YEAR, 2015-03-05T09:32:05.359))': 'timestamp', 'typeof(YEAR)': 'string', 'typeof(2015-03-05T09:32:05.359)': 'string'} -#query -#SELECT date_trunc('YEAR'::string, '2015-03-05T09:32:05.359'::string); +# YEAR - truncate to first date of year, time zeroed +query P +SELECT date_trunc('YEAR', '2015-03-05T09:32:05.123456'::timestamp); +---- +2015-01-01T00:00:00 + +query P +SELECT date_trunc('YYYY', '2015-03-05T09:32:05.123456'::timestamp); +---- +2015-01-01T00:00:00 + +query P +SELECT date_trunc('YY', '2015-03-05T09:32:05.123456'::timestamp); +---- +2015-01-01T00:00:00 + +# QUARTER - truncate to first date of quarter, time zeroed +query P +SELECT date_trunc('QUARTER', '2015-05-05T09:32:05.123456'::timestamp); +---- +2015-04-01T00:00:00 + +# MONTH - truncate to first date of month, time zeroed +query P +SELECT date_trunc('MONTH', '2015-03-05T09:32:05.123456'::timestamp); +---- +2015-03-01T00:00:00 + +query P +SELECT date_trunc('MM', '2015-03-05T09:32:05.123456'::timestamp); +---- +2015-03-01T00:00:00 + +query P +SELECT date_trunc('MON', '2015-03-05T09:32:05.123456'::timestamp); +---- +2015-03-01T00:00:00 + +# WEEK - truncate to Monday of the week, time zeroed +query P +SELECT date_trunc('WEEK', '2015-03-05T09:32:05.123456'::timestamp); +---- +2015-03-02T00:00:00 + +# DAY - zero out time part +query P +SELECT date_trunc('DAY', '2015-03-05T09:32:05.123456'::timestamp); +---- +2015-03-05T00:00:00 + +query P +SELECT date_trunc('DD', '2015-03-05T09:32:05.123456'::timestamp); +---- +2015-03-05T00:00:00 + +# HOUR - zero out minute and second with fraction +query P +SELECT date_trunc('HOUR', '2015-03-05T09:32:05.123456'::timestamp); +---- +2015-03-05T09:00:00 + +# MINUTE - zero out second with fraction +query P +SELECT date_trunc('MINUTE', '2015-03-05T09:32:05.123456'::timestamp); +---- +2015-03-05T09:32:00 + +# SECOND - zero out fraction +query P +SELECT date_trunc('SECOND', '2015-03-05T09:32:05.123456'::timestamp); +---- +2015-03-05T09:32:05 + +# MILLISECOND - zero out microseconds +query P +SELECT date_trunc('MILLISECOND', '2015-03-05T09:32:05.123456'::timestamp); +---- +2015-03-05T09:32:05.123 + +# MICROSECOND - everything remains +query P +SELECT date_trunc('MICROSECOND', '2015-03-05T09:32:05.123456'::timestamp); +---- +2015-03-05T09:32:05.123456 + +query P +SELECT date_trunc('YEAR', column1) +FROM VALUES +('2015-03-05T09:32:05.123456'::timestamp), +('2020-11-15T22:45:30.654321'::timestamp), +('1999-07-20T14:20:10.000001'::timestamp), +(NULL::timestamp); +---- +2015-01-01T00:00:00 +2020-01-01T00:00:00 +1999-01-01T00:00:00 +NULL + +# String input +query P +SELECT date_trunc('YEAR', '2015-03-05T09:32:05.123456'); +---- +2015-01-01T00:00:00 + +# Null handling +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: First argument of `DATE_TRUNC` must be non-null scalar Utf8 +SELECT date_trunc(NULL, '2015-03-05T09:32:05.123456'); + +query P +SELECT date_trunc('YEAR', NULL::timestamp); +---- +NULL + +# incorrect format +query error DataFusion error: Execution error: Unsupported date_trunc granularity: 'test'. Supported values are: microsecond, millisecond, second, minute, hour, day, week, month, quarter, year +SELECT date_trunc('test', '2015-03-05T09:32:05.123456'); + +# Timezone handling - Spark-compatible behavior +# Spark converts timestamps to session timezone before truncating for coarse granularities + +query P +SELECT date_trunc('DAY', arrow_cast(timestamp '2024-07-15T03:30:00', 'Timestamp(Microsecond, Some("UTC"))')); +---- +2024-07-15T00:00:00Z + +query P +SELECT date_trunc('DAY', arrow_cast(timestamp '2024-07-15T03:30:00', 'Timestamp(Microsecond, None)')); +---- +2024-07-15T00:00:00 + +statement ok +SET datafusion.execution.time_zone = 'America/New_York'; + +# This timestamp is 03:30 UTC = 23:30 EDT (previous day) on July 14 +# With session timezone, truncation happens in America/New_York timezone +query P +SELECT date_trunc('DAY', arrow_cast(timestamp '2024-07-15T03:30:00', 'Timestamp(Microsecond, Some("UTC"))')); +---- +2024-07-14T00:00:00Z + +query P +SELECT date_trunc('DAY', arrow_cast(timestamp '2024-07-15T03:30:00', 'Timestamp(Microsecond, None)')); +---- +2024-07-15T00:00:00 + +statement ok +RESET datafusion.execution.time_zone; diff --git a/datafusion/sqllogictest/test_files/spark/datetime/datediff.slt b/datafusion/sqllogictest/test_files/spark/datetime/datediff.slt deleted file mode 100644 index 223e2c313ae86..0000000000000 --- a/datafusion/sqllogictest/test_files/spark/datetime/datediff.slt +++ /dev/null @@ -1,32 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT datediff('2009-07-30', '2009-07-31'); -## PySpark 3.5.5 Result: {'datediff(2009-07-30, 2009-07-31)': -1, 'typeof(datediff(2009-07-30, 2009-07-31))': 'int', 'typeof(2009-07-30)': 'string', 'typeof(2009-07-31)': 'string'} -#query -#SELECT datediff('2009-07-30'::string, '2009-07-31'::string); - -## Original Query: SELECT datediff('2009-07-31', '2009-07-30'); -## PySpark 3.5.5 Result: {'datediff(2009-07-31, 2009-07-30)': 1, 'typeof(datediff(2009-07-31, 2009-07-30))': 'int', 'typeof(2009-07-31)': 'string', 'typeof(2009-07-30)': 'string'} -#query -#SELECT datediff('2009-07-31'::string, '2009-07-30'::string); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/datepart.slt b/datafusion/sqllogictest/test_files/spark/datetime/datepart.slt deleted file mode 100644 index b2dd0089c2823..0000000000000 --- a/datafusion/sqllogictest/test_files/spark/datetime/datepart.slt +++ /dev/null @@ -1,62 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT datepart('MINUTE', INTERVAL '123 23:55:59.002001' DAY TO SECOND); -## PySpark 3.5.5 Result: {"datepart(MINUTE FROM INTERVAL '123 23:55:59.002001' DAY TO SECOND)": 55, "typeof(datepart(MINUTE FROM INTERVAL '123 23:55:59.002001' DAY TO SECOND))": 'tinyint', 'typeof(MINUTE)': 'string', "typeof(INTERVAL '123 23:55:59.002001' DAY TO SECOND)": 'interval day to second'} -#query -#SELECT datepart('MINUTE'::string, INTERVAL '123 23:55:59.002001' DAY TO SECOND::interval day to second); - -## Original Query: SELECT datepart('MONTH', INTERVAL '2021-11' YEAR TO MONTH); -## PySpark 3.5.5 Result: {"datepart(MONTH FROM INTERVAL '2021-11' YEAR TO MONTH)": 11, "typeof(datepart(MONTH FROM INTERVAL '2021-11' YEAR TO MONTH))": 'tinyint', 'typeof(MONTH)': 'string', "typeof(INTERVAL '2021-11' YEAR TO MONTH)": 'interval year to month'} -#query -#SELECT datepart('MONTH'::string, INTERVAL '2021-11' YEAR TO MONTH::interval year to month); - -## Original Query: SELECT datepart('SECONDS', timestamp'2019-10-01 00:00:01.000001'); -## PySpark 3.5.5 Result: {"datepart(SECONDS FROM TIMESTAMP '2019-10-01 00:00:01.000001')": Decimal('1.000001'), "typeof(datepart(SECONDS FROM TIMESTAMP '2019-10-01 00:00:01.000001'))": 'decimal(8,6)', 'typeof(SECONDS)': 'string', "typeof(TIMESTAMP '2019-10-01 00:00:01.000001')": 'timestamp'} -#query -#SELECT datepart('SECONDS'::string, TIMESTAMP '2019-10-01 00:00:01.000001'::timestamp); - -## Original Query: SELECT datepart('YEAR', TIMESTAMP '2019-08-12 01:00:00.123456'); -## PySpark 3.5.5 Result: {"datepart(YEAR FROM TIMESTAMP '2019-08-12 01:00:00.123456')": 2019, "typeof(datepart(YEAR FROM TIMESTAMP '2019-08-12 01:00:00.123456'))": 'int', 'typeof(YEAR)': 'string', "typeof(TIMESTAMP '2019-08-12 01:00:00.123456')": 'timestamp'} -#query -#SELECT datepart('YEAR'::string, TIMESTAMP '2019-08-12 01:00:00.123456'::timestamp); - -## Original Query: SELECT datepart('days', interval 5 days 3 hours 7 minutes); -## PySpark 3.5.5 Result: {"datepart(days FROM INTERVAL '5 03:07' DAY TO MINUTE)": 5, "typeof(datepart(days FROM INTERVAL '5 03:07' DAY TO MINUTE))": 'int', 'typeof(days)': 'string', "typeof(INTERVAL '5 03:07' DAY TO MINUTE)": 'interval day to minute'} -#query -#SELECT datepart('days'::string, INTERVAL '5 03:07' DAY TO MINUTE::interval day to minute); - -## Original Query: SELECT datepart('doy', DATE'2019-08-12'); -## PySpark 3.5.5 Result: {"datepart(doy FROM DATE '2019-08-12')": 224, "typeof(datepart(doy FROM DATE '2019-08-12'))": 'int', 'typeof(doy)': 'string', "typeof(DATE '2019-08-12')": 'date'} -#query -#SELECT datepart('doy'::string, DATE '2019-08-12'::date); - -## Original Query: SELECT datepart('seconds', interval 5 hours 30 seconds 1 milliseconds 1 microseconds); -## PySpark 3.5.5 Result: {"datepart(seconds FROM INTERVAL '05:00:30.001001' HOUR TO SECOND)": Decimal('30.001001'), "typeof(datepart(seconds FROM INTERVAL '05:00:30.001001' HOUR TO SECOND))": 'decimal(8,6)', 'typeof(seconds)': 'string', "typeof(INTERVAL '05:00:30.001001' HOUR TO SECOND)": 'interval hour to second'} -#query -#SELECT datepart('seconds'::string, INTERVAL '05:00:30.001001' HOUR TO SECOND::interval hour to second); - -## Original Query: SELECT datepart('week', timestamp'2019-08-12 01:00:00.123456'); -## PySpark 3.5.5 Result: {"datepart(week FROM TIMESTAMP '2019-08-12 01:00:00.123456')": 33, "typeof(datepart(week FROM TIMESTAMP '2019-08-12 01:00:00.123456'))": 'int', 'typeof(week)': 'string', "typeof(TIMESTAMP '2019-08-12 01:00:00.123456')": 'timestamp'} -#query -#SELECT datepart('week'::string, TIMESTAMP '2019-08-12 01:00:00.123456'::timestamp); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/from_utc_timestamp.slt b/datafusion/sqllogictest/test_files/spark/datetime/from_utc_timestamp.slt new file mode 100644 index 0000000000000..5a39bda0a651b --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/from_utc_timestamp.slt @@ -0,0 +1,156 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# String inputs +query P +SELECT from_utc_timestamp('2016-08-31'::string, 'UTC'::string); +---- +2016-08-31T00:00:00 + +query P +SELECT from_utc_timestamp('2016-08-31'::string, 'Asia/Seoul'::string); +---- +2016-08-31T09:00:00 + +query P +SELECT from_utc_timestamp('2016-08-31'::string, 'America/New_York'::string); +---- +2016-08-30T20:00:00 + +# String inputs with offsets +query P +SELECT from_utc_timestamp('2018-03-13T06:18:23+02:00'::string, 'UTC'::string); +---- +2018-03-13T04:18:23 + +query P +SELECT from_utc_timestamp('2018-03-13T06:18:23+02:00'::string, 'Asia/Seoul'::string); +---- +2018-03-13T13:18:23 + +query P +SELECT from_utc_timestamp('2018-03-13T06:18:23+02:00'::string, 'America/New_York'::string); +---- +2018-03-13T00:18:23 + +# Timestamp inputs +query P +SELECT from_utc_timestamp('2018-03-13T06:18:23+02:00'::timestamp, 'UTC'::string); +---- +2018-03-13T04:18:23 + +query P +SELECT from_utc_timestamp('2018-03-13T06:18:23+02:00'::timestamp, 'Asia/Seoul'::string); +---- +2018-03-13T13:18:23 + +query P +SELECT from_utc_timestamp('2018-03-13T06:18:23+02:00'::timestamp, 'America/New_York'::string); +---- +2018-03-13T00:18:23 + +# Null inputs +query P +SELECT from_utc_timestamp(NULL::string, 'Asia/Seoul'::string); +---- +NULL + +query P +SELECT from_utc_timestamp(NULL::timestamp, 'Asia/Seoul'::string); +---- +NULL + +query P +SELECT from_utc_timestamp('2016-08-31'::string, NULL::string); +---- +NULL + +query P +SELECT from_utc_timestamp(column1, column2) +FROM VALUES +('2016-08-31'::string, 'Asia/Seoul'::string), +('2018-03-13T06:18:23+02:00'::string, 'Asia/Seoul'::string), +('2016-08-31'::string, 'UTC'::string), +('2018-03-13T06:18:23+02:00'::string, 'UTC'::string), +('2016-08-31'::string, 'America/New_York'::string), +('2018-03-13T06:18:23+02:00'::string, 'America/New_York'::string), +(NULL::string, 'Asia/Seoul'::string), +('2016-08-31'::string, NULL::string); +---- +2016-08-31T09:00:00 +2018-03-13T13:18:23 +2016-08-31T00:00:00 +2018-03-13T04:18:23 +2016-08-30T20:00:00 +2018-03-13T00:18:23 +NULL +NULL + +query P +SELECT from_utc_timestamp(column1, column2) +FROM VALUES +('2016-08-31'::timestamp, 'Asia/Seoul'::string), +('2018-03-13T06:18:23+02:00'::timestamp, 'Asia/Seoul'::string), +('2016-08-31'::timestamp, 'UTC'::string), +('2018-03-13T06:18:23+02:00'::timestamp, 'UTC'::string), +('2016-08-31'::timestamp, 'America/New_York'::string), +('2018-03-13T06:18:23+02:00'::timestamp, 'America/New_York'::string), +(NULL::timestamp, 'Asia/Seoul'::string), +('2018-03-13T06:18:23+00:00'::timestamp, NULL::string); +---- +2016-08-31T09:00:00 +2018-03-13T13:18:23 +2016-08-31T00:00:00 +2018-03-13T04:18:23 +2016-08-30T20:00:00 +2018-03-13T00:18:23 +NULL +NULL + +query P +SELECT from_utc_timestamp(arrow_cast(column1, 'Timestamp(Microsecond, Some("Asia/Seoul"))'), column2) +FROM VALUES +('2016-08-31'::timestamp, 'Asia/Seoul'::string), +('2018-03-13T06:18:23+02:00'::timestamp, 'Asia/Seoul'::string), +('2016-08-31'::timestamp, 'UTC'::string), +('2018-03-13T06:18:23+02:00'::timestamp, 'UTC'::string), +('2016-08-31'::timestamp, 'America/New_York'::string), +('2018-03-13T06:18:23+02:00'::timestamp, 'America/New_York'::string), +(NULL::timestamp, 'Asia/Seoul'::string), +('2018-03-13T06:18:23+00:00'::timestamp, NULL::string); +---- +2016-08-31T09:00:00+09:00 +2018-03-13T13:18:23+09:00 +2016-08-31T00:00:00+09:00 +2018-03-13T04:18:23+09:00 +2016-08-30T20:00:00+09:00 +2018-03-13T00:18:23+09:00 +NULL +NULL + + +# DST edge cases +query P +SELECT from_utc_timestamp('2020-03-31T13:40:00'::timestamp, 'America/New_York'::string); +---- +2020-03-31T09:40:00 + + +query P +SELECT from_utc_timestamp('2020-11-04T14:06:40'::timestamp, 'America/New_York'::string); +---- +2020-11-04T09:06:40 diff --git a/datafusion/sqllogictest/test_files/spark/datetime/time_trunc.slt b/datafusion/sqllogictest/test_files/spark/datetime/time_trunc.slt new file mode 100644 index 0000000000000..35ffa483bb068 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/time_trunc.slt @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# HOUR - zero out minute and second with fraction +query D +SELECT time_trunc('HOUR', '09:32:05.123456'::time); +---- +09:00:00 + +# MINUTE - zero out second with fraction +query D +SELECT time_trunc('MINUTE', '09:32:05.123456'::time); +---- +09:32:00 + +# SECOND - zero out fraction +query D +SELECT time_trunc('SECOND', '09:32:05.123456'::time); +---- +09:32:05 + +# MILLISECOND - zero out microseconds +query D +SELECT time_trunc('MILLISECOND', '09:32:05.123456'::time); +---- +09:32:05.123 + +# MICROSECOND - everything remains +query D +SELECT time_trunc('MICROSECOND', '09:32:05.123456'::time); +---- +09:32:05.123456 + +query D +SELECT time_trunc('HOUR', column1) +FROM VALUES +('09:32:05.123456'::time), +('22:45:30.654321'::time), +('14:20:10.000001'::time), +(NULL::time); +---- +09:00:00 +22:00:00 +14:00:00 +NULL + + +# Null handling +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: First argument of `TIME_TRUNC` must be non-null scalar Utf8 +SELECT time_trunc(NULL, '09:32:05.123456'::time); + +query D +SELECT time_trunc('HOUR', NULL::time); +---- +NULL + +# incorrect format +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: The format argument of `TIME_TRUNC` must be one of: hour, minute, second, millisecond, microsecond +SELECT time_trunc('test', '09:32:05.123456'::time); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/to_utc_timestamp.slt b/datafusion/sqllogictest/test_files/spark/datetime/to_utc_timestamp.slt index 24693016be1a7..086716e5bcd0e 100644 --- a/datafusion/sqllogictest/test_files/spark/datetime/to_utc_timestamp.slt +++ b/datafusion/sqllogictest/test_files/spark/datetime/to_utc_timestamp.slt @@ -15,13 +15,143 @@ # specific language governing permissions and limitations # under the License. -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT to_utc_timestamp('2016-08-31', 'Asia/Seoul'); -## PySpark 3.5.5 Result: {'to_utc_timestamp(2016-08-31, Asia/Seoul)': datetime.datetime(2016, 8, 30, 15, 0), 'typeof(to_utc_timestamp(2016-08-31, Asia/Seoul))': 'timestamp', 'typeof(2016-08-31)': 'string', 'typeof(Asia/Seoul)': 'string'} -#query -#SELECT to_utc_timestamp('2016-08-31'::string, 'Asia/Seoul'::string); + +# String inputs +query P +SELECT to_utc_timestamp('2016-08-31'::string, 'UTC'::string); +---- +2016-08-31T00:00:00 + +query P +SELECT to_utc_timestamp('2016-08-31'::string, 'Asia/Seoul'::string); +---- +2016-08-30T15:00:00 + +query P +SELECT to_utc_timestamp('2016-08-31'::string, 'America/New_York'::string); +---- +2016-08-31T04:00:00 + +# String inputs with offsets +query P +SELECT to_utc_timestamp('2018-03-13T06:18:23+02:00'::string, 'UTC'::string); +---- +2018-03-13T04:18:23 + +query P +SELECT to_utc_timestamp('2018-03-13T06:18:23+02:00'::string, 'Asia/Seoul'::string); +---- +2018-03-12T19:18:23 + +query P +SELECT to_utc_timestamp('2018-03-13T06:18:23+02:00'::string, 'America/New_York'::string); +---- +2018-03-13T08:18:23 + +# Timestamp inputs +query P +SELECT to_utc_timestamp('2018-03-13T06:18:23+02:00'::timestamp, 'UTC'::string); +---- +2018-03-13T04:18:23 + +query P +SELECT to_utc_timestamp('2018-03-13T06:18:23+02:00'::timestamp, 'Asia/Seoul'::string); +---- +2018-03-12T19:18:23 + +query P +SELECT to_utc_timestamp('2018-03-13T06:18:23+02:00'::timestamp, 'America/New_York'::string); +---- +2018-03-13T08:18:23 + +# Null inputs +query P +SELECT to_utc_timestamp(NULL::string, 'Asia/Seoul'::string); +---- +NULL + +query P +SELECT to_utc_timestamp(NULL::timestamp, 'Asia/Seoul'::string); +---- +NULL + +query P +SELECT to_utc_timestamp('2016-08-31'::string, NULL::string); +---- +NULL + +query P +SELECT to_utc_timestamp(column1, column2) +FROM VALUES +('2016-08-31'::string, 'Asia/Seoul'::string), +('2018-03-13T06:18:23+02:00'::string, 'Asia/Seoul'::string), +('2016-08-31'::string, 'UTC'::string), +('2018-03-13T06:18:23+02:00'::string, 'UTC'::string), +('2016-08-31'::string, 'America/New_York'::string), +('2018-03-13T06:18:23+02:00'::string, 'America/New_York'::string), +(NULL::string, 'Asia/Seoul'::string), +('2016-08-31'::string, NULL::string); +---- +2016-08-30T15:00:00 +2018-03-12T19:18:23 +2016-08-31T00:00:00 +2018-03-13T04:18:23 +2016-08-31T04:00:00 +2018-03-13T08:18:23 +NULL +NULL + +query P +SELECT to_utc_timestamp(column1, column2) +FROM VALUES +('2016-08-31'::timestamp, 'Asia/Seoul'::string), +('2018-03-13T06:18:23+02:00'::timestamp, 'Asia/Seoul'::string), +('2016-08-31'::timestamp, 'UTC'::string), +('2018-03-13T06:18:23+02:00'::timestamp, 'UTC'::string), +('2016-08-31'::timestamp, 'America/New_York'::string), +('2018-03-13T06:18:23+02:00'::timestamp, 'America/New_York'::string), +(NULL::timestamp, 'Asia/Seoul'::string), +('2018-03-13T06:18:23+00:00'::timestamp, NULL::string); +---- +2016-08-30T15:00:00 +2018-03-12T19:18:23 +2016-08-31T00:00:00 +2018-03-13T04:18:23 +2016-08-31T04:00:00 +2018-03-13T08:18:23 +NULL +NULL + +query P +SELECT to_utc_timestamp(arrow_cast(column1, 'Timestamp(Microsecond, Some("Asia/Seoul"))'), column2) +FROM VALUES +('2016-08-31'::timestamp, 'Asia/Seoul'::string), +('2018-03-13T06:18:23+02:00'::timestamp, 'Asia/Seoul'::string), +('2016-08-31'::timestamp, 'UTC'::string), +('2018-03-13T06:18:23+02:00'::timestamp, 'UTC'::string), +('2016-08-31'::timestamp, 'America/New_York'::string), +('2018-03-13T06:18:23+02:00'::timestamp, 'America/New_York'::string), +(NULL::timestamp, 'Asia/Seoul'::string), +('2018-03-13T06:18:23+00:00'::timestamp, NULL::string); +---- +2016-08-30T15:00:00+09:00 +2018-03-12T19:18:23+09:00 +2016-08-31T00:00:00+09:00 +2018-03-13T04:18:23+09:00 +2016-08-31T04:00:00+09:00 +2018-03-13T08:18:23+09:00 +NULL +NULL + + +# DST edge cases +query P +SELECT to_utc_timestamp('2020-03-31T13:40:00'::timestamp, 'America/New_York'::string); +---- +2020-03-31T17:40:00 + + +query P +SELECT to_utc_timestamp('2020-11-04T14:06:40'::timestamp, 'America/New_York'::string); +---- +2020-11-04T19:06:40 diff --git a/datafusion/sqllogictest/test_files/spark/datetime/trunc.slt b/datafusion/sqllogictest/test_files/spark/datetime/trunc.slt index a502e2f7f7b00..aa26d7bd0ef06 100644 --- a/datafusion/sqllogictest/test_files/spark/datetime/trunc.slt +++ b/datafusion/sqllogictest/test_files/spark/datetime/trunc.slt @@ -15,28 +15,78 @@ # specific language governing permissions and limitations # under the License. -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT trunc('2009-02-12', 'MM'); -## PySpark 3.5.5 Result: {'trunc(2009-02-12, MM)': datetime.date(2009, 2, 1), 'typeof(trunc(2009-02-12, MM))': 'date', 'typeof(2009-02-12)': 'string', 'typeof(MM)': 'string'} -#query -#SELECT trunc('2009-02-12'::string, 'MM'::string); - -## Original Query: SELECT trunc('2015-10-27', 'YEAR'); -## PySpark 3.5.5 Result: {'trunc(2015-10-27, YEAR)': datetime.date(2015, 1, 1), 'typeof(trunc(2015-10-27, YEAR))': 'date', 'typeof(2015-10-27)': 'string', 'typeof(YEAR)': 'string'} -#query -#SELECT trunc('2015-10-27'::string, 'YEAR'::string); - -## Original Query: SELECT trunc('2019-08-04', 'quarter'); -## PySpark 3.5.5 Result: {'trunc(2019-08-04, quarter)': datetime.date(2019, 7, 1), 'typeof(trunc(2019-08-04, quarter))': 'date', 'typeof(2019-08-04)': 'string', 'typeof(quarter)': 'string'} -#query -#SELECT trunc('2019-08-04'::string, 'quarter'::string); - -## Original Query: SELECT trunc('2019-08-04', 'week'); -## PySpark 3.5.5 Result: {'trunc(2019-08-04, week)': datetime.date(2019, 7, 29), 'typeof(trunc(2019-08-04, week))': 'date', 'typeof(2019-08-04)': 'string', 'typeof(week)': 'string'} -#query -#SELECT trunc('2019-08-04'::string, 'week'::string); +# YEAR - truncate to first date of year +query D +SELECT trunc('2009-02-12'::date, 'YEAR'::string); +---- +2009-01-01 + +query D +SELECT trunc('2009-02-12'::date, 'YYYY'::string); +---- +2009-01-01 + +query D +SELECT trunc('2009-02-12'::date, 'YY'::string); +---- +2009-01-01 + +# QUARTER - truncate to first date of quarter +query D +SELECT trunc('2009-02-12'::date, 'QUARTER'::string); +---- +2009-01-01 + +# MONTH - truncate to first date of month +query D +SELECT trunc('2009-02-12'::date, 'MONTH'::string); +---- +2009-02-01 + +query D +SELECT trunc('2009-02-12'::date, 'MM'::string); +---- +2009-02-01 + +query D +SELECT trunc('2009-02-12'::date, 'MON'::string); +---- +2009-02-01 + +# WEEK - truncate to Monday of the week +query D +SELECT trunc('2009-02-12'::date, 'WEEK'::string); +---- +2009-02-09 + +# string input +query D +SELECT trunc('2009-02-12'::string, 'YEAR'::string); +---- +2009-01-01 + +query D +SELECT trunc(column1, 'YEAR'::string) +FROM VALUES +('2009-02-12'::date), +('2000-02-12'::date), +('2042-02-12'::date), +(NULL::date); +---- +2009-01-01 +2000-01-01 +2042-01-01 +NULL + +# Null handling +query D +SELECT trunc(NULL::date, 'YEAR'::string); +---- +NULL + +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Second argument of `TRUNC` must be non-null scalar Utf8 +SELECT trunc('2009-02-12'::date, NULL::string); + +# incorrect format +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: The format argument of `TRUNC` must be one of: year, yy, yyyy, month, mm, mon, day, week, quarter. +SELECT trunc('2009-02-12'::date, 'test'::string); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/unix.slt b/datafusion/sqllogictest/test_files/spark/datetime/unix.slt new file mode 100644 index 0000000000000..d7441f487d037 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/unix.slt @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Unix Date tests + +query I +SELECT unix_date('1970-01-02'::date); +---- +1 + +query I +SELECT unix_date('1900-01-02'::date); +---- +-25566 + + +query I +SELECT unix_date(arrow_cast('1970-01-02', 'Date64')); +---- +1 + +query I +SELECT unix_date(NULL::date); +---- +NULL + +query error Function 'unix_date' requires TypeSignatureClass::Native\(LogicalType\(Native\(Date\), Date\)\), but received String \(DataType: Utf8View\) +SELECT unix_date('1970-01-02'::string); + +# Unix Micro Tests + +query I +SELECT unix_micros('1970-01-01 00:00:01Z'::timestamp); +---- +1000000 + +query I +SELECT unix_micros('1900-01-01 00:00:01Z'::timestamp); +---- +-2208988799000000 + +query I +SELECT unix_micros(arrow_cast('1970-01-01 00:00:01+02:00', 'Timestamp(Microsecond, None)')); +---- +-7199000000 + +query I +SELECT unix_micros(arrow_cast('1970-01-01 00:00:01Z', 'Timestamp(Second, None)')); +---- +1000000 + +query I +SELECT unix_micros(NULL::timestamp); +---- +NULL + +query error Function 'unix_micros' requires TypeSignatureClass::Timestamp, but received String \(DataType: Utf8View\) +SELECT unix_micros('1970-01-01 00:00:01Z'::string); + + +# Unix Millis Tests + +query I +SELECT unix_millis('1970-01-01 00:00:01Z'::timestamp); +---- +1000 + +query I +SELECT unix_millis('1900-01-01 00:00:01Z'::timestamp); +---- +-2208988799000 + +query I +SELECT unix_millis(arrow_cast('1970-01-01 00:00:01+02:00', 'Timestamp(Microsecond, None)')); +---- +-7199000 + +query I +SELECT unix_millis(arrow_cast('1970-01-01 00:00:01Z', 'Timestamp(Second, None)')); +---- +1000 + +query I +SELECT unix_millis(NULL::timestamp); +---- +NULL + +query error Function 'unix_millis' requires TypeSignatureClass::Timestamp, but received String \(DataType: Utf8View\) +SELECT unix_millis('1970-01-01 00:00:01Z'::string); + + +# Unix Seconds Tests + +query I +SELECT unix_seconds('1970-01-01 00:00:01Z'::timestamp); +---- +1 + +query I +SELECT unix_seconds('1900-01-01 00:00:01Z'::timestamp); +---- +-2208988799 + +query I +SELECT unix_seconds(arrow_cast('1970-01-01 00:00:01+02:00', 'Timestamp(Microsecond, None)')); +---- +-7199 + +query I +SELECT unix_seconds(arrow_cast('1970-01-01 00:00:01Z', 'Timestamp(Second, None)')); +---- +1 + +query I +SELECT unix_seconds(NULL::timestamp); +---- +NULL + +query error Function 'unix_seconds' requires TypeSignatureClass::Timestamp, but received String \(DataType: Utf8View\) +SELECT unix_seconds('1970-01-01 00:00:01Z'::string); diff --git a/datafusion/sqllogictest/test_files/spark/hash/crc32.slt b/datafusion/sqllogictest/test_files/spark/hash/crc32.slt index 6fbeb11fb9a36..df5588c75837d 100644 --- a/datafusion/sqllogictest/test_files/spark/hash/crc32.slt +++ b/datafusion/sqllogictest/test_files/spark/hash/crc32.slt @@ -81,7 +81,7 @@ SELECT crc32(arrow_cast('Spark', 'BinaryView')); ---- 1557323817 -# Upstream arrow-rs issue: https://github.com/apache/arrow-rs/issues/8841 -# This should succeed after we receive the fix -query error Arrow error: Compute error: Internal Error: Cannot cast BinaryView to BinaryArray of expected type +query I select crc32(arrow_cast(null, 'Dictionary(Int32, Utf8)')) +---- +NULL diff --git a/datafusion/sqllogictest/test_files/spark/hash/sha2.slt b/datafusion/sqllogictest/test_files/spark/hash/sha2.slt index 7690a38773b04..07f70947fe926 100644 --- a/datafusion/sqllogictest/test_files/spark/hash/sha2.slt +++ b/datafusion/sqllogictest/test_files/spark/hash/sha2.slt @@ -75,3 +75,58 @@ SELECT sha2(expr, bit_length) FROM VALUES ('foo',0::INT), ('bar',224::INT), ('ba 967004d25de4abc1bd6a7c9a216254a5ac0733e8ad96dc9f1ea0fad9619da7c32d654ec8ad8ba2f9b5728fed6633bd91 8c6be9ed448a34883a13a13f4ead4aefa036b67dcda59020c01e57ea075ea8a4792d428f2c6fd0c09d1c49994d6c22789336e062188df29572ed07e7f9779c52 NULL + +# All string types +query T +SELECT sha2(arrow_cast('foo', 'Utf8'), bit_length) FROM VALUES (224::INT), (256::INT) AS t(bit_length); +---- +0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db +2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae + +query T +SELECT sha2(arrow_cast('foo', 'LargeUtf8'), bit_length) FROM VALUES (224::INT), (256::INT) AS t(bit_length); +---- +0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db +2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae + +query T +SELECT sha2(arrow_cast('foo', 'Utf8View'), bit_length) FROM VALUES (224::INT), (256::INT) AS t(bit_length); +---- +0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db +2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae + +# All binary types +query T +SELECT sha2(arrow_cast('foo', 'Binary'), bit_length) FROM VALUES (224::INT), (256::INT) AS t(bit_length); +---- +0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db +2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae + +query T +SELECT sha2(arrow_cast('foo', 'LargeBinary'), bit_length) FROM VALUES (224::INT), (256::INT) AS t(bit_length); +---- +0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db +2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae + +query T +SELECT sha2(arrow_cast('foo', 'BinaryView'), bit_length) FROM VALUES (224::INT), (256::INT) AS t(bit_length); +---- +0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db +2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae + + +# Null cases +query T +select sha2(null, 0); +---- +NULL + +query T +select sha2('a', null); +---- +NULL + +query T +select sha2('a', null::int); +---- +NULL diff --git a/datafusion/sqllogictest/test_files/spark/map/str_to_map.slt b/datafusion/sqllogictest/test_files/spark/map/str_to_map.slt new file mode 100644 index 0000000000000..30d1672aef0ae --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/map/str_to_map.slt @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Tests for Spark-compatible str_to_map function +# https://spark.apache.org/docs/latest/api/sql/index.html#str_to_map +# +# Test cases derived from Spark test("StringToMap"): +# https://github.com/apache/spark/blob/v4.0.0/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala#L525-L618 + +# s0: Basic test with default delimiters +query ? +SELECT str_to_map('a:1,b:2,c:3'); +---- +{a: 1, b: 2, c: 3} + +# s1: Preserve spaces in values +query ? +SELECT str_to_map('a: ,b:2'); +---- +{a: , b: 2} + +# s2: Custom key-value delimiter '=' +query ? +SELECT str_to_map('a=1,b=2,c=3', ',', '='); +---- +{a: 1, b: 2, c: 3} + +# s3: Empty string returns map with empty key and NULL value +query ? +SELECT str_to_map('', ',', '='); +---- +{: NULL} + +# s4: Custom pair delimiter '_' +query ? +SELECT str_to_map('a:1_b:2_c:3', '_', ':'); +---- +{a: 1, b: 2, c: 3} + +# s5: Single key without value returns NULL value +query ? +SELECT str_to_map('a'); +---- +{a: NULL} + +# s6: Custom delimiters '&' and '=' +query ? +SELECT str_to_map('a=1&b=2&c=3', '&', '='); +---- +{a: 1, b: 2, c: 3} + +# Duplicate keys: EXCEPTION policy (Spark 3.0+ default) +# TODO: Add LAST_WIN policy tests when spark.sql.mapKeyDedupPolicy config is supported +statement error +Duplicate map key +SELECT str_to_map('a:1,b:2,a:3'); + +# Additional tests (DataFusion-specific) + +# NULL input returns NULL +query ? +SELECT str_to_map(NULL, ',', ':'); +---- +NULL + +# Explicit 3-arg form +query ? +SELECT str_to_map('a:1,b:2,c:3', ',', ':'); +---- +{a: 1, b: 2, c: 3} + +# Missing key-value delimiter results in NULL value +query ? +SELECT str_to_map('a,b:2', ',', ':'); +---- +{a: NULL, b: 2} + +# Multi-row test +query ? +SELECT str_to_map(col) FROM (VALUES ('a:1,b:2'), ('x:9'), (NULL)) AS t(col); +---- +{a: 1, b: 2} +{x: 9} +NULL + +# Multi-row with custom delimiter +query ? +SELECT str_to_map(col, ',', '=') FROM (VALUES ('a=1,b=2'), ('x=9'), (NULL)) AS t(col); +---- +{a: 1, b: 2} +{x: 9} +NULL + +# Per-row delimiters: each row can have different delimiters +query ? +SELECT str_to_map(col1, col2, col3) FROM (VALUES ('a=1,b=2', ',', '='), ('x#9', ',', '#'), (NULL, ',', '=')) AS t(col1, col2, col3); +---- +{a: 1, b: 2} +{x: 9} +NULL \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/spark/math/abs.slt b/datafusion/sqllogictest/test_files/spark/math/abs.slt index 19ca902ea3de6..94092caab9854 100644 --- a/datafusion/sqllogictest/test_files/spark/math/abs.slt +++ b/datafusion/sqllogictest/test_files/spark/math/abs.slt @@ -24,71 +24,187 @@ ## Original Query: SELECT abs(-1); ## PySpark 3.5.5 Result: {'abs(-1)': 1, 'typeof(abs(-1))': 'int', 'typeof(-1)': 'int'} -# abs: signed int and NULL +# Scalar input +## Scalar input: signed int and NULL query IIIIR SELECT abs(-127::TINYINT), abs(-32767::SMALLINT), abs(-2147483647::INT), abs(-9223372036854775807::BIGINT), abs(NULL); ---- 127 32767 2147483647 9223372036854775807 NULL - -# See https://github.com/apache/datafusion/issues/18794 for operator precedence -# abs: signed int minimal values +## Scalar input: signed int minimal values +## See https://github.com/apache/datafusion/issues/18794 for operator precedence query IIII -select abs((-128)::TINYINT), abs((-32768)::SMALLINT), abs((-2147483648)::INT), abs((-9223372036854775808)::BIGINT) +select abs((-128)::TINYINT), abs((-32768)::SMALLINT), abs((-2147483648)::INT), abs((-9223372036854775808)::BIGINT); ---- -128 -32768 -2147483648 -9223372036854775808 -# abs: floats, NULL, NaN, -0, infinity, -infinity +## Scalar input: Spark ANSI mode, signed int minimal values +statement ok +set datafusion.execution.enable_ansi_mode = true; + +query error DataFusion error: Arrow error: Compute error: Int8 overflow on abs\(\-128\) +select abs((-128)::TINYINT); + +query error DataFusion error: Arrow error: Compute error: Int16 overflow on abs\(\-32768\) +select abs((-32768)::SMALLINT); + +query error DataFusion error: Arrow error: Compute error: Int32 overflow on abs\(\-2147483648\) +select abs((-2147483648)::INT); + +query error DataFusion error: Arrow error: Compute error: Int64 overflow on abs\(\-9223372036854775808\) +select abs((-9223372036854775808)::BIGINT); + +statement ok +set datafusion.execution.enable_ansi_mode = false; + +## Scalar input: float, NULL, NaN, -0, infinity, -infinity query RRRRRRRRRRRR -SELECT abs(-1.0::FLOAT), abs(0.::FLOAT), abs(-0.::FLOAT), abs(-0::FLOAT), abs(NULL::FLOAT), abs('NaN'::FLOAT), abs('inf'::FLOAT), abs('+inf'::FLOAT), abs('-inf'::FLOAT), abs('infinity'::FLOAT), abs('+infinity'::FLOAT), abs('-infinity'::FLOAT) +SELECT abs(-1.0::FLOAT), abs(0.::FLOAT), abs(-0.::FLOAT), abs(-0::FLOAT), abs(NULL::FLOAT), abs('NaN'::FLOAT), abs('inf'::FLOAT), abs('+inf'::FLOAT), abs('-inf'::FLOAT), abs('infinity'::FLOAT), abs('+infinity'::FLOAT), abs('-infinity'::FLOAT); ---- 1 0 0 0 NULL NaN Infinity Infinity Infinity Infinity Infinity Infinity -# abs: doubles, NULL, NaN, -0, infinity, -infinity +## Scalar input: double, NULL, NaN, -0, infinity, -infinity query RRRRRRRRRRRR -SELECT abs(-1.0::DOUBLE), abs(0.::DOUBLE), abs(-0.::DOUBLE), abs(-0::DOUBLE), abs(NULL::DOUBLE), abs('NaN'::DOUBLE), abs('inf'::DOUBLE), abs('+inf'::DOUBLE), abs('-inf'::DOUBLE), abs('infinity'::DOUBLE), abs('+infinity'::DOUBLE), abs('-infinity'::DOUBLE) +SELECT abs(-1.0::DOUBLE), abs(0.::DOUBLE), abs(-0.::DOUBLE), abs(-0::DOUBLE), abs(NULL::DOUBLE), abs('NaN'::DOUBLE), abs('inf'::DOUBLE), abs('+inf'::DOUBLE), abs('-inf'::DOUBLE), abs('infinity'::DOUBLE), abs('+infinity'::DOUBLE), abs('-infinity'::DOUBLE); ---- 1 0 0 0 NULL NaN Infinity Infinity Infinity Infinity Infinity Infinity -# abs: decimal128 and decimal256 -statement ok -CREATE TABLE test_nullable_decimal( - c1 DECIMAL(10, 2), /* Decimal128 */ - c2 DECIMAL(38, 10), /* Decimal128 with max precision */ - c3 DECIMAL(40, 2), /* Decimal256 */ - c4 DECIMAL(76, 10) /* Decimal256 with max precision */ - ) AS VALUES - (0, 0, 0, 0), - (NULL, NULL, NULL, NULL); +## Scalar input: decimal128 +query RRR +SELECT abs(('-99999999.99')::DECIMAL(10, 2)), abs(0::DECIMAL(10, 2)), abs(NULL::DECIMAL(10, 2)); +---- +99999999.99 0 NULL + +query RRR +SELECT abs(('-9999999999999999999999999999.9999999999')::DECIMAL(38, 10)), abs(0::DECIMAL(38, 10)), abs(NULL::DECIMAL(38, 10)); +---- +9999999999999999999999999999.9999999999 0 NULL + +## Scalar input: decimal256 +query RRR +SELECT abs(('-99999999999999999999999999999999999999.99')::DECIMAL(40, 2)), abs(0::DECIMAL(40, 2)), abs(NULL::DECIMAL(40, 2)); +---- +99999999999999999999999999999999999999.99 0 NULL + +query RRR +SELECT abs(('-999999999999999999999999999999999999999999999999999999999999999999.9999999999')::DECIMAL(76, 10)), abs(0::DECIMAL(76, 10)), abs(NULL::DECIMAL(76, 10)); +---- +999999999999999999999999999999999999999999999999999999999999999999.9999999999 0 NULL + + +# Array input +## Array input: signed int, signed int minimal values and NULL +query I +SELECT abs(a) FROM (VALUES (-127::TINYINT), ((-128)::TINYINT), (NULL)) AS t(a); +---- +127 +-128 +NULL + +query I +select abs(a) FROM (VALUES (-32767::SMALLINT), ((-32768)::SMALLINT), (NULL)) AS t(a); +---- +32767 +-32768 +NULL + +query I +select abs(a) FROM (VALUES (-2147483647::INT), ((-2147483648)::INT), (NULL)) AS t(a); +---- +2147483647 +-2147483648 +NULL query I -INSERT into test_nullable_decimal values - ( - -99999999.99, - '-9999999999999999999999999999.9999999999', - '-99999999999999999999999999999999999999.99', - '-999999999999999999999999999999999999999999999999999999999999999999.9999999999' - ), - ( - 99999999.99, - '9999999999999999999999999999.9999999999', - '99999999999999999999999999999999999999.99', - '999999999999999999999999999999999999999999999999999999999999999999.9999999999' - ) ----- -2 - -query RRRR rowsort -SELECT abs(c1), abs(c2), abs(c3), abs(c4) FROM test_nullable_decimal ----- -0 0 0 0 -99999999.99 9999999999999999999999999999.9999999999 99999999999999999999999999999999999999.99 999999999999999999999999999999999999999999999999999999999999999999.9999999999 -99999999.99 9999999999999999999999999999.9999999999 99999999999999999999999999999999999999.99 999999999999999999999999999999999999999999999999999999999999999999.9999999999 -NULL NULL NULL NULL +select abs(a) FROM (VALUES (-9223372036854775807::BIGINT), ((-9223372036854775808)::BIGINT), (NULL)) AS t(a); +---- +9223372036854775807 +-9223372036854775808 +NULL + +## Array Input: Spark ANSI mode, signed int minimal values +statement ok +set datafusion.execution.enable_ansi_mode = true; + +query error DataFusion error: Arrow error: Compute error: Int8Array overflow on abs\(\-128\) +SELECT abs(a) FROM (VALUES (-127::TINYINT), ((-128)::TINYINT)) AS t(a); + +query error DataFusion error: Arrow error: Compute error: Int16Array overflow on abs\(\-32768\) +select abs(a) FROM (VALUES (-32767::SMALLINT), ((-32768)::SMALLINT)) AS t(a); +query error DataFusion error: Arrow error: Compute error: Int32Array overflow on abs\(\-2147483648\) +select abs(a) FROM (VALUES (-2147483647::INT), ((-2147483648)::INT)) AS t(a); + +query error DataFusion error: Arrow error: Compute error: Int64Array overflow on abs\(\-9223372036854775808\) +select abs(a) FROM (VALUES (-9223372036854775807::BIGINT), ((-9223372036854775808)::BIGINT)) AS t(a); statement ok -drop table test_nullable_decimal +set datafusion.execution.enable_ansi_mode = false; + +## Array input: float, NULL, NaN, -0, infinity, -infinity +query R +SELECT abs(a) FROM (VALUES (-1.0::FLOAT), (0.::FLOAT), (-0.::FLOAT), (-0::FLOAT), (NULL::FLOAT), ('NaN'::FLOAT), ('inf'::FLOAT), ('+inf'::FLOAT), ('-inf'::FLOAT), ('infinity'::FLOAT), ('+infinity'::FLOAT), ('-infinity'::FLOAT)) AS t(a); +---- +1 +0 +0 +0 +NULL +NaN +Infinity +Infinity +Infinity +Infinity +Infinity +Infinity + + +## Array input: double, NULL, NaN, -0, infinity, -infinity +query R +SELECT abs(a) FROM (VALUES (-1.0::DOUBLE), (0.::DOUBLE), (-0.::DOUBLE), (-0::DOUBLE), (NULL::DOUBLE), ('NaN'::DOUBLE), ('inf'::DOUBLE), ('+inf'::DOUBLE), ('-inf'::DOUBLE), ('infinity'::DOUBLE), ('+infinity'::DOUBLE), ('-infinity'::DOUBLE)) AS t(a); +---- +1 +0 +0 +0 +NULL +NaN +Infinity +Infinity +Infinity +Infinity +Infinity +Infinity + +## Array input: decimal128 +query R +SELECT abs(a) FROM (VALUES (('-99999999.99')::DECIMAL(10, 2)), (0::DECIMAL(10, 2)), (NULL::DECIMAL(10, 2))) AS t(a); +---- +99999999.99 +0 +NULL + +query R +SELECT abs(a) FROM (VALUES (('-9999999999999999999999999999.9999999999')::DECIMAL(38, 10)), (0::DECIMAL(38, 10)), (NULL::DECIMAL(38, 10))) AS t(a); +---- +9999999999999999999999999999.9999999999 +0 +NULL + +## Array input: decimal256 +query R +SELECT abs(a) FROM (VALUES (('-99999999999999999999999999999999999999.99')::DECIMAL(40, 2)), (0::DECIMAL(40, 2)), (NULL::DECIMAL(40, 2))) AS t(a); +---- +99999999999999999999999999999999999999.99 +0 +NULL + +query R +SELECT abs(a) FROM (VALUES (('-999999999999999999999999999999999999999999999999999999999999999999.9999999999')::DECIMAL(76, 10)), (0::DECIMAL(76, 10)), (NULL::DECIMAL(76, 10))) AS t(a); +---- +999999999999999999999999999999999999999999999999999999999999999999.9999999999 +0 +NULL ## Original Query: SELECT abs(INTERVAL -'1-1' YEAR TO MONTH); ## PySpark 3.5.5 Result: {"abs(INTERVAL '-1-1' YEAR TO MONTH)": 13, "typeof(abs(INTERVAL '-1-1' YEAR TO MONTH))": 'interval year to month', "typeof(INTERVAL '-1-1' YEAR TO MONTH)": 'interval year to month'} diff --git a/datafusion/sqllogictest/test_files/spark/math/hex.slt b/datafusion/sqllogictest/test_files/spark/math/hex.slt index 05c9fb3f31b28..17e9ff432890d 100644 --- a/datafusion/sqllogictest/test_files/spark/math/hex.slt +++ b/datafusion/sqllogictest/test_files/spark/math/hex.slt @@ -63,3 +63,23 @@ query T SELECT hex(arrow_cast('test', 'LargeBinary')) as lar_b; ---- 74657374 + +statement ok +CREATE TABLE t_dict_binary AS +SELECT arrow_cast(column1, 'Dictionary(Int32, Binary)') as dict_col +FROM VALUES ('foo'), ('bar'), ('foo'), (NULL), ('baz'), ('bar'); + +query T +SELECT hex(dict_col) FROM t_dict_binary; +---- +666F6F +626172 +666F6F +NULL +62617A +626172 + +query T +SELECT arrow_typeof(hex(dict_col)) FROM t_dict_binary LIMIT 1; +---- +Dictionary(Int32, Utf8) diff --git a/datafusion/sqllogictest/test_files/spark/math/negative.slt b/datafusion/sqllogictest/test_files/spark/math/negative.slt index aa8e558e9895e..40bfaf791fe81 100644 --- a/datafusion/sqllogictest/test_files/spark/math/negative.slt +++ b/datafusion/sqllogictest/test_files/spark/math/negative.slt @@ -23,5 +23,309 @@ ## Original Query: SELECT negative(1); ## PySpark 3.5.5 Result: {'negative(1)': -1, 'typeof(negative(1))': 'int', 'typeof(1)': 'int'} -#query -#SELECT negative(1::int); + +# Test negative with integer +query I +SELECT negative(1::int); +---- +-1 + +# Test negative with positive integer +query I +SELECT negative(42::int); +---- +-42 + +# Test negative with negative integer +query I +SELECT negative(-10::int); +---- +10 + +# Test negative with zero +query I +SELECT negative(0::int); +---- +0 + +# Test negative with bigint +query I +SELECT negative(9223372036854775807::bigint); +---- +-9223372036854775807 + +# Test negative with negative bigint +query I +SELECT negative(-100::bigint); +---- +100 + +# Test negative with smallint +query I +SELECT negative(32767::smallint); +---- +-32767 + +# Test negative with float +query R +SELECT negative(3.14::float); +---- +-3.14 + +# Test negative with negative float +query R +SELECT negative(-2.5::float); +---- +2.5 + +# Test negative with double +query R +SELECT negative(3.14159265358979::double); +---- +-3.14159265358979 + +# Test negative with negative double +query R +SELECT negative(-1.5::double); +---- +1.5 + +# Test negative with decimal +query R +SELECT negative(123.456::decimal(10,3)); +---- +-123.456 + +# Test negative with negative decimal +query R +SELECT negative(-99.99::decimal(10,2)); +---- +99.99 + +# Test negative with NULL +query I +SELECT negative(NULL::int); +---- +NULL + +# Test negative with column values +statement ok +CREATE TABLE test_negative (id int, value int) AS VALUES (1, 10), (2, -20), (3, 0), (4, NULL); + +query II rowsort +SELECT id, negative(value) FROM test_negative; +---- +1 -10 +2 20 +3 0 +4 NULL + +statement ok +DROP TABLE test_negative; + +# Test negative in expressions +query I +SELECT negative(5) + 3; +---- +-2 + +# Test nested negative +query I +SELECT negative(negative(7)); +---- +7 + +# Test negative with large numbers +query R +SELECT negative(1234567890.123456::double); +---- +-1234567890.123456 + +# Test wrap-around: negative of minimum int (should wrap to same value) +# Using table to avoid constant folding overflow during optimization +statement ok +CREATE TABLE min_values_int AS VALUES (-2147483648); + +query I +SELECT negative(column1::int) FROM min_values_int; +---- +-2147483648 + +statement ok +DROP TABLE min_values_int; + +# Test wrap-around: negative of minimum bigint (should wrap to same value) +statement ok +CREATE TABLE min_values_bigint AS VALUES (-9223372036854775808); + +query I +SELECT negative(column1::bigint) FROM min_values_bigint; +---- +-9223372036854775808 + +statement ok +DROP TABLE min_values_bigint; + +# Test wrap-around: negative of minimum smallint (should wrap to same value) +statement ok +CREATE TABLE min_values_smallint AS VALUES (-32768); + +query I +SELECT negative(column1::smallint) FROM min_values_smallint; +---- +-32768 + +statement ok +DROP TABLE min_values_smallint; + +# Test wrap-around: negative of minimum tinyint (should wrap to same value) +statement ok +CREATE TABLE min_values_tinyint AS VALUES (-128); + +query I +SELECT negative(column1::tinyint) FROM min_values_tinyint; +---- +-128 + +statement ok +DROP TABLE min_values_tinyint; + +# Test overflow: negative of positive infinity (float) +query R +SELECT negative('Infinity'::float); +---- +-Infinity + +# Test overflow: negative of negative infinity (float) +query R +SELECT negative('-Infinity'::float); +---- +Infinity + +# Test overflow: negative of positive infinity (double) +query R +SELECT negative('Infinity'::double); +---- +-Infinity + +# Test overflow: negative of negative infinity (double) +query R +SELECT negative('-Infinity'::double); +---- +Infinity + +# Test overflow: negative of NaN (float) +query R +SELECT negative('NaN'::float); +---- +NaN + +# Test overflow: negative of NaN (double) +query R +SELECT negative('NaN'::double); +---- +NaN + +# Test overflow: negative of maximum float value +query R +SELECT negative(3.4028235e38::float); +---- +-340282350000000000000000000000000000000 + +# Test overflow: negative of minimum float value +query R +SELECT negative(-3.4028235e38::float); +---- +340282350000000000000000000000000000000 + +# Test overflow: negative of maximum double value +query R +SELECT negative(1.7976931348623157e308::double); +---- +-179769313486231570000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 + +# Test overflow: negative of minimum double value +query R +SELECT negative(-1.7976931348623157e308::double); +---- +179769313486231570000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 + +# Test negative with CalendarIntervalType (IntervalMonthDayNano) +# Spark make_interval creates CalendarInterval +query ? +SELECT negative(make_interval(1, 2, 3, 4, 5, 6, 7.5)); +---- +-14 mons -25 days -5 hours -6 mins -7.500000000 secs + +# Test negative with negative CalendarIntervalType +query ? +SELECT negative(make_interval(-2, -5, -1, -10, -3, -30, -15.25)); +---- +29 mons 17 days 3 hours 30 mins 15.250000000 secs + +# Test negative with CalendarInterval from table +statement ok +CREATE TABLE interval_test AS VALUES + (make_interval(1, 2, 0, 5, 0, 0, 0.0)), + (make_interval(-3, -1, 0, -2, 0, 0, 0.0)); + +query ? rowsort +SELECT negative(column1) FROM interval_test; +---- +-14 mons -5 days +37 mons 2 days + +statement ok +DROP TABLE interval_test; + +## ANSI mode tests: overflow detection +statement ok +set datafusion.execution.enable_ansi_mode = true; + +# Test ANSI mode: negative of minimum values should error (overflow) +query error DataFusion error: Execution error: Int8 overflow on negative\(\-128\) +SELECT negative((-128)::tinyint); + +query error DataFusion error: Execution error: Int16 overflow on negative\(\-32768\) +SELECT negative((-32768)::smallint); + +query error DataFusion error: Execution error: Int32 overflow on negative\(\-2147483648\) +SELECT negative((-2147483648)::int); + +query error DataFusion error: Execution error: Int64 overflow on negative\(\-9223372036854775808\) +SELECT negative((-9223372036854775808)::bigint); + +# Test ANSI mode: negative of (MIN+1) should succeed (boundary test) +query I +SELECT negative((-127)::tinyint); +---- +127 + +query I +SELECT negative((-32767)::smallint); +---- +32767 + +query I +SELECT negative((-2147483647)::int); +---- +2147483647 + +query I +SELECT negative((-9223372036854775807)::bigint); +---- +9223372036854775807 + +# Test ANSI mode: array with MIN value should error +statement ok +CREATE TABLE min_values_ansi AS VALUES (-2147483648); + +query error DataFusion error: Execution error: Int32 overflow on negative\(\-2147483648\) +SELECT negative(column1::int) FROM min_values_ansi; + +statement ok +DROP TABLE min_values_ansi; + +# Reset ANSI mode to false +statement ok +set datafusion.execution.enable_ansi_mode = false; diff --git a/datafusion/sqllogictest/test_files/spark/math/unhex.slt b/datafusion/sqllogictest/test_files/spark/math/unhex.slt new file mode 100644 index 0000000000000..051d8826c8a6c --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/unhex.slt @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Basic hex string +query ? +SELECT unhex('537061726B2053514C'); +---- +537061726b2053514c + +query T +SELECT arrow_cast(unhex('537061726B2053514C'), 'Utf8'); +---- +Spark SQL + +# Lowercase hex +query ? +SELECT unhex('616263'); +---- +616263 + +query T +SELECT arrow_cast(unhex('616263'), 'Utf8'); +---- +abc + +# Odd length hex (left pad with 0) +query ? +SELECT unhex(a) FROM VALUES ('1A2B3'), ('1'), ('ABC'), ('123') AS t(a); +---- +01a2b3 +01 +0abc +0123 + +# Null input +query ? +SELECT unhex(NULL); +---- +NULL + +# Invalid hex characters +query ? +SELECT unhex('GGHH'); +---- +NULL + +# Empty hex string +query T +SELECT arrow_cast(unhex(''), 'Utf8'); +---- +(empty) + +# Array with mixed case +query ? +SELECT unhex(a) FROM VALUES ('4a4B4c'), ('F'), ('A'), ('AbCdEf'), ('123abc'), ('41 42'), ('00'), ('FF') AS t(a); +---- +4a4b4c +0f +0a +abcdef +123abc +NULL +00 +ff + +# LargeUtf8 type +statement ok +CREATE TABLE t_large_utf8 AS VALUES (arrow_cast('414243', 'LargeUtf8')), (NULL); + +query ? +SELECT unhex(column1) FROM t_large_utf8; +---- +414243 +NULL + +# Utf8View type +statement ok +CREATE TABLE t_utf8view AS VALUES (arrow_cast('414243', 'Utf8View')), (NULL); + +query ? +SELECT unhex(column1) FROM t_utf8view; +---- +414243 +NULL diff --git a/datafusion/sqllogictest/test_files/spark/string/base64.slt b/datafusion/sqllogictest/test_files/spark/string/base64.slt index 66edbe8442158..03b488de0ee9a 100644 --- a/datafusion/sqllogictest/test_files/spark/string/base64.slt +++ b/datafusion/sqllogictest/test_files/spark/string/base64.slt @@ -15,18 +15,101 @@ # specific language governing permissions and limitations # under the License. -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT base64('Spark SQL'); -## PySpark 3.5.5 Result: {'base64(Spark SQL)': 'U3BhcmsgU1FM', 'typeof(base64(Spark SQL))': 'string', 'typeof(Spark SQL)': 'string'} -#query -#SELECT base64('Spark SQL'::string); - -## Original Query: SELECT base64(x'537061726b2053514c'); -## PySpark 3.5.5 Result: {"base64(X'537061726B2053514C')": 'U3BhcmsgU1FM', "typeof(base64(X'537061726B2053514C'))": 'string', "typeof(X'537061726B2053514C')": 'binary'} -#query -#SELECT base64(X'537061726B2053514C'::binary); +query T +SELECT base64('Spark SQL'::string); +---- +U3BhcmsgU1FM + +query T +SELECT base64('Spark SQ'::string); +---- +U3BhcmsgU1E= + +query T +SELECT base64('Spark S'::string); +---- +U3BhcmsgUw== + +query T +SELECT base64('Spark SQL'::bytea); +---- +U3BhcmsgU1FM + +query T +SELECT base64(NULL::string); +---- +NULL + +query T +SELECT base64(NULL::bytea); +---- +NULL + +query T +SELECT base64(column1) +FROM VALUES +('Spark SQL'::bytea), +('Spark SQ'::bytea), +('Spark S'::bytea), +(NULL::bytea); +---- +U3BhcmsgU1FM +U3BhcmsgU1E= +U3BhcmsgUw== +NULL + +query error Function 'base64' requires TypeSignatureClass::Binary, but received Int32 \(DataType: Int32\) +SELECT base64(12::integer); + + +query T +SELECT arrow_cast(unbase64('U3BhcmsgU1FM'::string), 'Utf8'); +---- +Spark SQL + +query T +SELECT arrow_cast(unbase64('U3BhcmsgU1E='::string), 'Utf8'); +---- +Spark SQ + +query T +SELECT arrow_cast(unbase64('U3BhcmsgUw=='::string), 'Utf8'); +---- +Spark S + +query T +SELECT arrow_cast(unbase64('U3BhcmsgU1FM'::bytea), 'Utf8'); +---- +Spark SQL + +query ? +SELECT unbase64(NULL::string); +---- +NULL + +query ? +SELECT unbase64(NULL::bytea); +---- +NULL + +query T +SELECT arrow_cast(unbase64(column1), 'Utf8') +FROM VALUES +('U3BhcmsgU1FM'::string), +('U3BhcmsgU1E='::string), +('U3BhcmsgUw=='::string), +(NULL::string); +---- +Spark SQL +Spark SQ +Spark S +NULL + +query error Failed to decode value using base64 +SELECT unbase64('123'::string); + +query error Failed to decode value using base64 +SELECT unbase64('123'::bytea); + +query error Function 'unbase64' requires TypeSignatureClass::Binary, but received Int32 \(DataType: Int32\) +SELECT unbase64(12::integer); diff --git a/datafusion/sqllogictest/test_files/spark/string/concat.slt b/datafusion/sqllogictest/test_files/spark/string/concat.slt index 258cb829d7d4b..97e7b57f7d06e 100644 --- a/datafusion/sqllogictest/test_files/spark/string/concat.slt +++ b/datafusion/sqllogictest/test_files/spark/string/concat.slt @@ -20,6 +20,12 @@ SELECT concat('Spark', 'SQL'); ---- SparkSQL +# Test two Utf8View inputs: value and return type +query TT +SELECT concat(arrow_cast('Spark', 'Utf8View'), arrow_cast('SQL', 'Utf8View')), arrow_typeof(concat(arrow_cast('Spark', 'Utf8View'), arrow_cast('SQL', 'Utf8View'))); +---- +SparkSQL Utf8View + query T SELECT concat('Spark', 'SQL', NULL); ---- @@ -46,3 +52,21 @@ SELECT concat(a, b, c) from (select 'a' a, 'b' b, 'c' c union all select null a, ---- abc NULL + +# Test mixed types: Utf8View + Utf8 +query TT +SELECT concat(arrow_cast('hello', 'Utf8View'), ' world'), arrow_typeof(concat(arrow_cast('hello', 'Utf8View'), ' world')); +---- +hello world Utf8View + +# Test Utf8 + LargeUtf8 => return type LargeUtf8 +query TT +SELECT concat('a', arrow_cast('b', 'LargeUtf8')), arrow_typeof(concat('a', arrow_cast('b', 'LargeUtf8'))); +---- +ab LargeUtf8 + +# Test all three types mixed together +query TT +SELECT concat('a', arrow_cast('b', 'LargeUtf8'), arrow_cast('c', 'Utf8View')), arrow_typeof(concat('a', arrow_cast('b', 'LargeUtf8'), arrow_cast('c', 'Utf8View'))); +---- +abc Utf8View diff --git a/datafusion/sqllogictest/test_files/spark/string/format_string.slt b/datafusion/sqllogictest/test_files/spark/string/format_string.slt index 048863ebfbedb..8ba3cfc951cdc 100644 --- a/datafusion/sqllogictest/test_files/spark/string/format_string.slt +++ b/datafusion/sqllogictest/test_files/spark/string/format_string.slt @@ -931,13 +931,13 @@ Char: NULL ## NULL with timestamp format using arrow_cast query T -SELECT format_string('Hour: %tH', arrow_cast(NULL, 'Timestamp(Nanosecond, None)')); +SELECT format_string('Hour: %tH', arrow_cast(NULL, 'Timestamp(ns)')); ---- Hour: null ## NULL with timestamp format using arrow_cast query T -SELECT format_string('Month: %tB', arrow_cast(NULL, 'Timestamp(Nanosecond, None)')); +SELECT format_string('Month: %tB', arrow_cast(NULL, 'Timestamp(ns)')); ---- Month: null @@ -967,25 +967,25 @@ Month: null ## NULL with timestamp format using arrow_cast query T -SELECT format_string('Month: %tB', arrow_cast(NULL, 'Timestamp(Second, None)')); +SELECT format_string('Month: %tB', arrow_cast(NULL, 'Timestamp(s)')); ---- Month: null ## NULL with timestamp format using arrow_cast query T -SELECT format_string('Month: %tB', arrow_cast(NULL, 'Timestamp(Millisecond, None)')); +SELECT format_string('Month: %tB', arrow_cast(NULL, 'Timestamp(ms)')); ---- Month: null ## NULL with timestamp format using arrow_cast query T -SELECT format_string('Month: %tB', arrow_cast(NULL, 'Timestamp(Microsecond, None)')); +SELECT format_string('Month: %tB', arrow_cast(NULL, 'Timestamp(µs)')); ---- Month: null ## NULL with timestamp format using arrow_cast query T -SELECT format_string('Month: %tB', arrow_cast(NULL, 'Timestamp(Nanosecond, None)')); +SELECT format_string('Month: %tB', arrow_cast(NULL, 'Timestamp(ns)')); ---- Month: null @@ -1051,7 +1051,7 @@ Value: null ## NULL Timestamp with string format using arrow_cast query T -SELECT format_string('Value: %s', arrow_cast(NULL, 'Timestamp(Nanosecond, None)')); +SELECT format_string('Value: %s', arrow_cast(NULL, 'Timestamp(ns)')); ---- Value: null @@ -1717,49 +1717,49 @@ String: 52245000000000 ## TimestampSecond with time formats query T -SELECT format_string('Year: %tY', arrow_cast(1703512245, 'Timestamp(Second, None)')); +SELECT format_string('Year: %tY', arrow_cast(1703512245, 'Timestamp(s)')); ---- Year: 2023 query T -SELECT format_string('Month: %tm', arrow_cast(1703512245, 'Timestamp(Second, None)')); +SELECT format_string('Month: %tm', arrow_cast(1703512245, 'Timestamp(s)')); ---- Month: 12 query T -SELECT format_string('String: %s', arrow_cast(1703512245, 'Timestamp(Second, None)')); +SELECT format_string('String: %s', arrow_cast(1703512245, 'Timestamp(s)')); ---- String: 1703512245 query T -SELECT format_string('String: %S', arrow_cast(1703512245, 'Timestamp(Second, None)')); +SELECT format_string('String: %S', arrow_cast(1703512245, 'Timestamp(s)')); ---- String: 1703512245 ## TimestampMillisecond with time formats query T -SELECT format_string('ISO Date: %tF', arrow_cast(1703512245000, 'Timestamp(Millisecond, None)')); +SELECT format_string('ISO Date: %tF', arrow_cast(1703512245000, 'Timestamp(ms)')); ---- ISO Date: 2023-12-25 query T -SELECT format_string('String: %s', arrow_cast(1703512245000, 'Timestamp(Millisecond, None)')); +SELECT format_string('String: %s', arrow_cast(1703512245000, 'Timestamp(ms)')); ---- String: 1703512245000 ## TimestampMicrosecond with time formats query T -SELECT format_string('Date: %tD', arrow_cast(1703512245000000, 'Timestamp(Microsecond, None)')); +SELECT format_string('Date: %tD', arrow_cast(1703512245000000, 'Timestamp(µs)')); ---- Date: 12/25/23 query T -SELECT format_string('String: %s', arrow_cast(1703512245000000, 'Timestamp(Microsecond, None)')); +SELECT format_string('String: %s', arrow_cast(1703512245000000, 'Timestamp(µs)')); ---- String: 1703512245000000 query T -SELECT format_string('String: %s', arrow_cast('2020-01-02 01:01:11.1234567890Z', 'Timestamp(Nanosecond, None)')); +SELECT format_string('String: %s', arrow_cast('2020-01-02 01:01:11.1234567890Z', 'Timestamp(ns)')); ---- String: 1577926871123456789 diff --git a/datafusion/sqllogictest/test_files/spark/string/substr.slt b/datafusion/sqllogictest/test_files/spark/string/substr.slt deleted file mode 100644 index 0942bdd86a4ef..0000000000000 --- a/datafusion/sqllogictest/test_files/spark/string/substr.slt +++ /dev/null @@ -1,37 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT substr('Spark SQL', -3); -## PySpark 3.5.5 Result: {'substr(Spark SQL, -3, 2147483647)': 'SQL', 'typeof(substr(Spark SQL, -3, 2147483647))': 'string', 'typeof(Spark SQL)': 'string', 'typeof(-3)': 'int'} -#query -#SELECT substr('Spark SQL'::string, -3::int); - -## Original Query: SELECT substr('Spark SQL', 5); -## PySpark 3.5.5 Result: {'substr(Spark SQL, 5, 2147483647)': 'k SQL', 'typeof(substr(Spark SQL, 5, 2147483647))': 'string', 'typeof(Spark SQL)': 'string', 'typeof(5)': 'int'} -#query -#SELECT substr('Spark SQL'::string, 5::int); - -## Original Query: SELECT substr('Spark SQL', 5, 1); -## PySpark 3.5.5 Result: {'substr(Spark SQL, 5, 1)': 'k', 'typeof(substr(Spark SQL, 5, 1))': 'string', 'typeof(Spark SQL)': 'string', 'typeof(5)': 'int', 'typeof(1)': 'int'} -#query -#SELECT substr('Spark SQL'::string, 5::int, 1::int); diff --git a/datafusion/sqllogictest/test_files/spark/string/substring.slt b/datafusion/sqllogictest/test_files/spark/string/substring.slt index 847ce4b6d4739..5bf2fdf2fb954 100644 --- a/datafusion/sqllogictest/test_files/spark/string/substring.slt +++ b/datafusion/sqllogictest/test_files/spark/string/substring.slt @@ -15,23 +15,189 @@ # specific language governing permissions and limitations # under the License. -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT substring('Spark SQL', -3); -## PySpark 3.5.5 Result: {'substring(Spark SQL, -3, 2147483647)': 'SQL', 'typeof(substring(Spark SQL, -3, 2147483647))': 'string', 'typeof(Spark SQL)': 'string', 'typeof(-3)': 'int'} -#query -#SELECT substring('Spark SQL'::string, -3::int); - -## Original Query: SELECT substring('Spark SQL', 5); -## PySpark 3.5.5 Result: {'substring(Spark SQL, 5, 2147483647)': 'k SQL', 'typeof(substring(Spark SQL, 5, 2147483647))': 'string', 'typeof(Spark SQL)': 'string', 'typeof(5)': 'int'} -#query -#SELECT substring('Spark SQL'::string, 5::int); - -## Original Query: SELECT substring('Spark SQL', 5, 1); -## PySpark 3.5.5 Result: {'substring(Spark SQL, 5, 1)': 'k', 'typeof(substring(Spark SQL, 5, 1))': 'string', 'typeof(Spark SQL)': 'string', 'typeof(5)': 'int', 'typeof(1)': 'int'} -#query -#SELECT substring('Spark SQL'::string, 5::int, 1::int); + +query T +SELECT substring('Spark SQL'::string, 0::int); +---- +Spark SQL + +query T +SELECT substring('Spark SQL'::string, 5::int); +---- +k SQL + +query T +SELECT substring('Spark SQL'::string, 3::int, 1::int); +---- +a + +# Test negative start +query T +SELECT substring('Spark SQL'::string, -3::int); +---- +SQL + +query T +SELECT substring('Spark SQL'::string, -3::int, 2::int); +---- +SQ + +# Test length exceeding string length +query T +SELECT substring('Spark SQL'::string, 2::int, 700::int); +---- +park SQL + +# Test start position beyond string length +query T +SELECT substring('Spark SQL'::string, 30::int); +---- +(empty) + +query T +SELECT substring('Spark SQL'::string, -30::int); +---- +Spark SQL + +# Test negative length +query T +SELECT substring('Spark SQL'::string, 3::int, -1::int); +---- +(empty) + +query T +SELECT substring('Spark SQL'::string, 3::int, 0::int); +---- +(empty) + +# Test unicode strings +query T +SELECT substring('joséésoj'::string, 5::int); +---- +ésoj + +query T +SELECT substring('joséésoj'::string, 5::int, 2::int); +---- +és + +# NULL handling +query T +SELECT substring('Spark SQL'::string, NULL::int); +---- +NULL + +query T +SELECT substring(NULL::string, 5::int); +---- +NULL + +query T +SELECT substring(NULL::string, 3::int, 1::int); +---- +NULL + +query T +SELECT substring('Spark SQL'::string, NULL::int, 1::int); +---- +NULL + +query T +SELECT substring('Spark SQL'::string, 3::int, NULL::int); +---- +NULL + +query T +SELECT substring(column1, column2) +FROM VALUES +('Spark SQL'::string, 0::int), +('Spark SQL'::string, 5::int), +('Spark SQL'::string, -3::int), +('Spark SQL'::string, 500::int), +('Spark SQL'::string, -300::int), +(NULL::string, 5::int), +('Spark SQL'::string, NULL::int); +---- +Spark SQL +k SQL +SQL +(empty) +Spark SQL +NULL +NULL + +query T +SELECT substring(column1, column2, column3) +FROM VALUES +('Spark SQL'::string, -3::int, 2::int), +('Spark SQL'::string, 3::int, 1::int), +('Spark SQL'::string, 3::int, 700::int), +('Spark SQL'::string, 3::int, -1::int), +('Spark SQL'::string, 3::int, 0::int), +('Spark SQL'::string, 300::int, 3::int), +('Spark SQL'::string, -300::int, 3::int), +(NULL::string, 3::int, 1::int), +('Spark SQL'::string, NULL::int, 1::int), +('Spark SQL'::string, 3::int, NULL::int); +---- +SQ +a +ark SQL +(empty) +(empty) +(empty) +Spa +NULL +NULL +NULL + +# alias substr + +query T +SELECT substr('Spark SQL'::string, 0::int); +---- +Spark SQL + +query T +SELECT substr(column1, column2) +FROM VALUES +('Spark SQL'::string, 0::int), +('Spark SQL'::string, 5::int), +('Spark SQL'::string, -3::int), +('Spark SQL'::string, 500::int), +('Spark SQL'::string, -300::int), +(NULL::string, 5::int), +('Spark SQL'::string, NULL::int); +---- +Spark SQL +k SQL +SQL +(empty) +Spark SQL +NULL +NULL + +query T +SELECT substr(column1, column2, column3) +FROM VALUES +('Spark SQL'::string, -3::int, 2::int), +('Spark SQL'::string, 3::int, 1::int), +('Spark SQL'::string, 3::int, 700::int), +('Spark SQL'::string, 3::int, -1::int), +('Spark SQL'::string, 3::int, 0::int), +('Spark SQL'::string, 300::int, 3::int), +('Spark SQL'::string, -300::int, 3::int), +(NULL::string, 3::int, 1::int), +('Spark SQL'::string, NULL::int, 1::int), +('Spark SQL'::string, 3::int, NULL::int); +---- +SQ +a +ark SQL +(empty) +(empty) +(empty) +Spa +NULL +NULL +NULL diff --git a/datafusion/sqllogictest/test_files/string/string_query.slt.part b/datafusion/sqllogictest/test_files/string/string_query.slt.part index a182ba8cde111..2884c3518610d 100644 --- a/datafusion/sqllogictest/test_files/string/string_query.slt.part +++ b/datafusion/sqllogictest/test_files/string/string_query.slt.part @@ -993,25 +993,27 @@ NULL NULL NULL NULL # Test FIND_IN_SET # -------------------------------------- -query IIII +query IIIIII SELECT FIND_IN_SET(ascii_1, 'a,b,c,d'), FIND_IN_SET(ascii_1, 'Andrew,Xiangpeng,Raphael'), FIND_IN_SET(unicode_1, 'a,b,c,d'), - FIND_IN_SET(unicode_1, 'datafusion📊🔥,datafusion数据融合,datafusionДатаФусион') + FIND_IN_SET(unicode_1, 'datafusion📊🔥,datafusion数据融合,datafusionДатаФусион'), + FIND_IN_SET(NULL, unicode_1), + FIND_IN_SET(unicode_1, NULL) FROM test_basic_operator; ---- -0 1 0 1 -0 2 0 2 -0 3 0 3 -0 0 0 0 -0 0 0 0 -0 0 0 0 -0 0 0 0 -0 0 0 0 -0 0 0 0 -NULL NULL NULL NULL -NULL NULL NULL NULL +0 1 0 1 NULL NULL +0 2 0 2 NULL NULL +0 3 0 3 NULL NULL +0 0 0 0 NULL NULL +0 0 0 0 NULL NULL +0 0 0 0 NULL NULL +0 0 0 0 NULL NULL +0 0 0 0 NULL NULL +0 0 0 0 NULL NULL +NULL NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL NULL # -------------------------------------- # Test || operator diff --git a/datafusion/sqllogictest/test_files/struct.slt b/datafusion/sqllogictest/test_files/struct.slt index d985af1104da3..e20815a58c765 100644 --- a/datafusion/sqllogictest/test_files/struct.slt +++ b/datafusion/sqllogictest/test_files/struct.slt @@ -38,9 +38,9 @@ CREATE TABLE struct_values ( s1 struct, s2 struct ) AS VALUES - (struct(1), struct(1, 'string1')), - (struct(2), struct(2, 'string2')), - (struct(3), struct(3, 'string3')) + (struct(1), struct(1 AS a, 'string1' AS b)), + (struct(2), struct(2 AS a, 'string2' AS b)), + (struct(3), struct(3 AS a, 'string3' AS b)) ; query ?? @@ -397,7 +397,8 @@ drop view complex_view; # struct with different keys r1 and r2 is not valid statement ok -create table t(a struct, b struct) as values (struct('red', 1), struct('blue', 2.3)); +create table t(a struct, b struct) as values + (struct('red' AS r1, 1 AS c), struct('blue' AS r2, 2.3 AS c)); # Expect same keys for struct type but got mismatched pair r1,c and r2,c query error @@ -408,7 +409,8 @@ drop table t; # struct with the same key statement ok -create table t(a struct, b struct) as values (struct('red', 1), struct('blue', 2.3)); +create table t(a struct, b struct) as values + (struct('red' AS r, 1 AS c), struct('blue' AS r, 2.3 AS c)); query T select arrow_typeof([a, b]) from t; @@ -442,9 +444,9 @@ CREATE TABLE struct_values ( s1 struct(a int, b varchar), s2 struct(a int, b varchar) ) AS VALUES - (row(1, 'red'), row(1, 'string1')), - (row(2, 'blue'), row(2, 'string2')), - (row(3, 'green'), row(3, 'string3')) + ({a: 1, b: 'red'}, {a: 1, b: 'string1'}), + ({a: 2, b: 'blue'}, {a: 2, b: 'string2'}), + ({a: 3, b: 'green'}, {a: 3, b: 'string3'}) ; statement ok @@ -452,8 +454,8 @@ drop table struct_values; statement ok create table t (c1 struct(r varchar, b int), c2 struct(r varchar, b float)) as values ( - row('red', 2), - row('blue', 2.3) + {r: 'red', b: 2}, + {r: 'blue', b: 2.3} ); query ?? @@ -492,9 +494,6 @@ Struct("r": Utf8, "c": Float64) statement ok drop table t; -query error DataFusion error: Optimizer rule 'simplify_expressions' failed[\s\S]*Arrow error: Cast error: Cannot cast string 'a' to value of Float64 type -create table t as values({r: 'a', c: 1}), ({c: 2.3, r: 'b'}); - ################################## ## Test Coalesce with Struct ################################## @@ -504,9 +503,9 @@ CREATE TABLE t ( s1 struct(a int, b varchar), s2 struct(a float, b varchar) ) AS VALUES - (row(1, 'red'), row(1.1, 'string1')), - (row(2, 'blue'), row(2.2, 'string2')), - (row(3, 'green'), row(33.2, 'string3')) + ({a: 1, b: 'red'}, {a: 1.1, b: 'string1'}), + ({a: 2, b: 'blue'}, {a: 2.2, b: 'string2'}), + ({a: 3, b: 'green'}, {a: 33.2, b: 'string3'}) ; query ? @@ -531,9 +530,9 @@ CREATE TABLE t ( s1 struct(a int, b varchar), s2 struct(a float, b varchar) ) AS VALUES - (row(1, 'red'), row(1.1, 'string1')), - (null, row(2.2, 'string2')), - (row(3, 'green'), row(33.2, 'string3')) + ({a: 1, b: 'red'}, {a: 1.1, b: 'string1'}), + (null, {a: 2.2, b: 'string2'}), + ({a: 3, b: 'green'}, {a: 33.2, b: 'string3'}) ; query ? @@ -553,16 +552,12 @@ Struct("a": Float32, "b": Utf8View) statement ok drop table t; -# row() with incorrect order +# row() with incorrect order - row() is positional, not name-based statement error DataFusion error: Optimizer rule 'simplify_expressions' failed[\s\S]*Arrow error: Cast error: Cannot cast string 'blue' to value of Float32 type create table t(a struct(r varchar, c int), b struct(r varchar, c float)) as values - (row('red', 1), row(2.3, 'blue')), - (row('purple', 1), row('green', 2.3)); + ({r: 'red', c: 1}, {r: 2.3, c: 'blue'}), + ({r: 'purple', c: 1}, {r: 'green', c: 2.3}); -# out of order struct literal -# TODO: This query should not fail -statement error DataFusion error: Optimizer rule 'simplify_expressions' failed[\s\S]*Arrow error: Cast error: Cannot cast string 'b' to value of Int32 type -create table t(a struct(r varchar, c int)) as values ({r: 'a', c: 1}), ({c: 2, r: 'b'}); ################################## ## Test Array of Struct @@ -573,12 +568,9 @@ select [{r: 'a', c: 1}, {r: 'b', c: 2}]; ---- [{r: a, c: 1}, {r: b, c: 2}] -# Can't create a list of struct with different field types -query error -select [{r: 'a', c: 1}, {c: 2, r: 'b'}]; statement ok -create table t(a struct(r varchar, c int), b struct(r varchar, c float)) as values (row('a', 1), row('b', 2.3)); +create table t(a struct(r varchar, c int), b struct(r varchar, c float)) as values ({r: 'a', c: 1}, {r: 'b', c: 2.3}); query T select arrow_typeof([a, b]) from t; @@ -588,27 +580,17 @@ List(Struct("r": Utf8View, "c": Float32)) statement ok drop table t; -# create table with different struct type is fine -statement ok -create table t(a struct(r varchar, c int), b struct(c float, r varchar)) as values (row('a', 1), row(2.3, 'b')); - -# create array with different struct type is not valid -query error -select arrow_typeof([a, b]) from t; - -statement ok -drop table t; statement ok -create table t(a struct(r varchar, c int, g float), b struct(r varchar, c float, g int)) as values (row('a', 1, 2.3), row('b', 2.3, 2)); +create table t(a struct(r varchar, c int, g float), b struct(r varchar, c float, g int)) as values ({r: 'a', c: 1, g: 2.3}, {r: 'b', c: 2.3, g: 2}); -# type of each column should not coerced but perserve as it is +# type of each column should not coerced but preserve as it is query T select arrow_typeof(a) from t; ---- Struct("r": Utf8View, "c": Int32, "g": Float32) -# type of each column should not coerced but perserve as it is +# type of each column should not coerced but preserve as it is query T select arrow_typeof(b) from t; ---- @@ -622,7 +604,7 @@ drop table t; # This tests accessing struct fields using the subscript notation with string literals statement ok -create table test (struct_field struct(substruct int)) as values (struct(1)); +create table test (struct_field struct(substruct int)) as values ({substruct: 1}); query ?? select * @@ -635,7 +617,7 @@ statement ok DROP TABLE test; statement ok -create table test (struct_field struct(substruct struct(subsubstruct int))) as values (struct(struct(1))); +create table test (struct_field struct(substruct struct(subsubstruct int))) as values ({substruct: {subsubstruct: 1}}); query ?? select * @@ -824,3 +806,864 @@ NULL statement ok drop table nullable_parent_test; + +# Test struct casting with field reordering - string fields +query ? +SELECT CAST({b: 'b_value', a: 'a_value'} AS STRUCT(a VARCHAR, b VARCHAR)); +---- +{a: a_value, b: b_value} + +# Test struct casting with field reordering - integer fields +query ? +SELECT CAST({b: 3, a: 4} AS STRUCT(a INT, b INT)); +---- +{a: 4, b: 3} + +# Test with type casting AND field reordering +query ? +SELECT CAST({b: 3, a: 4} AS STRUCT(a BIGINT, b INT)); +---- +{a: 4, b: 3} + +# Test casting with explicit field names +query ? +SELECT CAST({a: 1, b: 'x'} AS STRUCT(a INT, b VARCHAR)); +---- +{a: 1, b: x} + +# Test with missing field - should insert nulls +query ? +SELECT CAST({a: 1} AS STRUCT(a INT, b INT)); +---- +{a: 1, b: NULL} + +# Test with extra source field - should be ignored +query ? +SELECT CAST({a: 1, b: 2, extra: 3} AS STRUCT(a INT, b INT)); +---- +{a: 1, b: 2} + +# Test no overlap with mismatched field count - should fail because no field names match +statement error DataFusion error: (Plan error|Error during planning|This feature is not implemented): (Cannot cast struct: at least one field name must match between source and target|Cannot cast struct with 3 fields to 2 fields without name overlap|Unsupported CAST from Struct) +SELECT CAST(struct(1, 'x', 'y') AS STRUCT(a INT, b VARCHAR)); + +# Test nested struct with field reordering +query ? +SELECT CAST( + {inner: {y: 2, x: 1}} + AS STRUCT(inner STRUCT(x INT, y INT)) +); +---- +{inner: {x: 1, y: 2}} + +# Test field reordering with table data +statement ok +CREATE TABLE struct_reorder_test ( + data STRUCT(b INT, a VARCHAR) +) AS VALUES + ({b: 100, a: 'first'}), + ({b: 200, a: 'second'}), + ({b: 300, a: 'third'}) +; + +query ? +SELECT CAST(data AS STRUCT(a VARCHAR, b INT)) AS casted_data FROM struct_reorder_test ORDER BY data['b']; +---- +{a: first, b: 100} +{a: second, b: 200} +{a: third, b: 300} + +statement ok +drop table struct_reorder_test; + +# Test casting struct with multiple levels of nesting and reordering +query ? +SELECT CAST( + {level1: {z: 100, y: 'inner', x: 1}} + AS STRUCT(level1 STRUCT(x INT, y VARCHAR, z INT)) +); +---- +{level1: {x: 1, y: inner, z: 100}} + +# Test field reordering with nulls in source +query ? +SELECT CAST( + {b: NULL::INT, a: 42} + AS STRUCT(a INT, b INT) +); +---- +{a: 42, b: NULL} + +# Test casting preserves struct-level nulls +query ? +SELECT CAST(NULL::STRUCT(b INT, a INT) AS STRUCT(a INT, b INT)); +---- +NULL + +############################ +# Implicit Coercion Tests with CREATE TABLE AS VALUES +############################ + +# Test implicit coercion with same field order, different types +statement ok +create table t as values({r: 'a', c: 1}), ({r: 'b', c: 2.3}); + +query T +select arrow_typeof(column1) from t limit 1; +---- +Struct("r": Utf8, "c": Float64) + +query ? +select * from t order by column1.r; +---- +{r: a, c: 1.0} +{r: b, c: 2.3} + +statement ok +drop table t; + +# Test implicit coercion with nullable fields (same order) +statement ok +create table t as values({a: 1, b: 'x'}), ({a: 2, b: 'y'}); + +query T +select arrow_typeof(column1) from t limit 1; +---- +Struct("a": Int64, "b": Utf8) + +query ? +select * from t order by column1.a; +---- +{a: 1, b: x} +{a: 2, b: y} + +statement ok +drop table t; + +# Test implicit coercion with nested structs (same field order) +statement ok +create table t as + select {outer: {x: 1, y: 2}} as column1 + union all + select {outer: {x: 3, y: 4}}; + +query T +select arrow_typeof(column1) from t limit 1; +---- +Struct("outer": Struct("x": Int64, "y": Int64)) + +query ? +select column1 from t order by column1.outer.x; +---- +{outer: {x: 1, y: 2}} +{outer: {x: 3, y: 4}} + +statement ok +drop table t; + +# Test implicit coercion with type widening (Int32 -> Int64) +statement ok +create table t as values({id: 1, val: 100}), ({id: 2, val: 9223372036854775807}); + +query T +select arrow_typeof(column1) from t limit 1; +---- +Struct("id": Int64, "val": Int64) + +query ? +select * from t order by column1.id; +---- +{id: 1, val: 100} +{id: 2, val: 9223372036854775807} + +statement ok +drop table t; + +# Test implicit coercion with nested struct and type coercion +statement ok +create table t as + select {name: 'Alice', data: {score: 100, active: true}} as column1 + union all + select {name: 'Bob', data: {score: 200, active: false}}; + +query T +select arrow_typeof(column1) from t limit 1; +---- +Struct("name": Utf8, "data": Struct("score": Int64, "active": Boolean)) + +query ? +select column1 from t order by column1.name; +---- +{name: Alice, data: {score: 100, active: true}} +{name: Bob, data: {score: 200, active: false}} + +statement ok +drop table t; + +############################ +# Field Reordering Tests (using explicit CAST) +############################ + +# Test explicit cast with field reordering in VALUES - basic case +query ? +select CAST({c: 2.3, r: 'b'} AS STRUCT(r VARCHAR, c FLOAT)); +---- +{r: b, c: 2.3} + +# Test explicit cast with field reordering - multiple rows +query ? +select * from (values + (CAST({c: 1, r: 'a'} AS STRUCT(r VARCHAR, c FLOAT))), + (CAST({c: 2.3, r: 'b'} AS STRUCT(r VARCHAR, c FLOAT))) +) order by column1.r; +---- +{r: a, c: 1.0} +{r: b, c: 2.3} + +# Test table with explicit cast for field reordering +statement ok +create table t as select CAST({c: 1, r: 'a'} AS STRUCT(r VARCHAR, c FLOAT)) as s +union all +select CAST({c: 2.3, r: 'b'} AS STRUCT(r VARCHAR, c FLOAT)); + +query T +select arrow_typeof(s) from t limit 1; +---- +Struct("r": Utf8View, "c": Float32) + +query ? +select * from t order by s.r; +---- +{r: a, c: 1.0} +{r: b, c: 2.3} + +statement ok +drop table t; + +# Test field reordering with nullable fields using CAST +query ? +select CAST({b: NULL, a: 42} AS STRUCT(a INT, b INT)); +---- +{a: 42, b: NULL} + +# Test field reordering with nested structs using CAST +query ? +select CAST({outer: {y: 4, x: 3}} AS STRUCT(outer STRUCT(x INT, y INT))); +---- +{outer: {x: 3, y: 4}} + +# Test complex nested field reordering +query ? +select CAST( + {data: {active: false, score: 200}, name: 'Bob'} + AS STRUCT(name VARCHAR, data STRUCT(score INT, active BOOLEAN)) +); +---- +{name: Bob, data: {score: 200, active: false}} + +############################ +# Array Literal Tests with Struct Field Reordering (Implicit Coercion) +############################ + +# Test array literal with reordered struct fields - implicit coercion by name +# Field order in unified schema is determined during type coercion +query ? +select [{r: 'a', c: 1}, {c: 2.3, r: 'b'}]; +---- +[{c: 1.0, r: a}, {c: 2.3, r: b}] + +# Test array literal with same-named fields but different order +# Fields are reordered during coercion +query ? +select [{a: 1, b: 2}, {b: 3, a: 4}]; +---- +[{b: 2, a: 1}, {b: 3, a: 4}] + +# Test array literal with explicit cast to unify struct schemas with partial overlap +# Use CAST to explicitly unify schemas when fields don't match completely +query ? +select [ + CAST({a: 1, b: 2} AS STRUCT(a INT, b INT, c INT)), + CAST({b: 3, c: 4} AS STRUCT(a INT, b INT, c INT)) +]; +---- +[{a: 1, b: 2, c: NULL}, {a: NULL, b: 3, c: 4}] + +# Test NULL handling in array literals with reordered but matching fields +query ? +select [{a: NULL, b: 1}, {b: 2, a: NULL}]; +---- +[{b: 1, a: NULL}, {b: 2, a: NULL}] + +# Verify arrow_typeof for array with reordered struct fields +# The unified schema type follows the coercion order +query T +select arrow_typeof([{x: 1, y: 2}, {y: 3, x: 4}]); +---- +List(Struct("y": Int64, "x": Int64)) + +# Test array of structs with matching nested fields in different order +# Inner nested fields are also reordered during coercion +query ? +select [ + {id: 1, info: {name: 'Alice', age: 30}}, + {info: {age: 25, name: 'Bob'}, id: 2} +]; +---- +[{info: {age: 30, name: Alice}, id: 1}, {info: {age: 25, name: Bob}, id: 2}] + +# Test nested arrays with matching struct fields (different order) +query ? +select [[{x: 1, y: 2}], [{y: 4, x: 3}]]; +---- +[[{x: 1, y: 2}], [{x: 3, y: 4}]] + +# Test array literal with float type coercion across elements +query ? +select [{val: 1}, {val: 2.5}]; +---- +[{val: 1.0}, {val: 2.5}] + +############################ +# Dynamic Array Construction Tests (from Table Columns) +############################ + +# Setup test table with struct columns for dynamic array construction +statement ok +create table t_complete_overlap ( + s1 struct(x int, y int), + s2 struct(y int, x int) +) as values + ({x: 1, y: 2}, {y: 3, x: 4}), + ({x: 5, y: 6}, {y: 7, x: 8}); + +# Test 1: Complete overlap - same fields, different order +# Verify arrow_typeof for dynamically created array +query T +select arrow_typeof([s1, s2]) from t_complete_overlap limit 1; +---- +List(Struct("y": Int32, "x": Int32)) + +# Verify values are correctly mapped by name in the array +# Field order follows the second column's field order +query ? +select [s1, s2] from t_complete_overlap order by s1.x; +---- +[{y: 2, x: 1}, {y: 3, x: 4}] +[{y: 6, x: 5}, {y: 7, x: 8}] + +statement ok +drop table t_complete_overlap; + +# Test 2: Partial overlap - some shared fields between columns +# Note: Columns must have the exact same field set for array construction to work +# Test with identical field set (all fields present in both columns) +statement ok +create table t_partial_overlap ( + col_a struct(name VARCHAR, age int, active boolean), + col_b struct(age int, name VARCHAR, active boolean) +) as values + ({name: 'Alice', age: 30, active: true}, {age: 25, name: 'Bob', active: false}), + ({name: 'Charlie', age: 35, active: true}, {age: 40, name: 'Diana', active: false}); + +# Verify unified type includes all fields from both structs +query T +select arrow_typeof([col_a, col_b]) from t_partial_overlap limit 1; +---- +List(Struct("age": Int32, "name": Utf8View, "active": Boolean)) + +# Verify values are correctly mapped by name in the array +# Field order follows the second column's field order +query ? +select [col_a, col_b] from t_partial_overlap order by col_a.name; +---- +[{age: 30, name: Alice, active: true}, {age: 25, name: Bob, active: false}] +[{age: 35, name: Charlie, active: true}, {age: 40, name: Diana, active: false}] + +statement ok +drop table t_partial_overlap; + +# Test 3: Complete field set matching (no CAST needed) +# Schemas already align; confirm unified type and values +statement ok +create table t_with_cast ( + col_x struct(id int, description VARCHAR), + col_y struct(id int, description VARCHAR) +) as values + ({id: 1, description: 'First'}, {id: 10, description: 'First Value'}), + ({id: 2, description: 'Second'}, {id: 20, description: 'Second Value'}); + +# Verify type unification with all fields +query T +select arrow_typeof([col_x, col_y]) from t_with_cast limit 1; +---- +List(Struct("id": Int32, "description": Utf8View)) + +# Verify values remain aligned by name +query ? +select [col_x, col_y] from t_with_cast order by col_x.id; +---- +[{id: 1, description: First}, {id: 10, description: First Value}] +[{id: 2, description: Second}, {id: 20, description: Second Value}] + +statement ok +drop table t_with_cast; + +# Test 4: Explicit CAST for partial field overlap scenarios +# When columns have different field sets, use explicit CAST to unify schemas +query ? +select [ + CAST({id: 1} AS STRUCT(id INT, description VARCHAR)), + CAST({id: 10, description: 'Value'} AS STRUCT(id INT, description VARCHAR)) +]; +---- +[{id: 1, description: NULL}, {id: 10, description: Value}] + +# Test 5: Complex nested structs with field reordering +# Nested fields must have the exact same field set for array construction +statement ok +create table t_nested ( + col_1 struct(id int, outer struct(x int, y int)), + col_2 struct(id int, outer struct(x int, y int)) +) as values + ({id: 100, outer: {x: 1, y: 2}}, {id: 101, outer: {x: 4, y: 3}}), + ({id: 200, outer: {x: 5, y: 6}}, {id: 201, outer: {x: 8, y: 7}}); + +# Verify nested struct in unified schema +query T +select arrow_typeof([col_1, col_2]) from t_nested limit 1; +---- +List(Struct("id": Int32, "outer": Struct("x": Int32, "y": Int32))) + +# Verify nested field values are correctly mapped +query ? +select [col_1, col_2] from t_nested order by col_1.id; +---- +[{id: 100, outer: {x: 1, y: 2}}, {id: 101, outer: {x: 4, y: 3}}] +[{id: 200, outer: {x: 5, y: 6}}, {id: 201, outer: {x: 8, y: 7}}] + +statement ok +drop table t_nested; + +# Test 6: NULL handling with matching field sets +statement ok +create table t_nulls ( + col_a struct(val int, flag boolean), + col_b struct(val int, flag boolean) +) as values + ({val: 1, flag: true}, {val: 10, flag: false}), + ({val: NULL, flag: false}, {val: NULL, flag: true}); + +# Verify NULL values are preserved +query ? +select [col_a, col_b] from t_nulls order by col_a.val; +---- +[{val: 1, flag: true}, {val: 10, flag: false}] +[{val: NULL, flag: false}, {val: NULL, flag: true}] + +statement ok +drop table t_nulls; + +# Test 7: Multiple columns with complete field matching +statement ok +create table t_multi ( + col1 struct(a int, b int, c int), + col2 struct(a int, b int, c int) +) as values + ({a: 1, b: 2, c: 3}, {a: 10, b: 20, c: 30}), + ({a: 4, b: 5, c: 6}, {a: 40, b: 50, c: 60}); + +# Verify array with complete field matching +query T +select arrow_typeof([col1, col2]) from t_multi limit 1; +---- +List(Struct("a": Int32, "b": Int32, "c": Int32)) + +# Verify values are correctly unified +query ? +select [col1, col2] from t_multi order by col1.a; +---- +[{a: 1, b: 2, c: 3}, {a: 10, b: 20, c: 30}] +[{a: 4, b: 5, c: 6}, {a: 40, b: 50, c: 60}] + +statement ok +drop table t_multi; + +############################ +# Comprehensive Implicit Struct Coercion Suite +############################ + +# Test 1: VALUES clause with field reordering coerced by name into declared schema +statement ok +create table implicit_values_reorder ( + s struct(a int, b int) +) as values + ({a: 1, b: 2}), + ({b: 3, a: 4}); + +query T +select arrow_typeof(s) from implicit_values_reorder limit 1; +---- +Struct("a": Int32, "b": Int32) + +query ? +select s from implicit_values_reorder order by s.a; +---- +{a: 1, b: 2} +{a: 4, b: 3} + +statement ok +drop table implicit_values_reorder; + +# Test 2: Array literal coercion with reordered struct fields +query IIII +select + [{a: 1, b: 2}, {b: 3, a: 4}][1]['a'], + [{a: 1, b: 2}, {b: 3, a: 4}][1]['b'], + [{a: 1, b: 2}, {b: 3, a: 4}][2]['a'], + [{a: 1, b: 2}, {b: 3, a: 4}][2]['b']; +---- +1 2 4 3 + +# Test 3: Array construction from columns with reordered struct fields +statement ok +create table struct_columns_order ( + s1 struct(a int, b int), + s2 struct(b int, a int) +) as values + ({a: 1, b: 2}, {b: 3, a: 4}), + ({a: 5, b: 6}, {b: 7, a: 8}); + +query IIII +select + [s1, s2][1]['a'], + [s1, s2][1]['b'], + [s1, s2][2]['a'], + [s1, s2][2]['b'] +from struct_columns_order +order by s1['a']; +---- +1 2 4 3 +5 6 8 7 + +statement ok +drop table struct_columns_order; + +# Test 4: UNION with struct field reordering +query II +select s['a'], s['b'] +from ( + select {a: 1, b: 2} as s + union all + select {b: 3, a: 4} as s +) t +order by s['a']; +---- +1 2 +4 3 + +# Test 5: CTE with struct coercion across branches +query II +with + t1 as (select {a: 1, b: 2} as s), + t2 as (select {b: 3, a: 4} as s) +select t1.s['a'], t1.s['b'] from t1 +union all +select t2.s['a'], t2.s['b'] from t2 +order by 1; +---- +1 2 +4 3 + +# Test 6: Struct aggregation retains name-based mapping +statement ok +create table agg_structs_reorder ( + k int, + s struct(x int, y int) +) as values + (1, {x: 1, y: 2}), + (1, {y: 3, x: 4}), + (2, {x: 5, y: 6}); + +query I? +select k, array_agg(s) from agg_structs_reorder group by k order by k; +---- +1 [{x: 1, y: 2}, {x: 4, y: 3}] +2 [{x: 5, y: 6}] + +statement ok +drop table agg_structs_reorder; + +# Test 7: Nested struct coercion with reordered inner fields +query IIII +with nested as ( + select [{outer: {inner: 1, value: 2}}, {outer: {value: 3, inner: 4}}] as arr +) +select + arr[1]['outer']['inner'], + arr[1]['outer']['value'], + arr[2]['outer']['inner'], + arr[2]['outer']['value'] +from nested; +---- +1 2 4 3 + +# Test 8: Partial name overlap - currently errors (field count mismatch detected) +# This is a documented limitation: structs must have exactly same field set for coercion +query error DataFusion error: Error during planning: Inconsistent data type across values list +select column1 from (values ({a: 1, b: 2}), ({b: 3, c: 4})) order by column1['a']; + +# Negative test: mismatched struct field counts are rejected (documented limitation) +query error DataFusion error: .* +select [{a: 1}, {a: 2, b: 3}]; + +# Test 9: INSERT with name-based struct coercion into target schema +statement ok +create table target_struct_insert (s struct(a int, b int)); + +statement ok +insert into target_struct_insert values ({b: 1, a: 2}); + +query ? +select s from target_struct_insert; +---- +{a: 2, b: 1} + +statement ok +drop table target_struct_insert; + +# Test 10: CASE expression with different struct field orders +query II +select + (case when true then {a: 1, b: 2} else {b: 3, a: 4} end)['a'] as a_val, + (case when true then {a: 1, b: 2} else {b: 3, a: 4} end)['b'] as b_val; +---- +1 2 + +############################ +# JOIN Coercion Tests +############################ + +# Test: Struct coercion in JOIN ON condition +statement ok +create table t_left ( + id int, + s struct(x int, y int) +) as values + (1, {x: 1, y: 2}), + (2, {x: 3, y: 4}); + +statement ok +create table t_right ( + id int, + s struct(y int, x int) +) as values + (1, {y: 2, x: 1}), + (2, {y: 4, x: 3}); + +# JOIN on reordered struct fields - matched by name +query IIII +select t_left.id, t_left.s['x'], t_left.s['y'], t_right.id +from t_left +join t_right on t_left.s = t_right.s +order by t_left.id; +---- +1 1 2 1 +2 3 4 2 + +statement ok +drop table t_left; + +statement ok +drop table t_right; + +# Test: Struct coercion with filtered JOIN +statement ok +create table orders ( + order_id int, + customer struct(name varchar, id int) +) as values + (1, {name: 'Alice', id: 100}), + (2, {name: 'Bob', id: 101}), + (3, {name: 'Charlie', id: 102}); + +statement ok +create table customers ( + customer_id int, + info struct(id int, name varchar) +) as values + (100, {id: 100, name: 'Alice'}), + (101, {id: 101, name: 'Bob'}), + (103, {id: 103, name: 'Diana'}); + +# Join with struct field reordering - names matched, not positions +query I +select count(*) from orders +join customers on orders.customer = customers.info +where orders.order_id <= 2; +---- +2 + +statement ok +drop table orders; + +statement ok +drop table customers; + +############################ +# WHERE Predicate Coercion Tests +############################ + +# Test: Struct equality in WHERE clause with field reordering +statement ok +create table t_where ( + id int, + s struct(x int, y int) +) as values + (1, {x: 1, y: 2}), + (2, {x: 3, y: 4}), + (3, {x: 1, y: 2}); + +# WHERE clause with struct comparison - coerced by name +query I +select id from t_where +where s = {y: 2, x: 1} +order by id; +---- +1 +3 + +statement ok +drop table t_where; + +# Test: Struct IN clause with reordering +statement ok +create table t_in ( + id int, + s struct(a int, b varchar) +) as values + (1, {a: 1, b: 'x'}), + (2, {a: 2, b: 'y'}), + (3, {a: 1, b: 'x'}); + +# IN clause with reordered struct literals +query I +select id from t_in +where s in ({b: 'x', a: 1}, {b: 'y', a: 2}) +order by id; +---- +1 +2 +3 + +statement ok +drop table t_in; + +# Test: Struct BETWEEN (not supported, but documents limitation) +# Structs don't support BETWEEN, but can use comparison operators + +statement ok +create table t_between ( + id int, + s struct(val int) +) as values + (1, {val: 10}), + (2, {val: 20}), + (3, {val: 30}); + +# Comparison via field extraction works +query I +select id from t_between +where s['val'] >= 20 +order by id; +---- +2 +3 + +statement ok +drop table t_between; + +############################ +# Window Function Coercion Tests +############################ + +# Test: Struct in window function PARTITION BY +statement ok +create table t_window ( + id int, + s struct(category int, value int) +) as values + (1, {category: 1, value: 10}), + (2, {category: 1, value: 20}), + (3, {category: 2, value: 30}), + (4, {category: 2, value: 40}); + +# Window partition on struct field via extraction +query III +select + id, + s['value'], + row_number() over (partition by s['category'] order by s['value']) +from t_window +order by id; +---- +1 10 1 +2 20 2 +3 30 1 +4 40 2 + +statement ok +drop table t_window; + +# Test: Struct in window function ORDER BY with coercion +statement ok +create table t_rank ( + id int, + s struct(rank_val int, group_id int) +) as values + (1, {rank_val: 10, group_id: 1}), + (2, {rank_val: 20, group_id: 1}), + (3, {rank_val: 15, group_id: 2}); + +# Window ranking with struct field extraction +query III +select + id, + s['rank_val'], + rank() over (partition by s['group_id'] order by s['rank_val']) +from t_rank +order by id; +---- +1 10 1 +2 20 2 +3 15 1 + +statement ok +drop table t_rank; + +# Test: Aggregate function with struct coercion across window partitions +statement ok +create table t_agg_window ( + id int, + partition_id int, + s struct(amount int) +) as values + (1, 1, {amount: 100}), + (2, 1, {amount: 200}), + (3, 2, {amount: 150}); + +# Running sum via extracted struct field +query III +select + id, + partition_id, + sum(s['amount']) over (partition by partition_id order by id) +from t_agg_window +order by id; +---- +1 1 100 +2 1 300 +3 2 150 + +statement ok +drop table t_agg_window; \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index da0bfc89d5848..9c7c2ddb5d85c 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -430,7 +430,7 @@ SELECT t1_id, t1_name, t1_int, (select t2_id, t2_name FROM t2 WHERE t2.t2_id = t #subquery_not_allowed #In/Exist Subquery is not allowed in ORDER BY clause. -statement error DataFusion error: Invalid \(non-executable\) plan after Analyzer\ncaused by\nError during planning: In/Exist subquery can only be used in Projection, Filter, TableScan, Window functions, Aggregate and Join plan nodes, but was used in \[Sort: t1.t1_int IN \(\) ASC NULLS LAST\] +statement error DataFusion error: Invalid \(non-executable\) plan after Analyzer\ncaused by\nError during planning: In/Exist/SetComparison subquery can only be used in Projection, Filter, TableScan, Window functions, Aggregate and Join plan nodes, but was used in \[Sort: t1.t1_int IN \(\) ASC NULLS LAST\] SELECT t1_id, t1_name, t1_int FROM t1 order by t1_int in (SELECT t2_int FROM t2 WHERE t1.t1_id > t1.t1_int) #non_aggregated_correlated_scalar_subquery @@ -1469,3 +1469,198 @@ logical_plan statement count 0 drop table person; + +# Set comparison subqueries (ANY/ALL) +statement ok +create table set_cmp_t(v int) as values (1), (6), (10); + +statement ok +create table set_cmp_s(v int) as values (5), (null); + +statement ok +create table set_cmp_empty(v int); + +query I rowsort +select v from set_cmp_t where v > any(select v from set_cmp_s); +---- +10 +6 + +query I rowsort +select v from set_cmp_t where v < all(select v from set_cmp_empty); +---- +1 +10 +6 + +statement count 0 +drop table set_cmp_t; + +statement count 0 +drop table set_cmp_s; + +statement count 0 +drop table set_cmp_empty; + +query TT +explain select v from (values (1), (6), (10)) set_cmp_t(v) where v > any(select v from (values (5), (null)) set_cmp_s(v)); +---- +logical_plan +01)Projection: set_cmp_t.v +02)--Filter: __correlated_sq_1.mark OR __correlated_sq_2.mark AND NOT __correlated_sq_3.mark AND Boolean(NULL) +03)----LeftMark Join: Filter: set_cmp_t.v > __correlated_sq_3.v IS TRUE +04)------Filter: __correlated_sq_1.mark OR __correlated_sq_2.mark AND Boolean(NULL) +05)--------LeftMark Join: Filter: set_cmp_t.v > __correlated_sq_2.v IS NULL +06)----------Filter: __correlated_sq_1.mark OR Boolean(NULL) +07)------------LeftMark Join: Filter: set_cmp_t.v > __correlated_sq_1.v IS TRUE +08)--------------SubqueryAlias: set_cmp_t +09)----------------Projection: column1 AS v +10)------------------Values: (Int64(1)), (Int64(6)), (Int64(10)) +11)--------------SubqueryAlias: __correlated_sq_1 +12)----------------SubqueryAlias: set_cmp_s +13)------------------Projection: column1 AS v +14)--------------------Values: (Int64(5)), (Int64(NULL)) +15)----------SubqueryAlias: __correlated_sq_2 +16)------------SubqueryAlias: set_cmp_s +17)--------------Projection: column1 AS v +18)----------------Values: (Int64(5)), (Int64(NULL)) +19)------SubqueryAlias: __correlated_sq_3 +20)--------SubqueryAlias: set_cmp_s +21)----------Projection: column1 AS v +22)------------Values: (Int64(5)), (Int64(NULL)) + +# correlated_recursive_scalar_subquery_with_level_3_exists_subquery_referencing_level1_relation +query TT +explain select c_custkey from customer +where c_acctbal < ( + select sum(o_totalprice) from orders + where o_custkey = c_custkey + and exists ( + select * from lineitem where l_orderkey = o_orderkey + and l_extendedprice < c_acctbal + ) +) order by c_custkey; +---- +logical_plan +01)Sort: customer.c_custkey ASC NULLS LAST +02)--Projection: customer.c_custkey +03)----Inner Join: customer.c_custkey = __scalar_sq_2.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_2.sum(orders.o_totalprice) +04)------TableScan: customer projection=[c_custkey, c_acctbal] +05)------SubqueryAlias: __scalar_sq_2 +06)--------Projection: sum(orders.o_totalprice), orders.o_custkey +07)----------Aggregate: groupBy=[[orders.o_custkey]], aggr=[[sum(orders.o_totalprice)]] +08)------------Projection: orders.o_custkey, orders.o_totalprice +09)--------------LeftSemi Join: orders.o_orderkey = __correlated_sq_1.l_orderkey Filter: __correlated_sq_1.l_extendedprice < customer.c_acctbal +10)----------------TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice] +11)----------------SubqueryAlias: __correlated_sq_1 +12)------------------TableScan: lineitem projection=[l_orderkey, l_extendedprice] + +# correlated_recursive_scalar_subquery_with_level_3_in_subquery_referencing_level1_relation +query TT +explain select c_custkey from customer +where c_acctbal < ( + select sum(o_totalprice) from orders + where o_custkey = c_custkey + and o_totalprice in ( + select l_extendedprice as price from lineitem where l_orderkey = o_orderkey + and l_extendedprice < c_acctbal + ) +) order by c_custkey; +---- +logical_plan +01)Sort: customer.c_custkey ASC NULLS LAST +02)--Projection: customer.c_custkey +03)----Inner Join: customer.c_custkey = __scalar_sq_2.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_2.sum(orders.o_totalprice) +04)------TableScan: customer projection=[c_custkey, c_acctbal] +05)------SubqueryAlias: __scalar_sq_2 +06)--------Projection: sum(orders.o_totalprice), orders.o_custkey +07)----------Aggregate: groupBy=[[orders.o_custkey]], aggr=[[sum(orders.o_totalprice)]] +08)------------Projection: orders.o_custkey, orders.o_totalprice +09)--------------LeftSemi Join: orders.o_totalprice = __correlated_sq_1.price, orders.o_orderkey = __correlated_sq_1.l_orderkey Filter: __correlated_sq_1.l_extendedprice < customer.c_acctbal +10)----------------TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice] +11)----------------SubqueryAlias: __correlated_sq_1 +12)------------------Projection: lineitem.l_extendedprice AS price, lineitem.l_extendedprice, lineitem.l_orderkey +13)--------------------TableScan: lineitem projection=[l_orderkey, l_extendedprice] + +# Setup tables for recursive correlation tests +statement ok +CREATE TABLE employees ( + employee_id INTEGER, + employee_name VARCHAR, + dept_id INTEGER, + salary DECIMAL +); + +statement ok +CREATE TABLE project_assignments ( + project_id INTEGER, + employee_id INTEGER, + priority INTEGER +); + +# Provided recursive scalar subquery explain case +query TT +EXPLAIN SELECT e1.employee_name, e1.salary +FROM employees e1 +WHERE e1.salary > ( + SELECT AVG(e2.salary) + FROM employees e2 + WHERE e2.dept_id = e1.dept_id + AND e2.salary > ( + SELECT AVG(e3.salary) + FROM employees e3 + WHERE e3.dept_id = e1.dept_id + ) +); +---- +logical_plan +01)Projection: e1.employee_name, e1.salary +02)--Inner Join: e1.dept_id = __scalar_sq_1.dept_id Filter: CAST(e1.salary AS Decimal128(38, 14)) > __scalar_sq_1.avg(e2.salary) +03)----SubqueryAlias: e1 +04)------TableScan: employees projection=[employee_name, dept_id, salary] +05)----SubqueryAlias: __scalar_sq_1 +06)------Projection: avg(e2.salary), e2.dept_id +07)--------Aggregate: groupBy=[[e2.dept_id]], aggr=[[avg(e2.salary)]] +08)----------Projection: e2.dept_id, e2.salary +09)------------Inner Join: Filter: CAST(e2.salary AS Decimal128(38, 14)) > __scalar_sq_2.avg(e3.salary) AND __scalar_sq_2.dept_id = e1.dept_id +10)--------------SubqueryAlias: e2 +11)----------------TableScan: employees projection=[dept_id, salary] +12)--------------SubqueryAlias: __scalar_sq_2 +13)----------------Projection: avg(e3.salary), e3.dept_id +14)------------------Aggregate: groupBy=[[e3.dept_id]], aggr=[[avg(e3.salary)]] +15)--------------------SubqueryAlias: e3 +16)----------------------TableScan: employees projection=[dept_id, salary] + +# Check shadowing: `dept_id` should resolve to the nearest outer relation (`e2`) +# in the innermost subquery rather than the outermost +query TT +EXPLAIN SELECT e1.employee_id +FROM employees e1 +WHERE EXISTS ( + SELECT 1 + FROM employees e2 + WHERE EXISTS ( + SELECT 1 + FROM project_assignments p + WHERE p.project_id = dept_id + ) +); +---- +logical_plan +01)LeftSemi Join: +02)--SubqueryAlias: e1 +03)----TableScan: employees projection=[employee_id] +04)--SubqueryAlias: __correlated_sq_2 +05)----Projection: +06)------LeftSemi Join: e2.dept_id = __correlated_sq_1.project_id +07)--------SubqueryAlias: e2 +08)----------TableScan: employees projection=[dept_id] +09)--------SubqueryAlias: __correlated_sq_1 +10)----------SubqueryAlias: p +11)------------TableScan: project_assignments projection=[project_id] + +statement count 0 +drop table employees; + +statement count 0 +drop table project_assignments; diff --git a/datafusion/sqllogictest/test_files/table_functions.slt b/datafusion/sqllogictest/test_files/table_functions.slt index cf8a091880d3d..f0e00ffc69233 100644 --- a/datafusion/sqllogictest/test_files/table_functions.slt +++ b/datafusion/sqllogictest/test_files/table_functions.slt @@ -160,17 +160,20 @@ physical_plan LazyMemoryExec: partitions=1, batch_generators=[generate_series: s # Test generate_series with invalid arguments # -query error DataFusion error: Error during planning: Start is bigger than end, but increment is positive: Cannot generate infinite series +query I SELECT * FROM generate_series(5, 1) +---- -query error DataFusion error: Error during planning: Start is smaller than end, but increment is negative: Cannot generate infinite series +query I SELECT * FROM generate_series(-6, 6, -1) +---- query error DataFusion error: Error during planning: Step cannot be zero SELECT * FROM generate_series(-6, 6, 0) -query error DataFusion error: Error during planning: Start is bigger than end, but increment is positive: Cannot generate infinite series +query I SELECT * FROM generate_series(6, -6, 1) +---- statement error DataFusion error: Error during planning: generate_series function requires 1 to 3 arguments @@ -298,17 +301,20 @@ physical_plan LazyMemoryExec: partitions=1, batch_generators=[range: start=1, en # Test range with invalid arguments # -query error DataFusion error: Error during planning: Start is bigger than end, but increment is positive: Cannot generate infinite series +query I SELECT * FROM range(5, 1) +---- -query error DataFusion error: Error during planning: Start is smaller than end, but increment is negative: Cannot generate infinite series +query I SELECT * FROM range(-6, 6, -1) +---- query error DataFusion error: Error during planning: Step cannot be zero SELECT * FROM range(-6, 6, 0) -query error DataFusion error: Error during planning: Start is bigger than end, but increment is positive: Cannot generate infinite series +query I SELECT * FROM range(6, -6, 1) +---- statement error DataFusion error: Error during planning: range function requires 1 to 3 arguments @@ -378,11 +384,13 @@ SELECT * FROM range(TIMESTAMP '2023-01-03T00:00:00', TIMESTAMP '2023-01-01T00:00 2023-01-03T00:00:00 2023-01-02T00:00:00 -query error DataFusion error: Error during planning: Start is bigger than end, but increment is positive: Cannot generate infinite series +query P SELECT * FROM range(TIMESTAMP '2023-01-03T00:00:00', TIMESTAMP '2023-01-01T00:00:00', INTERVAL '1' DAY) +---- -query error DataFusion error: Error during planning: Start is smaller than end, but increment is negative: Cannot generate infinite series +query P SELECT * FROM range(TIMESTAMP '2023-01-01T00:00:00', TIMESTAMP '2023-01-02T00:00:00', INTERVAL '-1' DAY) +---- query error DataFusion error: Error during planning: range function with timestamps requires exactly 3 arguments SELECT * FROM range(TIMESTAMP '2023-01-03T00:00:00', TIMESTAMP '2023-01-01T00:00:00') @@ -489,11 +497,13 @@ query P SELECT * FROM range(DATE '1992-09-01', DATE '1992-10-01', NULL::INTERVAL) ---- -query error DataFusion error: Error during planning: Start is bigger than end, but increment is positive: Cannot generate infinite series +query P SELECT * FROM range(DATE '2023-01-03', DATE '2023-01-01', INTERVAL '1' DAY) +---- -query error DataFusion error: Error during planning: Start is smaller than end, but increment is negative: Cannot generate infinite series +query P SELECT * FROM range(DATE '2023-01-01', DATE '2023-01-02', INTERVAL '-1' DAY) +---- query error DataFusion error: Error during planning: range function with dates requires exactly 3 arguments SELECT * FROM range(DATE '2023-01-01', DATE '2023-01-03') diff --git a/datafusion/sqllogictest/test_files/topk.slt b/datafusion/sqllogictest/test_files/topk.slt index aba468d21fd08..8a1fef0722297 100644 --- a/datafusion/sqllogictest/test_files/topk.slt +++ b/datafusion/sqllogictest/test_files/topk.slt @@ -383,7 +383,7 @@ physical_plan 03)----ProjectionExec: expr=[__common_expr_1@0 as number_plus, number@1 as number, __common_expr_1@0 as other_number_plus, age@2 as age] 04)------ProjectionExec: expr=[CAST(number@0 AS Int64) + 1 as __common_expr_1, number@0 as number, age@1 as age] 05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, maintains_sort_order=true -06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, age], output_ordering=[number@0 DESC], file_type=parquet +06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, age], output_ordering=[number@0 DESC], file_type=parquet, predicate=DynamicFilter [ empty ] # Cleanup statement ok diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part index 0ee60a1e8afb2..b01110b567ca8 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part @@ -71,17 +71,18 @@ physical_plan 04)------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[count(alias1)] 05)--------RepartitionExec: partitioning=Hash([p_brand@0, p_type@1, p_size@2], 4), input_partitions=4 06)----------AggregateExec: mode=Partial, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[count(alias1)] -07)------------AggregateExec: mode=SinglePartitioned, gby=[p_brand@1 as p_brand, p_type@2 as p_type, p_size@3 as p_size, ps_suppkey@0 as alias1], aggr=[] -08)--------------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(ps_suppkey@0, s_suppkey@0)] -09)----------------RepartitionExec: partitioning=Hash([ps_suppkey@0], 4), input_partitions=4 -10)------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_partkey@0, p_partkey@0)], projection=[ps_suppkey@1, p_brand@3, p_type@4, p_size@5] -11)--------------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4 -12)----------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey], file_type=csv, has_header=false -13)--------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 -14)----------------------FilterExec: p_brand@1 != Brand#45 AND p_type@2 NOT LIKE MEDIUM POLISHED% AND p_size@3 IN (SET) ([49, 14, 23, 45, 19, 3, 36, 9]) -15)------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -16)--------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_type, p_size], file_type=csv, has_header=false -17)----------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 -18)------------------FilterExec: s_comment@1 LIKE %Customer%Complaints%, projection=[s_suppkey@0] -19)--------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -20)----------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_comment], file_type=csv, has_header=false +07)------------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size, alias1@3 as alias1], aggr=[] +08)--------------RepartitionExec: partitioning=Hash([p_brand@0, p_type@1, p_size@2, alias1@3], 4), input_partitions=4 +09)----------------AggregateExec: mode=Partial, gby=[p_brand@1 as p_brand, p_type@2 as p_type, p_size@3 as p_size, ps_suppkey@0 as alias1], aggr=[] +10)------------------HashJoinExec: mode=CollectLeft, join_type=LeftAnti, on=[(ps_suppkey@0, s_suppkey@0)] +11)--------------------CoalescePartitionsExec +12)----------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_partkey@0, p_partkey@0)], projection=[ps_suppkey@1, p_brand@3, p_type@4, p_size@5] +13)------------------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4 +14)--------------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey], file_type=csv, has_header=false +15)------------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 +16)--------------------------FilterExec: p_brand@1 != Brand#45 AND p_type@2 NOT LIKE MEDIUM POLISHED% AND p_size@3 IN (SET) ([49, 14, 23, 45, 19, 3, 36, 9]) +17)----------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +18)------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_type, p_size], file_type=csv, has_header=false +19)--------------------FilterExec: s_comment@1 LIKE %Customer%Complaints%, projection=[s_suppkey@0] +20)----------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +21)------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_comment], file_type=csv, has_header=false diff --git a/datafusion/sqllogictest/test_files/truncate.slt b/datafusion/sqllogictest/test_files/truncate.slt new file mode 100644 index 0000000000000..ad3ccbb1a7cf4 --- /dev/null +++ b/datafusion/sqllogictest/test_files/truncate.slt @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## Truncate Tests +########## + +statement ok +create table t1(a int, b varchar, c double, d int); + +statement ok +insert into t1 values (1, 'abc', 3.14, 4), (2, 'def', 2.71, 5); + +# Truncate all rows from table +query TT +explain truncate table t1; +---- +logical_plan +01)Dml: op=[Truncate] table=[t1] +02)--EmptyRelation: rows=0 +physical_plan_error +01)TRUNCATE operation on table 't1' +02)caused by +03)This feature is not implemented: TRUNCATE not supported for Base table + +# Test TRUNCATE with fully qualified table name +statement ok +create schema test_schema; + +statement ok +create table test_schema.t5(a int); + +query TT +explain truncate table test_schema.t5; +---- +logical_plan +01)Dml: op=[Truncate] table=[test_schema.t5] +02)--EmptyRelation: rows=0 +physical_plan_error +01)TRUNCATE operation on table 'test_schema.t5' +02)caused by +03)This feature is not implemented: TRUNCATE not supported for Base table + +# Test TRUNCATE with CASCADE option +statement error TRUNCATE with CASCADE/RESTRICT is not supported +TRUNCATE TABLE t1 CASCADE; + +# Test TRUNCATE with multiple tables +statement error TRUNCATE with multiple tables is not supported +TRUNCATE TABLE t1, t2; + +statement error TRUNCATE with PARTITION is not supported +TRUNCATE TABLE t1 PARTITION (p1); + +statement error TRUNCATE with ONLY is not supported +TRUNCATE ONLY t1; + +statement error TRUNCATE with RESTART/CONTINUE IDENTITY is not supported +TRUNCATE TABLE t1 RESTART IDENTITY; + +# Test TRUNCATE without TABLE keyword +query TT +explain truncate t1; +---- +logical_plan +01)Dml: op=[Truncate] table=[t1] +02)--EmptyRelation: rows=0 +physical_plan_error +01)TRUNCATE operation on table 't1' +02)caused by +03)This feature is not implemented: TRUNCATE not supported for Base table diff --git a/datafusion/sqllogictest/test_files/type_coercion.slt b/datafusion/sqllogictest/test_files/type_coercion.slt index e3baa8fedcf63..8ab5b63e697d6 100644 --- a/datafusion/sqllogictest/test_files/type_coercion.slt +++ b/datafusion/sqllogictest/test_files/type_coercion.slt @@ -254,3 +254,30 @@ DROP TABLE orders; ######################################## ## Test type coercion with UNIONs end ## ######################################## + +# https://github.com/apache/datafusion/issues/15661 +# LIKE is a string pattern matching operator and is not supported for nested types. + +statement ok +CREATE TABLE t0(v0 BIGINT, v1 STRING, v2 BOOLEAN); + +statement ok +INSERT INTO t0(v0, v2) VALUES (123, true); + +query error There isn't a common type to coerce .* in .* expression +SELECT true FROM t0 WHERE ((REGEXP_MATCH(t0.v1, t0.v1)) NOT LIKE (REGEXP_MATCH(t0.v1, t0.v1, 'jH'))); + +query error There isn't a common type to coerce .* in .* expression +SELECT true FROM t0 WHERE (REGEXP_MATCH(t0.v1, t0.v1)) NOT LIKE []; + +query error There isn't a common type to coerce .* in .* expression +SELECT true FROM t0 WHERE (REGEXP_MATCH(t0.v1, t0.v1)) LIKE []; + +query error There isn't a common type to coerce .* in .* expression +SELECT true FROM t0 WHERE (REGEXP_MATCH(t0.v1, t0.v1)) ILIKE []; + +query error There isn't a common type to coerce .* in .* expression +SELECT true FROM t0 WHERE (REGEXP_MATCH(t0.v1, t0.v1)) NOT ILIKE []; + +statement ok +DROP TABLE t0; diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index b79b6d2fe5e9e..d858d0ae3ea4e 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -494,22 +494,25 @@ physical_plan 01)CoalescePartitionsExec: fetch=3 02)--UnionExec 03)----ProjectionExec: expr=[count(Int64(1))@0 as cnt] -04)------AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] -05)--------CoalescePartitionsExec -06)----------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] -07)------------ProjectionExec: expr=[] -08)--------------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[] -09)----------------RepartitionExec: partitioning=Hash([c1@0], 4), input_partitions=4 -10)------------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[] -11)--------------------FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434, projection=[c1@0] -12)----------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -13)------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c13], file_type=csv, has_header=true -14)----ProjectionExec: expr=[1 as cnt] -15)------PlaceholderRowExec -16)----ProjectionExec: expr=[lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as cnt] -17)------BoundedWindowAggExec: wdw=[lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Field { "lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING": nullable Int64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], mode=[Sorted] -18)--------ProjectionExec: expr=[1 as c1] -19)----------PlaceholderRowExec +04)------GlobalLimitExec: skip=0, fetch=3 +05)--------AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] +06)----------CoalescePartitionsExec +07)------------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] +08)--------------ProjectionExec: expr=[] +09)----------------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[] +10)------------------RepartitionExec: partitioning=Hash([c1@0], 4), input_partitions=4 +11)--------------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[] +12)----------------------FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434, projection=[c1@0] +13)------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +14)--------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c13], file_type=csv, has_header=true +15)----ProjectionExec: expr=[1 as cnt] +16)------GlobalLimitExec: skip=0, fetch=3 +17)--------PlaceholderRowExec +18)----ProjectionExec: expr=[lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as cnt] +19)------GlobalLimitExec: skip=0, fetch=3 +20)--------BoundedWindowAggExec: wdw=[lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Field { "lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING": nullable Int64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], mode=[Sorted] +21)----------ProjectionExec: expr=[1 as c1] +22)------------PlaceholderRowExec ######## diff --git a/datafusion/sqllogictest/test_files/unnest.slt b/datafusion/sqllogictest/test_files/unnest.slt index f939cd0154a82..73aeb6c99d0db 100644 --- a/datafusion/sqllogictest/test_files/unnest.slt +++ b/datafusion/sqllogictest/test_files/unnest.slt @@ -666,15 +666,15 @@ explain select unnest(unnest(unnest(column3)['c1'])), column3 from recursive_unn logical_plan 01)Projection: __unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1],depth=2) AS UNNEST(UNNEST(UNNEST(recursive_unnest_table.column3)[c1])), recursive_unnest_table.column3 02)--Unnest: lists[__unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1])|depth=2] structs[] -03)----Projection: get_field(__unnest_placeholder(recursive_unnest_table.column3,depth=1) AS UNNEST(recursive_unnest_table.column3), Utf8("c1")) AS __unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1]), recursive_unnest_table.column3 +03)----Projection: get_field(__unnest_placeholder(recursive_unnest_table.column3,depth=1), Utf8("c1")) AS __unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1]), recursive_unnest_table.column3 04)------Unnest: lists[__unnest_placeholder(recursive_unnest_table.column3)|depth=1] structs[] 05)--------Projection: recursive_unnest_table.column3 AS __unnest_placeholder(recursive_unnest_table.column3), recursive_unnest_table.column3 06)----------TableScan: recursive_unnest_table projection=[column3] physical_plan 01)ProjectionExec: expr=[__unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1],depth=2)@0 as UNNEST(UNNEST(UNNEST(recursive_unnest_table.column3)[c1])), column3@1 as column3] 02)--UnnestExec -03)----ProjectionExec: expr=[get_field(__unnest_placeholder(recursive_unnest_table.column3,depth=1)@0, c1) as __unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1]), column3@1 as column3] -04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +04)------ProjectionExec: expr=[get_field(__unnest_placeholder(recursive_unnest_table.column3,depth=1)@0, c1) as __unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1]), column3@1 as column3] 05)--------UnnestExec 06)----------ProjectionExec: expr=[column3@0 as __unnest_placeholder(recursive_unnest_table.column3), column3@0 as column3] 07)------------DataSourceExec: partitions=1, partition_sizes=[1] diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index 8bfec86497ef0..753afc08d4f60 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -41,12 +41,13 @@ datafusion = { workspace = true, features = ["sql"] } half = { workspace = true } itertools = { workspace = true } object_store = { workspace = true } -pbjson-types = { workspace = true } +# We need to match the version in substrait, so we don't use the workspace version here +pbjson-types = { version = "0.8.0" } prost = { workspace = true } substrait = { version = "0.62", features = ["serde"] } url = { workspace = true } tokio = { workspace = true, features = ["fs"] } -uuid = { version = "1.19.0", features = ["v4"] } +uuid = { workspace = true, features = ["v4"] } [dev-dependencies] datafusion = { workspace = true, features = ["nested_expressions", "unicode_expressions"] } diff --git a/datafusion/substrait/src/lib.rs b/datafusion/substrait/src/lib.rs index 407408aaa71b3..0819fd3a592f9 100644 --- a/datafusion/substrait/src/lib.rs +++ b/datafusion/substrait/src/lib.rs @@ -23,7 +23,6 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] -#![deny(clippy::allow_attributes)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] //! Serialize / Deserialize DataFusion Plans to [Substrait.io] diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/literal.rs b/datafusion/substrait/src/logical_plan/consumer/expr/literal.rs index 112f1ea374b32..ad38b6addee0b 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/literal.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/literal.rs @@ -102,6 +102,7 @@ pub(crate) fn from_substrait_literal( }, Some(LiteralType::Fp32(f)) => ScalarValue::Float32(Some(*f)), Some(LiteralType::Fp64(f)) => ScalarValue::Float64(Some(*f)), + #[expect(deprecated)] Some(LiteralType::Timestamp(t)) => { // Kept for backwards compatibility, new plans should use PrecisionTimestamp(Tz) instead #[expect(deprecated)] @@ -385,6 +386,7 @@ pub(crate) fn from_substrait_literal( use interval_day_to_second::PrecisionMode; // DF only supports millisecond precision, so for any more granular type we lose precision let milliseconds = match precision_mode { + #[expect(deprecated)] Some(PrecisionMode::Microseconds(ms)) => ms / 1000, None => { if *subseconds != 0 { diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/mod.rs b/datafusion/substrait/src/logical_plan/consumer/expr/mod.rs index 6c2bc652bb19c..71e3b9e96e153 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/mod.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/mod.rs @@ -88,6 +88,7 @@ pub async fn from_substrait_rex( consumer.consume_subquery(expr.as_ref(), input_schema).await } RexType::Nested(expr) => consumer.consume_nested(expr, input_schema).await, + #[expect(deprecated)] RexType::Enum(expr) => consumer.consume_enum(expr, input_schema).await, RexType::DynamicParameter(expr) => { consumer.consume_dynamic_parameter(expr, input_schema).await diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs b/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs index 15fe7947a1e1f..61a381e9eb407 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs @@ -16,12 +16,13 @@ // under the License. use crate::logical_plan::consumer::SubstraitConsumer; -use datafusion::common::{DFSchema, Spans, substrait_err}; -use datafusion::logical_expr::expr::{Exists, InSubquery}; -use datafusion::logical_expr::{Expr, Subquery}; +use datafusion::common::{DFSchema, Spans, substrait_datafusion_err, substrait_err}; +use datafusion::logical_expr::expr::{Exists, InSubquery, SetComparison, SetQuantifier}; +use datafusion::logical_expr::{Expr, Operator, Subquery}; use std::sync::Arc; use substrait::proto::expression as substrait_expression; use substrait::proto::expression::subquery::SubqueryType; +use substrait::proto::expression::subquery::set_comparison::{ComparisonOp, ReductionOp}; use substrait::proto::expression::subquery::set_predicate::PredicateOp; pub async fn from_subquery( @@ -96,8 +97,53 @@ pub async fn from_subquery( ), } } - other_type => { - substrait_err!("Subquery type {other_type:?} not implemented") + SubqueryType::SetComparison(comparison) => { + let left = comparison.left.as_ref().ok_or_else(|| { + substrait_datafusion_err!("SetComparison requires a left expression") + })?; + let right = comparison.right.as_ref().ok_or_else(|| { + substrait_datafusion_err!("SetComparison requires a right relation") + })?; + let reduction_op = match ReductionOp::try_from(comparison.reduction_op) { + Ok(ReductionOp::Any) => SetQuantifier::Any, + Ok(ReductionOp::All) => SetQuantifier::All, + _ => { + return substrait_err!( + "Unsupported reduction op for SetComparison: {}", + comparison.reduction_op + ); + } + }; + let comparison_op = match ComparisonOp::try_from(comparison.comparison_op) + { + Ok(ComparisonOp::Eq) => Operator::Eq, + Ok(ComparisonOp::Ne) => Operator::NotEq, + Ok(ComparisonOp::Lt) => Operator::Lt, + Ok(ComparisonOp::Gt) => Operator::Gt, + Ok(ComparisonOp::Le) => Operator::LtEq, + Ok(ComparisonOp::Ge) => Operator::GtEq, + _ => { + return substrait_err!( + "Unsupported comparison op for SetComparison: {}", + comparison.comparison_op + ); + } + }; + + let left_expr = consumer.consume_expression(left, input_schema).await?; + let plan = consumer.consume_rel(right).await?; + let outer_ref_columns = plan.all_out_ref_exprs(); + + Ok(Expr::SetComparison(SetComparison::new( + Box::new(left_expr), + Subquery { + subquery: Arc::new(plan), + outer_ref_columns, + spans: Spans::new(), + }, + comparison_op, + reduction_op, + ))) } }, None => { diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/fetch_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/fetch_rel.rs index bd6d94736e265..12a8a77199b1a 100644 --- a/datafusion/substrait/src/logical_plan/consumer/rel/fetch_rel.rs +++ b/datafusion/substrait/src/logical_plan/consumer/rel/fetch_rel.rs @@ -30,6 +30,7 @@ pub async fn from_fetch_rel( let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); let empty_schema = DFSchemaRef::new(DFSchema::empty()); let offset = match &fetch.offset_mode { + #[expect(deprecated)] Some(fetch_rel::OffsetMode::Offset(offset)) => Some(lit(*offset)), Some(fetch_rel::OffsetMode::OffsetExpr(expr)) => { Some(consumer.consume_expression(expr, &empty_schema).await?) @@ -37,6 +38,7 @@ pub async fn from_fetch_rel( None => None, }; let count = match &fetch.count_mode { + #[expect(deprecated)] Some(fetch_rel::CountMode::Count(count)) => { // -1 means that ALL records should be returned, equivalent to None (*count != -1).then(|| lit(*count)) diff --git a/datafusion/substrait/src/logical_plan/consumer/types.rs b/datafusion/substrait/src/logical_plan/consumer/types.rs index eb2cc967ca236..9ef7a0dd46b86 100644 --- a/datafusion/substrait/src/logical_plan/consumer/types.rs +++ b/datafusion/substrait/src/logical_plan/consumer/types.rs @@ -88,6 +88,7 @@ pub fn from_substrait_type( }, r#type::Kind::Fp32(_) => Ok(DataType::Float32), r#type::Kind::Fp64(_) => Ok(DataType::Float64), + #[expect(deprecated)] r#type::Kind::Timestamp(ts) => { // Kept for backwards compatibility, new plans should use PrecisionTimestamp(Tz) instead #[expect(deprecated)] diff --git a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs index 5057564d370cf..74b1a65215376 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs @@ -141,6 +141,7 @@ pub fn to_substrait_rex( Expr::InList(expr) => producer.handle_in_list(expr, schema), Expr::Exists(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), Expr::InSubquery(expr) => producer.handle_in_subquery(expr, schema), + Expr::SetComparison(expr) => producer.handle_set_comparison(expr, schema), Expr::ScalarSubquery(expr) => { not_impl_err!("Cannot convert {expr:?} to Substrait") } diff --git a/datafusion/substrait/src/logical_plan/producer/expr/subquery.rs b/datafusion/substrait/src/logical_plan/producer/expr/subquery.rs index f2e6ff551223c..e5b9241c10104 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/subquery.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/subquery.rs @@ -16,9 +16,11 @@ // under the License. use crate::logical_plan::producer::SubstraitProducer; -use datafusion::common::DFSchemaRef; -use datafusion::logical_expr::expr::InSubquery; +use datafusion::common::{DFSchemaRef, substrait_err}; +use datafusion::logical_expr::Operator; +use datafusion::logical_expr::expr::{InSubquery, SetComparison, SetQuantifier}; use substrait::proto::expression::subquery::InPredicate; +use substrait::proto::expression::subquery::set_comparison::{ComparisonOp, ReductionOp}; use substrait::proto::expression::{RexType, ScalarFunction}; use substrait::proto::function_argument::ArgType; use substrait::proto::{Expression, FunctionArgument}; @@ -70,3 +72,53 @@ pub fn from_in_subquery( Ok(substrait_subquery) } } + +fn comparison_op_to_proto(op: &Operator) -> datafusion::common::Result { + match op { + Operator::Eq => Ok(ComparisonOp::Eq), + Operator::NotEq => Ok(ComparisonOp::Ne), + Operator::Lt => Ok(ComparisonOp::Lt), + Operator::Gt => Ok(ComparisonOp::Gt), + Operator::LtEq => Ok(ComparisonOp::Le), + Operator::GtEq => Ok(ComparisonOp::Ge), + _ => substrait_err!("Unsupported operator {op:?} for SetComparison subquery"), + } +} + +fn reduction_op_to_proto( + quantifier: &SetQuantifier, +) -> datafusion::common::Result { + match quantifier { + SetQuantifier::Any => Ok(ReductionOp::Any), + SetQuantifier::All => Ok(ReductionOp::All), + } +} + +pub fn from_set_comparison( + producer: &mut impl SubstraitProducer, + set_comparison: &SetComparison, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let comparison_op = comparison_op_to_proto(&set_comparison.op)? as i32; + let reduction_op = reduction_op_to_proto(&set_comparison.quantifier)? as i32; + let left = producer.handle_expr(set_comparison.expr.as_ref(), schema)?; + let subquery_plan = + producer.handle_plan(set_comparison.subquery.subquery.as_ref())?; + + Ok(Expression { + rex_type: Some(RexType::Subquery(Box::new( + substrait::proto::expression::Subquery { + subquery_type: Some( + substrait::proto::expression::subquery::SubqueryType::SetComparison( + Box::new(substrait::proto::expression::subquery::SetComparison { + reduction_op, + comparison_op, + left: Some(Box::new(left)), + right: Some(subquery_plan), + }), + ), + ), + }, + ))), + }) +} diff --git a/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs index ffc920ffe609e..c7518bd04e4a1 100644 --- a/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs +++ b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs @@ -20,14 +20,17 @@ use crate::logical_plan::producer::{ from_aggregate, from_aggregate_function, from_alias, from_between, from_binary_expr, from_case, from_cast, from_column, from_distinct, from_empty_relation, from_filter, from_in_list, from_in_subquery, from_join, from_like, from_limit, from_literal, - from_projection, from_repartition, from_scalar_function, from_sort, - from_subquery_alias, from_table_scan, from_try_cast, from_unary_expr, from_union, - from_values, from_window, from_window_function, to_substrait_rel, to_substrait_rex, + from_projection, from_repartition, from_scalar_function, from_set_comparison, + from_sort, from_subquery_alias, from_table_scan, from_try_cast, from_unary_expr, + from_union, from_values, from_window, from_window_function, to_substrait_rel, + to_substrait_rex, }; use datafusion::common::{Column, DFSchemaRef, ScalarValue, substrait_err}; use datafusion::execution::SessionState; use datafusion::execution::registry::SerializerRegistry; -use datafusion::logical_expr::expr::{Alias, InList, InSubquery, WindowFunction}; +use datafusion::logical_expr::expr::{ + Alias, InList, InSubquery, SetComparison, WindowFunction, +}; use datafusion::logical_expr::{ Aggregate, Between, BinaryExpr, Case, Cast, Distinct, EmptyRelation, Expr, Extension, Filter, Join, Like, Limit, LogicalPlan, Projection, Repartition, Sort, SubqueryAlias, @@ -361,6 +364,14 @@ pub trait SubstraitProducer: Send + Sync + Sized { ) -> datafusion::common::Result { from_in_subquery(self, in_subquery, schema) } + + fn handle_set_comparison( + &mut self, + set_comparison: &SetComparison, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_set_comparison(self, set_comparison, schema) + } } pub struct DefaultSubstraitProducer<'a> { diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index 194098cf060e3..2d814654ba68c 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -78,17 +78,17 @@ mod tests { Aggregate: groupBy=[[]], aggr=[[min(PARTSUPP.PS_SUPPLYCOST)]] Projection: PARTSUPP.PS_SUPPLYCOST Filter: PARTSUPP.PS_PARTKEY = PARTSUPP.PS_PARTKEY AND SUPPLIER.S_SUPPKEY = PARTSUPP.PS_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_REGIONKEY = REGION.R_REGIONKEY AND REGION.R_NAME = Utf8("EUROPE") - Cross Join: - Cross Join: - Cross Join: + Cross Join: + Cross Join: + Cross Join: TableScan: PARTSUPP TableScan: SUPPLIER TableScan: NATION TableScan: REGION - Cross Join: - Cross Join: - Cross Join: - Cross Join: + Cross Join: + Cross Join: + Cross Join: + Cross Join: TableScan: PART TableScan: SUPPLIER TableScan: PARTSUPP @@ -112,8 +112,8 @@ mod tests { Aggregate: groupBy=[[LINEITEM.L_ORDERKEY, ORDERS.O_ORDERDATE, ORDERS.O_SHIPPRIORITY]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]] Projection: LINEITEM.L_ORDERKEY, ORDERS.O_ORDERDATE, ORDERS.O_SHIPPRIORITY, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) Filter: CUSTOMER.C_MKTSEGMENT = Utf8("BUILDING") AND CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY AND LINEITEM.L_ORDERKEY = ORDERS.O_ORDERKEY AND ORDERS.O_ORDERDATE < CAST(Utf8("1995-03-15") AS Date32) AND LINEITEM.L_SHIPDATE > CAST(Utf8("1995-03-15") AS Date32) - Cross Join: - Cross Join: + Cross Join: + Cross Join: TableScan: LINEITEM TableScan: CUSTOMER TableScan: ORDERS @@ -153,11 +153,11 @@ mod tests { Aggregate: groupBy=[[NATION.N_NAME]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]] Projection: NATION.N_NAME, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) Filter: CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY AND LINEITEM.L_ORDERKEY = ORDERS.O_ORDERKEY AND LINEITEM.L_SUPPKEY = SUPPLIER.S_SUPPKEY AND CUSTOMER.C_NATIONKEY = SUPPLIER.S_NATIONKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_REGIONKEY = REGION.R_REGIONKEY AND REGION.R_NAME = Utf8("ASIA") AND ORDERS.O_ORDERDATE >= CAST(Utf8("1994-01-01") AS Date32) AND ORDERS.O_ORDERDATE < CAST(Utf8("1995-01-01") AS Date32) - Cross Join: - Cross Join: - Cross Join: - Cross Join: - Cross Join: + Cross Join: + Cross Join: + Cross Join: + Cross Join: + Cross Join: TableScan: CUSTOMER TableScan: ORDERS TableScan: LINEITEM @@ -221,9 +221,9 @@ mod tests { Aggregate: groupBy=[[CUSTOMER.C_CUSTKEY, CUSTOMER.C_NAME, CUSTOMER.C_ACCTBAL, CUSTOMER.C_PHONE, NATION.N_NAME, CUSTOMER.C_ADDRESS, CUSTOMER.C_COMMENT]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]] Projection: CUSTOMER.C_CUSTKEY, CUSTOMER.C_NAME, CUSTOMER.C_ACCTBAL, CUSTOMER.C_PHONE, NATION.N_NAME, CUSTOMER.C_ADDRESS, CUSTOMER.C_COMMENT, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) Filter: CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY AND LINEITEM.L_ORDERKEY = ORDERS.O_ORDERKEY AND ORDERS.O_ORDERDATE >= CAST(Utf8("1993-10-01") AS Date32) AND ORDERS.O_ORDERDATE < CAST(Utf8("1994-01-01") AS Date32) AND LINEITEM.L_RETURNFLAG = Utf8("R") AND CUSTOMER.C_NATIONKEY = NATION.N_NATIONKEY - Cross Join: - Cross Join: - Cross Join: + Cross Join: + Cross Join: + Cross Join: TableScan: CUSTOMER TableScan: ORDERS TableScan: LINEITEM @@ -247,16 +247,16 @@ mod tests { Aggregate: groupBy=[[]], aggr=[[sum(PARTSUPP.PS_SUPPLYCOST * PARTSUPP.PS_AVAILQTY)]] Projection: PARTSUPP.PS_SUPPLYCOST * CAST(PARTSUPP.PS_AVAILQTY AS Decimal128(19, 0)) Filter: PARTSUPP.PS_SUPPKEY = SUPPLIER.S_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8("JAPAN") - Cross Join: - Cross Join: + Cross Join: + Cross Join: TableScan: PARTSUPP TableScan: SUPPLIER TableScan: NATION Aggregate: groupBy=[[PARTSUPP.PS_PARTKEY]], aggr=[[sum(PARTSUPP.PS_SUPPLYCOST * PARTSUPP.PS_AVAILQTY)]] Projection: PARTSUPP.PS_PARTKEY, PARTSUPP.PS_SUPPLYCOST * CAST(PARTSUPP.PS_AVAILQTY AS Decimal128(19, 0)) Filter: PARTSUPP.PS_SUPPKEY = SUPPLIER.S_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8("JAPAN") - Cross Join: - Cross Join: + Cross Join: + Cross Join: TableScan: PARTSUPP TableScan: SUPPLIER TableScan: NATION @@ -276,7 +276,7 @@ mod tests { Aggregate: groupBy=[[LINEITEM.L_SHIPMODE]], aggr=[[sum(CASE WHEN ORDERS.O_ORDERPRIORITY = Utf8("1-URGENT") OR ORDERS.O_ORDERPRIORITY = Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END), sum(CASE WHEN ORDERS.O_ORDERPRIORITY != Utf8("1-URGENT") AND ORDERS.O_ORDERPRIORITY != Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END)]] Projection: LINEITEM.L_SHIPMODE, CASE WHEN ORDERS.O_ORDERPRIORITY = Utf8("1-URGENT") OR ORDERS.O_ORDERPRIORITY = Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END, CASE WHEN ORDERS.O_ORDERPRIORITY != Utf8("1-URGENT") AND ORDERS.O_ORDERPRIORITY != Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END Filter: ORDERS.O_ORDERKEY = LINEITEM.L_ORDERKEY AND (LINEITEM.L_SHIPMODE = CAST(Utf8("MAIL") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8("SHIP") AS Utf8)) AND LINEITEM.L_COMMITDATE < LINEITEM.L_RECEIPTDATE AND LINEITEM.L_SHIPDATE < LINEITEM.L_COMMITDATE AND LINEITEM.L_RECEIPTDATE >= CAST(Utf8("1994-01-01") AS Date32) AND LINEITEM.L_RECEIPTDATE < CAST(Utf8("1995-01-01") AS Date32) - Cross Join: + Cross Join: TableScan: ORDERS TableScan: LINEITEM "# @@ -314,7 +314,7 @@ mod tests { Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN PART.P_TYPE LIKE Utf8("PROMO%") THEN LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT ELSE Decimal128(Some(0),19,4) END), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]] Projection: CASE WHEN PART.P_TYPE LIKE CAST(Utf8("PROMO%") AS Utf8) THEN LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) ELSE Decimal128(Some(0),19,4) END, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) Filter: LINEITEM.L_PARTKEY = PART.P_PARTKEY AND LINEITEM.L_SHIPDATE >= Date32("1995-09-01") AND LINEITEM.L_SHIPDATE < CAST(Utf8("1995-10-01") AS Date32) - Cross Join: + Cross Join: TableScan: LINEITEM TableScan: PART "# @@ -345,7 +345,7 @@ mod tests { Projection: SUPPLIER.S_SUPPKEY Filter: SUPPLIER.S_COMMENT LIKE CAST(Utf8("%Customer%Complaints%") AS Utf8) TableScan: SUPPLIER - Cross Join: + Cross Join: TableScan: PARTSUPP TableScan: PART "# @@ -379,8 +379,8 @@ mod tests { Aggregate: groupBy=[[LINEITEM.L_ORDERKEY]], aggr=[[sum(LINEITEM.L_QUANTITY)]] Projection: LINEITEM.L_ORDERKEY, LINEITEM.L_QUANTITY TableScan: LINEITEM - Cross Join: - Cross Join: + Cross Join: + Cross Join: TableScan: CUSTOMER TableScan: ORDERS TableScan: LINEITEM @@ -397,7 +397,7 @@ mod tests { Aggregate: groupBy=[[]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS REVENUE]] Projection: LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) Filter: PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND = Utf8("Brand#12") AND (PART.P_CONTAINER = CAST(Utf8("SM CASE") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("SM BOX") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("SM PACK") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("SM PKG") AS Utf8)) AND LINEITEM.L_QUANTITY >= CAST(Int32(1) AS Decimal128(15, 2)) AND LINEITEM.L_QUANTITY <= CAST(Int32(1) + Int32(10) AS Decimal128(15, 2)) AND PART.P_SIZE >= Int32(1) AND PART.P_SIZE <= Int32(5) AND (LINEITEM.L_SHIPMODE = CAST(Utf8("AIR") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8("AIR REG") AS Utf8)) AND LINEITEM.L_SHIPINSTRUCT = Utf8("DELIVER IN PERSON") OR PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND = Utf8("Brand#23") AND (PART.P_CONTAINER = CAST(Utf8("MED BAG") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("MED BOX") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("MED PKG") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("MED PACK") AS Utf8)) AND LINEITEM.L_QUANTITY >= CAST(Int32(10) AS Decimal128(15, 2)) AND LINEITEM.L_QUANTITY <= CAST(Int32(10) + Int32(10) AS Decimal128(15, 2)) AND PART.P_SIZE >= Int32(1) AND PART.P_SIZE <= Int32(10) AND (LINEITEM.L_SHIPMODE = CAST(Utf8("AIR") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8("AIR REG") AS Utf8)) AND LINEITEM.L_SHIPINSTRUCT = Utf8("DELIVER IN PERSON") OR PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND = Utf8("Brand#34") AND (PART.P_CONTAINER = CAST(Utf8("LG CASE") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("LG BOX") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("LG PACK") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("LG PKG") AS Utf8)) AND LINEITEM.L_QUANTITY >= CAST(Int32(20) AS Decimal128(15, 2)) AND LINEITEM.L_QUANTITY <= CAST(Int32(20) + Int32(10) AS Decimal128(15, 2)) AND PART.P_SIZE >= Int32(1) AND PART.P_SIZE <= Int32(15) AND (LINEITEM.L_SHIPMODE = CAST(Utf8("AIR") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8("AIR REG") AS Utf8)) AND LINEITEM.L_SHIPINSTRUCT = Utf8("DELIVER IN PERSON") - Cross Join: + Cross Join: TableScan: LINEITEM TableScan: PART "# @@ -428,7 +428,7 @@ mod tests { Filter: LINEITEM.L_PARTKEY = LINEITEM.L_ORDERKEY AND LINEITEM.L_SUPPKEY = LINEITEM.L_PARTKEY AND LINEITEM.L_SHIPDATE >= CAST(Utf8("1994-01-01") AS Date32) AND LINEITEM.L_SHIPDATE < CAST(Utf8("1995-01-01") AS Date32) TableScan: LINEITEM TableScan: PARTSUPP - Cross Join: + Cross Join: TableScan: SUPPLIER TableScan: NATION "# @@ -454,9 +454,9 @@ mod tests { Subquery: Filter: LINEITEM.L_ORDERKEY = LINEITEM.L_TAX AND LINEITEM.L_SUPPKEY != LINEITEM.L_LINESTATUS AND LINEITEM.L_RECEIPTDATE > LINEITEM.L_COMMITDATE TableScan: LINEITEM - Cross Join: - Cross Join: - Cross Join: + Cross Join: + Cross Join: + Cross Join: TableScan: SUPPLIER TableScan: LINEITEM TableScan: ORDERS diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs index 5ebacaf5336d4..5a72f9e64636b 100644 --- a/datafusion/substrait/tests/cases/logical_plans.rs +++ b/datafusion/substrait/tests/cases/logical_plans.rs @@ -20,6 +20,9 @@ #[cfg(test)] mod tests { use crate::utils::test::{add_plan_schemas_to_ctx, read_json}; + use datafusion::common::test_util::format_batches; + use std::collections::HashSet; + use datafusion::common::Result; use datafusion::dataframe::DataFrame; use datafusion::prelude::SessionContext; @@ -229,4 +232,49 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn duplicate_name_in_union() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/duplicate_name_in_union.substrait.json"); + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; + + assert_snapshot!( + plan, + @r" + Projection: foo AS col1, bar AS col2 + Union + Projection: foo, bar + Values: (Int64(100), Int64(200)) + Projection: x, foo + Values: (Int32(300), Int64(400)) + " + ); + + // Trigger execution to ensure plan validity + let results = DataFrame::new(ctx.state(), plan).collect().await?; + + assert_snapshot!( + format_batches(&results)?, + @r" + +------+------+ + | col1 | col2 | + +------+------+ + | 100 | 200 | + | 300 | 400 | + +------+------+ + ", + ); + + // also verify that the output schema has unique field names + let schema = results[0].schema(); + for batch in &results { + assert_eq!(schema, batch.schema()); + } + let field_names: HashSet<_> = schema.fields().iter().map(|f| f.name()).collect(); + assert_eq!(field_names.len(), schema.fields().len()); + + Ok(()) + } } diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 98b35bf082ec4..386ef9dc55b08 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -29,14 +29,15 @@ use std::mem::size_of_val; use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}; use datafusion::common::tree_node::Transformed; -use datafusion::common::{DFSchema, DFSchemaRef, not_impl_err, plan_err}; +use datafusion::common::{DFSchema, DFSchemaRef, Spans, not_impl_err, plan_err}; use datafusion::error::Result; use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::session_state::SessionStateBuilder; +use datafusion::logical_expr::expr::{SetComparison, SetQuantifier}; use datafusion::logical_expr::{ - EmptyRelation, Extension, InvariantLevel, LogicalPlan, PartitionEvaluator, - Repartition, UserDefinedLogicalNode, Values, Volatility, + EmptyRelation, Extension, InvariantLevel, LogicalPlan, Operator, PartitionEvaluator, + Repartition, Subquery, UserDefinedLogicalNode, Values, Volatility, }; use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST; use datafusion::prelude::*; @@ -689,6 +690,29 @@ async fn roundtrip_exists_filter() -> Result<()> { Ok(()) } +// assemble logical plan manually to ensure SetComparison expr is present (not rewrite away) +#[tokio::test] +async fn roundtrip_set_comparison_any_substrait() -> Result<()> { + let ctx = create_context().await?; + let plan = build_set_comparison_plan(&ctx, SetQuantifier::Any, Operator::Gt).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let roundtrip_plan = from_substrait_plan(&ctx.state(), &proto).await?; + assert_set_comparison_predicate(&roundtrip_plan, Operator::Gt, SetQuantifier::Any); + Ok(()) +} + +// assemble logical plan manually to ensure SetComparison expr is present (not rewrite away) +#[tokio::test] +async fn roundtrip_set_comparison_all_substrait() -> Result<()> { + let ctx = create_context().await?; + let plan = + build_set_comparison_plan(&ctx, SetQuantifier::All, Operator::NotEq).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let roundtrip_plan = from_substrait_plan(&ctx.state(), &proto).await?; + assert_set_comparison_predicate(&roundtrip_plan, Operator::NotEq, SetQuantifier::All); + Ok(()) +} + #[tokio::test] async fn roundtrip_not_exists_filter_left_anti_join() -> Result<()> { let plan = generate_plan_from_sql( @@ -1353,7 +1377,7 @@ async fn roundtrip_literal_named_struct() -> Result<()> { assert_snapshot!( plan, @r#" - Projection: Struct({int_field:1,boolean_field:true,string_field:}) AS named_struct(Utf8("int_field"),Int64(1),Utf8("boolean_field"),Boolean(true),Utf8("string_field"),NULL) + Projection: CAST(Struct({c0:1,c1:true,c2:}) AS Struct("int_field": Int64, "boolean_field": Boolean, "string_field": Utf8View)) AS named_struct(Utf8("int_field"),Int64(1),Utf8("boolean_field"),Boolean(true),Utf8("string_field"),NULL) TableScan: data projection=[] "# ); @@ -1373,10 +1397,10 @@ async fn roundtrip_literal_renamed_struct() -> Result<()> { assert_snapshot!( plan, - @r" - Projection: Struct({int_field:1}) AS Struct({c0:1}) + @r#" + Projection: CAST(Struct({c0:1}) AS Struct("int_field": Int32)) TableScan: data projection=[] - " + "# ); Ok(()) } @@ -1865,6 +1889,56 @@ async fn assert_substrait_sql(substrait_plan: Plan, sql: &str) -> Result<()> { Ok(()) } +async fn build_set_comparison_plan( + ctx: &SessionContext, + quantifier: SetQuantifier, + op: Operator, +) -> Result { + let base_scan = ctx.table("data").await?.into_unoptimized_plan(); + let subquery_scan = ctx.table("data2").await?.into_unoptimized_plan(); + let subquery_plan = LogicalPlanBuilder::from(subquery_scan) + .project(vec![col("data2.a")])? + .build()?; + let predicate = Expr::SetComparison(SetComparison::new( + Box::new(col("data.a")), + Subquery { + subquery: Arc::new(subquery_plan), + outer_ref_columns: vec![], + spans: Spans::new(), + }, + op, + quantifier, + )); + + LogicalPlanBuilder::from(base_scan) + .filter(predicate)? + .project(vec![col("data.a")])? + .build() +} + +fn assert_set_comparison_predicate( + plan: &LogicalPlan, + expected_op: Operator, + expected_quantifier: SetQuantifier, +) { + let predicate = match plan { + LogicalPlan::Projection(p) => match p.input.as_ref() { + LogicalPlan::Filter(filter) => &filter.predicate, + other => panic!("expected Filter inside Projection, got {other:?}"), + }, + LogicalPlan::Filter(filter) => &filter.predicate, + other => panic!("expected Filter plan, got {other:?}"), + }; + + match predicate { + Expr::SetComparison(set_comparison) => { + assert_eq!(set_comparison.op, expected_op); + assert_eq!(set_comparison.quantifier, expected_quantifier); + } + other => panic!("expected SetComparison predicate, got {other:?}"), + } +} + async fn roundtrip_fill_na(sql: &str) -> Result<()> { let ctx = create_context().await?; let df = ctx.sql(sql).await?; diff --git a/datafusion/substrait/tests/cases/serialize.rs b/datafusion/substrait/tests/cases/serialize.rs index d0f9511760938..2d7257fad3394 100644 --- a/datafusion/substrait/tests/cases/serialize.rs +++ b/datafusion/substrait/tests/cases/serialize.rs @@ -17,7 +17,6 @@ #[cfg(test)] mod tests { - use datafusion::common::assert_contains; use datafusion::datasource::provider_as_source; use datafusion::logical_expr::LogicalPlanBuilder; use datafusion_substrait::logical_plan::consumer::from_substrait_plan; @@ -44,8 +43,18 @@ mod tests { serializer::deserialize(path).await?; // Test case 2: serializing to an existing file should fail. - let got = serializer::serialize(sql, &ctx, path).await.unwrap_err(); - assert_contains!(got.to_string(), "File exists"); + let got = serializer::serialize(sql, &ctx, path) + .await + .unwrap_err() + .to_string(); + assert!( + [ + "File exists", // unix + "os error 80" // windows + ] + .iter() + .any(|s| got.contains(s)) + ); fs::remove_file(path)?; diff --git a/datafusion/substrait/tests/testdata/test_plans/duplicate_name_in_union.substrait.json b/datafusion/substrait/tests/testdata/test_plans/duplicate_name_in_union.substrait.json new file mode 100644 index 0000000000000..1da2ff6131368 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/duplicate_name_in_union.substrait.json @@ -0,0 +1,171 @@ +{ + "version": { + "minorNumber": 54, + "producer": "datafusion-test" + }, + "relations": [ + { + "root": { + "input": { + "set": { + "common": { + "direct": {} + }, + "inputs": [ + { + "project": { + "common": { + "emit": { + "outputMapping": [2, 3] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": ["foo", "bar"], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "virtualTable": { + "expressions": [ + { + "fields": [ + { + "literal": { + "i64": "100" + } + }, + { + "literal": { + "i64": "200" + } + } + ] + } + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [2, 3] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": ["x", "foo"], + "struct": { + "types": [ + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "virtualTable": { + "expressions": [ + { + "fields": [ + { + "literal": { + "i32": 300 + } + }, + { + "literal": { + "i64": "400" + } + } + ] + } + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + ] + } + } + ], + "op": "SET_OP_UNION_ALL" + } + }, + "names": ["col1", "col2"] + } + } + ] +} diff --git a/datafusion/substrait/tests/utils.rs b/datafusion/substrait/tests/utils.rs index 2d63980aadf0d..6a6824579b4e8 100644 --- a/datafusion/substrait/tests/utils.rs +++ b/datafusion/substrait/tests/utils.rs @@ -484,6 +484,7 @@ pub mod test { } RexType::DynamicParameter(_) => {} // Enum is deprecated + #[expect(deprecated)] RexType::Enum(_) => {} } Ok(()) diff --git a/datafusion/wasmtest/Cargo.toml b/datafusion/wasmtest/Cargo.toml index 16fa9790f65b6..0bb304af6f9c3 100644 --- a/datafusion/wasmtest/Cargo.toml +++ b/datafusion/wasmtest/Cargo.toml @@ -47,7 +47,7 @@ chrono = { version = "0.4", features = ["wasmbind"] } # all the `std::fmt` and `std::panicking` infrastructure, so isn't great for # code size when deploying. console_error_panic_hook = { version = "0.1.1", optional = true } -datafusion = { workspace = true, features = ["parquet", "sql"] } +datafusion = { workspace = true, features = ["compression", "parquet", "sql"] } datafusion-common = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } @@ -59,11 +59,13 @@ getrandom = { version = "0.3", features = ["wasm_js"] } wasm-bindgen = "0.2.99" [dev-dependencies] +bytes = { workspace = true } +futures = { workspace = true } object_store = { workspace = true } # needs to be compiled tokio = { workspace = true } url = { workspace = true } -wasm-bindgen-test = "0.3.56" +wasm-bindgen-test = "0.3.58" [package.metadata.cargo-machete] ignored = ["chrono", "getrandom"] diff --git a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json index 98ee1a34f01eb..8f175b0001229 100644 --- a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json +++ b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json @@ -13,7 +13,7 @@ }, "devDependencies": { "copy-webpack-plugin": "12.0.2", - "webpack": "5.94.0", + "webpack": "5.105.0", "webpack-cli": "5.1.4", "webpack-dev-server": "5.2.1" } @@ -32,17 +32,13 @@ } }, "node_modules/@jridgewell/gen-mapping": { - "version": "0.3.5", - "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.5.tgz", - "integrity": "sha512-IzL8ZoEDIBRWEzlCcRhOaCupYyN5gdIK+Q6fbFdPDg6HqX6jpkItn7DFIpW9LQzXG6Df9sA7+OKnq0qlz/GaQg==", + "version": "0.3.13", + "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.13.tgz", + "integrity": "sha512-2kkt/7niJ6MgEPxF0bYdQ6etZaA+fQvDcLKckhy1yIQOzaoKjBBjSj63/aLVjYE3qhRt5dvM+uUyfCg6UKCBbA==", "dev": true, "dependencies": { - "@jridgewell/set-array": "^1.2.1", - "@jridgewell/sourcemap-codec": "^1.4.10", + "@jridgewell/sourcemap-codec": "^1.5.0", "@jridgewell/trace-mapping": "^0.3.24" - }, - "engines": { - "node": ">=6.0.0" } }, "node_modules/@jridgewell/resolve-uri": { @@ -54,19 +50,10 @@ "node": ">=6.0.0" } }, - "node_modules/@jridgewell/set-array": { - "version": "1.2.1", - "resolved": "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.2.1.tgz", - "integrity": "sha512-R8gLRTZeyp03ymzP/6Lil/28tGeGEzhx1q2k703KGWRAI1VdvPIXdG70VJc2pAMw3NA6JKL5hhFu1sJX0Mnn/A==", - "dev": true, - "engines": { - "node": ">=6.0.0" - } - }, "node_modules/@jridgewell/source-map": { - "version": "0.3.6", - "resolved": "https://registry.npmjs.org/@jridgewell/source-map/-/source-map-0.3.6.tgz", - "integrity": "sha512-1ZJTZebgqllO79ue2bm3rIGud/bOe0pP5BjSRCRxxYkEZS8STV7zN84UBbiYu7jy+eCKSnVIUgoWWE/tt+shMQ==", + "version": "0.3.11", + "resolved": "https://registry.npmjs.org/@jridgewell/source-map/-/source-map-0.3.11.tgz", + "integrity": "sha512-ZMp1V8ZFcPG5dIWnQLr3NSI1MiCU7UETdS/A0G8V/XWHvJv3ZsFqutJn1Y5RPmAPX6F3BiE397OqveU/9NCuIA==", "dev": true, "dependencies": { "@jridgewell/gen-mapping": "^0.3.5", @@ -74,15 +61,15 @@ } }, "node_modules/@jridgewell/sourcemap-codec": { - "version": "1.5.0", - "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.0.tgz", - "integrity": "sha512-gv3ZRaISU3fjPAgNsriBRqGWQL6quFx04YMPW/zD8XMLsU32mhCCbfbO6KZFLjvYpCZ8zyDEgqsgf+PwPaM7GQ==", + "version": "1.5.5", + "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.5.tgz", + "integrity": "sha512-cYQ9310grqxueWbl+WuIUIaiUaDcj7WOq5fVhEljNVgRfOUhY9fy2zTvfoqWsnebh8Sl70VScFbICvJnLKB0Og==", "dev": true }, "node_modules/@jridgewell/trace-mapping": { - "version": "0.3.25", - "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.25.tgz", - "integrity": "sha512-vNk6aEwybGtawWmy/PzwnGDOjCkLWSD2wqvjGGAgOAwCGWySYXfYoxt00IJkTF+8Lb57DwOb3Aa0o9CApepiYQ==", + "version": "0.3.31", + "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.31.tgz", + "integrity": "sha512-zzNR+SdQSDJzc8joaeP8QQoCQr8NuYx2dIIytl1QeBEZHJ9uW6hebsrYgbz8hJwUQao3TWCMtmfV8Nu1twOLAw==", "dev": true, "dependencies": { "@jridgewell/resolve-uri": "^3.1.0", @@ -187,10 +174,30 @@ "@types/node": "*" } }, + "node_modules/@types/eslint": { + "version": "9.6.1", + "resolved": "https://registry.npmjs.org/@types/eslint/-/eslint-9.6.1.tgz", + "integrity": "sha512-FXx2pKgId/WyYo2jXw63kk7/+TY7u7AziEJxJAnSFzHlqTAS3Ync6SvgYAN/k4/PQpnnVuzoMuVnByKK2qp0ag==", + "dev": true, + "dependencies": { + "@types/estree": "*", + "@types/json-schema": "*" + } + }, + "node_modules/@types/eslint-scope": { + "version": "3.7.7", + "resolved": "https://registry.npmjs.org/@types/eslint-scope/-/eslint-scope-3.7.7.tgz", + "integrity": "sha512-MzMFlSLBqNF2gcHWO0G1vP/YQyfvrxZ0bF+u7mzUdZ1/xK4A4sru+nraZz5i3iEIk1l1uyicaDVTB4QbbEkAYg==", + "dev": true, + "dependencies": { + "@types/eslint": "*", + "@types/estree": "*" + } + }, "node_modules/@types/estree": { - "version": "1.0.5", - "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.5.tgz", - "integrity": "sha512-/kYRxGDLWzHOB7q+wtSUQlFrtcdUccpfy+X+9iMBpHK8QLLhx2wIPYuS5DYtR9Wa/YlZAbIovy7qVdB1Aq6Lyw==", + "version": "1.0.8", + "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.8.tgz", + "integrity": "sha512-dWHzHa2WqEXI/O1E9OjrocMTKJl2mSrEolh1Iomrv6U+JuNwaHXsXx9bLu5gG7BUWFIN0skIQJQ/L1rIex4X6w==", "dev": true }, "node_modules/@types/express": { @@ -234,9 +241,9 @@ } }, "node_modules/@types/json-schema": { - "version": "7.0.13", - "resolved": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.13.tgz", - "integrity": "sha512-RbSSoHliUbnXj3ny0CNFOoxrIDV6SUGyStHsvDqosw6CkdPV8TtWGlfecuK4ToyMEAql6pzNxgCFKanovUzlgQ==", + "version": "7.0.15", + "resolved": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.15.tgz", + "integrity": "sha512-5+fP8P8MFNC+AyZCDxrB2pkZFPGzqQWUzpSeuuVLvm8VMcorNYavBqoFcxK8bQz4Qsbn4oUEEem4wDLfcysGHA==", "dev": true }, "node_modules/@types/mime": { @@ -333,148 +340,148 @@ } }, "node_modules/@webassemblyjs/ast": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/ast/-/ast-1.12.1.tgz", - "integrity": "sha512-EKfMUOPRRUTy5UII4qJDGPpqfwjOmZ5jeGFwid9mnoqIFK+e0vqoi1qH56JpmZSzEL53jKnNzScdmftJyG5xWg==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/ast/-/ast-1.14.1.tgz", + "integrity": "sha512-nuBEDgQfm1ccRp/8bCQrx1frohyufl4JlbMMZ4P1wpeOfDhF6FQkxZJ1b/e+PLwr6X1Nhw6OLme5usuBWYBvuQ==", "dev": true, "dependencies": { - "@webassemblyjs/helper-numbers": "1.11.6", - "@webassemblyjs/helper-wasm-bytecode": "1.11.6" + "@webassemblyjs/helper-numbers": "1.13.2", + "@webassemblyjs/helper-wasm-bytecode": "1.13.2" } }, "node_modules/@webassemblyjs/floating-point-hex-parser": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/floating-point-hex-parser/-/floating-point-hex-parser-1.11.6.tgz", - "integrity": "sha512-ejAj9hfRJ2XMsNHk/v6Fu2dGS+i4UaXBXGemOfQ/JfQ6mdQg/WXtwleQRLLS4OvfDhv8rYnVwH27YJLMyYsxhw==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/floating-point-hex-parser/-/floating-point-hex-parser-1.13.2.tgz", + "integrity": "sha512-6oXyTOzbKxGH4steLbLNOu71Oj+C8Lg34n6CqRvqfS2O71BxY6ByfMDRhBytzknj9yGUPVJ1qIKhRlAwO1AovA==", "dev": true }, "node_modules/@webassemblyjs/helper-api-error": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-api-error/-/helper-api-error-1.11.6.tgz", - "integrity": "sha512-o0YkoP4pVu4rN8aTJgAyj9hC2Sv5UlkzCHhxqWj8butaLvnpdc2jOwh4ewE6CX0txSfLn/UYaV/pheS2Txg//Q==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-api-error/-/helper-api-error-1.13.2.tgz", + "integrity": "sha512-U56GMYxy4ZQCbDZd6JuvvNV/WFildOjsaWD3Tzzvmw/mas3cXzRJPMjP83JqEsgSbyrmaGjBfDtV7KDXV9UzFQ==", "dev": true }, "node_modules/@webassemblyjs/helper-buffer": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-buffer/-/helper-buffer-1.12.1.tgz", - "integrity": "sha512-nzJwQw99DNDKr9BVCOZcLuJJUlqkJh+kVzVl6Fmq/tI5ZtEyWT1KZMyOXltXLZJmDtvLCDgwsyrkohEtopTXCw==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-buffer/-/helper-buffer-1.14.1.tgz", + "integrity": "sha512-jyH7wtcHiKssDtFPRB+iQdxlDf96m0E39yb0k5uJVhFGleZFoNw1c4aeIcVUPPbXUVJ94wwnMOAqUHyzoEPVMA==", "dev": true }, "node_modules/@webassemblyjs/helper-numbers": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-numbers/-/helper-numbers-1.11.6.tgz", - "integrity": "sha512-vUIhZ8LZoIWHBohiEObxVm6hwP034jwmc9kuq5GdHZH0wiLVLIPcMCdpJzG4C11cHoQ25TFIQj9kaVADVX7N3g==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-numbers/-/helper-numbers-1.13.2.tgz", + "integrity": "sha512-FE8aCmS5Q6eQYcV3gI35O4J789wlQA+7JrqTTpJqn5emA4U2hvwJmvFRC0HODS+3Ye6WioDklgd6scJ3+PLnEA==", "dev": true, "dependencies": { - "@webassemblyjs/floating-point-hex-parser": "1.11.6", - "@webassemblyjs/helper-api-error": "1.11.6", + "@webassemblyjs/floating-point-hex-parser": "1.13.2", + "@webassemblyjs/helper-api-error": "1.13.2", "@xtuc/long": "4.2.2" } }, "node_modules/@webassemblyjs/helper-wasm-bytecode": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-bytecode/-/helper-wasm-bytecode-1.11.6.tgz", - "integrity": "sha512-sFFHKwcmBprO9e7Icf0+gddyWYDViL8bpPjJJl0WHxCdETktXdmtWLGVzoHbqUcY4Be1LkNfwTmXOJUFZYSJdA==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-bytecode/-/helper-wasm-bytecode-1.13.2.tgz", + "integrity": "sha512-3QbLKy93F0EAIXLh0ogEVR6rOubA9AoZ+WRYhNbFyuB70j3dRdwH9g+qXhLAO0kiYGlg3TxDV+I4rQTr/YNXkA==", "dev": true }, "node_modules/@webassemblyjs/helper-wasm-section": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-section/-/helper-wasm-section-1.12.1.tgz", - "integrity": "sha512-Jif4vfB6FJlUlSbgEMHUyk1j234GTNG9dBJ4XJdOySoj518Xj0oGsNi59cUQF4RRMS9ouBUxDDdyBVfPTypa5g==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-section/-/helper-wasm-section-1.14.1.tgz", + "integrity": "sha512-ds5mXEqTJ6oxRoqjhWDU83OgzAYjwsCV8Lo/N+oRsNDmx/ZDpqalmrtgOMkHwxsG0iI//3BwWAErYRHtgn0dZw==", "dev": true, "dependencies": { - "@webassemblyjs/ast": "1.12.1", - "@webassemblyjs/helper-buffer": "1.12.1", - "@webassemblyjs/helper-wasm-bytecode": "1.11.6", - "@webassemblyjs/wasm-gen": "1.12.1" + "@webassemblyjs/ast": "1.14.1", + "@webassemblyjs/helper-buffer": "1.14.1", + "@webassemblyjs/helper-wasm-bytecode": "1.13.2", + "@webassemblyjs/wasm-gen": "1.14.1" } }, "node_modules/@webassemblyjs/ieee754": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/ieee754/-/ieee754-1.11.6.tgz", - "integrity": "sha512-LM4p2csPNvbij6U1f19v6WR56QZ8JcHg3QIJTlSwzFcmx6WSORicYj6I63f9yU1kEUtrpG+kjkiIAkevHpDXrg==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/ieee754/-/ieee754-1.13.2.tgz", + "integrity": "sha512-4LtOzh58S/5lX4ITKxnAK2USuNEvpdVV9AlgGQb8rJDHaLeHciwG4zlGr0j/SNWlr7x3vO1lDEsuePvtcDNCkw==", "dev": true, "dependencies": { "@xtuc/ieee754": "^1.2.0" } }, "node_modules/@webassemblyjs/leb128": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/leb128/-/leb128-1.11.6.tgz", - "integrity": "sha512-m7a0FhE67DQXgouf1tbN5XQcdWoNgaAuoULHIfGFIEVKA6tu/edls6XnIlkmS6FrXAquJRPni3ZZKjw6FSPjPQ==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/leb128/-/leb128-1.13.2.tgz", + "integrity": "sha512-Lde1oNoIdzVzdkNEAWZ1dZ5orIbff80YPdHx20mrHwHrVNNTjNr8E3xz9BdpcGqRQbAEa+fkrCb+fRFTl/6sQw==", "dev": true, "dependencies": { "@xtuc/long": "4.2.2" } }, "node_modules/@webassemblyjs/utf8": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/utf8/-/utf8-1.11.6.tgz", - "integrity": "sha512-vtXf2wTQ3+up9Zsg8sa2yWiQpzSsMyXj0qViVP6xKGCUT8p8YJ6HqI7l5eCnWx1T/FYdsv07HQs2wTFbbof/RA==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/utf8/-/utf8-1.13.2.tgz", + "integrity": "sha512-3NQWGjKTASY1xV5m7Hr0iPeXD9+RDobLll3T9d2AO+g3my8xy5peVyjSag4I50mR1bBSN/Ct12lo+R9tJk0NZQ==", "dev": true }, "node_modules/@webassemblyjs/wasm-edit": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-edit/-/wasm-edit-1.12.1.tgz", - "integrity": "sha512-1DuwbVvADvS5mGnXbE+c9NfA8QRcZ6iKquqjjmR10k6o+zzsRVesil54DKexiowcFCPdr/Q0qaMgB01+SQ1u6g==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-edit/-/wasm-edit-1.14.1.tgz", + "integrity": "sha512-RNJUIQH/J8iA/1NzlE4N7KtyZNHi3w7at7hDjvRNm5rcUXa00z1vRz3glZoULfJ5mpvYhLybmVcwcjGrC1pRrQ==", "dev": true, "dependencies": { - "@webassemblyjs/ast": "1.12.1", - "@webassemblyjs/helper-buffer": "1.12.1", - "@webassemblyjs/helper-wasm-bytecode": "1.11.6", - "@webassemblyjs/helper-wasm-section": "1.12.1", - "@webassemblyjs/wasm-gen": "1.12.1", - "@webassemblyjs/wasm-opt": "1.12.1", - "@webassemblyjs/wasm-parser": "1.12.1", - "@webassemblyjs/wast-printer": "1.12.1" + "@webassemblyjs/ast": "1.14.1", + "@webassemblyjs/helper-buffer": "1.14.1", + "@webassemblyjs/helper-wasm-bytecode": "1.13.2", + "@webassemblyjs/helper-wasm-section": "1.14.1", + "@webassemblyjs/wasm-gen": "1.14.1", + "@webassemblyjs/wasm-opt": "1.14.1", + "@webassemblyjs/wasm-parser": "1.14.1", + "@webassemblyjs/wast-printer": "1.14.1" } }, "node_modules/@webassemblyjs/wasm-gen": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-gen/-/wasm-gen-1.12.1.tgz", - "integrity": "sha512-TDq4Ojh9fcohAw6OIMXqiIcTq5KUXTGRkVxbSo1hQnSy6lAM5GSdfwWeSxpAo0YzgsgF182E/U0mDNhuA0tW7w==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-gen/-/wasm-gen-1.14.1.tgz", + "integrity": "sha512-AmomSIjP8ZbfGQhumkNvgC33AY7qtMCXnN6bL2u2Js4gVCg8fp735aEiMSBbDR7UQIj90n4wKAFUSEd0QN2Ukg==", "dev": true, "dependencies": { - "@webassemblyjs/ast": "1.12.1", - "@webassemblyjs/helper-wasm-bytecode": "1.11.6", - "@webassemblyjs/ieee754": "1.11.6", - "@webassemblyjs/leb128": "1.11.6", - "@webassemblyjs/utf8": "1.11.6" + "@webassemblyjs/ast": "1.14.1", + "@webassemblyjs/helper-wasm-bytecode": "1.13.2", + "@webassemblyjs/ieee754": "1.13.2", + "@webassemblyjs/leb128": "1.13.2", + "@webassemblyjs/utf8": "1.13.2" } }, "node_modules/@webassemblyjs/wasm-opt": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-opt/-/wasm-opt-1.12.1.tgz", - "integrity": "sha512-Jg99j/2gG2iaz3hijw857AVYekZe2SAskcqlWIZXjji5WStnOpVoat3gQfT/Q5tb2djnCjBtMocY/Su1GfxPBg==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-opt/-/wasm-opt-1.14.1.tgz", + "integrity": "sha512-PTcKLUNvBqnY2U6E5bdOQcSM+oVP/PmrDY9NzowJjislEjwP/C4an2303MCVS2Mg9d3AJpIGdUFIQQWbPds0Sw==", "dev": true, "dependencies": { - "@webassemblyjs/ast": "1.12.1", - "@webassemblyjs/helper-buffer": "1.12.1", - "@webassemblyjs/wasm-gen": "1.12.1", - "@webassemblyjs/wasm-parser": "1.12.1" + "@webassemblyjs/ast": "1.14.1", + "@webassemblyjs/helper-buffer": "1.14.1", + "@webassemblyjs/wasm-gen": "1.14.1", + "@webassemblyjs/wasm-parser": "1.14.1" } }, "node_modules/@webassemblyjs/wasm-parser": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-parser/-/wasm-parser-1.12.1.tgz", - "integrity": "sha512-xikIi7c2FHXysxXe3COrVUPSheuBtpcfhbpFj4gmu7KRLYOzANztwUU0IbsqvMqzuNK2+glRGWCEqZo1WCLyAQ==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-parser/-/wasm-parser-1.14.1.tgz", + "integrity": "sha512-JLBl+KZ0R5qB7mCnud/yyX08jWFw5MsoalJ1pQ4EdFlgj9VdXKGuENGsiCIjegI1W7p91rUlcB/LB5yRJKNTcQ==", "dev": true, "dependencies": { - "@webassemblyjs/ast": "1.12.1", - "@webassemblyjs/helper-api-error": "1.11.6", - "@webassemblyjs/helper-wasm-bytecode": "1.11.6", - "@webassemblyjs/ieee754": "1.11.6", - "@webassemblyjs/leb128": "1.11.6", - "@webassemblyjs/utf8": "1.11.6" + "@webassemblyjs/ast": "1.14.1", + "@webassemblyjs/helper-api-error": "1.13.2", + "@webassemblyjs/helper-wasm-bytecode": "1.13.2", + "@webassemblyjs/ieee754": "1.13.2", + "@webassemblyjs/leb128": "1.13.2", + "@webassemblyjs/utf8": "1.13.2" } }, "node_modules/@webassemblyjs/wast-printer": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wast-printer/-/wast-printer-1.12.1.tgz", - "integrity": "sha512-+X4WAlOisVWQMikjbcvY2e0rwPsKQ9F688lksZhBcPycBBuii3O7m8FACbDMWDojpAqvjIncrG8J0XHKyQfVeA==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wast-printer/-/wast-printer-1.14.1.tgz", + "integrity": "sha512-kPSSXE6De1XOR820C90RIo2ogvZG+c3KiHzqUoO/F34Y2shGzesfqv7o57xrxovZJH/MetF5UjroJ/R/3isoiw==", "dev": true, "dependencies": { - "@webassemblyjs/ast": "1.12.1", + "@webassemblyjs/ast": "1.14.1", "@xtuc/long": "4.2.2" } }, @@ -548,9 +555,9 @@ } }, "node_modules/acorn": { - "version": "8.12.1", - "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.12.1.tgz", - "integrity": "sha512-tcpGyI9zbizT9JbV6oYE477V6mTlXvvi0T0G3SNIYE2apm/G5huBa1+K89VGeovbg+jycCrfhl3ADxErOuO6Jg==", + "version": "8.15.0", + "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz", + "integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==", "dev": true, "bin": { "acorn": "bin/acorn" @@ -559,25 +566,28 @@ "node": ">=0.4.0" } }, - "node_modules/acorn-import-attributes": { - "version": "1.9.5", - "resolved": "https://registry.npmjs.org/acorn-import-attributes/-/acorn-import-attributes-1.9.5.tgz", - "integrity": "sha512-n02Vykv5uA3eHGM/Z2dQrcD56kL8TyDb2p1+0P83PClMnC/nc+anbQRhIOWnSq4Ke/KvDPrY3C9hDtC/A3eHnQ==", + "node_modules/acorn-import-phases": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/acorn-import-phases/-/acorn-import-phases-1.0.4.tgz", + "integrity": "sha512-wKmbr/DDiIXzEOiWrTTUcDm24kQ2vGfZQvM2fwg2vXqR5uW6aapr7ObPtj1th32b9u90/Pf4AItvdTh42fBmVQ==", "dev": true, + "engines": { + "node": ">=10.13.0" + }, "peerDependencies": { - "acorn": "^8" + "acorn": "^8.14.0" } }, "node_modules/ajv": { - "version": "6.12.6", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", - "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", + "version": "8.17.1", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz", + "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", "dev": true, "dependencies": { - "fast-deep-equal": "^3.1.1", - "fast-json-stable-stringify": "^2.0.0", - "json-schema-traverse": "^0.4.1", - "uri-js": "^4.2.2" + "fast-deep-equal": "^3.1.3", + "fast-uri": "^3.0.1", + "json-schema-traverse": "^1.0.0", + "require-from-string": "^2.0.2" }, "funding": { "type": "github", @@ -601,35 +611,16 @@ } } }, - "node_modules/ajv-formats/node_modules/ajv": { - "version": "8.12.0", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.12.0.tgz", - "integrity": "sha512-sRu1kpcO9yLtYxBKvqfTeh9KzZEwO3STyX1HT+4CaDzC6HpTGYhIhPIzj9XuKU7KYDwnaeh5hcOwjy1QuJzBPA==", + "node_modules/ajv-keywords": { + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-5.1.0.tgz", + "integrity": "sha512-YCS/JNFAUyr5vAuhk1DWm1CBxRHW9LbJ2ozWeemrIqpbsqKjHVxYPyi5GC0rjZIT5JxJ3virVTS8wk4i/Z+krw==", "dev": true, "dependencies": { - "fast-deep-equal": "^3.1.1", - "json-schema-traverse": "^1.0.0", - "require-from-string": "^2.0.2", - "uri-js": "^4.2.2" + "fast-deep-equal": "^3.1.3" }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/epoberezkin" - } - }, - "node_modules/ajv-formats/node_modules/json-schema-traverse": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", - "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", - "dev": true - }, - "node_modules/ajv-keywords": { - "version": "3.5.2", - "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-3.5.2.tgz", - "integrity": "sha512-5p6WTN0DdTGVQk6VjcEju19IgaHudalcfabD7yhDGeA6bcQnmL+CpveLJq/3hvfwd1aof6L386Ougkx6RfyMIQ==", - "dev": true, "peerDependencies": { - "ajv": "^6.9.1" + "ajv": "^8.8.2" } }, "node_modules/ansi-html-community": { @@ -665,6 +656,15 @@ "dev": true, "license": "MIT" }, + "node_modules/baseline-browser-mapping": { + "version": "2.9.19", + "resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.9.19.tgz", + "integrity": "sha512-ipDqC8FrAl/76p2SSWKSI+H9tFwm7vYqXQrItCuiVPt26Km0jS+NzSsBWAaBusvSbQcfJG+JitdMm+wZAgTYqg==", + "dev": true, + "bin": { + "baseline-browser-mapping": "dist/cli.js" + } + }, "node_modules/batch": { "version": "0.6.1", "resolved": "https://registry.npmjs.org/batch/-/batch-0.6.1.tgz", @@ -753,9 +753,9 @@ } }, "node_modules/browserslist": { - "version": "4.21.11", - "resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.21.11.tgz", - "integrity": "sha512-xn1UXOKUz7DjdGlg9RrUr0GGiWzI97UQJnugHtH0OLDfJB7jMgoIkYvRIEO1l9EeEERVqeqLYOcFBW9ldjypbQ==", + "version": "4.28.1", + "resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.28.1.tgz", + "integrity": "sha512-ZC5Bd0LgJXgwGqUknZY/vkUQ04r8NXnJZ3yYi4vDmSiZmC/pdSN0NbNRPxZpbtO4uAfDUAFffO8IZoM3Gj8IkA==", "dev": true, "funding": [ { @@ -772,10 +772,11 @@ } ], "dependencies": { - "caniuse-lite": "^1.0.30001538", - "electron-to-chromium": "^1.4.526", - "node-releases": "^2.0.13", - "update-browserslist-db": "^1.0.13" + "baseline-browser-mapping": "^2.9.0", + "caniuse-lite": "^1.0.30001759", + "electron-to-chromium": "^1.5.263", + "node-releases": "^2.0.27", + "update-browserslist-db": "^1.2.0" }, "bin": { "browserslist": "cli.js" @@ -847,9 +848,9 @@ } }, "node_modules/caniuse-lite": { - "version": "1.0.30001538", - "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001538.tgz", - "integrity": "sha512-HWJnhnID+0YMtGlzcp3T9drmBJUVDchPJ08tpUGFLs9CYlwWPH2uLgpHn8fND5pCgXVtnGS3H4QR9XLMHVNkHw==", + "version": "1.0.30001768", + "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001768.tgz", + "integrity": "sha512-qY3aDRZC5nWPgHUgIB84WL+nySuo19wk0VJpp/XI9T34lrvkyhRvNVOFJOp2kxClQhiFBu+TaUSudf6oa3vkSA==", "dev": true, "funding": [ { @@ -1092,36 +1093,6 @@ "webpack": "^5.1.0" } }, - "node_modules/copy-webpack-plugin/node_modules/ajv": { - "version": "8.17.1", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz", - "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", - "dev": true, - "license": "MIT", - "dependencies": { - "fast-deep-equal": "^3.1.3", - "fast-uri": "^3.0.1", - "json-schema-traverse": "^1.0.0", - "require-from-string": "^2.0.2" - }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/epoberezkin" - } - }, - "node_modules/copy-webpack-plugin/node_modules/ajv-keywords": { - "version": "5.1.0", - "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-5.1.0.tgz", - "integrity": "sha512-YCS/JNFAUyr5vAuhk1DWm1CBxRHW9LbJ2ozWeemrIqpbsqKjHVxYPyi5GC0rjZIT5JxJ3virVTS8wk4i/Z+krw==", - "dev": true, - "license": "MIT", - "dependencies": { - "fast-deep-equal": "^3.1.3" - }, - "peerDependencies": { - "ajv": "^8.8.2" - } - }, "node_modules/copy-webpack-plugin/node_modules/glob-parent": { "version": "6.0.2", "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", @@ -1135,33 +1106,6 @@ "node": ">=10.13.0" } }, - "node_modules/copy-webpack-plugin/node_modules/json-schema-traverse": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", - "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", - "dev": true, - "license": "MIT" - }, - "node_modules/copy-webpack-plugin/node_modules/schema-utils": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.3.0.tgz", - "integrity": "sha512-Gf9qqc58SpCA/xdziiHz35F4GNIWYWZrEshUc/G/r5BnLph6xpKuLeoJoQuj5WfBIx/eQLf+hmVPYHaxJu7V2g==", - "dev": true, - "license": "MIT", - "dependencies": { - "@types/json-schema": "^7.0.9", - "ajv": "^8.9.0", - "ajv-formats": "^2.1.1", - "ajv-keywords": "^5.1.0" - }, - "engines": { - "node": ">= 10.13.0" - }, - "funding": { - "type": "opencollective", - "url": "https://opencollective.com/webpack" - } - }, "node_modules/core-util-is": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/core-util-is/-/core-util-is-1.0.2.tgz", @@ -1307,9 +1251,9 @@ "license": "MIT" }, "node_modules/electron-to-chromium": { - "version": "1.4.528", - "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.4.528.tgz", - "integrity": "sha512-UdREXMXzLkREF4jA8t89FQjA8WHI6ssP38PMY4/4KhXFQbtImnghh4GkCgrtiZwLKUKVD2iTVXvDVQjfomEQuA==", + "version": "1.5.286", + "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.5.286.tgz", + "integrity": "sha512-9tfDXhJ4RKFNerfjdCcZfufu49vg620741MNs26a9+bhLThdB+plgMeou98CAaHu/WATj2iHOOHTp1hWtABj2A==", "dev": true }, "node_modules/encodeurl": { @@ -1323,13 +1267,13 @@ } }, "node_modules/enhanced-resolve": { - "version": "5.17.1", - "resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.17.1.tgz", - "integrity": "sha512-LMHl3dXhTcfv8gM4kEzIUeTQ+7fpdA0l2tUf34BddXPkz2A5xJ5L/Pchd5BL6rdccM9QGvu0sWZzK1Z1t4wwyg==", + "version": "5.19.0", + "resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.19.0.tgz", + "integrity": "sha512-phv3E1Xl4tQOShqSte26C7Fl84EwUdZsyOuSSk9qtAGyyQs2s3jJzComh+Abf4g187lUUAvH+H26omrqia2aGg==", "dev": true, "dependencies": { "graceful-fs": "^4.2.4", - "tapable": "^2.2.0" + "tapable": "^2.3.0" }, "engines": { "node": ">=10.13.0" @@ -1368,9 +1312,9 @@ } }, "node_modules/es-module-lexer": { - "version": "1.3.1", - "resolved": "https://registry.npmjs.org/es-module-lexer/-/es-module-lexer-1.3.1.tgz", - "integrity": "sha512-JUFAyicQV9mXc3YRxPnDlrfBKpqt6hUYzz9/boprUJHs4e4KVr3XwOF70doO6gwXUor6EWZJAyWAfKki84t20Q==", + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/es-module-lexer/-/es-module-lexer-2.0.0.tgz", + "integrity": "sha512-5POEcUuZybH7IdmGsD8wlf0AI55wMecM9rVBTI/qEAy2c1kTOm3DjFYjrBdI2K3BaJjJYfYFeRtM0t9ssnRuxw==", "dev": true }, "node_modules/es-object-atoms": { @@ -1387,9 +1331,9 @@ } }, "node_modules/escalade": { - "version": "3.1.1", - "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.1.1.tgz", - "integrity": "sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw==", + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.2.0.tgz", + "integrity": "sha512-WUj2qlxaQtO4g6Pq5c29GTcWGDyd8itL8zTlipgECz3JesAiiOKotd8JU6otB3PACgG6xkJUyVhboMS+bje/jA==", "dev": true, "engines": { "node": ">=6" @@ -1604,16 +1548,10 @@ "node": ">=8.6.0" } }, - "node_modules/fast-json-stable-stringify": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/fast-json-stable-stringify/-/fast-json-stable-stringify-2.1.0.tgz", - "integrity": "sha512-lhd/wF+Lk98HZoTCtlVraHtfh5XYijIjalXck7saUtuanSDyLMxnHhSXEDJqHxD7msR8D0uCmqlkwjCV8xvwHw==", - "dev": true - }, "node_modules/fast-uri": { - "version": "3.0.6", - "resolved": "https://registry.npmjs.org/fast-uri/-/fast-uri-3.0.6.tgz", - "integrity": "sha512-Atfo14OibSv5wAp4VWNsFYE1AchQRTv9cBGWET4pZWHzYshFSS9NQI6I57rdKn9croWVMbYFbLhJ+yJvmZIIHw==", + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/fast-uri/-/fast-uri-3.1.0.tgz", + "integrity": "sha512-iPeeDKJSWf4IEOasVVrknXpaBV0IApz/gp7S2bb7Z4Lljbl2MGJRqInZiUrQwV16cpzw/D3S5j5Julj/gT52AA==", "dev": true, "funding": [ { @@ -1624,8 +1562,7 @@ "type": "opencollective", "url": "https://opencollective.com/fastify" } - ], - "license": "BSD-3-Clause" + ] }, "node_modules/fastest-levenshtein": { "version": "1.0.16", @@ -2304,9 +2241,9 @@ "dev": true }, "node_modules/json-schema-traverse": { - "version": "0.4.1", - "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz", - "integrity": "sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==", + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", + "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", "dev": true }, "node_modules/kind-of": { @@ -2330,12 +2267,16 @@ } }, "node_modules/loader-runner": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/loader-runner/-/loader-runner-4.3.0.tgz", - "integrity": "sha512-3R/1M+yS3j5ou80Me59j7F9IMs4PXs3VqRrm0TU3AbKPxlmpoY1TNscJV/oGJXo8qCatFGTfDbY6W6ipGOYXfg==", + "version": "4.3.1", + "resolved": "https://registry.npmjs.org/loader-runner/-/loader-runner-4.3.1.tgz", + "integrity": "sha512-IWqP2SCPhyVFTBtRcgMHdzlf9ul25NwaFx4wCEH/KjAXuuHY4yNjvPXsBokp8jCB936PyWRaPKUNh8NvylLp2Q==", "dev": true, "engines": { "node": ">=6.11.5" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" } }, "node_modules/locate-path": { @@ -2619,9 +2560,9 @@ } }, "node_modules/node-releases": { - "version": "2.0.13", - "resolved": "https://registry.npmjs.org/node-releases/-/node-releases-2.0.13.tgz", - "integrity": "sha512-uYr7J37ae/ORWdZeQ1xxMJe3NtdmqMC/JZK+geofDrkLUApKRHPd18/TxtBOJ4A0/+uUIliorNrfYV6s1b02eQ==", + "version": "2.0.27", + "resolved": "https://registry.npmjs.org/node-releases/-/node-releases-2.0.27.tgz", + "integrity": "sha512-nmh3lCkYZ3grZvqcCH+fjmQ7X+H0OeZgP40OierEaAptX4XofMh5kwNbWh7lBduUzCcV/8kZ+NDLCwm2iorIlA==", "dev": true }, "node_modules/normalize-path": { @@ -2801,9 +2742,9 @@ } }, "node_modules/picocolors": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.0.0.tgz", - "integrity": "sha512-1fygroTLlHu66zi26VoTDv8yRgm0Fccecssto+MhsZ0D/DGW2sm8E8AjW7NU5VVTRt5GxbeZ5qBuJr+HyLYkjQ==", + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz", + "integrity": "sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==", "dev": true }, "node_modules/picomatch": { @@ -2860,15 +2801,6 @@ "node": ">= 0.10" } }, - "node_modules/punycode": { - "version": "2.3.0", - "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.0.tgz", - "integrity": "sha512-rRV+zQD8tVFys26lAGR9WUuS4iUAngJScM+ZRSKtvl5tKeZ2t5bvdNFdNHBW9FWR4guGHlgmsZ1G7BSm2wTbuA==", - "dev": true, - "engines": { - "node": ">=6" - } - }, "node_modules/qs": { "version": "6.13.0", "resolved": "https://registry.npmjs.org/qs/-/qs-6.13.0.tgz", @@ -3106,14 +3038,15 @@ "license": "MIT" }, "node_modules/schema-utils": { - "version": "3.3.0", - "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-3.3.0.tgz", - "integrity": "sha512-pN/yOAvcC+5rQ5nERGuwrjLlYvLTbCibnZ1I7B1LaiAz9BRBlE9GMgE/eqV30P7aJQUf7Ddimy/RsbYO/GrVGg==", + "version": "4.3.3", + "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.3.3.tgz", + "integrity": "sha512-eflK8wEtyOE6+hsaRVPxvUKYCpRgzLqDTb8krvAsRIwOGlHoSgYLgBXoubGgLd2fT41/OUYdb48v4k4WWHQurA==", "dev": true, "dependencies": { - "@types/json-schema": "^7.0.8", - "ajv": "^6.12.5", - "ajv-keywords": "^3.5.2" + "@types/json-schema": "^7.0.9", + "ajv": "^8.9.0", + "ajv-formats": "^2.1.1", + "ajv-keywords": "^5.1.0" }, "engines": { "node": ">= 10.13.0" @@ -3558,22 +3491,26 @@ } }, "node_modules/tapable": { - "version": "2.2.1", - "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.2.1.tgz", - "integrity": "sha512-GNzQvQTOIP6RyTfE2Qxb8ZVlNmw0n88vp1szwWRimP02mnTsx3Wtn5qRdqY9w2XduFNUgvOwhNnQsjwCp+kqaQ==", + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.3.0.tgz", + "integrity": "sha512-g9ljZiwki/LfxmQADO3dEY1CbpmXT5Hm2fJ+QaGKwSXUylMybePR7/67YW7jOrrvjEgL1Fmz5kzyAjWVWLlucg==", "dev": true, "engines": { "node": ">=6" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" } }, "node_modules/terser": { - "version": "5.31.6", - "resolved": "https://registry.npmjs.org/terser/-/terser-5.31.6.tgz", - "integrity": "sha512-PQ4DAriWzKj+qgehQ7LK5bQqCFNMmlhjR2PFFLuqGCpuCAauxemVBWwWOxo3UIwWQx8+Pr61Df++r76wDmkQBg==", + "version": "5.46.0", + "resolved": "https://registry.npmjs.org/terser/-/terser-5.46.0.tgz", + "integrity": "sha512-jTwoImyr/QbOWFFso3YoU3ik0jBBDJ6JTOQiy/J2YxVJdZCc+5u7skhNwiOR3FQIygFqVUPHl7qbbxtjW2K3Qg==", "dev": true, "dependencies": { "@jridgewell/source-map": "^0.3.3", - "acorn": "^8.8.2", + "acorn": "^8.15.0", "commander": "^2.20.0", "source-map-support": "~0.5.20" }, @@ -3585,16 +3522,16 @@ } }, "node_modules/terser-webpack-plugin": { - "version": "5.3.10", - "resolved": "https://registry.npmjs.org/terser-webpack-plugin/-/terser-webpack-plugin-5.3.10.tgz", - "integrity": "sha512-BKFPWlPDndPs+NGGCr1U59t0XScL5317Y0UReNrHaw9/FwhPENlq6bfgs+4yPfyP51vqC1bQ4rp1EfXW5ZSH9w==", + "version": "5.3.16", + "resolved": "https://registry.npmjs.org/terser-webpack-plugin/-/terser-webpack-plugin-5.3.16.tgz", + "integrity": "sha512-h9oBFCWrq78NyWWVcSwZarJkZ01c2AyGrzs1crmHZO3QUg9D61Wu4NPjBy69n7JqylFF5y+CsUZYmYEIZ3mR+Q==", "dev": true, "dependencies": { - "@jridgewell/trace-mapping": "^0.3.20", + "@jridgewell/trace-mapping": "^0.3.25", "jest-worker": "^27.4.5", - "schema-utils": "^3.1.1", - "serialize-javascript": "^6.0.1", - "terser": "^5.26.0" + "schema-utils": "^4.3.0", + "serialize-javascript": "^6.0.2", + "terser": "^5.31.1" }, "engines": { "node": ">= 10.13.0" @@ -3691,9 +3628,9 @@ } }, "node_modules/update-browserslist-db": { - "version": "1.0.13", - "resolved": "https://registry.npmjs.org/update-browserslist-db/-/update-browserslist-db-1.0.13.tgz", - "integrity": "sha512-xebP81SNcPuNpPP3uzeW1NYXxI3rxyJzF3pD6sH4jE7o/IX+WtSpwnVU+qIsDPyk0d3hmFQ7mjqc6AtV604hbg==", + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/update-browserslist-db/-/update-browserslist-db-1.2.3.tgz", + "integrity": "sha512-Js0m9cx+qOgDxo0eMiFGEueWztz+d4+M3rGlmKPT+T4IS/jP4ylw3Nwpu6cpTTP8R1MAC1kF4VbdLt3ARf209w==", "dev": true, "funding": [ { @@ -3710,8 +3647,8 @@ } ], "dependencies": { - "escalade": "^3.1.1", - "picocolors": "^1.0.0" + "escalade": "^3.2.0", + "picocolors": "^1.1.1" }, "bin": { "update-browserslist-db": "cli.js" @@ -3720,15 +3657,6 @@ "browserslist": ">= 4.21.0" } }, - "node_modules/uri-js": { - "version": "4.4.1", - "resolved": "https://registry.npmjs.org/uri-js/-/uri-js-4.4.1.tgz", - "integrity": "sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg==", - "dev": true, - "dependencies": { - "punycode": "^2.1.0" - } - }, "node_modules/util-deprecate": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz", @@ -3764,9 +3692,9 @@ } }, "node_modules/watchpack": { - "version": "2.4.2", - "resolved": "https://registry.npmjs.org/watchpack/-/watchpack-2.4.2.tgz", - "integrity": "sha512-TnbFSbcOCcDgjZ4piURLCbJ3nJhznVh9kw6F6iokjiFPl8ONxe9A6nMDVXDiNbrSfLILs6vB07F7wLBrwPYzJw==", + "version": "2.5.1", + "resolved": "https://registry.npmjs.org/watchpack/-/watchpack-2.5.1.tgz", + "integrity": "sha512-Zn5uXdcFNIA1+1Ei5McRd+iRzfhENPCe7LeABkJtNulSxjma+l7ltNx55BWZkRlwRnpOgHqxnjyaDgJnNXnqzg==", "dev": true, "dependencies": { "glob-to-regexp": "^0.4.1", @@ -3786,34 +3714,36 @@ } }, "node_modules/webpack": { - "version": "5.94.0", - "resolved": "https://registry.npmjs.org/webpack/-/webpack-5.94.0.tgz", - "integrity": "sha512-KcsGn50VT+06JH/iunZJedYGUJS5FGjow8wb9c0v5n1Om8O1g4L6LjtfxwlXIATopoQu+vOXXa7gYisWxCoPyg==", - "dev": true, - "dependencies": { - "@types/estree": "^1.0.5", - "@webassemblyjs/ast": "^1.12.1", - "@webassemblyjs/wasm-edit": "^1.12.1", - "@webassemblyjs/wasm-parser": "^1.12.1", - "acorn": "^8.7.1", - "acorn-import-attributes": "^1.9.5", - "browserslist": "^4.21.10", + "version": "5.105.0", + "resolved": "https://registry.npmjs.org/webpack/-/webpack-5.105.0.tgz", + "integrity": "sha512-gX/dMkRQc7QOMzgTe6KsYFM7DxeIONQSui1s0n/0xht36HvrgbxtM1xBlgx596NbpHuQU8P7QpKwrZYwUX48nw==", + "dev": true, + "dependencies": { + "@types/eslint-scope": "^3.7.7", + "@types/estree": "^1.0.8", + "@types/json-schema": "^7.0.15", + "@webassemblyjs/ast": "^1.14.1", + "@webassemblyjs/wasm-edit": "^1.14.1", + "@webassemblyjs/wasm-parser": "^1.14.1", + "acorn": "^8.15.0", + "acorn-import-phases": "^1.0.3", + "browserslist": "^4.28.1", "chrome-trace-event": "^1.0.2", - "enhanced-resolve": "^5.17.1", - "es-module-lexer": "^1.2.1", + "enhanced-resolve": "^5.19.0", + "es-module-lexer": "^2.0.0", "eslint-scope": "5.1.1", "events": "^3.2.0", "glob-to-regexp": "^0.4.1", "graceful-fs": "^4.2.11", "json-parse-even-better-errors": "^2.3.1", - "loader-runner": "^4.2.0", + "loader-runner": "^4.3.1", "mime-types": "^2.1.27", "neo-async": "^2.6.2", - "schema-utils": "^3.2.0", - "tapable": "^2.1.1", - "terser-webpack-plugin": "^5.3.10", - "watchpack": "^2.4.1", - "webpack-sources": "^3.2.3" + "schema-utils": "^4.3.3", + "tapable": "^2.3.0", + "terser-webpack-plugin": "^5.3.16", + "watchpack": "^2.5.1", + "webpack-sources": "^3.3.3" }, "bin": { "webpack": "bin/webpack.js" @@ -3915,63 +3845,6 @@ } } }, - "node_modules/webpack-dev-middleware/node_modules/ajv": { - "version": "8.17.1", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz", - "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", - "dev": true, - "license": "MIT", - "dependencies": { - "fast-deep-equal": "^3.1.3", - "fast-uri": "^3.0.1", - "json-schema-traverse": "^1.0.0", - "require-from-string": "^2.0.2" - }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/epoberezkin" - } - }, - "node_modules/webpack-dev-middleware/node_modules/ajv-keywords": { - "version": "5.1.0", - "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-5.1.0.tgz", - "integrity": "sha512-YCS/JNFAUyr5vAuhk1DWm1CBxRHW9LbJ2ozWeemrIqpbsqKjHVxYPyi5GC0rjZIT5JxJ3virVTS8wk4i/Z+krw==", - "dev": true, - "license": "MIT", - "dependencies": { - "fast-deep-equal": "^3.1.3" - }, - "peerDependencies": { - "ajv": "^8.8.2" - } - }, - "node_modules/webpack-dev-middleware/node_modules/json-schema-traverse": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", - "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", - "dev": true, - "license": "MIT" - }, - "node_modules/webpack-dev-middleware/node_modules/schema-utils": { - "version": "4.3.2", - "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.3.2.tgz", - "integrity": "sha512-Gn/JaSk/Mt9gYubxTtSn/QCV4em9mpAPiR1rqy/Ocu19u/G9J5WWdNoUT4SiV6mFC3y6cxyFcFwdzPM3FgxGAQ==", - "dev": true, - "license": "MIT", - "dependencies": { - "@types/json-schema": "^7.0.9", - "ajv": "^8.9.0", - "ajv-formats": "^2.1.1", - "ajv-keywords": "^5.1.0" - }, - "engines": { - "node": ">= 10.13.0" - }, - "funding": { - "type": "opencollective", - "url": "https://opencollective.com/webpack" - } - }, "node_modules/webpack-dev-server": { "version": "5.2.1", "resolved": "https://registry.npmjs.org/webpack-dev-server/-/webpack-dev-server-5.2.1.tgz", @@ -4030,59 +3903,6 @@ } } }, - "node_modules/webpack-dev-server/node_modules/ajv": { - "version": "8.12.0", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.12.0.tgz", - "integrity": "sha512-sRu1kpcO9yLtYxBKvqfTeh9KzZEwO3STyX1HT+4CaDzC6HpTGYhIhPIzj9XuKU7KYDwnaeh5hcOwjy1QuJzBPA==", - "dev": true, - "dependencies": { - "fast-deep-equal": "^3.1.1", - "json-schema-traverse": "^1.0.0", - "require-from-string": "^2.0.2", - "uri-js": "^4.2.2" - }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/epoberezkin" - } - }, - "node_modules/webpack-dev-server/node_modules/ajv-keywords": { - "version": "5.1.0", - "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-5.1.0.tgz", - "integrity": "sha512-YCS/JNFAUyr5vAuhk1DWm1CBxRHW9LbJ2ozWeemrIqpbsqKjHVxYPyi5GC0rjZIT5JxJ3virVTS8wk4i/Z+krw==", - "dev": true, - "dependencies": { - "fast-deep-equal": "^3.1.3" - }, - "peerDependencies": { - "ajv": "^8.8.2" - } - }, - "node_modules/webpack-dev-server/node_modules/json-schema-traverse": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", - "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", - "dev": true - }, - "node_modules/webpack-dev-server/node_modules/schema-utils": { - "version": "4.2.0", - "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.2.0.tgz", - "integrity": "sha512-L0jRsrPpjdckP3oPug3/VxNKt2trR8TcabrM6FOAAlvC/9Phcmm+cuAgTlxBqdBR1WJx7Naj9WHw+aOmheSVbw==", - "dev": true, - "dependencies": { - "@types/json-schema": "^7.0.9", - "ajv": "^8.9.0", - "ajv-formats": "^2.1.1", - "ajv-keywords": "^5.1.0" - }, - "engines": { - "node": ">= 12.13.0" - }, - "funding": { - "type": "opencollective", - "url": "https://opencollective.com/webpack" - } - }, "node_modules/webpack-merge": { "version": "5.9.0", "resolved": "https://registry.npmjs.org/webpack-merge/-/webpack-merge-5.9.0.tgz", @@ -4096,10 +3916,10 @@ "node": ">=10.0.0" } }, - "node_modules/webpack/node_modules/webpack-sources": { - "version": "3.2.3", - "resolved": "https://registry.npmjs.org/webpack-sources/-/webpack-sources-3.2.3.tgz", - "integrity": "sha512-/DyMEOrDgLKKIG0fmvtz+4dUX/3Ghozwgm6iPp8KRhvn+eQf9+Q7GWxVNMk3+uCPWfdXYC4ExGBckIXdFEfH1w==", + "node_modules/webpack-sources": { + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/webpack-sources/-/webpack-sources-3.3.3.tgz", + "integrity": "sha512-yd1RBzSGanHkitROoPFd6qsrxt+oFhg/129YzheDGqeustzX0vTZJZsSsQjVQC4yzBQ56K55XU8gaNCtIzOnTg==", "dev": true, "engines": { "node": ">=10.13.0" @@ -4180,13 +4000,12 @@ "dev": true }, "@jridgewell/gen-mapping": { - "version": "0.3.5", - "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.5.tgz", - "integrity": "sha512-IzL8ZoEDIBRWEzlCcRhOaCupYyN5gdIK+Q6fbFdPDg6HqX6jpkItn7DFIpW9LQzXG6Df9sA7+OKnq0qlz/GaQg==", + "version": "0.3.13", + "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.13.tgz", + "integrity": "sha512-2kkt/7niJ6MgEPxF0bYdQ6etZaA+fQvDcLKckhy1yIQOzaoKjBBjSj63/aLVjYE3qhRt5dvM+uUyfCg6UKCBbA==", "dev": true, "requires": { - "@jridgewell/set-array": "^1.2.1", - "@jridgewell/sourcemap-codec": "^1.4.10", + "@jridgewell/sourcemap-codec": "^1.5.0", "@jridgewell/trace-mapping": "^0.3.24" } }, @@ -4196,16 +4015,10 @@ "integrity": "sha512-bRISgCIjP20/tbWSPWMEi54QVPRZExkuD9lJL+UIxUKtwVJA8wW1Trb1jMs1RFXo1CBTNZ/5hpC9QvmKWdopKw==", "dev": true }, - "@jridgewell/set-array": { - "version": "1.2.1", - "resolved": "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.2.1.tgz", - "integrity": "sha512-R8gLRTZeyp03ymzP/6Lil/28tGeGEzhx1q2k703KGWRAI1VdvPIXdG70VJc2pAMw3NA6JKL5hhFu1sJX0Mnn/A==", - "dev": true - }, "@jridgewell/source-map": { - "version": "0.3.6", - "resolved": "https://registry.npmjs.org/@jridgewell/source-map/-/source-map-0.3.6.tgz", - "integrity": "sha512-1ZJTZebgqllO79ue2bm3rIGud/bOe0pP5BjSRCRxxYkEZS8STV7zN84UBbiYu7jy+eCKSnVIUgoWWE/tt+shMQ==", + "version": "0.3.11", + "resolved": "https://registry.npmjs.org/@jridgewell/source-map/-/source-map-0.3.11.tgz", + "integrity": "sha512-ZMp1V8ZFcPG5dIWnQLr3NSI1MiCU7UETdS/A0G8V/XWHvJv3ZsFqutJn1Y5RPmAPX6F3BiE397OqveU/9NCuIA==", "dev": true, "requires": { "@jridgewell/gen-mapping": "^0.3.5", @@ -4213,15 +4026,15 @@ } }, "@jridgewell/sourcemap-codec": { - "version": "1.5.0", - "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.0.tgz", - "integrity": "sha512-gv3ZRaISU3fjPAgNsriBRqGWQL6quFx04YMPW/zD8XMLsU32mhCCbfbO6KZFLjvYpCZ8zyDEgqsgf+PwPaM7GQ==", + "version": "1.5.5", + "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.5.tgz", + "integrity": "sha512-cYQ9310grqxueWbl+WuIUIaiUaDcj7WOq5fVhEljNVgRfOUhY9fy2zTvfoqWsnebh8Sl70VScFbICvJnLKB0Og==", "dev": true }, "@jridgewell/trace-mapping": { - "version": "0.3.25", - "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.25.tgz", - "integrity": "sha512-vNk6aEwybGtawWmy/PzwnGDOjCkLWSD2wqvjGGAgOAwCGWySYXfYoxt00IJkTF+8Lb57DwOb3Aa0o9CApepiYQ==", + "version": "0.3.31", + "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.31.tgz", + "integrity": "sha512-zzNR+SdQSDJzc8joaeP8QQoCQr8NuYx2dIIytl1QeBEZHJ9uW6hebsrYgbz8hJwUQao3TWCMtmfV8Nu1twOLAw==", "dev": true, "requires": { "@jridgewell/resolve-uri": "^3.1.0", @@ -4304,10 +4117,30 @@ "@types/node": "*" } }, + "@types/eslint": { + "version": "9.6.1", + "resolved": "https://registry.npmjs.org/@types/eslint/-/eslint-9.6.1.tgz", + "integrity": "sha512-FXx2pKgId/WyYo2jXw63kk7/+TY7u7AziEJxJAnSFzHlqTAS3Ync6SvgYAN/k4/PQpnnVuzoMuVnByKK2qp0ag==", + "dev": true, + "requires": { + "@types/estree": "*", + "@types/json-schema": "*" + } + }, + "@types/eslint-scope": { + "version": "3.7.7", + "resolved": "https://registry.npmjs.org/@types/eslint-scope/-/eslint-scope-3.7.7.tgz", + "integrity": "sha512-MzMFlSLBqNF2gcHWO0G1vP/YQyfvrxZ0bF+u7mzUdZ1/xK4A4sru+nraZz5i3iEIk1l1uyicaDVTB4QbbEkAYg==", + "dev": true, + "requires": { + "@types/eslint": "*", + "@types/estree": "*" + } + }, "@types/estree": { - "version": "1.0.5", - "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.5.tgz", - "integrity": "sha512-/kYRxGDLWzHOB7q+wtSUQlFrtcdUccpfy+X+9iMBpHK8QLLhx2wIPYuS5DYtR9Wa/YlZAbIovy7qVdB1Aq6Lyw==", + "version": "1.0.8", + "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.8.tgz", + "integrity": "sha512-dWHzHa2WqEXI/O1E9OjrocMTKJl2mSrEolh1Iomrv6U+JuNwaHXsXx9bLu5gG7BUWFIN0skIQJQ/L1rIex4X6w==", "dev": true }, "@types/express": { @@ -4350,9 +4183,9 @@ } }, "@types/json-schema": { - "version": "7.0.13", - "resolved": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.13.tgz", - "integrity": "sha512-RbSSoHliUbnXj3ny0CNFOoxrIDV6SUGyStHsvDqosw6CkdPV8TtWGlfecuK4ToyMEAql6pzNxgCFKanovUzlgQ==", + "version": "7.0.15", + "resolved": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.15.tgz", + "integrity": "sha512-5+fP8P8MFNC+AyZCDxrB2pkZFPGzqQWUzpSeuuVLvm8VMcorNYavBqoFcxK8bQz4Qsbn4oUEEem4wDLfcysGHA==", "dev": true }, "@types/mime": { @@ -4443,148 +4276,148 @@ } }, "@webassemblyjs/ast": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/ast/-/ast-1.12.1.tgz", - "integrity": "sha512-EKfMUOPRRUTy5UII4qJDGPpqfwjOmZ5jeGFwid9mnoqIFK+e0vqoi1qH56JpmZSzEL53jKnNzScdmftJyG5xWg==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/ast/-/ast-1.14.1.tgz", + "integrity": "sha512-nuBEDgQfm1ccRp/8bCQrx1frohyufl4JlbMMZ4P1wpeOfDhF6FQkxZJ1b/e+PLwr6X1Nhw6OLme5usuBWYBvuQ==", "dev": true, "requires": { - "@webassemblyjs/helper-numbers": "1.11.6", - "@webassemblyjs/helper-wasm-bytecode": "1.11.6" + "@webassemblyjs/helper-numbers": "1.13.2", + "@webassemblyjs/helper-wasm-bytecode": "1.13.2" } }, "@webassemblyjs/floating-point-hex-parser": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/floating-point-hex-parser/-/floating-point-hex-parser-1.11.6.tgz", - "integrity": "sha512-ejAj9hfRJ2XMsNHk/v6Fu2dGS+i4UaXBXGemOfQ/JfQ6mdQg/WXtwleQRLLS4OvfDhv8rYnVwH27YJLMyYsxhw==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/floating-point-hex-parser/-/floating-point-hex-parser-1.13.2.tgz", + "integrity": "sha512-6oXyTOzbKxGH4steLbLNOu71Oj+C8Lg34n6CqRvqfS2O71BxY6ByfMDRhBytzknj9yGUPVJ1qIKhRlAwO1AovA==", "dev": true }, "@webassemblyjs/helper-api-error": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-api-error/-/helper-api-error-1.11.6.tgz", - "integrity": "sha512-o0YkoP4pVu4rN8aTJgAyj9hC2Sv5UlkzCHhxqWj8butaLvnpdc2jOwh4ewE6CX0txSfLn/UYaV/pheS2Txg//Q==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-api-error/-/helper-api-error-1.13.2.tgz", + "integrity": "sha512-U56GMYxy4ZQCbDZd6JuvvNV/WFildOjsaWD3Tzzvmw/mas3cXzRJPMjP83JqEsgSbyrmaGjBfDtV7KDXV9UzFQ==", "dev": true }, "@webassemblyjs/helper-buffer": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-buffer/-/helper-buffer-1.12.1.tgz", - "integrity": "sha512-nzJwQw99DNDKr9BVCOZcLuJJUlqkJh+kVzVl6Fmq/tI5ZtEyWT1KZMyOXltXLZJmDtvLCDgwsyrkohEtopTXCw==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-buffer/-/helper-buffer-1.14.1.tgz", + "integrity": "sha512-jyH7wtcHiKssDtFPRB+iQdxlDf96m0E39yb0k5uJVhFGleZFoNw1c4aeIcVUPPbXUVJ94wwnMOAqUHyzoEPVMA==", "dev": true }, "@webassemblyjs/helper-numbers": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-numbers/-/helper-numbers-1.11.6.tgz", - "integrity": "sha512-vUIhZ8LZoIWHBohiEObxVm6hwP034jwmc9kuq5GdHZH0wiLVLIPcMCdpJzG4C11cHoQ25TFIQj9kaVADVX7N3g==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-numbers/-/helper-numbers-1.13.2.tgz", + "integrity": "sha512-FE8aCmS5Q6eQYcV3gI35O4J789wlQA+7JrqTTpJqn5emA4U2hvwJmvFRC0HODS+3Ye6WioDklgd6scJ3+PLnEA==", "dev": true, "requires": { - "@webassemblyjs/floating-point-hex-parser": "1.11.6", - "@webassemblyjs/helper-api-error": "1.11.6", + "@webassemblyjs/floating-point-hex-parser": "1.13.2", + "@webassemblyjs/helper-api-error": "1.13.2", "@xtuc/long": "4.2.2" } }, "@webassemblyjs/helper-wasm-bytecode": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-bytecode/-/helper-wasm-bytecode-1.11.6.tgz", - "integrity": "sha512-sFFHKwcmBprO9e7Icf0+gddyWYDViL8bpPjJJl0WHxCdETktXdmtWLGVzoHbqUcY4Be1LkNfwTmXOJUFZYSJdA==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-bytecode/-/helper-wasm-bytecode-1.13.2.tgz", + "integrity": "sha512-3QbLKy93F0EAIXLh0ogEVR6rOubA9AoZ+WRYhNbFyuB70j3dRdwH9g+qXhLAO0kiYGlg3TxDV+I4rQTr/YNXkA==", "dev": true }, "@webassemblyjs/helper-wasm-section": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-section/-/helper-wasm-section-1.12.1.tgz", - "integrity": "sha512-Jif4vfB6FJlUlSbgEMHUyk1j234GTNG9dBJ4XJdOySoj518Xj0oGsNi59cUQF4RRMS9ouBUxDDdyBVfPTypa5g==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-section/-/helper-wasm-section-1.14.1.tgz", + "integrity": "sha512-ds5mXEqTJ6oxRoqjhWDU83OgzAYjwsCV8Lo/N+oRsNDmx/ZDpqalmrtgOMkHwxsG0iI//3BwWAErYRHtgn0dZw==", "dev": true, "requires": { - "@webassemblyjs/ast": "1.12.1", - "@webassemblyjs/helper-buffer": "1.12.1", - "@webassemblyjs/helper-wasm-bytecode": "1.11.6", - "@webassemblyjs/wasm-gen": "1.12.1" + "@webassemblyjs/ast": "1.14.1", + "@webassemblyjs/helper-buffer": "1.14.1", + "@webassemblyjs/helper-wasm-bytecode": "1.13.2", + "@webassemblyjs/wasm-gen": "1.14.1" } }, "@webassemblyjs/ieee754": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/ieee754/-/ieee754-1.11.6.tgz", - "integrity": "sha512-LM4p2csPNvbij6U1f19v6WR56QZ8JcHg3QIJTlSwzFcmx6WSORicYj6I63f9yU1kEUtrpG+kjkiIAkevHpDXrg==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/ieee754/-/ieee754-1.13.2.tgz", + "integrity": "sha512-4LtOzh58S/5lX4ITKxnAK2USuNEvpdVV9AlgGQb8rJDHaLeHciwG4zlGr0j/SNWlr7x3vO1lDEsuePvtcDNCkw==", "dev": true, "requires": { "@xtuc/ieee754": "^1.2.0" } }, "@webassemblyjs/leb128": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/leb128/-/leb128-1.11.6.tgz", - "integrity": "sha512-m7a0FhE67DQXgouf1tbN5XQcdWoNgaAuoULHIfGFIEVKA6tu/edls6XnIlkmS6FrXAquJRPni3ZZKjw6FSPjPQ==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/leb128/-/leb128-1.13.2.tgz", + "integrity": "sha512-Lde1oNoIdzVzdkNEAWZ1dZ5orIbff80YPdHx20mrHwHrVNNTjNr8E3xz9BdpcGqRQbAEa+fkrCb+fRFTl/6sQw==", "dev": true, "requires": { "@xtuc/long": "4.2.2" } }, "@webassemblyjs/utf8": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/utf8/-/utf8-1.11.6.tgz", - "integrity": "sha512-vtXf2wTQ3+up9Zsg8sa2yWiQpzSsMyXj0qViVP6xKGCUT8p8YJ6HqI7l5eCnWx1T/FYdsv07HQs2wTFbbof/RA==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/utf8/-/utf8-1.13.2.tgz", + "integrity": "sha512-3NQWGjKTASY1xV5m7Hr0iPeXD9+RDobLll3T9d2AO+g3my8xy5peVyjSag4I50mR1bBSN/Ct12lo+R9tJk0NZQ==", "dev": true }, "@webassemblyjs/wasm-edit": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-edit/-/wasm-edit-1.12.1.tgz", - "integrity": "sha512-1DuwbVvADvS5mGnXbE+c9NfA8QRcZ6iKquqjjmR10k6o+zzsRVesil54DKexiowcFCPdr/Q0qaMgB01+SQ1u6g==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-edit/-/wasm-edit-1.14.1.tgz", + "integrity": "sha512-RNJUIQH/J8iA/1NzlE4N7KtyZNHi3w7at7hDjvRNm5rcUXa00z1vRz3glZoULfJ5mpvYhLybmVcwcjGrC1pRrQ==", "dev": true, "requires": { - "@webassemblyjs/ast": "1.12.1", - "@webassemblyjs/helper-buffer": "1.12.1", - "@webassemblyjs/helper-wasm-bytecode": "1.11.6", - "@webassemblyjs/helper-wasm-section": "1.12.1", - "@webassemblyjs/wasm-gen": "1.12.1", - "@webassemblyjs/wasm-opt": "1.12.1", - "@webassemblyjs/wasm-parser": "1.12.1", - "@webassemblyjs/wast-printer": "1.12.1" + "@webassemblyjs/ast": "1.14.1", + "@webassemblyjs/helper-buffer": "1.14.1", + "@webassemblyjs/helper-wasm-bytecode": "1.13.2", + "@webassemblyjs/helper-wasm-section": "1.14.1", + "@webassemblyjs/wasm-gen": "1.14.1", + "@webassemblyjs/wasm-opt": "1.14.1", + "@webassemblyjs/wasm-parser": "1.14.1", + "@webassemblyjs/wast-printer": "1.14.1" } }, "@webassemblyjs/wasm-gen": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-gen/-/wasm-gen-1.12.1.tgz", - "integrity": "sha512-TDq4Ojh9fcohAw6OIMXqiIcTq5KUXTGRkVxbSo1hQnSy6lAM5GSdfwWeSxpAo0YzgsgF182E/U0mDNhuA0tW7w==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-gen/-/wasm-gen-1.14.1.tgz", + "integrity": "sha512-AmomSIjP8ZbfGQhumkNvgC33AY7qtMCXnN6bL2u2Js4gVCg8fp735aEiMSBbDR7UQIj90n4wKAFUSEd0QN2Ukg==", "dev": true, "requires": { - "@webassemblyjs/ast": "1.12.1", - "@webassemblyjs/helper-wasm-bytecode": "1.11.6", - "@webassemblyjs/ieee754": "1.11.6", - "@webassemblyjs/leb128": "1.11.6", - "@webassemblyjs/utf8": "1.11.6" + "@webassemblyjs/ast": "1.14.1", + "@webassemblyjs/helper-wasm-bytecode": "1.13.2", + "@webassemblyjs/ieee754": "1.13.2", + "@webassemblyjs/leb128": "1.13.2", + "@webassemblyjs/utf8": "1.13.2" } }, "@webassemblyjs/wasm-opt": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-opt/-/wasm-opt-1.12.1.tgz", - "integrity": "sha512-Jg99j/2gG2iaz3hijw857AVYekZe2SAskcqlWIZXjji5WStnOpVoat3gQfT/Q5tb2djnCjBtMocY/Su1GfxPBg==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-opt/-/wasm-opt-1.14.1.tgz", + "integrity": "sha512-PTcKLUNvBqnY2U6E5bdOQcSM+oVP/PmrDY9NzowJjislEjwP/C4an2303MCVS2Mg9d3AJpIGdUFIQQWbPds0Sw==", "dev": true, "requires": { - "@webassemblyjs/ast": "1.12.1", - "@webassemblyjs/helper-buffer": "1.12.1", - "@webassemblyjs/wasm-gen": "1.12.1", - "@webassemblyjs/wasm-parser": "1.12.1" + "@webassemblyjs/ast": "1.14.1", + "@webassemblyjs/helper-buffer": "1.14.1", + "@webassemblyjs/wasm-gen": "1.14.1", + "@webassemblyjs/wasm-parser": "1.14.1" } }, "@webassemblyjs/wasm-parser": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-parser/-/wasm-parser-1.12.1.tgz", - "integrity": "sha512-xikIi7c2FHXysxXe3COrVUPSheuBtpcfhbpFj4gmu7KRLYOzANztwUU0IbsqvMqzuNK2+glRGWCEqZo1WCLyAQ==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-parser/-/wasm-parser-1.14.1.tgz", + "integrity": "sha512-JLBl+KZ0R5qB7mCnud/yyX08jWFw5MsoalJ1pQ4EdFlgj9VdXKGuENGsiCIjegI1W7p91rUlcB/LB5yRJKNTcQ==", "dev": true, "requires": { - "@webassemblyjs/ast": "1.12.1", - "@webassemblyjs/helper-api-error": "1.11.6", - "@webassemblyjs/helper-wasm-bytecode": "1.11.6", - "@webassemblyjs/ieee754": "1.11.6", - "@webassemblyjs/leb128": "1.11.6", - "@webassemblyjs/utf8": "1.11.6" + "@webassemblyjs/ast": "1.14.1", + "@webassemblyjs/helper-api-error": "1.13.2", + "@webassemblyjs/helper-wasm-bytecode": "1.13.2", + "@webassemblyjs/ieee754": "1.13.2", + "@webassemblyjs/leb128": "1.13.2", + "@webassemblyjs/utf8": "1.13.2" } }, "@webassemblyjs/wast-printer": { - "version": "1.12.1", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wast-printer/-/wast-printer-1.12.1.tgz", - "integrity": "sha512-+X4WAlOisVWQMikjbcvY2e0rwPsKQ9F688lksZhBcPycBBuii3O7m8FACbDMWDojpAqvjIncrG8J0XHKyQfVeA==", + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wast-printer/-/wast-printer-1.14.1.tgz", + "integrity": "sha512-kPSSXE6De1XOR820C90RIo2ogvZG+c3KiHzqUoO/F34Y2shGzesfqv7o57xrxovZJH/MetF5UjroJ/R/3isoiw==", "dev": true, "requires": { - "@webassemblyjs/ast": "1.12.1", + "@webassemblyjs/ast": "1.14.1", "@xtuc/long": "4.2.2" } }, @@ -4632,28 +4465,28 @@ } }, "acorn": { - "version": "8.12.1", - "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.12.1.tgz", - "integrity": "sha512-tcpGyI9zbizT9JbV6oYE477V6mTlXvvi0T0G3SNIYE2apm/G5huBa1+K89VGeovbg+jycCrfhl3ADxErOuO6Jg==", + "version": "8.15.0", + "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz", + "integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==", "dev": true }, - "acorn-import-attributes": { - "version": "1.9.5", - "resolved": "https://registry.npmjs.org/acorn-import-attributes/-/acorn-import-attributes-1.9.5.tgz", - "integrity": "sha512-n02Vykv5uA3eHGM/Z2dQrcD56kL8TyDb2p1+0P83PClMnC/nc+anbQRhIOWnSq4Ke/KvDPrY3C9hDtC/A3eHnQ==", + "acorn-import-phases": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/acorn-import-phases/-/acorn-import-phases-1.0.4.tgz", + "integrity": "sha512-wKmbr/DDiIXzEOiWrTTUcDm24kQ2vGfZQvM2fwg2vXqR5uW6aapr7ObPtj1th32b9u90/Pf4AItvdTh42fBmVQ==", "dev": true, "requires": {} }, "ajv": { - "version": "6.12.6", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", - "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", + "version": "8.17.1", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz", + "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", "dev": true, "requires": { - "fast-deep-equal": "^3.1.1", - "fast-json-stable-stringify": "^2.0.0", - "json-schema-traverse": "^0.4.1", - "uri-js": "^4.2.2" + "fast-deep-equal": "^3.1.3", + "fast-uri": "^3.0.1", + "json-schema-traverse": "^1.0.0", + "require-from-string": "^2.0.2" } }, "ajv-formats": { @@ -4663,34 +4496,16 @@ "dev": true, "requires": { "ajv": "^8.0.0" - }, - "dependencies": { - "ajv": { - "version": "8.12.0", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.12.0.tgz", - "integrity": "sha512-sRu1kpcO9yLtYxBKvqfTeh9KzZEwO3STyX1HT+4CaDzC6HpTGYhIhPIzj9XuKU7KYDwnaeh5hcOwjy1QuJzBPA==", - "dev": true, - "requires": { - "fast-deep-equal": "^3.1.1", - "json-schema-traverse": "^1.0.0", - "require-from-string": "^2.0.2", - "uri-js": "^4.2.2" - } - }, - "json-schema-traverse": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", - "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", - "dev": true - } } }, "ajv-keywords": { - "version": "3.5.2", - "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-3.5.2.tgz", - "integrity": "sha512-5p6WTN0DdTGVQk6VjcEju19IgaHudalcfabD7yhDGeA6bcQnmL+CpveLJq/3hvfwd1aof6L386Ougkx6RfyMIQ==", + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-5.1.0.tgz", + "integrity": "sha512-YCS/JNFAUyr5vAuhk1DWm1CBxRHW9LbJ2ozWeemrIqpbsqKjHVxYPyi5GC0rjZIT5JxJ3virVTS8wk4i/Z+krw==", "dev": true, - "requires": {} + "requires": { + "fast-deep-equal": "^3.1.3" + } }, "ansi-html-community": { "version": "0.0.8", @@ -4714,6 +4529,12 @@ "integrity": "sha512-PCVAQswWemu6UdxsDFFX/+gVeYqKAod3D3UVm91jHwynguOwAvYPhx8nNlM++NqRcK6CxxpUafjmhIdKiHibqg==", "dev": true }, + "baseline-browser-mapping": { + "version": "2.9.19", + "resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.9.19.tgz", + "integrity": "sha512-ipDqC8FrAl/76p2SSWKSI+H9tFwm7vYqXQrItCuiVPt26Km0jS+NzSsBWAaBusvSbQcfJG+JitdMm+wZAgTYqg==", + "dev": true + }, "batch": { "version": "0.6.1", "resolved": "https://registry.npmjs.org/batch/-/batch-0.6.1.tgz", @@ -4783,15 +4604,16 @@ } }, "browserslist": { - "version": "4.21.11", - "resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.21.11.tgz", - "integrity": "sha512-xn1UXOKUz7DjdGlg9RrUr0GGiWzI97UQJnugHtH0OLDfJB7jMgoIkYvRIEO1l9EeEERVqeqLYOcFBW9ldjypbQ==", + "version": "4.28.1", + "resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.28.1.tgz", + "integrity": "sha512-ZC5Bd0LgJXgwGqUknZY/vkUQ04r8NXnJZ3yYi4vDmSiZmC/pdSN0NbNRPxZpbtO4uAfDUAFffO8IZoM3Gj8IkA==", "dev": true, "requires": { - "caniuse-lite": "^1.0.30001538", - "electron-to-chromium": "^1.4.526", - "node-releases": "^2.0.13", - "update-browserslist-db": "^1.0.13" + "baseline-browser-mapping": "^2.9.0", + "caniuse-lite": "^1.0.30001759", + "electron-to-chromium": "^1.5.263", + "node-releases": "^2.0.27", + "update-browserslist-db": "^1.2.0" } }, "buffer-from": { @@ -4836,9 +4658,9 @@ } }, "caniuse-lite": { - "version": "1.0.30001538", - "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001538.tgz", - "integrity": "sha512-HWJnhnID+0YMtGlzcp3T9drmBJUVDchPJ08tpUGFLs9CYlwWPH2uLgpHn8fND5pCgXVtnGS3H4QR9XLMHVNkHw==", + "version": "1.0.30001768", + "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001768.tgz", + "integrity": "sha512-qY3aDRZC5nWPgHUgIB84WL+nySuo19wk0VJpp/XI9T34lrvkyhRvNVOFJOp2kxClQhiFBu+TaUSudf6oa3vkSA==", "dev": true }, "chokidar": { @@ -4991,27 +4813,6 @@ "serialize-javascript": "^6.0.2" }, "dependencies": { - "ajv": { - "version": "8.17.1", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz", - "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", - "dev": true, - "requires": { - "fast-deep-equal": "^3.1.3", - "fast-uri": "^3.0.1", - "json-schema-traverse": "^1.0.0", - "require-from-string": "^2.0.2" - } - }, - "ajv-keywords": { - "version": "5.1.0", - "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-5.1.0.tgz", - "integrity": "sha512-YCS/JNFAUyr5vAuhk1DWm1CBxRHW9LbJ2ozWeemrIqpbsqKjHVxYPyi5GC0rjZIT5JxJ3virVTS8wk4i/Z+krw==", - "dev": true, - "requires": { - "fast-deep-equal": "^3.1.3" - } - }, "glob-parent": { "version": "6.0.2", "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", @@ -5020,24 +4821,6 @@ "requires": { "is-glob": "^4.0.3" } - }, - "json-schema-traverse": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", - "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", - "dev": true - }, - "schema-utils": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.3.0.tgz", - "integrity": "sha512-Gf9qqc58SpCA/xdziiHz35F4GNIWYWZrEshUc/G/r5BnLph6xpKuLeoJoQuj5WfBIx/eQLf+hmVPYHaxJu7V2g==", - "dev": true, - "requires": { - "@types/json-schema": "^7.0.9", - "ajv": "^8.9.0", - "ajv-formats": "^2.1.1", - "ajv-keywords": "^5.1.0" - } } } }, @@ -5145,9 +4928,9 @@ "dev": true }, "electron-to-chromium": { - "version": "1.4.528", - "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.4.528.tgz", - "integrity": "sha512-UdREXMXzLkREF4jA8t89FQjA8WHI6ssP38PMY4/4KhXFQbtImnghh4GkCgrtiZwLKUKVD2iTVXvDVQjfomEQuA==", + "version": "1.5.286", + "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.5.286.tgz", + "integrity": "sha512-9tfDXhJ4RKFNerfjdCcZfufu49vg620741MNs26a9+bhLThdB+plgMeou98CAaHu/WATj2iHOOHTp1hWtABj2A==", "dev": true }, "encodeurl": { @@ -5157,13 +4940,13 @@ "dev": true }, "enhanced-resolve": { - "version": "5.17.1", - "resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.17.1.tgz", - "integrity": "sha512-LMHl3dXhTcfv8gM4kEzIUeTQ+7fpdA0l2tUf34BddXPkz2A5xJ5L/Pchd5BL6rdccM9QGvu0sWZzK1Z1t4wwyg==", + "version": "5.19.0", + "resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.19.0.tgz", + "integrity": "sha512-phv3E1Xl4tQOShqSte26C7Fl84EwUdZsyOuSSk9qtAGyyQs2s3jJzComh+Abf4g187lUUAvH+H26omrqia2aGg==", "dev": true, "requires": { "graceful-fs": "^4.2.4", - "tapable": "^2.2.0" + "tapable": "^2.3.0" } }, "envinfo": { @@ -5185,9 +4968,9 @@ "dev": true }, "es-module-lexer": { - "version": "1.3.1", - "resolved": "https://registry.npmjs.org/es-module-lexer/-/es-module-lexer-1.3.1.tgz", - "integrity": "sha512-JUFAyicQV9mXc3YRxPnDlrfBKpqt6hUYzz9/boprUJHs4e4KVr3XwOF70doO6gwXUor6EWZJAyWAfKki84t20Q==", + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/es-module-lexer/-/es-module-lexer-2.0.0.tgz", + "integrity": "sha512-5POEcUuZybH7IdmGsD8wlf0AI55wMecM9rVBTI/qEAy2c1kTOm3DjFYjrBdI2K3BaJjJYfYFeRtM0t9ssnRuxw==", "dev": true }, "es-object-atoms": { @@ -5200,9 +4983,9 @@ } }, "escalade": { - "version": "3.1.1", - "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.1.1.tgz", - "integrity": "sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw==", + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.2.0.tgz", + "integrity": "sha512-WUj2qlxaQtO4g6Pq5c29GTcWGDyd8itL8zTlipgECz3JesAiiOKotd8JU6otB3PACgG6xkJUyVhboMS+bje/jA==", "dev": true }, "escape-html": { @@ -5358,16 +5141,10 @@ "micromatch": "^4.0.8" } }, - "fast-json-stable-stringify": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/fast-json-stable-stringify/-/fast-json-stable-stringify-2.1.0.tgz", - "integrity": "sha512-lhd/wF+Lk98HZoTCtlVraHtfh5XYijIjalXck7saUtuanSDyLMxnHhSXEDJqHxD7msR8D0uCmqlkwjCV8xvwHw==", - "dev": true - }, "fast-uri": { - "version": "3.0.6", - "resolved": "https://registry.npmjs.org/fast-uri/-/fast-uri-3.0.6.tgz", - "integrity": "sha512-Atfo14OibSv5wAp4VWNsFYE1AchQRTv9cBGWET4pZWHzYshFSS9NQI6I57rdKn9croWVMbYFbLhJ+yJvmZIIHw==", + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/fast-uri/-/fast-uri-3.1.0.tgz", + "integrity": "sha512-iPeeDKJSWf4IEOasVVrknXpaBV0IApz/gp7S2bb7Z4Lljbl2MGJRqInZiUrQwV16cpzw/D3S5j5Julj/gT52AA==", "dev": true }, "fastest-levenshtein": { @@ -5831,9 +5608,9 @@ "dev": true }, "json-schema-traverse": { - "version": "0.4.1", - "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz", - "integrity": "sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==", + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", + "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", "dev": true }, "kind-of": { @@ -5853,9 +5630,9 @@ } }, "loader-runner": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/loader-runner/-/loader-runner-4.3.0.tgz", - "integrity": "sha512-3R/1M+yS3j5ou80Me59j7F9IMs4PXs3VqRrm0TU3AbKPxlmpoY1TNscJV/oGJXo8qCatFGTfDbY6W6ipGOYXfg==", + "version": "4.3.1", + "resolved": "https://registry.npmjs.org/loader-runner/-/loader-runner-4.3.1.tgz", + "integrity": "sha512-IWqP2SCPhyVFTBtRcgMHdzlf9ul25NwaFx4wCEH/KjAXuuHY4yNjvPXsBokp8jCB936PyWRaPKUNh8NvylLp2Q==", "dev": true }, "locate-path": { @@ -6035,9 +5812,9 @@ "dev": true }, "node-releases": { - "version": "2.0.13", - "resolved": "https://registry.npmjs.org/node-releases/-/node-releases-2.0.13.tgz", - "integrity": "sha512-uYr7J37ae/ORWdZeQ1xxMJe3NtdmqMC/JZK+geofDrkLUApKRHPd18/TxtBOJ4A0/+uUIliorNrfYV6s1b02eQ==", + "version": "2.0.27", + "resolved": "https://registry.npmjs.org/node-releases/-/node-releases-2.0.27.tgz", + "integrity": "sha512-nmh3lCkYZ3grZvqcCH+fjmQ7X+H0OeZgP40OierEaAptX4XofMh5kwNbWh7lBduUzCcV/8kZ+NDLCwm2iorIlA==", "dev": true }, "normalize-path": { @@ -6159,9 +5936,9 @@ "dev": true }, "picocolors": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.0.0.tgz", - "integrity": "sha512-1fygroTLlHu66zi26VoTDv8yRgm0Fccecssto+MhsZ0D/DGW2sm8E8AjW7NU5VVTRt5GxbeZ5qBuJr+HyLYkjQ==", + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz", + "integrity": "sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==", "dev": true }, "picomatch": { @@ -6203,12 +5980,6 @@ } } }, - "punycode": { - "version": "2.3.0", - "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.0.tgz", - "integrity": "sha512-rRV+zQD8tVFys26lAGR9WUuS4iUAngJScM+ZRSKtvl5tKeZ2t5bvdNFdNHBW9FWR4guGHlgmsZ1G7BSm2wTbuA==", - "dev": true - }, "qs": { "version": "6.13.0", "resolved": "https://registry.npmjs.org/qs/-/qs-6.13.0.tgz", @@ -6362,14 +6133,15 @@ "dev": true }, "schema-utils": { - "version": "3.3.0", - "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-3.3.0.tgz", - "integrity": "sha512-pN/yOAvcC+5rQ5nERGuwrjLlYvLTbCibnZ1I7B1LaiAz9BRBlE9GMgE/eqV30P7aJQUf7Ddimy/RsbYO/GrVGg==", + "version": "4.3.3", + "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.3.3.tgz", + "integrity": "sha512-eflK8wEtyOE6+hsaRVPxvUKYCpRgzLqDTb8krvAsRIwOGlHoSgYLgBXoubGgLd2fT41/OUYdb48v4k4WWHQurA==", "dev": true, "requires": { - "@types/json-schema": "^7.0.8", - "ajv": "^6.12.5", - "ajv-keywords": "^3.5.2" + "@types/json-schema": "^7.0.9", + "ajv": "^8.9.0", + "ajv-formats": "^2.1.1", + "ajv-keywords": "^5.1.0" } }, "select-hose": { @@ -6705,34 +6477,34 @@ "dev": true }, "tapable": { - "version": "2.2.1", - "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.2.1.tgz", - "integrity": "sha512-GNzQvQTOIP6RyTfE2Qxb8ZVlNmw0n88vp1szwWRimP02mnTsx3Wtn5qRdqY9w2XduFNUgvOwhNnQsjwCp+kqaQ==", + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.3.0.tgz", + "integrity": "sha512-g9ljZiwki/LfxmQADO3dEY1CbpmXT5Hm2fJ+QaGKwSXUylMybePR7/67YW7jOrrvjEgL1Fmz5kzyAjWVWLlucg==", "dev": true }, "terser": { - "version": "5.31.6", - "resolved": "https://registry.npmjs.org/terser/-/terser-5.31.6.tgz", - "integrity": "sha512-PQ4DAriWzKj+qgehQ7LK5bQqCFNMmlhjR2PFFLuqGCpuCAauxemVBWwWOxo3UIwWQx8+Pr61Df++r76wDmkQBg==", + "version": "5.46.0", + "resolved": "https://registry.npmjs.org/terser/-/terser-5.46.0.tgz", + "integrity": "sha512-jTwoImyr/QbOWFFso3YoU3ik0jBBDJ6JTOQiy/J2YxVJdZCc+5u7skhNwiOR3FQIygFqVUPHl7qbbxtjW2K3Qg==", "dev": true, "requires": { "@jridgewell/source-map": "^0.3.3", - "acorn": "^8.8.2", + "acorn": "^8.15.0", "commander": "^2.20.0", "source-map-support": "~0.5.20" } }, "terser-webpack-plugin": { - "version": "5.3.10", - "resolved": "https://registry.npmjs.org/terser-webpack-plugin/-/terser-webpack-plugin-5.3.10.tgz", - "integrity": "sha512-BKFPWlPDndPs+NGGCr1U59t0XScL5317Y0UReNrHaw9/FwhPENlq6bfgs+4yPfyP51vqC1bQ4rp1EfXW5ZSH9w==", + "version": "5.3.16", + "resolved": "https://registry.npmjs.org/terser-webpack-plugin/-/terser-webpack-plugin-5.3.16.tgz", + "integrity": "sha512-h9oBFCWrq78NyWWVcSwZarJkZ01c2AyGrzs1crmHZO3QUg9D61Wu4NPjBy69n7JqylFF5y+CsUZYmYEIZ3mR+Q==", "dev": true, "requires": { - "@jridgewell/trace-mapping": "^0.3.20", + "@jridgewell/trace-mapping": "^0.3.25", "jest-worker": "^27.4.5", - "schema-utils": "^3.1.1", - "serialize-javascript": "^6.0.1", - "terser": "^5.26.0" + "schema-utils": "^4.3.0", + "serialize-javascript": "^6.0.2", + "terser": "^5.31.1" } }, "thunky": { @@ -6785,22 +6557,13 @@ "dev": true }, "update-browserslist-db": { - "version": "1.0.13", - "resolved": "https://registry.npmjs.org/update-browserslist-db/-/update-browserslist-db-1.0.13.tgz", - "integrity": "sha512-xebP81SNcPuNpPP3uzeW1NYXxI3rxyJzF3pD6sH4jE7o/IX+WtSpwnVU+qIsDPyk0d3hmFQ7mjqc6AtV604hbg==", - "dev": true, - "requires": { - "escalade": "^3.1.1", - "picocolors": "^1.0.0" - } - }, - "uri-js": { - "version": "4.4.1", - "resolved": "https://registry.npmjs.org/uri-js/-/uri-js-4.4.1.tgz", - "integrity": "sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg==", + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/update-browserslist-db/-/update-browserslist-db-1.2.3.tgz", + "integrity": "sha512-Js0m9cx+qOgDxo0eMiFGEueWztz+d4+M3rGlmKPT+T4IS/jP4ylw3Nwpu6cpTTP8R1MAC1kF4VbdLt3ARf209w==", "dev": true, "requires": { - "punycode": "^2.1.0" + "escalade": "^3.2.0", + "picocolors": "^1.1.1" } }, "util-deprecate": { @@ -6828,9 +6591,9 @@ "dev": true }, "watchpack": { - "version": "2.4.2", - "resolved": "https://registry.npmjs.org/watchpack/-/watchpack-2.4.2.tgz", - "integrity": "sha512-TnbFSbcOCcDgjZ4piURLCbJ3nJhznVh9kw6F6iokjiFPl8ONxe9A6nMDVXDiNbrSfLILs6vB07F7wLBrwPYzJw==", + "version": "2.5.1", + "resolved": "https://registry.npmjs.org/watchpack/-/watchpack-2.5.1.tgz", + "integrity": "sha512-Zn5uXdcFNIA1+1Ei5McRd+iRzfhENPCe7LeABkJtNulSxjma+l7ltNx55BWZkRlwRnpOgHqxnjyaDgJnNXnqzg==", "dev": true, "requires": { "glob-to-regexp": "^0.4.1", @@ -6847,42 +6610,36 @@ } }, "webpack": { - "version": "5.94.0", - "resolved": "https://registry.npmjs.org/webpack/-/webpack-5.94.0.tgz", - "integrity": "sha512-KcsGn50VT+06JH/iunZJedYGUJS5FGjow8wb9c0v5n1Om8O1g4L6LjtfxwlXIATopoQu+vOXXa7gYisWxCoPyg==", - "dev": true, - "requires": { - "@types/estree": "^1.0.5", - "@webassemblyjs/ast": "^1.12.1", - "@webassemblyjs/wasm-edit": "^1.12.1", - "@webassemblyjs/wasm-parser": "^1.12.1", - "acorn": "^8.7.1", - "acorn-import-attributes": "^1.9.5", - "browserslist": "^4.21.10", + "version": "5.105.0", + "resolved": "https://registry.npmjs.org/webpack/-/webpack-5.105.0.tgz", + "integrity": "sha512-gX/dMkRQc7QOMzgTe6KsYFM7DxeIONQSui1s0n/0xht36HvrgbxtM1xBlgx596NbpHuQU8P7QpKwrZYwUX48nw==", + "dev": true, + "requires": { + "@types/eslint-scope": "^3.7.7", + "@types/estree": "^1.0.8", + "@types/json-schema": "^7.0.15", + "@webassemblyjs/ast": "^1.14.1", + "@webassemblyjs/wasm-edit": "^1.14.1", + "@webassemblyjs/wasm-parser": "^1.14.1", + "acorn": "^8.15.0", + "acorn-import-phases": "^1.0.3", + "browserslist": "^4.28.1", "chrome-trace-event": "^1.0.2", - "enhanced-resolve": "^5.17.1", - "es-module-lexer": "^1.2.1", + "enhanced-resolve": "^5.19.0", + "es-module-lexer": "^2.0.0", "eslint-scope": "5.1.1", "events": "^3.2.0", "glob-to-regexp": "^0.4.1", "graceful-fs": "^4.2.11", "json-parse-even-better-errors": "^2.3.1", - "loader-runner": "^4.2.0", + "loader-runner": "^4.3.1", "mime-types": "^2.1.27", "neo-async": "^2.6.2", - "schema-utils": "^3.2.0", - "tapable": "^2.1.1", - "terser-webpack-plugin": "^5.3.10", - "watchpack": "^2.4.1", - "webpack-sources": "^3.2.3" - }, - "dependencies": { - "webpack-sources": { - "version": "3.2.3", - "resolved": "https://registry.npmjs.org/webpack-sources/-/webpack-sources-3.2.3.tgz", - "integrity": "sha512-/DyMEOrDgLKKIG0fmvtz+4dUX/3Ghozwgm6iPp8KRhvn+eQf9+Q7GWxVNMk3+uCPWfdXYC4ExGBckIXdFEfH1w==", - "dev": true - } + "schema-utils": "^4.3.3", + "tapable": "^2.3.0", + "terser-webpack-plugin": "^5.3.16", + "watchpack": "^2.5.1", + "webpack-sources": "^3.3.3" } }, "webpack-cli": { @@ -6926,47 +6683,6 @@ "on-finished": "^2.4.1", "range-parser": "^1.2.1", "schema-utils": "^4.0.0" - }, - "dependencies": { - "ajv": { - "version": "8.17.1", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz", - "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", - "dev": true, - "requires": { - "fast-deep-equal": "^3.1.3", - "fast-uri": "^3.0.1", - "json-schema-traverse": "^1.0.0", - "require-from-string": "^2.0.2" - } - }, - "ajv-keywords": { - "version": "5.1.0", - "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-5.1.0.tgz", - "integrity": "sha512-YCS/JNFAUyr5vAuhk1DWm1CBxRHW9LbJ2ozWeemrIqpbsqKjHVxYPyi5GC0rjZIT5JxJ3virVTS8wk4i/Z+krw==", - "dev": true, - "requires": { - "fast-deep-equal": "^3.1.3" - } - }, - "json-schema-traverse": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", - "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", - "dev": true - }, - "schema-utils": { - "version": "4.3.2", - "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.3.2.tgz", - "integrity": "sha512-Gn/JaSk/Mt9gYubxTtSn/QCV4em9mpAPiR1rqy/Ocu19u/G9J5WWdNoUT4SiV6mFC3y6cxyFcFwdzPM3FgxGAQ==", - "dev": true, - "requires": { - "@types/json-schema": "^7.0.9", - "ajv": "^8.9.0", - "ajv-formats": "^2.1.1", - "ajv-keywords": "^5.1.0" - } - } } }, "webpack-dev-server": { @@ -7003,47 +6719,6 @@ "spdy": "^4.0.2", "webpack-dev-middleware": "^7.4.2", "ws": "^8.18.0" - }, - "dependencies": { - "ajv": { - "version": "8.12.0", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.12.0.tgz", - "integrity": "sha512-sRu1kpcO9yLtYxBKvqfTeh9KzZEwO3STyX1HT+4CaDzC6HpTGYhIhPIzj9XuKU7KYDwnaeh5hcOwjy1QuJzBPA==", - "dev": true, - "requires": { - "fast-deep-equal": "^3.1.1", - "json-schema-traverse": "^1.0.0", - "require-from-string": "^2.0.2", - "uri-js": "^4.2.2" - } - }, - "ajv-keywords": { - "version": "5.1.0", - "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-5.1.0.tgz", - "integrity": "sha512-YCS/JNFAUyr5vAuhk1DWm1CBxRHW9LbJ2ozWeemrIqpbsqKjHVxYPyi5GC0rjZIT5JxJ3virVTS8wk4i/Z+krw==", - "dev": true, - "requires": { - "fast-deep-equal": "^3.1.3" - } - }, - "json-schema-traverse": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", - "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", - "dev": true - }, - "schema-utils": { - "version": "4.2.0", - "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.2.0.tgz", - "integrity": "sha512-L0jRsrPpjdckP3oPug3/VxNKt2trR8TcabrM6FOAAlvC/9Phcmm+cuAgTlxBqdBR1WJx7Naj9WHw+aOmheSVbw==", - "dev": true, - "requires": { - "@types/json-schema": "^7.0.9", - "ajv": "^8.9.0", - "ajv-formats": "^2.1.1", - "ajv-keywords": "^5.1.0" - } - } } }, "webpack-merge": { @@ -7056,6 +6731,12 @@ "wildcard": "^2.0.0" } }, + "webpack-sources": { + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/webpack-sources/-/webpack-sources-3.3.3.tgz", + "integrity": "sha512-yd1RBzSGanHkitROoPFd6qsrxt+oFhg/129YzheDGqeustzX0vTZJZsSsQjVQC4yzBQ56K55XU8gaNCtIzOnTg==", + "dev": true + }, "websocket-driver": { "version": "0.7.4", "resolved": "https://registry.npmjs.org/websocket-driver/-/websocket-driver-0.7.4.tgz", diff --git a/datafusion/wasmtest/datafusion-wasm-app/package.json b/datafusion/wasmtest/datafusion-wasm-app/package.json index b46993de77d9b..aecc5b689554e 100644 --- a/datafusion/wasmtest/datafusion-wasm-app/package.json +++ b/datafusion/wasmtest/datafusion-wasm-app/package.json @@ -27,7 +27,7 @@ "datafusion-wasmtest": "../pkg" }, "devDependencies": { - "webpack": "5.94.0", + "webpack": "5.105.0", "webpack-cli": "5.1.4", "webpack-dev-server": "5.2.1", "copy-webpack-plugin": "12.0.2" diff --git a/datafusion/wasmtest/src/lib.rs b/datafusion/wasmtest/src/lib.rs index c5948bd7343a6..403509515bf31 100644 --- a/datafusion/wasmtest/src/lib.rs +++ b/datafusion/wasmtest/src/lib.rs @@ -79,7 +79,8 @@ pub fn basic_parse() { mod test { use std::sync::Arc; - use super::*; + use bytes::Bytes; + use datafusion::datasource::file_format::file_compression_type::FileCompressionType; use datafusion::{ arrow::{ array::{ArrayRef, Int32Array, RecordBatch, StringArray}, @@ -87,8 +88,9 @@ mod test { }, datasource::MemTable, execution::context::SessionContext, + prelude::CsvReadOptions, }; - use datafusion_common::test_util::batches_to_string; + use datafusion_common::{DataFusionError, test_util::batches_to_string}; use datafusion_execution::{ config::SessionConfig, disk_manager::{DiskManagerBuilder, DiskManagerMode}, @@ -96,17 +98,18 @@ mod test { }; use datafusion_physical_plan::collect; use datafusion_sql::parser::DFParser; - use object_store::{ObjectStore, memory::InMemory, path::Path}; + use futures::{StreamExt, TryStreamExt, stream}; + use object_store::{ObjectStore, PutPayload, memory::InMemory, path::Path}; use url::Url; use wasm_bindgen_test::wasm_bindgen_test; wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); + #[cfg(target_arch = "wasm32")] #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] - #[cfg_attr(not(target_arch = "wasm32"), allow(dead_code))] fn datafusion_test() { - basic_exprs(); - basic_parse(); + super::basic_exprs(); + super::basic_parse(); } fn get_ctx() -> Arc { @@ -259,4 +262,55 @@ mod test { +----+-------+" ); } + + #[wasm_bindgen_test(unsupported = tokio::test)] + async fn test_csv_read_xz_compressed() { + let csv_data = "id,value\n1,a\n2,b\n3,c\n"; + let input = Bytes::from(csv_data.as_bytes().to_vec()); + let input_stream = + stream::iter(vec![Ok::(input)]).boxed(); + + let compressed_stream = FileCompressionType::XZ + .convert_to_compress_stream(input_stream) + .unwrap(); + let compressed_data: Vec = compressed_stream.try_collect().await.unwrap(); + + let store = InMemory::new(); + let path = Path::from("data.csv.xz"); + store + .put(&path, PutPayload::from_iter(compressed_data)) + .await + .unwrap(); + + let url = Url::parse("memory://").unwrap(); + let ctx = SessionContext::new(); + ctx.register_object_store(&url, Arc::new(store)); + + let csv_options = CsvReadOptions::new() + .has_header(true) + .file_compression_type(FileCompressionType::XZ) + .file_extension("csv.xz"); + ctx.register_csv("compressed", "memory:///data.csv.xz", csv_options) + .await + .unwrap(); + + let result = ctx + .sql("SELECT * FROM compressed") + .await + .unwrap() + .collect() + .await + .unwrap(); + + assert_eq!( + batches_to_string(&result), + "+----+-------+\n\ + | id | value |\n\ + +----+-------+\n\ + | 1 | a |\n\ + | 2 | b |\n\ + | 3 | c |\n\ + +----+-------+" + ); + } } diff --git a/dev/changelog/52.1.0.md b/dev/changelog/52.1.0.md new file mode 100644 index 0000000000000..97a1435c41a44 --- /dev/null +++ b/dev/changelog/52.1.0.md @@ -0,0 +1,46 @@ + + +# Apache DataFusion 52.1.0 Changelog + +This release consists of 3 commits from 3 contributors. See credits at the end of this changelog for more information. + +See the [upgrade guide](https://datafusion.apache.org/library-user-guide/upgrading.html) for information on how to upgrade from previous versions. + +**Documentation updates:** + +- [branch-52] Fix Internal error: Assertion failed: !self.finished: LimitedBatchCoalescer (#19785) [#19836](https://github.com/apache/datafusion/pull/19836) (alamb) + +**Other:** + +- [branch-52] fix: expose `ListFilesEntry` [#19818](https://github.com/apache/datafusion/pull/19818) (lonless9) +- [branch 52] Fix grouping set subset satisfaction [#19855](https://github.com/apache/datafusion/pull/19855) (gabotechs) +- Add BatchAdapter to simplify using PhysicalExprAdapter / Projector [#19877](https://github.com/apache/datafusion/pull/19877) (alamb) + +## Credits + +Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) per contributor. + +``` + 1 Andrew Lamb + 1 Gabriel + 1 XL Liang +``` + +Thank you also to everyone who contributed in other ways such as filing issues, reviewing PRs, and providing feedback on this release. diff --git a/dev/release/release-tarball.sh b/dev/release/release-tarball.sh index bd858d23a767c..a284b6c4351f3 100755 --- a/dev/release/release-tarball.sh +++ b/dev/release/release-tarball.sh @@ -43,6 +43,13 @@ fi version=$1 rc=$2 +read -r -p "Proceed to release tarball for ${version}-rc${rc}? [y/N]: " answer +answer=${answer:-no} +if [ "${answer}" != "y" ]; then + echo "Cancelled tarball release!" + exit 1 +fi + tmp_dir=tmp-apache-datafusion-dist echo "Recreate temporary directory: ${tmp_dir}" diff --git a/dev/rust_lint.sh b/dev/rust_lint.sh index 21d4611846413..43d29bd88166d 100755 --- a/dev/rust_lint.sh +++ b/dev/rust_lint.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -23,30 +23,103 @@ # Note: The installed checking tools (e.g., taplo) are not guaranteed to match # the CI versions for simplicity, there might be some minor differences. Check # `.github/workflows` for the CI versions. +# +# +# +# For each lint scripts: +# +# By default, they run in check mode: +# ./ci/scripts/rust_fmt.sh +# +# With `--write`, scripts perform best-effort auto fixes: +# ./ci/scripts/rust_fmt.sh --write +# +# The `--write` flag assumes a clean git repository (no uncommitted changes); to force +# auto fixes even if there are unstaged changes, use `--allow-dirty`: +# ./ci/scripts/rust_fmt.sh --write --allow-dirty +# +# New scripts can use `rust_fmt.sh` as a reference. + +set -euo pipefail + +usage() { + cat >&2 < /dev/null; then + echo "Installing $cmd using: $install_cmd" + eval "$install_cmd" + fi +} + +MODE="check" +ALLOW_DIRTY=0 + +while [[ $# -gt 0 ]]; do + case "$1" in + --write) + MODE="write" + ;; + --allow-dirty) + ALLOW_DIRTY=1 + ;; + -h|--help) + usage + ;; + *) + usage + ;; + esac + shift +done + +SCRIPT_NAME="$(basename "${BASH_SOURCE[0]}")" + +ensure_tool "taplo" "cargo install taplo-cli --locked" +ensure_tool "hawkeye" "cargo install hawkeye --locked" +ensure_tool "typos" "cargo install typos-cli --locked" + +run_step() { + local name="$1" + shift + echo "[${SCRIPT_NAME}] Running ${name}" + "$@" +} + +declare -a WRITE_STEPS=( + "ci/scripts/rust_fmt.sh|true" + "ci/scripts/rust_clippy.sh|true" + "ci/scripts/rust_toml_fmt.sh|true" + "ci/scripts/license_header.sh|true" + "ci/scripts/typos_check.sh|true" + "ci/scripts/doc_prettier_check.sh|true" +) + +declare -a READONLY_STEPS=( + "ci/scripts/rust_docs.sh|false" +) -# For `.toml` format checking -set -e -if ! command -v taplo &> /dev/null; then - echo "Installing taplo using cargo" - cargo install taplo-cli -fi - -# For Apache licence header checking -if ! command -v hawkeye &> /dev/null; then - echo "Installing hawkeye using cargo" - cargo install hawkeye --locked -fi - -# For spelling checks -if ! command -v typos &> /dev/null; then - echo "Installing typos using cargo" - cargo install typos-cli --locked -fi - -ci/scripts/rust_fmt.sh -ci/scripts/rust_clippy.sh -ci/scripts/rust_toml_fmt.sh -ci/scripts/rust_docs.sh -ci/scripts/license_header.sh -ci/scripts/typos_check.sh -ci/scripts/doc_prettier_check.sh +for entry in "${WRITE_STEPS[@]}" "${READONLY_STEPS[@]}"; do + IFS='|' read -r script_path supports_write <<<"$entry" + script_name="$(basename "$script_path")" + args=() + if [[ "$supports_write" == "true" && "$MODE" == "write" ]]; then + args+=(--write) + [[ $ALLOW_DIRTY -eq 1 ]] && args+=(--allow-dirty) + fi + if [[ ${#args[@]} -gt 0 ]]; then + run_step "$script_name" "$script_path" "${args[@]}" + else + run_step "$script_name" "$script_path" + fi +done diff --git a/dev/update_config_docs.sh b/dev/update_config_docs.sh index 90bbc5d3bad06..f39bdda3aee87 100755 --- a/dev/update_config_docs.sh +++ b/dev/update_config_docs.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -20,14 +20,16 @@ set -e -SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -cd "${SOURCE_DIR}/../" && pwd +ROOT_DIR="$(git rev-parse --show-toplevel)" +cd "${ROOT_DIR}" + +# Load centralized tool versions +source "${ROOT_DIR}/ci/scripts/utils/tool_versions.sh" TARGET_FILE="docs/source/user-guide/configs.md" PRINT_CONFIG_DOCS_COMMAND="cargo run --manifest-path datafusion/core/Cargo.toml --bin print_config_docs" PRINT_RUNTIME_CONFIG_DOCS_COMMAND="cargo run --manifest-path datafusion/core/Cargo.toml --bin print_runtime_config_docs" - echo "Inserting header" cat <<'EOF' > "$TARGET_FILE" + +# Workspace Dependency Graph + +This page shows the dependency relationships between DataFusion's workspace +crates. This only includes internal dependencies, external crates like `Arrow` are not included + +The dependency graph is auto-generated by `docs/scripts/generate_dependency_graph.sh` to ensure it stays up-to-date, and the script now runs automatically as part of `docs/build.sh`. + +## Dependency Graph for Workspace Crates + + + +```{raw} html + + +``` + +### Legend + +- black lines: normal dependency +- blue lines: dev-dependency +- green lines: build-dependency +- dotted lines: optional dependency (could be removed by disabling a cargo feature) + +Transitive dependencies are intentionally ignored to keep the graph readable. + +The dependency graph is generated through `cargo depgraph` by `docs/scripts/generate_dependency_graph.sh`. diff --git a/docs/source/contributor-guide/howtos.md b/docs/source/contributor-guide/howtos.md index 1b38e95bf35d6..18d9391d24bbe 100644 --- a/docs/source/contributor-guide/howtos.md +++ b/docs/source/contributor-guide/howtos.md @@ -187,4 +187,4 @@ valid installation of [protoc] (see [installation instructions] for details). ``` [protoc]: https://github.com/protocolbuffers/protobuf#protocol-compiler-installation -[installation instructions]: https://datafusion.apache.org/contributor-guide/getting_started.html#protoc-installation +[installation instructions]: https://datafusion.apache.org/contributor-guide/development_environment.html#protoc-installation diff --git a/docs/source/contributor-guide/index.md b/docs/source/contributor-guide/index.md index ea42329f2c00f..900df2f88174f 100644 --- a/docs/source/contributor-guide/index.md +++ b/docs/source/contributor-guide/index.md @@ -199,3 +199,9 @@ Please understand the reviewing capacity is **very limited** for the project, so ### Better ways to contribute than an “AI dump” It's recommended to write a high-quality issue with a clear problem statement and a minimal, reproducible example. This can make it easier for others to contribute. + +### CI Runners + +We use [Runs-On](https://runs-on.com/) for some actions in the main repository, which run in the ASF AWS account to speed up CI time. In forks, these actions run on the default GitHub runners since forks do not have access to ASF infrastructure. + +We also use standard GitHub runners for some actions in the main repository; these are also runnable in forks. diff --git a/docs/source/contributor-guide/testing.md b/docs/source/contributor-guide/testing.md index 81ceabb646bf3..5a6caed224cfe 100644 --- a/docs/source/contributor-guide/testing.md +++ b/docs/source/contributor-guide/testing.md @@ -104,6 +104,7 @@ locally by following the [instructions in the documentation]. [sqlite test suite]: https://www.sqlite.org/sqllogictest/dir?ci=tip [instructions in the documentation]: https://github.com/apache/datafusion/tree/main/datafusion/sqllogictest#running-tests-sqlite +[extended.yml]: https://github.com/apache/datafusion/blob/main/.github/workflows/extended.yml ## Rust Integration Tests diff --git a/docs/source/download.md b/docs/source/download.md index 7a62e398c02b5..e358f39940cdb 100644 --- a/docs/source/download.md +++ b/docs/source/download.md @@ -26,7 +26,7 @@ For example: ```toml [dependencies] -datafusion = "41.0.0" +datafusion = "52.0.0" ``` While DataFusion is distributed via [crates.io] as a convenience, the diff --git a/docs/source/index.rst b/docs/source/index.rst index 181d54a66477c..4d57faa0cbf73 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -134,7 +134,7 @@ To get started, see :caption: Library User Guide library-user-guide/index - library-user-guide/upgrading + library-user-guide/upgrading/index library-user-guide/extensions library-user-guide/using-the-sql-api library-user-guide/extending-sql @@ -159,6 +159,7 @@ To get started, see contributor-guide/communication contributor-guide/development_environment contributor-guide/architecture + contributor-guide/architecture/dependency-graph contributor-guide/testing contributor-guide/api-health contributor-guide/howtos diff --git a/docs/source/library-user-guide/extending-sql.md b/docs/source/library-user-guide/extending-sql.md index 409a0fb89a321..687d884895c8b 100644 --- a/docs/source/library-user-guide/extending-sql.md +++ b/docs/source/library-user-guide/extending-sql.md @@ -27,6 +27,11 @@ need to: - Add custom data types not natively supported - Implement SQL constructs like `TABLESAMPLE`, `PIVOT`/`UNPIVOT`, or `MATCH_RECOGNIZE` +You can read more about this topic in the [Extending SQL in DataFusion: from ->> +to TABLESAMPLE] blog. + +[extending sql in datafusion: from ->> to tablesample]: https://datafusion.apache.org/blog/2026/01/12/extending-sql + ## Architecture Overview When DataFusion processes a SQL query, it goes through these stages: @@ -329,7 +334,7 @@ SELECT * FROM sales [`executionplan`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html [`sessioncontext`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html [`sessionstatebuilder`]: https://docs.rs/datafusion/latest/datafusion/execution/session_state/struct.SessionStateBuilder.html -[`relationplannercontext`]: https://docs.rs/datafusion/latest/datafusion/sql/planner/trait.RelationPlannerContext.html +[`relationplannercontext`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/planner/trait.RelationPlannerContext.html [exprplanner api documentation]: https://docs.rs/datafusion/latest/datafusion/logical_expr/planner/trait.ExprPlanner.html [typeplanner api documentation]: https://docs.rs/datafusion/latest/datafusion/logical_expr/planner/trait.TypePlanner.html [relationplanner api documentation]: https://docs.rs/datafusion/latest/datafusion/logical_expr/planner/trait.RelationPlanner.html diff --git a/docs/source/library-user-guide/functions/adding-udfs.md b/docs/source/library-user-guide/functions/adding-udfs.md index 5d033ae3f9e97..48162d6abcdfb 100644 --- a/docs/source/library-user-guide/functions/adding-udfs.md +++ b/docs/source/library-user-guide/functions/adding-udfs.md @@ -583,7 +583,6 @@ For async UDF implementation details, see [`async_udf.rs`](https://github.com/ap [`scalarudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.ScalarUDF.html [`create_udf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udf.html -[`process_scalar_func_inputs`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.process_scalar_func_inputs.html [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/udf/advanced_udf.rs ## Named Arguments @@ -684,6 +683,10 @@ No function matches the given name and argument types substr(Utf8). Scalar UDFs are functions that take a row of data and return a single value. Window UDFs are similar, but they also have access to the rows around them. Access to the proximal rows is helpful, but adds some complexity to the implementation. +For background and other considerations, see the [User defined Window Functions in DataFusion] blog. + +[user defined window functions in datafusion]: https://datafusion.apache.org/blog/2025/04/19/user-defined-window-functions + For example, we will declare a user defined window function that computes a moving average. ```rust diff --git a/docs/source/library-user-guide/query-optimizer.md b/docs/source/library-user-guide/query-optimizer.md index 8ed6593d56203..2254776bf6e3c 100644 --- a/docs/source/library-user-guide/query-optimizer.md +++ b/docs/source/library-user-guide/query-optimizer.md @@ -25,11 +25,21 @@ format. DataFusion has modular design, allowing individual crates to be re-used in other projects. This crate is a submodule of DataFusion that provides a query optimizer for logical plans, and -contains an extensive set of [`OptimizerRule`]s and [`PhysicalOptimizerRules`] that may rewrite the plan and/or its expressions so +contains an extensive set of [`OptimizerRule`]s and [`PhysicalOptimizerRule`]s that may rewrite the plan and/or its expressions so they execute more quickly while still computing the same result. +For a deeper background on optimizer architecture and rule types and predicates, see +[Optimizing SQL (and DataFrames) in DataFusion, Part 1], [Part 2], +[Using Ordering for Better Plans in Apache DataFusion], and +[Dynamic Filters: Passing Information Between Operators During Execution for 25x Faster Queries]. + [`optimizerrule`]: https://docs.rs/datafusion/latest/datafusion/optimizer/trait.OptimizerRule.html -[`physicaloptimizerrules`]: https://docs.rs/datafusion/latest/datafusion/physical_optimizer/trait.PhysicalOptimizerRule.html +[`physicaloptimizerrule`]: https://docs.rs/datafusion/latest/datafusion/physical_optimizer/trait.PhysicalOptimizerRule.html +[optimizing sql (and dataframes) in datafusion, part 1]: https://datafusion.apache.org/blog/2025/06/15/optimizing-sql-dataframes-part-one +[part 2]: https://datafusion.apache.org/blog/2025/06/15/optimizing-sql-dataframes-part-two +[using ordering for better plans in apache datafusion]: https://datafusion.apache.org/blog/2025/03/11/ordering-analysis +[dynamic filters: passing information between operators during execution for 25x faster queries]: https://datafusion.apache.org/blog/2025/09/10/dynamic-filters +[`logicalplan`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.LogicalPlan.html ## Running the Optimizer @@ -75,7 +85,7 @@ Please refer to the example to learn more about the general approach to writing optimizer rules and then move onto studying the existing rules. -`OptimizerRule` transforms one ['LogicalPlan'] into another which +`OptimizerRule` transforms one [`LogicalPlan`] into another which computes the same results, but in a potentially more efficient way. If there are no suitable transformations for the input plan, the optimizer can simply return it as is. @@ -504,3 +514,5 @@ fn analyze_filter_example() -> Result<()> { Ok(()) } ``` + +[treenode api]: https://docs.rs/datafusion/latest/datafusion/common/tree_node/trait.TreeNode.html diff --git a/docs/source/library-user-guide/table-constraints.md b/docs/source/library-user-guide/table-constraints.md index dea746463d234..252817822d990 100644 --- a/docs/source/library-user-guide/table-constraints.md +++ b/docs/source/library-user-guide/table-constraints.md @@ -37,6 +37,6 @@ They are provided for informational purposes and can be used by custom - **Foreign keys and check constraints**: These constraints are parsed but are not validated or used during query planning. -[`tableconstraint`]: https://docs.rs/datafusion/latest/datafusion/sql/planner/enum.TableConstraint.html -[`constraints`]: https://docs.rs/datafusion/latest/datafusion/common/functional_dependencies/struct.Constraints.html -[`field`]: https://docs.rs/arrow/latest/arrow/datatype/struct.Field.html +[`tableconstraint`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/sqlparser/ast/enum.TableConstraint.html +[`constraints`]: https://docs.rs/datafusion/latest/datafusion/common/struct.Constraints.html +[`field`]: https://docs.rs/arrow/latest/arrow/datatypes/struct.Field.html diff --git a/docs/source/library-user-guide/upgrading.md b/docs/source/library-user-guide/upgrading.md deleted file mode 100644 index 157e0339e1eff..0000000000000 --- a/docs/source/library-user-guide/upgrading.md +++ /dev/null @@ -1,2180 +0,0 @@ - - -# Upgrade Guides - -## DataFusion `53.0.0` - -**Note:** DataFusion `53.0.0` has not been released yet. The information provided in this section pertains to features and changes that have already been merged to the main branch and are awaiting release in this version. - -### `SimplifyInfo` trait removed, `SimplifyContext` now uses builder-style API - -The `SimplifyInfo` trait has been removed and replaced with the concrete `SimplifyContext` struct. This simplifies the expression simplification API and removes the need for trait objects. - -**Who is affected:** - -- Users who implemented custom `SimplifyInfo` implementations -- Users who implemented `ScalarUDFImpl::simplify()` for custom scalar functions -- Users who directly use `SimplifyContext` or `ExprSimplifier` - -**Breaking changes:** - -1. The `SimplifyInfo` trait has been removed entirely -2. `SimplifyContext` no longer takes `&ExecutionProps` - it now uses a builder-style API with direct fields -3. `ScalarUDFImpl::simplify()` now takes `&SimplifyContext` instead of `&dyn SimplifyInfo` -4. Time-dependent function simplification (e.g., `now()`) is now optional - if `query_execution_start_time` is `None`, these functions won't be simplified - -**Migration guide:** - -If you implemented a custom `SimplifyInfo`: - -**Before:** - -```rust,ignore -impl SimplifyInfo for MySimplifyInfo { - fn is_boolean_type(&self, expr: &Expr) -> Result { ... } - fn nullable(&self, expr: &Expr) -> Result { ... } - fn execution_props(&self) -> &ExecutionProps { ... } - fn get_data_type(&self, expr: &Expr) -> Result { ... } -} -``` - -**After:** - -Use `SimplifyContext` directly with the builder-style API: - -```rust,ignore -let context = SimplifyContext::default() - .with_schema(schema) - .with_config_options(config_options) - .with_query_execution_start_time(Some(Utc::now())); // or use .with_current_time() -``` - -If you implemented `ScalarUDFImpl::simplify()`: - -**Before:** - -```rust,ignore -fn simplify( - &self, - args: Vec, - info: &dyn SimplifyInfo, -) -> Result { - let now_ts = info.execution_props().query_execution_start_time; - // ... -} -``` - -**After:** - -```rust,ignore -fn simplify( - &self, - args: Vec, - info: &SimplifyContext, -) -> Result { - // query_execution_start_time is now Option> - // Return Original if time is not set (simplification skipped) - let Some(now_ts) = info.query_execution_start_time() else { - return Ok(ExprSimplifyResult::Original(args)); - }; - // ... -} -``` - -If you created `SimplifyContext` from `ExecutionProps`: - -**Before:** - -```rust,ignore -let props = ExecutionProps::new(); -let context = SimplifyContext::new(&props).with_schema(schema); -``` - -**After:** - -```rust,ignore -let context = SimplifyContext::default() - .with_schema(schema) - .with_config_options(config_options) - .with_current_time(); // Sets query_execution_start_time to Utc::now() -``` - -See [`SimplifyContext` documentation](https://docs.rs/datafusion-expr/latest/datafusion_expr/simplify/struct.SimplifyContext.html) for more details. - -## DataFusion `52.0.0` - -### Changes to DFSchema API - -To permit more efficient planning, several methods on `DFSchema` have been -changed to return references to the underlying [`&FieldRef`] rather than -[`&Field`]. This allows planners to more cheaply copy the references via -`Arc::clone` rather than cloning the entire `Field` structure. - -You may need to change code to use `Arc::clone` instead of `.as_ref().clone()` -directly on the `Field`. For example: - -```diff -- let field = df_schema.field("my_column").as_ref().clone(); -+ let field = Arc::clone(df_schema.field("my_column")); -``` - -### ListingTableProvider now caches `LIST` commands - -In prior versions, `ListingTableProvider` would issue `LIST` commands to -the underlying object store each time it needed to list files for a query. -To improve performance, `ListingTableProvider` now caches the results of -`LIST` commands for the lifetime of the `ListingTableProvider` instance or -until a cache entry expires. - -Note that by default the cache has no expiration time, so if files are added or removed -from the underlying object store, the `ListingTableProvider` will not see -those changes until the `ListingTableProvider` instance is dropped and recreated. - -You can configure the maximum cache size and cache entry expiration time via configuration options: - -- `datafusion.runtime.list_files_cache_limit` - Limits the size of the cache in bytes -- `datafusion.runtime.list_files_cache_ttl` - Limits the TTL (time-to-live) of an entry in seconds - -Detailed configuration information can be found in the [DataFusion Runtime -Configuration](https://datafusion.apache.org/user-guide/configs.html#runtime-configuration-settings) user's guide. - -Caching can be disabled by setting the limit to 0: - -```sql -SET datafusion.runtime.list_files_cache_limit TO "0K"; -``` - -Note that the internal API has changed to use a trait `ListFilesCache` instead of a type alias. - -### `newlines_in_values` moved from `FileScanConfig` to `CsvOptions` - -The CSV-specific `newlines_in_values` configuration option has been moved from `FileScanConfig` to `CsvOptions`, as it only applies to CSV file parsing. - -**Who is affected:** - -- Users who set `newlines_in_values` via `FileScanConfigBuilder::with_newlines_in_values()` - -**Migration guide:** - -Set `newlines_in_values` in `CsvOptions` instead of on `FileScanConfigBuilder`: - -**Before:** - -```rust,ignore -let source = Arc::new(CsvSource::new(file_schema.clone())); -let config = FileScanConfigBuilder::new(object_store_url, source) - .with_newlines_in_values(true) - .build(); -``` - -**After:** - -```rust,ignore -let options = CsvOptions { - newlines_in_values: Some(true), - ..Default::default() -}; -let source = Arc::new(CsvSource::new(file_schema.clone()) - .with_csv_options(options)); -let config = FileScanConfigBuilder::new(object_store_url, source) - .build(); -``` - -### Removal of `pyarrow` feature - -The `pyarrow` feature flag has been removed. This feature has been migrated to -the `datafusion-python` repository since version `44.0.0`. - -### Refactoring of `FileSource` constructors and `FileScanConfigBuilder` to accept schemas upfront - -The way schemas are passed to file sources and scan configurations has been significantly refactored. File sources now require the schema (including partition columns) to be provided at construction time, and `FileScanConfigBuilder` no longer takes a separate schema parameter. - -**Who is affected:** - -- Users who create `FileScanConfig` or file sources (`ParquetSource`, `CsvSource`, `JsonSource`, `AvroSource`) directly -- Users who implement custom `FileFormat` implementations - -**Key changes:** - -1. **FileSource constructors now require TableSchema**: All built-in file sources now take the schema in their constructor: - - ```diff - - let source = ParquetSource::default(); - + let source = ParquetSource::new(table_schema); - ``` - -2. **FileScanConfigBuilder no longer takes schema as a parameter**: The schema is now passed via the FileSource: - - ```diff - - FileScanConfigBuilder::new(url, schema, source) - + FileScanConfigBuilder::new(url, source) - ``` - -3. **Partition columns are now part of TableSchema**: The `with_table_partition_cols()` method has been removed from `FileScanConfigBuilder`. Partition columns are now passed as part of the `TableSchema` to the FileSource constructor: - - ```diff - + let table_schema = TableSchema::new( - + file_schema, - + vec![Arc::new(Field::new("date", DataType::Utf8, false))], - + ); - + let source = ParquetSource::new(table_schema); - let config = FileScanConfigBuilder::new(url, source) - - .with_table_partition_cols(vec![Field::new("date", DataType::Utf8, false)]) - .with_file(partitioned_file) - .build(); - ``` - -4. **FileFormat::file_source() now takes TableSchema parameter**: Custom `FileFormat` implementations must be updated: - ```diff - impl FileFormat for MyFileFormat { - - fn file_source(&self) -> Arc { - + fn file_source(&self, table_schema: TableSchema) -> Arc { - - Arc::new(MyFileSource::default()) - + Arc::new(MyFileSource::new(table_schema)) - } - } - ``` - -**Migration examples:** - -For Parquet files: - -```diff -- let source = Arc::new(ParquetSource::default()); -- let config = FileScanConfigBuilder::new(url, schema, source) -+ let table_schema = TableSchema::new(schema, vec![]); -+ let source = Arc::new(ParquetSource::new(table_schema)); -+ let config = FileScanConfigBuilder::new(url, source) - .with_file(partitioned_file) - .build(); -``` - -For CSV files with partition columns: - -```diff -- let source = Arc::new(CsvSource::new(true, b',', b'"')); -- let config = FileScanConfigBuilder::new(url, file_schema, source) -- .with_table_partition_cols(vec![Field::new("year", DataType::Int32, false)]) -+ let options = CsvOptions { -+ has_header: Some(true), -+ delimiter: b',', -+ quote: b'"', -+ ..Default::default() -+ }; -+ let table_schema = TableSchema::new( -+ file_schema, -+ vec![Arc::new(Field::new("year", DataType::Int32, false))], -+ ); -+ let source = Arc::new(CsvSource::new(table_schema).with_csv_options(options)); -+ let config = FileScanConfigBuilder::new(url, source) - .build(); -``` - -### Adaptive filter representation in Parquet filter pushdown - -As of Arrow 57.1.0, DataFusion uses a new adaptive filter strategy when -evaluating pushed down filters for Parquet files. This new strategy improves -performance for certain types of queries where the results of filtering are -more efficiently represented with a bitmask rather than a selection. -See [arrow-rs #5523] for more details. - -This change only applies to the built-in Parquet data source with filter-pushdown enabled ( -which is [not yet the default behavior]). - -You can disable the new behavior by setting the -`datafusion.execution.parquet.force_filter_selections` [configuration setting] to true. - -```sql -> set datafusion.execution.parquet.force_filter_selections = true; -``` - -[arrow-rs #5523]: https://github.com/apache/arrow-rs/issues/5523 -[configuration setting]: https://datafusion.apache.org/user-guide/configs.html -[not yet the default behavior]: https://github.com/apache/datafusion/issues/3463 - -### Statistics handling moved from `FileSource` to `FileScanConfig` - -Statistics are now managed directly by `FileScanConfig` instead of being delegated to `FileSource` implementations. This simplifies the `FileSource` trait and provides more consistent statistics handling across all file formats. - -**Who is affected:** - -- Users who have implemented custom `FileSource` implementations - -**Breaking changes:** - -Two methods have been removed from the `FileSource` trait: - -- `with_statistics(&self, statistics: Statistics) -> Arc` -- `statistics(&self) -> Result` - -**Migration guide:** - -If you have a custom `FileSource` implementation, you need to: - -1. Remove the `with_statistics` method implementation -2. Remove the `statistics` method implementation -3. Remove any internal state that was storing statistics - -**Before:** - -```rust,ignore -#[derive(Clone)] -struct MyCustomSource { - table_schema: TableSchema, - projected_statistics: Option, - // other fields... -} - -impl FileSource for MyCustomSource { - fn with_statistics(&self, statistics: Statistics) -> Arc { - Arc::new(Self { - table_schema: self.table_schema.clone(), - projected_statistics: Some(statistics), - // other fields... - }) - } - - fn statistics(&self) -> Result { - Ok(self.projected_statistics.clone().unwrap_or_else(|| - Statistics::new_unknown(self.table_schema.file_schema()) - )) - } - - // other methods... -} -``` - -**After:** - -```rust,ignore -#[derive(Clone)] -struct MyCustomSource { - table_schema: TableSchema, - // projected_statistics field removed - // other fields... -} - -impl FileSource for MyCustomSource { - // with_statistics method removed - // statistics method removed - - // other methods... -} -``` - -**Accessing statistics:** - -Statistics are now accessed through `FileScanConfig` instead of `FileSource`: - -```diff -- let stats = config.file_source.statistics()?; -+ let stats = config.statistics(); -``` - -Note that `FileScanConfig::statistics()` automatically marks statistics as inexact when filters are present, ensuring correctness when filters are pushed down. - -### Partition column handling moved out of `PhysicalExprAdapter` - -Partition column replacement is now a separate preprocessing step performed before expression rewriting via `PhysicalExprAdapter`. This change provides better separation of concerns and makes the adapter more focused on schema differences rather than partition value substitution. - -**Who is affected:** - -- Users who have custom implementations of `PhysicalExprAdapterFactory` that handle partition columns -- Users who directly use the `FilePruner` API - -**Breaking changes:** - -1. `FilePruner::try_new()` signature changed: the `partition_fields` parameter has been removed since partition column handling is now done separately -2. Partition column replacement must now be done via `replace_columns_with_literals()` before expressions are passed to the adapter - -**Migration guide:** - -If you have code that creates a `FilePruner` with partition fields: - -**Before:** - -```rust,ignore -use datafusion_pruning::FilePruner; - -let pruner = FilePruner::try_new( - predicate, - file_schema, - partition_fields, // This parameter is removed - file_stats, -)?; -``` - -**After:** - -```rust,ignore -use datafusion_pruning::FilePruner; - -// Partition fields are no longer needed -let pruner = FilePruner::try_new( - predicate, - file_schema, - file_stats, -)?; -``` - -If you have custom code that relies on `PhysicalExprAdapter` to handle partition columns, you must now call `replace_columns_with_literals()` separately: - -**Before:** - -```rust,ignore -// Adapter handled partition column replacement internally -let adapted_expr = adapter.rewrite(expr)?; -``` - -**After:** - -```rust,ignore -use datafusion_physical_expr_adapter::replace_columns_with_literals; - -// Replace partition columns first -let expr_with_literals = replace_columns_with_literals(expr, &partition_values)?; -// Then apply the adapter -let adapted_expr = adapter.rewrite(expr_with_literals)?; -``` - -### `build_row_filter` signature simplified - -The `build_row_filter` function in `datafusion-datasource-parquet` has been simplified to take a single schema parameter instead of two. -The expectation is now that the filter has been adapted to the physical file schema (the arrow representation of the parquet file's schema) before being passed to this function -using a `PhysicalExprAdapter` for example. - -**Who is affected:** - -- Users who call `build_row_filter` directly - -**Breaking changes:** - -The function signature changed from: - -```rust,ignore -pub fn build_row_filter( - expr: &Arc, - physical_file_schema: &SchemaRef, - predicate_file_schema: &SchemaRef, // removed - metadata: &ParquetMetaData, - reorder_predicates: bool, - file_metrics: &ParquetFileMetrics, -) -> Result> -``` - -To: - -```rust,ignore -pub fn build_row_filter( - expr: &Arc, - file_schema: &SchemaRef, - metadata: &ParquetMetaData, - reorder_predicates: bool, - file_metrics: &ParquetFileMetrics, -) -> Result> -``` - -**Migration guide:** - -Remove the duplicate schema parameter from your call: - -```diff -- build_row_filter(&predicate, &file_schema, &file_schema, metadata, reorder, metrics) -+ build_row_filter(&predicate, &file_schema, metadata, reorder, metrics) -``` - -### Planner now requires explicit opt-in for WITHIN GROUP syntax - -The SQL planner now enforces the aggregate UDF contract more strictly: the -`WITHIN GROUP (ORDER BY ...)` syntax is accepted only if the aggregate UDAF -explicitly advertises support by returning `true` from -`AggregateUDFImpl::supports_within_group_clause()`. - -Previously the planner forwarded a `WITHIN GROUP` clause to order-sensitive -aggregates even when they did not implement ordered-set semantics, which could -cause queries such as `SUM(x) WITHIN GROUP (ORDER BY x)` to plan successfully. -This behavior was too permissive and has been changed to match PostgreSQL and -the documented semantics. - -Migration: If your UDAF intentionally implements ordered-set semantics and -wants to accept the `WITHIN GROUP` SQL syntax, update your implementation to -return `true` from `supports_within_group_clause()` and handle the ordering -semantics in your accumulator implementation. If your UDAF is merely -order-sensitive (but not an ordered-set aggregate), do not advertise -`supports_within_group_clause()` and clients should use alternative function -signatures (for example, explicit ordering as a function argument) instead. - -### `AggregateUDFImpl::supports_null_handling_clause` now defaults to `false` - -This method specifies whether an aggregate function allows `IGNORE NULLS`/`RESPECT NULLS` -during SQL parsing, with the implication it respects these configs during computation. - -Most DataFusion aggregate functions silently ignored this syntax in prior versions -as they did not make use of it and it was permitted by default. We change this so -only the few functions which do respect this clause (e.g. `array_agg`, `first_value`, -`last_value`) need to implement it. - -Custom user defined aggregate functions will also error if this syntax is used, -unless they explicitly declare support by overriding the method. - -For example, SQL parsing will now fail for queries such as this: - -```sql -SELECT median(c1) IGNORE NULLS FROM table -``` - -Instead of silently succeeding. - -### API change for `CacheAccessor` trait - -The remove API no longer requires a mutable instance - -### FFI crate updates - -Many of the structs in the `datafusion-ffi` crate have been updated to allow easier -conversion to the underlying trait types they represent. This simplifies some code -paths, but also provides an additional improvement in cases where library code goes -through a round trip via the foreign function interface. - -To update your code, suppose you have a `FFI_SchemaProvider` called `ffi_provider` -and you wish to use this as a `SchemaProvider`. In the old approach you would do -something like: - -```rust,ignore - let foreign_provider: ForeignSchemaProvider = ffi_provider.into(); - let foreign_provider = Arc::new(foreign_provider) as Arc; -``` - -This code should now be written as: - -```rust,ignore - let foreign_provider: Arc = ffi_provider.into(); - let foreign_provider = foreign_provider as Arc; -``` - -For the case of user defined functions, the updates are similar but you -may need to change the way you call the creation of the `ScalarUDF`. -Aggregate and window functions follow the same pattern. - -Previously you may write: - -```rust,ignore - let foreign_udf: ForeignScalarUDF = ffi_udf.try_into()?; - let foreign_udf: ScalarUDF = foreign_udf.into(); -``` - -Instead this should now be: - -```rust,ignore - let foreign_udf: Arc = ffi_udf.into(); - let foreign_udf = ScalarUDF::new_from_shared_impl(foreign_udf); -``` - -When creating any of the following structs, we now require the user to -provide a `TaskContextProvider` and optionally a `LogicalExtensionCodec`: - -- `FFI_CatalogListProvider` -- `FFI_CatalogProvider` -- `FFI_SchemaProvider` -- `FFI_TableProvider` -- `FFI_TableFunction` - -Each of these structs has a `new()` and a `new_with_ffi_codec()` method for -instantiation. For example, when you previously would write - -```rust,ignore - let table = Arc::new(MyTableProvider::new()); - let ffi_table = FFI_TableProvider::new(table, None); -``` - -Now you will need to provide a `TaskContextProvider`. The most common -implementation of this trait is `SessionContext`. - -```rust,ignore - let ctx = Arc::new(SessionContext::default()); - let table = Arc::new(MyTableProvider::new()); - let ffi_table = FFI_TableProvider::new(table, None, ctx, None); -``` - -The alternative function to create these structures may be more convenient -if you are doing many of these operations. A `FFI_LogicalExtensionCodec` will -store the `TaskContextProvider` as well. - -```rust,ignore - let codec = Arc::new(DefaultLogicalExtensionCodec {}); - let ctx = Arc::new(SessionContext::default()); - let ffi_codec = FFI_LogicalExtensionCodec::new(codec, None, ctx); - let table = Arc::new(MyTableProvider::new()); - let ffi_table = FFI_TableProvider::new_with_ffi_codec(table, None, ffi_codec); -``` - -Additional information about the usage of the `TaskContextProvider` can be -found in the crate README. - -Additionally, the FFI structure for Scalar UDF's no longer contains a -`return_type` call. This code was not used since the `ForeignScalarUDF` -struct implements the `return_field_from_args` instead. - -### Projection handling moved from FileScanConfig to FileSource - -Projection handling has been moved from `FileScanConfig` into `FileSource` implementations. This enables format-specific projection pushdown (e.g., Parquet can push down struct field access, Vortex can push down computed expressions into un-decoded data). - -**Who is affected:** - -- Users who have implemented custom `FileSource` implementations -- Users who use `FileScanConfigBuilder::with_projection_indices` directly - -**Breaking changes:** - -1. **`FileSource::with_projection` replaced with `try_pushdown_projection`:** - - The `with_projection(&self, config: &FileScanConfig) -> Arc` method has been removed and replaced with `try_pushdown_projection(&self, projection: &ProjectionExprs) -> Result>>`. - -2. **`FileScanConfig.projection_exprs` field removed:** - - Projections are now stored in the `FileSource` directly, not in `FileScanConfig`. - Various public helper methods that access projection information have been removed from `FileScanConfig`. - -3. **`FileScanConfigBuilder::with_projection_indices` now returns `Result`:** - - This method can now fail if the projection pushdown fails. - -4. **`FileSource::create_file_opener` now returns `Result>`:** - - Previously returned `Arc` directly. - Any `FileSource` implementation that may fail to create a `FileOpener` should now return an appropriate error. - -5. **`DataSource::try_swapping_with_projection` signature changed:** - - Parameter changed from `&[ProjectionExpr]` to `&ProjectionExprs`. - -**Migration guide:** - -If you have a custom `FileSource` implementation: - -**Before:** - -```rust,ignore -impl FileSource for MyCustomSource { - fn with_projection(&self, config: &FileScanConfig) -> Arc { - // Apply projection from config - Arc::new(Self { /* ... */ }) - } - - fn create_file_opener( - &self, - object_store: Arc, - base_config: &FileScanConfig, - partition: usize, - ) -> Arc { - Arc::new(MyOpener { /* ... */ }) - } -} -``` - -**After:** - -```rust,ignore -impl FileSource for MyCustomSource { - fn try_pushdown_projection( - &self, - projection: &ProjectionExprs, - ) -> Result>> { - // Return None if projection cannot be pushed down - // Return Some(new_source) with projection applied if it can - Ok(Some(Arc::new(Self { - projection: Some(projection.clone()), - /* ... */ - }))) - } - - fn projection(&self) -> Option<&ProjectionExprs> { - self.projection.as_ref() - } - - fn create_file_opener( - &self, - object_store: Arc, - base_config: &FileScanConfig, - partition: usize, - ) -> Result> { - Ok(Arc::new(MyOpener { /* ... */ })) - } -} -``` - -We recommend you look at [#18627](https://github.com/apache/datafusion/pull/18627) -that introduced these changes for more examples for how this was handled for the various built in file sources. - -We have added [`SplitProjection`](https://docs.rs/datafusion-datasource/latest/datafusion_datasource/projection/struct.SplitProjection.html) and [`ProjectionOpener`](https://docs.rs/datafusion-datasource/latest/datafusion_datasource/projection/struct.ProjectionOpener.html) helpers to make it easier to handle projections in your `FileSource` implementations. - -For file sources that can only handle simple column selections (not computed expressions), use the `SplitProjection` and `ProjectionOpener` helpers to split the projection into pushdownable and non-pushdownable parts: - -```rust,ignore -use datafusion_datasource::projection::{SplitProjection, ProjectionOpener}; - -// In try_pushdown_projection: -let split = SplitProjection::new(projection, self.table_schema())?; -// Use split.file_projection() for what to push down to the file format -// The ProjectionOpener wrapper will handle the rest -``` - -**For `FileScanConfigBuilder` users:** - -```diff -let config = FileScanConfigBuilder::new(url, source) -- .with_projection_indices(Some(vec![0, 2, 3])) -+ .with_projection_indices(Some(vec![0, 2, 3]))? - .build(); -``` - -### `SchemaAdapter` and `SchemaAdapterFactory` completely removed - -Following the deprecation announced in [DataFusion 49.0.0](#deprecating-schemaadapterfactory-and-schemaadapter), `SchemaAdapterFactory` has been fully removed from Parquet scanning. This applies to both: - -The following symbols have been deprecated and will be removed in the next release: - -- `SchemaAdapter` trait -- `SchemaAdapterFactory` trait -- `SchemaMapper` trait -- `SchemaMapping` struct -- `DefaultSchemaAdapterFactory` struct - -These types were previously used to adapt record batch schemas during file reading. -This functionality has been replaced by `PhysicalExprAdapterFactory`, which rewrites expressions at planning time rather than transforming batches at runtime. -If you were using a custom `SchemaAdapterFactory` for schema adaptation (e.g., default column values, type coercion), you should now implement `PhysicalExprAdapterFactory` instead. -See the [default column values example](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/custom_data_source/default_column_values.rs) for how to implement a custom `PhysicalExprAdapterFactory`. - -**Migration guide:** - -If you implemented a custom `SchemaAdapterFactory`, migrate to `PhysicalExprAdapterFactory`. -See the [default column values example](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/custom_data_source/default_column_values.rs) for a complete implementation. - -## DataFusion `51.0.0` - -### `arrow` / `parquet` updated to 57.0.0 - -### Upgrade to arrow `57.0.0` and parquet `57.0.0` - -This version of DataFusion upgrades the underlying Apache Arrow implementation -to version `57.0.0`, including several dependent crates such as `prost`, -`tonic`, `pyo3`, and `substrait`. . See the [release -notes](https://github.com/apache/arrow-rs/releases/tag/57.0.0) for more details. - -### `MSRV` updated to 1.88.0 - -The Minimum Supported Rust Version (MSRV) has been updated to [`1.88.0`]. - -[`1.88.0`]: https://releases.rs/docs/1.88.0/ - -### `FunctionRegistry` exposes two additional methods - -`FunctionRegistry` exposes two additional methods `udafs` and `udwfs` which expose set of registered user defined aggregation and window function names. To upgrade implement methods returning set of registered function names: - -```diff -impl FunctionRegistry for FunctionRegistryImpl { - fn udfs(&self) -> HashSet { - self.scalar_functions.keys().cloned().collect() - } -+ fn udafs(&self) -> HashSet { -+ self.aggregate_functions.keys().cloned().collect() -+ } -+ -+ fn udwfs(&self) -> HashSet { -+ self.window_functions.keys().cloned().collect() -+ } -} -``` - -### `datafusion-proto` use `TaskContext` rather than `SessionContext` in physical plan serde methods - -There have been changes in the public API methods of `datafusion-proto` which handle physical plan serde. - -Methods like `physical_plan_from_bytes`, `parse_physical_expr` and similar, expect `TaskContext` instead of `SessionContext` - -```diff -- let plan2 = physical_plan_from_bytes(&bytes, &ctx)?; -+ let plan2 = physical_plan_from_bytes(&bytes, &ctx.task_ctx())?; -``` - -as `TaskContext` contains `RuntimeEnv` methods such as `try_into_physical_plan` will not have explicit `RuntimeEnv` parameter. - -```diff -let result_exec_plan: Arc = proto -- .try_into_physical_plan(&ctx, runtime.deref(), &composed_codec) -+. .try_into_physical_plan(&ctx.task_ctx(), &composed_codec) -``` - -`PhysicalExtensionCodec::try_decode()` expects `TaskContext` instead of `FunctionRegistry`: - -```diff -pub trait PhysicalExtensionCodec { - fn try_decode( - &self, - buf: &[u8], - inputs: &[Arc], -- registry: &dyn FunctionRegistry, -+ ctx: &TaskContext, - ) -> Result>; -``` - -See [issue #17601] for more details. - -[issue #17601]: https://github.com/apache/datafusion/issues/17601 - -### `SessionState`'s `sql_to_statement` method takes `Dialect` rather than a `str` - -The `dialect` parameter of `sql_to_statement` method defined in `datafusion::execution::session_state::SessionState` -has changed from `&str` to `&Dialect`. -`Dialect` is an enum defined in the `datafusion-common` -crate under the `config` module that provides type safety -and better validation for SQL dialect selection - -### Reorganization of `ListingTable` into `datafusion-catalog-listing` crate - -There has been a long standing request to remove features such as `ListingTable` -from the `datafusion` crate to support faster build times. The structs -`ListingOptions`, `ListingTable`, and `ListingTableConfig` are now available -within the `datafusion-catalog-listing` crate. These are re-exported in -the `datafusion` crate, so this should be a minimal impact to existing users. - -See [issue #14462] and [issue #17713] for more details. - -[issue #14462]: https://github.com/apache/datafusion/issues/14462 -[issue #17713]: https://github.com/apache/datafusion/issues/17713 - -### Reorganization of `ArrowSource` into `datafusion-datasource-arrow` crate - -To support [issue #17713] the `ArrowSource` code has been removed from -the `datafusion` core crate into it's own crate, `datafusion-datasource-arrow`. -This follows the pattern for the AVRO, CSV, JSON, and Parquet data sources. -Users may need to update their paths to account for these changes. - -See [issue #17713] for more details. - -### `FileScanConfig::projection` renamed to `FileScanConfig::projection_exprs` - -The `projection` field in `FileScanConfig` has been renamed to `projection_exprs` and its type has changed from `Option>` to `Option`. This change enables more powerful projection pushdown capabilities by supporting arbitrary physical expressions rather than just column indices. - -**Impact on direct field access:** - -If you directly access the `projection` field: - -```rust,ignore -let config: FileScanConfig = ...; -let projection = config.projection; -``` - -You should update to: - -```rust,ignore -let config: FileScanConfig = ...; -let projection_exprs = config.projection_exprs; -``` - -**Impact on builders:** - -The `FileScanConfigBuilder::with_projection()` method has been deprecated in favor of `with_projection_indices()`: - -```diff -let config = FileScanConfigBuilder::new(url, file_source) -- .with_projection(Some(vec![0, 2, 3])) -+ .with_projection_indices(Some(vec![0, 2, 3])) - .build(); -``` - -Note: `with_projection()` still works but is deprecated and will be removed in a future release. - -**What is `ProjectionExprs`?** - -`ProjectionExprs` is a new type that represents a list of physical expressions for projection. While it can be constructed from column indices (which is what `with_projection_indices` does internally), it also supports arbitrary physical expressions, enabling advanced features like expression evaluation during scanning. - -You can access column indices from `ProjectionExprs` using its methods if needed: - -```rust,ignore -let projection_exprs: ProjectionExprs = ...; -// Get the column indices if the projection only contains simple column references -let indices = projection_exprs.column_indices(); -``` - -### `DESCRIBE query` support - -`DESCRIBE query` was previously an alias for `EXPLAIN query`, which outputs the -_execution plan_ of the query. With this release, `DESCRIBE query` now outputs -the computed _schema_ of the query, consistent with the behavior of `DESCRIBE table_name`. - -### `datafusion.execution.time_zone` default configuration changed - -The default value for `datafusion.execution.time_zone` previously was a string value of `+00:00` (GMT/Zulu time). -This was changed to be an `Option` with a default of `None`. If you want to change the timezone back -to the previous value you can execute the sql: - -```sql -SET -TIMEZONE = '+00:00'; -``` - -This change was made to better support using the default timezone in scalar UDF functions such as -`now`, `current_date`, `current_time`, and `to_timestamp` among others. - -### Introduction of `TableSchema` and changes to `FileSource::with_schema()` method - -A new `TableSchema` struct has been introduced in the `datafusion-datasource` crate to better manage table schemas with partition columns. This struct helps distinguish between: - -- **File schema**: The schema of actual data files on disk -- **Partition columns**: Columns derived from directory structure (e.g., Hive-style partitioning) -- **Table schema**: The complete schema combining both file and partition columns - -As part of this change, the `FileSource::with_schema()` method signature has changed from accepting a `SchemaRef` to accepting a `TableSchema`. - -**Who is affected:** - -- Users who have implemented custom `FileSource` implementations will need to update their code -- Users who only use built-in file sources (Parquet, CSV, JSON, AVRO, Arrow) are not affected - -**Migration guide for custom `FileSource` implementations:** - -```diff - use datafusion_datasource::file::FileSource; --use arrow::datatypes::SchemaRef; -+use datafusion_datasource::TableSchema; - - impl FileSource for MyCustomSource { -- fn with_schema(&self, schema: SchemaRef) -> Arc { -+ fn with_schema(&self, schema: TableSchema) -> Arc { - Arc::new(Self { -- schema: Some(schema), -+ // Use schema.file_schema() to get the file schema without partition columns -+ schema: Some(Arc::clone(schema.file_schema())), - ..self.clone() - }) - } - } -``` - -For implementations that need access to partition columns: - -```rust,ignore -fn with_schema(&self, schema: TableSchema) -> Arc { - Arc::new(Self { - file_schema: Arc::clone(schema.file_schema()), - partition_cols: schema.table_partition_cols().clone(), - table_schema: Arc::clone(schema.table_schema()), - ..self.clone() - }) -} -``` - -**Note**: Most `FileSource` implementations only need to store the file schema (without partition columns), as shown in the first example. The second pattern of storing all three schema components is typically only needed for advanced use cases where you need access to different schema representations for different operations (e.g., ParquetSource uses the file schema for building pruning predicates but needs the table schema for filter pushdown logic). - -**Using `TableSchema` directly:** - -If you're constructing a `FileScanConfig` or working with table schemas and partition columns, you can now use `TableSchema`: - -```rust -use datafusion_datasource::TableSchema; -use arrow::datatypes::{Schema, Field, DataType}; -use std::sync::Arc; - -// Create a TableSchema with partition columns -let file_schema = Arc::new(Schema::new(vec![ - Field::new("user_id", DataType::Int64, false), - Field::new("amount", DataType::Float64, false), -])); - -let partition_cols = vec![ - Arc::new(Field::new("date", DataType::Utf8, false)), - Arc::new(Field::new("region", DataType::Utf8, false)), -]; - -let table_schema = TableSchema::new(file_schema, partition_cols); - -// Access different schema representations -let file_schema_ref = table_schema.file_schema(); // Schema without partition columns -let full_schema = table_schema.table_schema(); // Complete schema with partition columns -let partition_cols_ref = table_schema.table_partition_cols(); // Just the partition columns -``` - -### `AggregateUDFImpl::is_ordered_set_aggregate` has been renamed to `AggregateUDFImpl::supports_within_group_clause` - -This method has been renamed to better reflect the actual impact it has for aggregate UDF implementations. -The accompanying `AggregateUDF::is_ordered_set_aggregate` has also been renamed to `AggregateUDF::supports_within_group_clause`. -No functionality has been changed with regards to this method; it still refers only to permitting use of `WITHIN GROUP` -SQL syntax for the aggregate function. - -## DataFusion `50.0.0` - -### ListingTable automatically detects Hive Partitioned tables - -DataFusion 50.0.0 automatically infers Hive partitions when using the `ListingTableFactory` and `CREATE EXTERNAL TABLE`. Previously, -when creating a `ListingTable`, datasets that use Hive partitioning (e.g. -`/table_root/column1=value1/column2=value2/data.parquet`) would not have the Hive columns reflected in -the table's schema or data. The previous behavior can be -restored by setting the `datafusion.execution.listing_table_factory_infer_partitions` configuration option to `false`. -See [issue #17049] for more details. - -[issue #17049]: https://github.com/apache/datafusion/issues/17049 - -### `MSRV` updated to 1.86.0 - -The Minimum Supported Rust Version (MSRV) has been updated to [`1.86.0`]. -See [#17230] for details. - -[`1.86.0`]: https://releases.rs/docs/1.86.0/ -[#17230]: https://github.com/apache/datafusion/pull/17230 - -### `ScalarUDFImpl`, `AggregateUDFImpl` and `WindowUDFImpl` traits now require `PartialEq`, `Eq`, and `Hash` traits - -To address error-proneness of `ScalarUDFImpl::equals`, `AggregateUDFImpl::equals`and -`WindowUDFImpl::equals` methods and to make it easy to implement function equality correctly, -the `equals` and `hash_value` methods have been removed from `ScalarUDFImpl`, `AggregateUDFImpl` -and `WindowUDFImpl` traits. They are replaced the requirement to implement the `PartialEq`, `Eq`, -and `Hash` traits on any type implementing `ScalarUDFImpl`, `AggregateUDFImpl` or `WindowUDFImpl`. -Please see [issue #16677] for more details. - -Most of the scalar functions are stateless and have a `signature` field. These can be migrated -using regular expressions - -- search for `\#\[derive\(Debug\)\](\n *(pub )?struct \w+ \{\n *signature\: Signature\,\n *\})`, -- replace with `#[derive(Debug, PartialEq, Eq, Hash)]$1`, -- review all the changes and make sure only function structs were changed. - -[issue #16677]: https://github.com/apache/datafusion/issues/16677 - -### `AsyncScalarUDFImpl::invoke_async_with_args` returns `ColumnarValue` - -In order to enable single value optimizations and be consistent with other -user defined function APIs, the `AsyncScalarUDFImpl::invoke_async_with_args` method now -returns a `ColumnarValue` instead of a `ArrayRef`. - -To upgrade, change the return type of your implementation - -```rust -# /* comment to avoid running -impl AsyncScalarUDFImpl for AskLLM { - async fn invoke_async_with_args( - &self, - args: ScalarFunctionArgs, - _option: &ConfigOptions, - ) -> Result { - .. - return array_ref; // old code - } -} -# */ -``` - -To return a `ColumnarValue` - -```rust -# /* comment to avoid running -impl AsyncScalarUDFImpl for AskLLM { - async fn invoke_async_with_args( - &self, - args: ScalarFunctionArgs, - _option: &ConfigOptions, - ) -> Result { - .. - return ColumnarValue::from(array_ref); // new code - } -} -# */ -``` - -See [#16896](https://github.com/apache/datafusion/issues/16896) for more details. - -### `ProjectionExpr` changed from type alias to struct - -`ProjectionExpr` has been changed from a type alias to a struct with named fields to improve code clarity and maintainability. - -**Before:** - -```rust,ignore -pub type ProjectionExpr = (Arc, String); -``` - -**After:** - -```rust,ignore -#[derive(Debug, Clone)] -pub struct ProjectionExpr { - pub expr: Arc, - pub alias: String, -} -``` - -To upgrade your code: - -- Replace tuple construction `(expr, alias)` with `ProjectionExpr::new(expr, alias)` or `ProjectionExpr { expr, alias }` -- Replace tuple field access `.0` and `.1` with `.expr` and `.alias` -- Update pattern matching from `(expr, alias)` to `ProjectionExpr { expr, alias }` - -This mainly impacts use of `ProjectionExec`. - -This change was done in [#17398] - -[#17398]: https://github.com/apache/datafusion/pull/17398 - -### `SessionState`, `SessionConfig`, and `OptimizerConfig` returns `&Arc` instead of `&ConfigOptions` - -To provide broader access to `ConfigOptions` and reduce required clones, some -APIs have been changed to return a `&Arc` instead of a -`&ConfigOptions`. This allows sharing the same `ConfigOptions` across multiple -threads without needing to clone the entire `ConfigOptions` structure unless it -is modified. - -Most users will not be impacted by this change since the Rust compiler typically -automatically dereference the `Arc` when needed. However, in some cases you may -have to change your code to explicitly call `as_ref()` for example, from - -```rust -# /* comment to avoid running -let optimizer_config: &ConfigOptions = state.options(); -# */ -``` - -To - -```rust -# /* comment to avoid running -let optimizer_config: &ConfigOptions = state.options().as_ref(); -# */ -``` - -See PR [#16970](https://github.com/apache/datafusion/pull/16970) - -### API Change to `AsyncScalarUDFImpl::invoke_async_with_args` - -The `invoke_async_with_args` method of the `AsyncScalarUDFImpl` trait has been -updated to remove the `_option: &ConfigOptions` parameter to simplify the API -now that the `ConfigOptions` can be accessed through the `ScalarFunctionArgs` -parameter. - -You can change your code like this - -```rust -# /* comment to avoid running -impl AsyncScalarUDFImpl for AskLLM { - async fn invoke_async_with_args( - &self, - args: ScalarFunctionArgs, - _option: &ConfigOptions, - ) -> Result { - .. - } - ... -} -# */ -``` - -To this: - -```rust -# /* comment to avoid running - -impl AsyncScalarUDFImpl for AskLLM { - async fn invoke_async_with_args( - &self, - args: ScalarFunctionArgs, - ) -> Result { - let options = &args.config_options; - .. - } - ... -} -# */ -``` - -### Schema Rewriter Module Moved to New Crate - -The `schema_rewriter` module and its associated symbols have been moved from `datafusion_physical_expr` to a new crate `datafusion_physical_expr_adapter`. This affects the following symbols: - -- `DefaultPhysicalExprAdapter` -- `DefaultPhysicalExprAdapterFactory` -- `PhysicalExprAdapter` -- `PhysicalExprAdapterFactory` - -To upgrade, change your imports to: - -```rust -use datafusion_physical_expr_adapter::{ - DefaultPhysicalExprAdapter, DefaultPhysicalExprAdapterFactory, - PhysicalExprAdapter, PhysicalExprAdapterFactory -}; -``` - -### Upgrade to arrow `56.0.0` and parquet `56.0.0` - -This version of DataFusion upgrades the underlying Apache Arrow implementation -to version `56.0.0`. See the [release notes](https://github.com/apache/arrow-rs/releases/tag/56.0.0) -for more details. - -### Added `ExecutionPlan::reset_state` - -In order to fix a bug in DataFusion `49.0.0` where dynamic filters (currently only generated in the presence of a query such as `ORDER BY ... LIMIT ...`) -produced incorrect results in recursive queries, a new method `reset_state` has been added to the `ExecutionPlan` trait. - -Any `ExecutionPlan` that needs to maintain internal state or references to other nodes in the execution plan tree should implement this method to reset that state. -See [#17028] for more details and an example implementation for `SortExec`. - -[#17028]: https://github.com/apache/datafusion/pull/17028 - -### Nested Loop Join input sort order cannot be preserved - -The Nested Loop Join operator has been rewritten from scratch to improve performance and memory efficiency. From the micro-benchmarks: this change introduces up to 5X speed-up and uses only 1% memory in extreme cases compared to the previous implementation. - -However, the new implementation cannot preserve input sort order like the old version could. This is a fundamental design trade-off that prioritizes performance and memory efficiency over sort order preservation. - -See [#16996] for details. - -[#16996]: https://github.com/apache/datafusion/pull/16996 - -### Add `as_any()` method to `LazyBatchGenerator` - -To help with protobuf serialization, the `as_any()` method has been added to the `LazyBatchGenerator` trait. This means you will need to add `as_any()` to your implementation of `LazyBatchGenerator`: - -```rust -# /* comment to avoid running - -impl LazyBatchGenerator for MyBatchGenerator { - fn as_any(&self) -> &dyn Any { - self - } - - ... -} - -# */ -``` - -See [#17200](https://github.com/apache/datafusion/pull/17200) for details. - -### Refactored `DataSource::try_swapping_with_projection` - -We refactored `DataSource::try_swapping_with_projection` to simplify the method and minimize leakage across the ExecutionPlan <-> DataSource abstraction layer. -Reimplementation for any custom `DataSource` should be relatively straightforward, see [#17395] for more details. - -[#17395]: https://github.com/apache/datafusion/pull/17395/ - -### `FileOpenFuture` now uses `DataFusionError` instead of `ArrowError` - -The `FileOpenFuture` type alias has been updated to use `DataFusionError` instead of `ArrowError` for its error type. This change affects the `FileOpener` trait and any implementations that work with file streaming operations. - -**Before:** - -```rust,ignore -pub type FileOpenFuture = BoxFuture<'static, Result>>>; -``` - -**After:** - -```rust,ignore -pub type FileOpenFuture = BoxFuture<'static, Result>>>; -``` - -If you have custom implementations of `FileOpener` or work directly with `FileOpenFuture`, you'll need to update your error handling to use `DataFusionError` instead of `ArrowError`. The `FileStreamState` enum's `Open` variant has also been updated accordingly. See [#17397] for more details. - -[#17397]: https://github.com/apache/datafusion/pull/17397 - -### FFI user defined aggregate function signature change - -The Foreign Function Interface (FFI) signature for user defined aggregate functions -has been updated to call `return_field` instead of `return_type` on the underlying -aggregate function. This is to support metadata handling with these aggregate functions. -This change should be transparent to most users. If you have written unit tests to call -`return_type` directly, you may need to change them to calling `return_field` instead. - -This update is a breaking change to the FFI API. The current best practice when using the -FFI crate is to ensure that all libraries that are interacting are using the same -underlying Rust version. Issue [#17374] has been opened to discuss stabilization of -this interface so that these libraries can be used across different DataFusion versions. - -See [#17407] for details. - -[#17407]: https://github.com/apache/datafusion/pull/17407 -[#17374]: https://github.com/apache/datafusion/issues/17374 - -### Added `PhysicalExpr::is_volatile_node` - -We added a method to `PhysicalExpr` to mark a `PhysicalExpr` as volatile: - -```rust,ignore -impl PhysicalExpr for MyRandomExpr { - fn is_volatile_node(&self) -> bool { - true - } -} -``` - -We've shipped this with a default value of `false` to minimize breakage but we highly recommend that implementers of `PhysicalExpr` opt into a behavior, even if it is returning `false`. - -You can see more discussion and example implementations in [#17351]. - -[#17351]: https://github.com/apache/datafusion/pull/17351 - -## DataFusion `49.0.0` - -### `MSRV` updated to 1.85.1 - -The Minimum Supported Rust Version (MSRV) has been updated to [`1.85.1`]. See -[#16728] for details. - -[`1.85.1`]: https://releases.rs/docs/1.85.1/ -[#16728]: https://github.com/apache/datafusion/pull/16728 - -### `DataFusionError` variants are now `Box`ed - -To reduce the size of `DataFusionError`, several variants that were previously stored inline are now `Box`ed. This reduces the size of `Result` and thus stack usage and async state machine size. Please see [#16652] for more details. - -The following variants of `DataFusionError` are now boxed: - -- `ArrowError` -- `SQL` -- `SchemaError` - -This is a breaking change. Code that constructs or matches on these variants will need to be updated. - -For example, to create a `SchemaError`, instead of: - -```rust -# /* comment to avoid running -use datafusion_common::{DataFusionError, SchemaError}; -DataFusionError::SchemaError( - SchemaError::DuplicateUnqualifiedField { name: "foo".to_string() }, - Box::new(None) -) -# */ -``` - -You now need to `Box` the inner error: - -```rust -# /* comment to avoid running -use datafusion_common::{DataFusionError, SchemaError}; -DataFusionError::SchemaError( - Box::new(SchemaError::DuplicateUnqualifiedField { name: "foo".to_string() }), - Box::new(None) -) -# */ -``` - -[#16652]: https://github.com/apache/datafusion/issues/16652 - -### Metadata on Arrow Types is now represented by `FieldMetadata` - -Metadata from the Arrow `Field` is now stored using the `FieldMetadata` -structure. In prior versions it was stored as both a `HashMap` -and a `BTreeMap`. `FieldMetadata` is a easier to work with and -is more efficient. - -To create `FieldMetadata` from a `Field`: - -```rust -# /* comment to avoid running - let metadata = FieldMetadata::from(&field); -# */ -``` - -To add metadata to a `Field`, use the `add_to_field` method: - -```rust -# /* comment to avoid running -let updated_field = metadata.add_to_field(field); -# */ -``` - -See [#16317] for details. - -[#16317]: https://github.com/apache/datafusion/pull/16317 - -### New `datafusion.execution.spill_compression` configuration option - -DataFusion 49.0.0 adds support for compressing spill files when data is written to disk during spilling query execution. A new configuration option `datafusion.execution.spill_compression` controls the compression codec used. - -**Configuration:** - -- **Key**: `datafusion.execution.spill_compression` -- **Default**: `uncompressed` -- **Valid values**: `uncompressed`, `lz4_frame`, `zstd` - -**Usage:** - -```rust -# /* comment to avoid running -use datafusion::prelude::*; -use datafusion_common::config::SpillCompression; - -let config = SessionConfig::default() - .with_spill_compression(SpillCompression::Zstd); -let ctx = SessionContext::new_with_config(config); -# */ -``` - -Or via SQL: - -```sql -SET datafusion.execution.spill_compression = 'zstd'; -``` - -For more details about this configuration option, including performance trade-offs between different compression codecs, see the [Configuration Settings](../user-guide/configs.md) documentation. - -### Deprecated `map_varchar_to_utf8view` configuration option - -See [issue #16290](https://github.com/apache/datafusion/pull/16290) for more information -The old configuration - -```text -datafusion.sql_parser.map_varchar_to_utf8view -``` - -is now **deprecated** in favor of the unified option below.\ -If you previously used this to control only `VARCHAR`→`Utf8View` mapping, please migrate to `map_string_types_to_utf8view`. - ---- - -### New `map_string_types_to_utf8view` configuration option - -To unify **all** SQL string types (`CHAR`, `VARCHAR`, `TEXT`, `STRING`) to Arrow’s zero‑copy `Utf8View`, DataFusion 49.0.0 introduces: - -- **Key**: `datafusion.sql_parser.map_string_types_to_utf8view` -- **Default**: `true` - -**Description:** - -- When **true** (default), **all** SQL string types are mapped to `Utf8View`, avoiding full‑copy UTF‑8 allocations and improving performance. -- When **false**, DataFusion falls back to the legacy `Utf8` mapping for **all** string types. - -#### Examples - -```rust -# /* comment to avoid running -// Disable Utf8View mapping for all SQL string types -let opts = datafusion::sql::planner::ParserOptions::new() - .with_map_string_types_to_utf8view(false); - -// Verify the setting is applied -assert!(!opts.map_string_types_to_utf8view); -# */ -``` - ---- - -```sql --- Disable Utf8View mapping globally -SET datafusion.sql_parser.map_string_types_to_utf8view = false; - --- Now VARCHAR, CHAR, TEXT, STRING all use Utf8 rather than Utf8View -CREATE TABLE my_table (a VARCHAR, b TEXT, c STRING); -DESCRIBE my_table; -``` - -### Deprecating `SchemaAdapterFactory` and `SchemaAdapter` - -We are moving away from converting data (using `SchemaAdapter`) to converting the expressions themselves (which is more efficient and flexible). - -See [issue #16800](https://github.com/apache/datafusion/issues/16800) for more information -The first place this change has taken place is in predicate pushdown for Parquet. -By default if you do not use a custom `SchemaAdapterFactory` we will use expression conversion instead. -If you do set a custom `SchemaAdapterFactory` we will continue to use it but emit a warning about that code path being deprecated. - -To resolve this you need to implement a custom `PhysicalExprAdapterFactory` and use that instead of a `SchemaAdapterFactory`. -See the [default values](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/custom_data_source/default_column_values.rs) for an example of how to do this. -Opting into the new APIs will set you up for future changes since we plan to expand use of `PhysicalExprAdapterFactory` to other areas of DataFusion. - -See [#16800] for details. - -[#16800]: https://github.com/apache/datafusion/issues/16800 - -### `TableParquetOptions` Updated - -The `TableParquetOptions` struct has a new `crypto` field to specify encryption -options for Parquet files. The `ParquetEncryptionOptions` implements `Default` -so you can upgrade your existing code like this: - -```rust -# /* comment to avoid running -TableParquetOptions { - global, - column_specific_options, - key_value_metadata, -} -# */ -``` - -To this: - -```rust -# /* comment to avoid running -TableParquetOptions { - global, - column_specific_options, - key_value_metadata, - crypto: Default::default(), // New crypto field -} -# */ -``` - -## DataFusion `48.0.1` - -### `datafusion.execution.collect_statistics` now defaults to `true` - -The default value of the `datafusion.execution.collect_statistics` configuration -setting is now true. This change impacts users that use that value directly and relied -on its default value being `false`. - -This change also restores the default behavior of `ListingTable` to its previous. If you use it directly -you can maintain the current behavior by overriding the default value in your code. - -```rust -# /* comment to avoid running -ListingOptions::new(Arc::new(ParquetFormat::default())) - .with_collect_stat(false) - // other options -# */ -``` - -## DataFusion `48.0.0` - -### `Expr::Literal` has optional metadata - -The [`Expr::Literal`] variant now includes optional metadata, which allows for -carrying through Arrow field metadata to support extension types and other uses. - -This means code such as - -```rust -# /* comment to avoid running -match expr { -... - Expr::Literal(scalar) => ... -... -} -# */ -``` - -Should be updated to: - -```rust -# /* comment to avoid running -match expr { -... - Expr::Literal(scalar, _metadata) => ... -... -} -# */ -``` - -Likewise constructing `Expr::Literal` requires metadata as well. The [`lit`] function -has not changed and returns an `Expr::Literal` with no metadata. - -[`expr::literal`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.Expr.html#variant.Literal -[`lit`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.lit.html - -### `Expr::WindowFunction` is now `Box`ed - -`Expr::WindowFunction` is now a `Box` instead of a `WindowFunction` directly. -This change was made to reduce the size of `Expr` and improve performance when -planning queries (see [details on #16207]). - -This is a breaking change, so you will need to update your code if you match -on `Expr::WindowFunction` directly. For example, if you have code like this: - -```rust -# /* comment to avoid running -match expr { - Expr::WindowFunction(WindowFunction { - params: - WindowFunctionParams { - partition_by, - order_by, - .. - } - }) => { - // Use partition_by and order_by as needed - } - _ => { - // other expr - } -} -# */ -``` - -You will need to change it to: - -```rust -# /* comment to avoid running -match expr { - Expr::WindowFunction(window_fun) => { - let WindowFunction { - fun, - params: WindowFunctionParams { - args, - partition_by, - .. - }, - } = window_fun.as_ref(); - // Use partition_by and order_by as needed - } - _ => { - // other expr - } -} -# */ -``` - -[details on #16207]: https://github.com/apache/datafusion/pull/16207#issuecomment-2922659103 - -### The `VARCHAR` SQL type is now represented as `Utf8View` in Arrow - -The mapping of the SQL `VARCHAR` type has been changed from `Utf8` to `Utf8View` -which improves performance for many string operations. You can read more about -`Utf8View` in the [DataFusion blog post on German-style strings] - -[datafusion blog post on german-style strings]: https://datafusion.apache.org/blog/2024/09/13/string-view-german-style-strings-part-1/ - -This means that when you create a table with a `VARCHAR` column, it will now use -`Utf8View` as the underlying data type. For example: - -```sql -> CREATE TABLE my_table (my_column VARCHAR); -0 row(s) fetched. -Elapsed 0.001 seconds. - -> DESCRIBE my_table; -+-------------+-----------+-------------+ -| column_name | data_type | is_nullable | -+-------------+-----------+-------------+ -| my_column | Utf8View | YES | -+-------------+-----------+-------------+ -1 row(s) fetched. -Elapsed 0.000 seconds. -``` - -You can restore the old behavior of using `Utf8` by changing the -`datafusion.sql_parser.map_varchar_to_utf8view` configuration setting. For -example - -```sql -> set datafusion.sql_parser.map_varchar_to_utf8view = false; -0 row(s) fetched. -Elapsed 0.001 seconds. - -> CREATE TABLE my_table (my_column VARCHAR); -0 row(s) fetched. -Elapsed 0.014 seconds. - -> DESCRIBE my_table; -+-------------+-----------+-------------+ -| column_name | data_type | is_nullable | -+-------------+-----------+-------------+ -| my_column | Utf8 | YES | -+-------------+-----------+-------------+ -1 row(s) fetched. -Elapsed 0.004 seconds. -``` - -### `ListingOptions` default for `collect_stat` changed from `true` to `false` - -This makes it agree with the default for `SessionConfig`. -Most users won't be impacted by this change but if you were using `ListingOptions` directly -and relied on the default value of `collect_stat` being `true`, you will need to -explicitly set it to `true` in your code. - -```rust -# /* comment to avoid running -ListingOptions::new(Arc::new(ParquetFormat::default())) - .with_collect_stat(true) - // other options -# */ -``` - -### Processing `FieldRef` instead of `DataType` for user defined functions - -In order to support metadata handling and extension types, user defined functions are -now switching to traits which use `FieldRef` rather than a `DataType` and nullability. -This gives a single interface to both of these parameters and additionally allows -access to metadata fields, which can be used for extension types. - -To upgrade structs which implement `ScalarUDFImpl`, if you have implemented -`return_type_from_args` you need instead to implement `return_field_from_args`. -If your functions do not need to handle metadata, this should be straightforward -repackaging of the output data into a `FieldRef`. The name you specify on the -field is not important. It will be overwritten during planning. `ReturnInfo` -has been removed, so you will need to remove all references to it. - -`ScalarFunctionArgs` now contains a field called `arg_fields`. You can use this -to access the metadata associated with the columnar values during invocation. - -To upgrade user defined aggregate functions, there is now a function -`return_field` that will allow you to specify both metadata and nullability of -your function. You are not required to implement this if you do not need to -handle metadata. - -The largest change to aggregate functions happens in the accumulator arguments. -Both the `AccumulatorArgs` and `StateFieldsArgs` now contain `FieldRef` rather -than `DataType`. - -To upgrade window functions, `ExpressionArgs` now contains input fields instead -of input data types. When setting these fields, the name of the field is -not important since this gets overwritten during the planning stage. All you -should need to do is wrap your existing data types in fields with nullability -set depending on your use case. - -### Physical Expression return `Field` - -To support the changes to user defined functions processing metadata, the -`PhysicalExpr` trait, which now must specify a return `Field` based on the input -schema. To upgrade structs which implement `PhysicalExpr` you need to implement -the `return_field` function. There are numerous examples in the `physical-expr` -crate. - -### `FileFormat::supports_filters_pushdown` replaced with `FileSource::try_pushdown_filters` - -To support more general filter pushdown, the `FileFormat::supports_filters_pushdown` was replaced with -`FileSource::try_pushdown_filters`. -If you implemented a custom `FileFormat` that uses a custom `FileSource` you will need to implement -`FileSource::try_pushdown_filters`. -See `ParquetSource::try_pushdown_filters` for an example of how to implement this. - -`FileFormat::supports_filters_pushdown` has been removed. - -### `ParquetExec`, `AvroExec`, `CsvExec`, `JsonExec` Removed - -`ParquetExec`, `AvroExec`, `CsvExec`, and `JsonExec` were deprecated in -DataFusion 46 and are removed in DataFusion 48. This is sooner than the normal -process described in the [API Deprecation Guidelines] because all the tests -cover the new `DataSourceExec` rather than the older structures. As we evolve -`DataSource`, the old structures began to show signs of "bit rotting" (not -working but no one knows due to lack of test coverage). - -[api deprecation guidelines]: https://datafusion.apache.org/contributor-guide/api-health.html#deprecation-guidelines - -### `PartitionedFile` added as an argument to the `FileOpener` trait - -This is necessary to properly fix filter pushdown for filters that combine partition -columns and file columns (e.g. `day = username['dob']`). - -If you implemented a custom `FileOpener` you will need to add the `PartitionedFile` argument -but are not required to use it in any way. - -## DataFusion `47.0.0` - -This section calls out some of the major changes in the `47.0.0` release of DataFusion. - -Here are some example upgrade PRs that demonstrate changes required when upgrading from DataFusion 46.0.0: - -- [delta-rs Upgrade to `47.0.0`](https://github.com/delta-io/delta-rs/pull/3378) -- [DataFusion Comet Upgrade to `47.0.0`](https://github.com/apache/datafusion-comet/pull/1563) -- [Sail Upgrade to `47.0.0`](https://github.com/lakehq/sail/pull/434) - -### Upgrades to `arrow-rs` and `arrow-parquet` 55.0.0 and `object_store` 0.12.0 - -Several APIs are changed in the underlying arrow and parquet libraries to use a -`u64` instead of `usize` to better support WASM (See [#7371] and [#6961]) - -Additionally `ObjectStore::list` and `ObjectStore::list_with_offset` have been changed to return `static` lifetimes (See [#6619]) - -[#6619]: https://github.com/apache/arrow-rs/pull/6619 -[#7371]: https://github.com/apache/arrow-rs/pull/7371 - -This requires converting from `usize` to `u64` occasionally as well as changes to `ObjectStore` implementations such as - -```rust -# /* comment to avoid running -impl Objectstore { - ... - // The range is now a u64 instead of usize - async fn get_range(&self, location: &Path, range: Range) -> ObjectStoreResult { - self.inner.get_range(location, range).await - } - ... - // the lifetime is now 'static instead of `_ (meaning the captured closure can't contain references) - // (this also applies to list_with_offset) - fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, ObjectStoreResult> { - self.inner.list(prefix) - } -} -# */ -``` - -The `ParquetObjectReader` has been updated to no longer require the object size -(it can be fetched using a single suffix request). See [#7334] for details - -[#7334]: https://github.com/apache/arrow-rs/pull/7334 - -Pattern in DataFusion `46.0.0`: - -```rust -# /* comment to avoid running -let meta: ObjectMeta = ...; -let reader = ParquetObjectReader::new(store, meta); -# */ -``` - -Pattern in DataFusion `47.0.0`: - -```rust -# /* comment to avoid running -let meta: ObjectMeta = ...; -let reader = ParquetObjectReader::new(store, location) - .with_file_size(meta.size); -# */ -``` - -### `DisplayFormatType::TreeRender` - -DataFusion now supports [`tree` style explain plans]. Implementations of -`Executionplan` must also provide a description in the -`DisplayFormatType::TreeRender` format. This can be the same as the existing -`DisplayFormatType::Default`. - -[`tree` style explain plans]: https://datafusion.apache.org/user-guide/sql/explain.html#tree-format-default - -### Removed Deprecated APIs - -Several APIs have been removed in this release. These were either deprecated -previously or were hard to use correctly such as the multiple different -`ScalarUDFImpl::invoke*` APIs. See [#15130], [#15123], and [#15027] for more -details. - -[#15130]: https://github.com/apache/datafusion/pull/15130 -[#15123]: https://github.com/apache/datafusion/pull/15123 -[#15027]: https://github.com/apache/datafusion/pull/15027 - -### `FileScanConfig` --> `FileScanConfigBuilder` - -Previously, `FileScanConfig::build()` directly created ExecutionPlans. In -DataFusion 47.0.0 this has been changed to use `FileScanConfigBuilder`. See -[#15352] for details. - -[#15352]: https://github.com/apache/datafusion/pull/15352 - -Pattern in DataFusion `46.0.0`: - -```rust -# /* comment to avoid running -let plan = FileScanConfig::new(url, schema, Arc::new(file_source)) - .with_statistics(stats) - ... - .build() -# */ -``` - -Pattern in DataFusion `47.0.0`: - -```rust -# /* comment to avoid running -let config = FileScanConfigBuilder::new(url, Arc::new(file_source)) - .with_statistics(stats) - ... - .build(); -let scan = DataSourceExec::from_data_source(config); -# */ -``` - -## DataFusion `46.0.0` - -### Use `invoke_with_args` instead of `invoke()` and `invoke_batch()` - -DataFusion is moving to a consistent API for invoking ScalarUDFs, -[`ScalarUDFImpl::invoke_with_args()`], and deprecating -[`ScalarUDFImpl::invoke()`], [`ScalarUDFImpl::invoke_batch()`], and [`ScalarUDFImpl::invoke_no_args()`] - -If you see errors such as the following it means the older APIs are being used: - -```text -This feature is not implemented: Function concat does not implement invoke but called -``` - -To fix this error, use [`ScalarUDFImpl::invoke_with_args()`] instead, as shown -below. See [PR 14876] for an example. - -Given existing code like this: - -```rust -# /* comment to avoid running -impl ScalarUDFImpl for SparkConcat { -... - fn invoke_batch(&self, args: &[ColumnarValue], number_rows: usize) -> Result { - if args - .iter() - .any(|arg| matches!(arg.data_type(), DataType::List(_))) - { - ArrayConcat::new().invoke_batch(args, number_rows) - } else { - ConcatFunc::new().invoke_batch(args, number_rows) - } - } -} -# */ -``` - -To - -```rust -# /* comment to avoid running -impl ScalarUDFImpl for SparkConcat { - ... - fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - if args - .args - .iter() - .any(|arg| matches!(arg.data_type(), DataType::List(_))) - { - ArrayConcat::new().invoke_with_args(args) - } else { - ConcatFunc::new().invoke_with_args(args) - } - } -} - # */ -``` - -[`scalarudfimpl::invoke()`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/trait.ScalarUDFImpl.html#method.invoke -[`scalarudfimpl::invoke_batch()`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/trait.ScalarUDFImpl.html#method.invoke_batch -[`scalarudfimpl::invoke_no_args()`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/trait.ScalarUDFImpl.html#method.invoke_no_args -[`scalarudfimpl::invoke_with_args()`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/trait.ScalarUDFImpl.html#method.invoke_with_args -[pr 14876]: https://github.com/apache/datafusion/pull/14876 - -### `ParquetExec`, `AvroExec`, `CsvExec`, `JsonExec` deprecated - -DataFusion 46 has a major change to how the built in DataSources are organized. -Instead of individual `ExecutionPlan`s for the different file formats they now -all use `DataSourceExec` and the format specific information is embodied in new -traits `DataSource` and `FileSource`. - -Here is more information about - -- [Design Ticket] -- Change PR [PR #14224] -- Example of an Upgrade [PR in delta-rs] - -[design ticket]: https://github.com/apache/datafusion/issues/13838 -[pr #14224]: https://github.com/apache/datafusion/pull/14224 -[pr in delta-rs]: https://github.com/delta-io/delta-rs/pull/3261 - -### Cookbook: Changes to `ParquetExecBuilder` - -Code that looks for `ParquetExec` like this will no longer work: - -```rust -# /* comment to avoid running - if let Some(parquet_exec) = plan.as_any().downcast_ref::() { - // Do something with ParquetExec here - } -# */ -``` - -Instead, with `DataSourceExec`, the same information is now on `FileScanConfig` and -`ParquetSource`. The equivalent code is - -```rust -# /* comment to avoid running -if let Some(datasource_exec) = plan.as_any().downcast_ref::() { - if let Some(scan_config) = datasource_exec.data_source().as_any().downcast_ref::() { - // FileGroups, and other information is on the FileScanConfig - // parquet - if let Some(parquet_source) = scan_config.file_source.as_any().downcast_ref::() - { - // Information on PruningPredicates and parquet options are here - } -} -# */ -``` - -### Cookbook: Changes to `ParquetExecBuilder` - -Likewise code that builds `ParquetExec` using the `ParquetExecBuilder` such as -the following must be changed: - -```rust -# /* comment to avoid running -let mut exec_plan_builder = ParquetExecBuilder::new( - FileScanConfig::new(self.log_store.object_store_url(), file_schema) - .with_projection(self.projection.cloned()) - .with_limit(self.limit) - .with_table_partition_cols(table_partition_cols), -) -.with_schema_adapter_factory(Arc::new(DeltaSchemaAdapterFactory {})) -.with_table_parquet_options(parquet_options); - -// Add filter -if let Some(predicate) = logical_filter { - if config.enable_parquet_pushdown { - exec_plan_builder = exec_plan_builder.with_predicate(predicate); - } -}; -# */ -``` - -New code should use `FileScanConfig` to build the appropriate `DataSourceExec`: - -```rust -# /* comment to avoid running -let mut file_source = ParquetSource::new(parquet_options) - .with_schema_adapter_factory(Arc::new(DeltaSchemaAdapterFactory {})); - -// Add filter -if let Some(predicate) = logical_filter { - if config.enable_parquet_pushdown { - file_source = file_source.with_predicate(predicate); - } -}; - -let file_scan_config = FileScanConfig::new( - self.log_store.object_store_url(), - file_schema, - Arc::new(file_source), -) -.with_statistics(stats) -.with_projection(self.projection.cloned()) -.with_limit(self.limit) -.with_table_partition_cols(table_partition_cols); - -// Build the actual scan like this -parquet_scan: file_scan_config.build(), -# */ -``` - -### `datafusion-cli` no longer automatically unescapes strings - -`datafusion-cli` previously would incorrectly unescape string literals (see [ticket] for more details). - -To escape `'` in SQL literals, use `''`: - -```sql -> select 'it''s escaped'; -+----------------------+ -| Utf8("it's escaped") | -+----------------------+ -| it's escaped | -+----------------------+ -1 row(s) fetched. -``` - -To include special characters (such as newlines via `\n`) you can use an `E` literal string. For example - -```sql -> select 'foo\nbar'; -+------------------+ -| Utf8("foo\nbar") | -+------------------+ -| foo\nbar | -+------------------+ -1 row(s) fetched. -Elapsed 0.005 seconds. -``` - -### Changes to array scalar function signatures - -DataFusion 46 has changed the way scalar array function signatures are -declared. Previously, functions needed to select from a list of predefined -signatures within the `ArrayFunctionSignature` enum. Now the signatures -can be defined via a `Vec` of pseudo-types, which each correspond to a -single argument. Those pseudo-types are the variants of the -`ArrayFunctionArgument` enum and are as follows: - -- `Array`: An argument of type List/LargeList/FixedSizeList. All Array - arguments must be coercible to the same type. -- `Element`: An argument that is coercible to the inner type of the `Array` - arguments. -- `Index`: An `Int64` argument. - -Each of the old variants can be converted to the new format as follows: - -`TypeSignature::ArraySignature(ArrayFunctionSignature::ArrayAndElement)`: - -```rust -# use datafusion::common::utils::ListCoercion; -# use datafusion_expr_common::signature::{ArrayFunctionArgument, ArrayFunctionSignature, TypeSignature}; - -TypeSignature::ArraySignature(ArrayFunctionSignature::Array { - arguments: vec![ArrayFunctionArgument::Array, ArrayFunctionArgument::Element], - array_coercion: Some(ListCoercion::FixedSizedListToList), -}); -``` - -`TypeSignature::ArraySignature(ArrayFunctionSignature::ElementAndArray)`: - -```rust -# use datafusion::common::utils::ListCoercion; -# use datafusion_expr_common::signature::{ArrayFunctionArgument, ArrayFunctionSignature, TypeSignature}; - -TypeSignature::ArraySignature(ArrayFunctionSignature::Array { - arguments: vec![ArrayFunctionArgument::Element, ArrayFunctionArgument::Array], - array_coercion: Some(ListCoercion::FixedSizedListToList), -}); -``` - -`TypeSignature::ArraySignature(ArrayFunctionSignature::ArrayAndIndex)`: - -```rust -# use datafusion::common::utils::ListCoercion; -# use datafusion_expr_common::signature::{ArrayFunctionArgument, ArrayFunctionSignature, TypeSignature}; - -TypeSignature::ArraySignature(ArrayFunctionSignature::Array { - arguments: vec![ArrayFunctionArgument::Array, ArrayFunctionArgument::Index], - array_coercion: None, -}); -``` - -`TypeSignature::ArraySignature(ArrayFunctionSignature::ArrayAndElementAndOptionalIndex)`: - -```rust -# use datafusion::common::utils::ListCoercion; -# use datafusion_expr_common::signature::{ArrayFunctionArgument, ArrayFunctionSignature, TypeSignature}; - -TypeSignature::OneOf(vec![ - TypeSignature::ArraySignature(ArrayFunctionSignature::Array { - arguments: vec![ArrayFunctionArgument::Array, ArrayFunctionArgument::Element], - array_coercion: None, - }), - TypeSignature::ArraySignature(ArrayFunctionSignature::Array { - arguments: vec![ - ArrayFunctionArgument::Array, - ArrayFunctionArgument::Element, - ArrayFunctionArgument::Index, - ], - array_coercion: None, - }), -]); -``` - -`TypeSignature::ArraySignature(ArrayFunctionSignature::Array)`: - -```rust -# use datafusion::common::utils::ListCoercion; -# use datafusion_expr_common::signature::{ArrayFunctionArgument, ArrayFunctionSignature, TypeSignature}; - -TypeSignature::ArraySignature(ArrayFunctionSignature::Array { - arguments: vec![ArrayFunctionArgument::Array], - array_coercion: None, -}); -``` - -Alternatively, you can switch to using one of the following functions which -take care of constructing the `TypeSignature` for you: - -- `Signature::array_and_element` -- `Signature::array_and_element_and_optional_index` -- `Signature::array_and_index` -- `Signature::array` - -[ticket]: https://github.com/apache/datafusion/issues/13286 diff --git a/docs/source/library-user-guide/upgrading/46.0.0.md b/docs/source/library-user-guide/upgrading/46.0.0.md new file mode 100644 index 0000000000000..e38d18c3d6609 --- /dev/null +++ b/docs/source/library-user-guide/upgrading/46.0.0.md @@ -0,0 +1,310 @@ + + +# Upgrade Guides + +## DataFusion 46.0.0 + +### Use `invoke_with_args` instead of `invoke()` and `invoke_batch()` + +DataFusion is moving to a consistent API for invoking ScalarUDFs, +[`ScalarUDFImpl::invoke_with_args()`], and deprecating +[`ScalarUDFImpl::invoke()`], [`ScalarUDFImpl::invoke_batch()`], and [`ScalarUDFImpl::invoke_no_args()`] + +If you see errors such as the following it means the older APIs are being used: + +```text +This feature is not implemented: Function concat does not implement invoke but called +``` + +To fix this error, use [`ScalarUDFImpl::invoke_with_args()`] instead, as shown +below. See [PR 14876] for an example. + +Given existing code like this: + +```rust +# /* comment to avoid running +impl ScalarUDFImpl for SparkConcat { +... + fn invoke_batch(&self, args: &[ColumnarValue], number_rows: usize) -> Result { + if args + .iter() + .any(|arg| matches!(arg.data_type(), DataType::List(_))) + { + ArrayConcat::new().invoke_batch(args, number_rows) + } else { + ConcatFunc::new().invoke_batch(args, number_rows) + } + } +} +# */ +``` + +To + +```rust +# /* comment to avoid running +impl ScalarUDFImpl for SparkConcat { + ... + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + if args + .args + .iter() + .any(|arg| matches!(arg.data_type(), DataType::List(_))) + { + ArrayConcat::new().invoke_with_args(args) + } else { + ConcatFunc::new().invoke_with_args(args) + } + } +} + # */ +``` + +[`scalarudfimpl::invoke()`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/trait.ScalarUDFImpl.html#method.invoke +[`scalarudfimpl::invoke_batch()`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/trait.ScalarUDFImpl.html#method.invoke_batch +[`scalarudfimpl::invoke_no_args()`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/trait.ScalarUDFImpl.html#method.invoke_no_args +[`scalarudfimpl::invoke_with_args()`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/trait.ScalarUDFImpl.html#method.invoke_with_args +[pr 14876]: https://github.com/apache/datafusion/pull/14876 + +### `ParquetExec`, `AvroExec`, `CsvExec`, `JsonExec` deprecated + +DataFusion 46 has a major change to how the built in DataSources are organized. +Instead of individual `ExecutionPlan`s for the different file formats they now +all use `DataSourceExec` and the format specific information is embodied in new +traits `DataSource` and `FileSource`. + +Here is more information about + +- [Design Ticket] +- Change PR [PR #14224] +- Example of an Upgrade [PR in delta-rs] + +[design ticket]: https://github.com/apache/datafusion/issues/13838 +[pr #14224]: https://github.com/apache/datafusion/pull/14224 +[pr in delta-rs]: https://github.com/delta-io/delta-rs/pull/3261 + +### Cookbook: Changes to `ParquetExecBuilder` + +Code that looks for `ParquetExec` like this will no longer work: + +```rust +# /* comment to avoid running + if let Some(parquet_exec) = plan.as_any().downcast_ref::() { + // Do something with ParquetExec here + } +# */ +``` + +Instead, with `DataSourceExec`, the same information is now on `FileScanConfig` and +`ParquetSource`. The equivalent code is + +```rust +# /* comment to avoid running +if let Some(datasource_exec) = plan.as_any().downcast_ref::() { + if let Some(scan_config) = datasource_exec.data_source().as_any().downcast_ref::() { + // FileGroups, and other information is on the FileScanConfig + // parquet + if let Some(parquet_source) = scan_config.file_source.as_any().downcast_ref::() + { + // Information on PruningPredicates and parquet options are here + } +} +# */ +``` + +### Cookbook: Changes to `ParquetExecBuilder` + +Likewise code that builds `ParquetExec` using the `ParquetExecBuilder` such as +the following must be changed: + +```rust +# /* comment to avoid running +let mut exec_plan_builder = ParquetExecBuilder::new( + FileScanConfig::new(self.log_store.object_store_url(), file_schema) + .with_projection(self.projection.cloned()) + .with_limit(self.limit) + .with_table_partition_cols(table_partition_cols), +) +.with_schema_adapter_factory(Arc::new(DeltaSchemaAdapterFactory {})) +.with_table_parquet_options(parquet_options); + +// Add filter +if let Some(predicate) = logical_filter { + if config.enable_parquet_pushdown { + exec_plan_builder = exec_plan_builder.with_predicate(predicate); + } +}; +# */ +``` + +New code should use `FileScanConfig` to build the appropriate `DataSourceExec`: + +```rust +# /* comment to avoid running +let mut file_source = ParquetSource::new(parquet_options) + .with_schema_adapter_factory(Arc::new(DeltaSchemaAdapterFactory {})); + +// Add filter +if let Some(predicate) = logical_filter { + if config.enable_parquet_pushdown { + file_source = file_source.with_predicate(predicate); + } +}; + +let file_scan_config = FileScanConfig::new( + self.log_store.object_store_url(), + file_schema, + Arc::new(file_source), +) +.with_statistics(stats) +.with_projection(self.projection.cloned()) +.with_limit(self.limit) +.with_table_partition_cols(table_partition_cols); + +// Build the actual scan like this +parquet_scan: file_scan_config.build(), +# */ +``` + +### `datafusion-cli` no longer automatically unescapes strings + +`datafusion-cli` previously would incorrectly unescape string literals (see [ticket] for more details). + +To escape `'` in SQL literals, use `''`: + +```sql +> select 'it''s escaped'; ++----------------------+ +| Utf8("it's escaped") | ++----------------------+ +| it's escaped | ++----------------------+ +1 row(s) fetched. +``` + +To include special characters (such as newlines via `\n`) you can use an `E` literal string. For example + +```sql +> select 'foo\nbar'; ++------------------+ +| Utf8("foo\nbar") | ++------------------+ +| foo\nbar | ++------------------+ +1 row(s) fetched. +Elapsed 0.005 seconds. +``` + +### Changes to array scalar function signatures + +DataFusion 46 has changed the way scalar array function signatures are +declared. Previously, functions needed to select from a list of predefined +signatures within the `ArrayFunctionSignature` enum. Now the signatures +can be defined via a `Vec` of pseudo-types, which each correspond to a +single argument. Those pseudo-types are the variants of the +`ArrayFunctionArgument` enum and are as follows: + +- `Array`: An argument of type List/LargeList/FixedSizeList. All Array + arguments must be coercible to the same type. +- `Element`: An argument that is coercible to the inner type of the `Array` + arguments. +- `Index`: An `Int64` argument. + +Each of the old variants can be converted to the new format as follows: + +`TypeSignature::ArraySignature(ArrayFunctionSignature::ArrayAndElement)`: + +```rust +# use datafusion::common::utils::ListCoercion; +# use datafusion_expr_common::signature::{ArrayFunctionArgument, ArrayFunctionSignature, TypeSignature}; + +TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array, ArrayFunctionArgument::Element], + array_coercion: Some(ListCoercion::FixedSizedListToList), +}); +``` + +`TypeSignature::ArraySignature(ArrayFunctionSignature::ElementAndArray)`: + +```rust +# use datafusion::common::utils::ListCoercion; +# use datafusion_expr_common::signature::{ArrayFunctionArgument, ArrayFunctionSignature, TypeSignature}; + +TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Element, ArrayFunctionArgument::Array], + array_coercion: Some(ListCoercion::FixedSizedListToList), +}); +``` + +`TypeSignature::ArraySignature(ArrayFunctionSignature::ArrayAndIndex)`: + +```rust +# use datafusion::common::utils::ListCoercion; +# use datafusion_expr_common::signature::{ArrayFunctionArgument, ArrayFunctionSignature, TypeSignature}; + +TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array, ArrayFunctionArgument::Index], + array_coercion: None, +}); +``` + +`TypeSignature::ArraySignature(ArrayFunctionSignature::ArrayAndElementAndOptionalIndex)`: + +```rust +# use datafusion::common::utils::ListCoercion; +# use datafusion_expr_common::signature::{ArrayFunctionArgument, ArrayFunctionSignature, TypeSignature}; + +TypeSignature::OneOf(vec![ + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array, ArrayFunctionArgument::Element], + array_coercion: None, + }), + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Element, + ArrayFunctionArgument::Index, + ], + array_coercion: None, + }), +]); +``` + +`TypeSignature::ArraySignature(ArrayFunctionSignature::Array)`: + +```rust +# use datafusion::common::utils::ListCoercion; +# use datafusion_expr_common::signature::{ArrayFunctionArgument, ArrayFunctionSignature, TypeSignature}; + +TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array], + array_coercion: None, +}); +``` + +Alternatively, you can switch to using one of the following functions which +take care of constructing the `TypeSignature` for you: + +- `Signature::array_and_element` +- `Signature::array_and_element_and_optional_index` +- `Signature::array_and_index` +- `Signature::array` + +[ticket]: https://github.com/apache/datafusion/issues/13286 diff --git a/docs/source/library-user-guide/upgrading/47.0.0.md b/docs/source/library-user-guide/upgrading/47.0.0.md new file mode 100644 index 0000000000000..354b6740df02f --- /dev/null +++ b/docs/source/library-user-guide/upgrading/47.0.0.md @@ -0,0 +1,135 @@ + + +# Upgrade Guides + +## DataFusion 47.0.0 + +This section calls out some of the major changes in the `47.0.0` release of DataFusion. + +Here are some example upgrade PRs that demonstrate changes required when upgrading from DataFusion 46.0.0: + +- [delta-rs Upgrade to `47.0.0`](https://github.com/delta-io/delta-rs/pull/3378) +- [DataFusion Comet Upgrade to `47.0.0`](https://github.com/apache/datafusion-comet/pull/1563) +- [Sail Upgrade to `47.0.0`](https://github.com/lakehq/sail/pull/434) + +### Upgrades to `arrow-rs` and `arrow-parquet` 55.0.0 and `object_store` 0.12.0 + +Several APIs are changed in the underlying arrow and parquet libraries to use a +`u64` instead of `usize` to better support WASM (See [#7371] and [#6961]) + +Additionally `ObjectStore::list` and `ObjectStore::list_with_offset` have been changed to return `static` lifetimes (See [#6619]) + +[#6619]: https://github.com/apache/arrow-rs/pull/6619 +[#7371]: https://github.com/apache/arrow-rs/pull/7371 + +This requires converting from `usize` to `u64` occasionally as well as changes to `ObjectStore` implementations such as + +```rust +# /* comment to avoid running +impl Objectstore { + ... + // The range is now a u64 instead of usize + async fn get_range(&self, location: &Path, range: Range) -> ObjectStoreResult { + self.inner.get_range(location, range).await + } + ... + // the lifetime is now 'static instead of `_ (meaning the captured closure can't contain references) + // (this also applies to list_with_offset) + fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, ObjectStoreResult> { + self.inner.list(prefix) + } +} +# */ +``` + +The `ParquetObjectReader` has been updated to no longer require the object size +(it can be fetched using a single suffix request). See [#7334] for details + +[#7334]: https://github.com/apache/arrow-rs/pull/7334 + +Pattern in DataFusion `46.0.0`: + +```rust +# /* comment to avoid running +let meta: ObjectMeta = ...; +let reader = ParquetObjectReader::new(store, meta); +# */ +``` + +Pattern in DataFusion `47.0.0`: + +```rust +# /* comment to avoid running +let meta: ObjectMeta = ...; +let reader = ParquetObjectReader::new(store, location) + .with_file_size(meta.size); +# */ +``` + +### `DisplayFormatType::TreeRender` + +DataFusion now supports [`tree` style explain plans]. Implementations of +`Executionplan` must also provide a description in the +`DisplayFormatType::TreeRender` format. This can be the same as the existing +`DisplayFormatType::Default`. + +[`tree` style explain plans]: https://datafusion.apache.org/user-guide/sql/explain.html#tree-format-default + +### Removed Deprecated APIs + +Several APIs have been removed in this release. These were either deprecated +previously or were hard to use correctly such as the multiple different +`ScalarUDFImpl::invoke*` APIs. See [#15130], [#15123], and [#15027] for more +details. + +[#15130]: https://github.com/apache/datafusion/pull/15130 +[#15123]: https://github.com/apache/datafusion/pull/15123 +[#15027]: https://github.com/apache/datafusion/pull/15027 + +### `FileScanConfig` --> `FileScanConfigBuilder` + +Previously, `FileScanConfig::build()` directly created ExecutionPlans. In +DataFusion 47.0.0 this has been changed to use `FileScanConfigBuilder`. See +[#15352] for details. + +[#15352]: https://github.com/apache/datafusion/pull/15352 + +Pattern in DataFusion `46.0.0`: + +```rust +# /* comment to avoid running +let plan = FileScanConfig::new(url, schema, Arc::new(file_source)) + .with_statistics(stats) + ... + .build() +# */ +``` + +Pattern in DataFusion `47.0.0`: + +```rust +# /* comment to avoid running +let config = FileScanConfigBuilder::new(url, Arc::new(file_source)) + .with_statistics(stats) + ... + .build(); +let scan = DataSourceExec::from_data_source(config); +# */ +``` diff --git a/docs/source/library-user-guide/upgrading/48.0.0.md b/docs/source/library-user-guide/upgrading/48.0.0.md new file mode 100644 index 0000000000000..7872a6f54f245 --- /dev/null +++ b/docs/source/library-user-guide/upgrading/48.0.0.md @@ -0,0 +1,244 @@ + + +# Upgrade Guides + +## DataFusion 48.0.0 + +### `Expr::Literal` has optional metadata + +The [`Expr::Literal`] variant now includes optional metadata, which allows for +carrying through Arrow field metadata to support extension types and other uses. + +This means code such as + +```rust +# /* comment to avoid running +match expr { +... + Expr::Literal(scalar) => ... +... +} +# */ +``` + +Should be updated to: + +```rust +# /* comment to avoid running +match expr { +... + Expr::Literal(scalar, _metadata) => ... +... +} +# */ +``` + +Likewise constructing `Expr::Literal` requires metadata as well. The [`lit`] function +has not changed and returns an `Expr::Literal` with no metadata. + +[`expr::literal`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.Expr.html#variant.Literal +[`lit`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.lit.html + +### `Expr::WindowFunction` is now `Box`ed + +`Expr::WindowFunction` is now a `Box` instead of a `WindowFunction` directly. +This change was made to reduce the size of `Expr` and improve performance when +planning queries (see [details on #16207]). + +This is a breaking change, so you will need to update your code if you match +on `Expr::WindowFunction` directly. For example, if you have code like this: + +```rust +# /* comment to avoid running +match expr { + Expr::WindowFunction(WindowFunction { + params: + WindowFunctionParams { + partition_by, + order_by, + .. + } + }) => { + // Use partition_by and order_by as needed + } + _ => { + // other expr + } +} +# */ +``` + +You will need to change it to: + +```rust +# /* comment to avoid running +match expr { + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + params: WindowFunctionParams { + args, + partition_by, + .. + }, + } = window_fun.as_ref(); + // Use partition_by and order_by as needed + } + _ => { + // other expr + } +} +# */ +``` + +[details on #16207]: https://github.com/apache/datafusion/pull/16207#issuecomment-2922659103 + +### The `VARCHAR` SQL type is now represented as `Utf8View` in Arrow + +The mapping of the SQL `VARCHAR` type has been changed from `Utf8` to `Utf8View` +which improves performance for many string operations. You can read more about +`Utf8View` in the [DataFusion blog post on German-style strings] + +[datafusion blog post on german-style strings]: https://datafusion.apache.org/blog/2024/09/13/string-view-german-style-strings-part-1/ + +This means that when you create a table with a `VARCHAR` column, it will now use +`Utf8View` as the underlying data type. For example: + +```sql +> CREATE TABLE my_table (my_column VARCHAR); +0 row(s) fetched. +Elapsed 0.001 seconds. + +> DESCRIBE my_table; ++-------------+-----------+-------------+ +| column_name | data_type | is_nullable | ++-------------+-----------+-------------+ +| my_column | Utf8View | YES | ++-------------+-----------+-------------+ +1 row(s) fetched. +Elapsed 0.000 seconds. +``` + +You can restore the old behavior of using `Utf8` by changing the +`datafusion.sql_parser.map_varchar_to_utf8view` configuration setting. For +example + +```sql +> set datafusion.sql_parser.map_varchar_to_utf8view = false; +0 row(s) fetched. +Elapsed 0.001 seconds. + +> CREATE TABLE my_table (my_column VARCHAR); +0 row(s) fetched. +Elapsed 0.014 seconds. + +> DESCRIBE my_table; ++-------------+-----------+-------------+ +| column_name | data_type | is_nullable | ++-------------+-----------+-------------+ +| my_column | Utf8 | YES | ++-------------+-----------+-------------+ +1 row(s) fetched. +Elapsed 0.004 seconds. +``` + +### `ListingOptions` default for `collect_stat` changed from `true` to `false` + +This makes it agree with the default for `SessionConfig`. +Most users won't be impacted by this change but if you were using `ListingOptions` directly +and relied on the default value of `collect_stat` being `true`, you will need to +explicitly set it to `true` in your code. + +```rust +# /* comment to avoid running +ListingOptions::new(Arc::new(ParquetFormat::default())) + .with_collect_stat(true) + // other options +# */ +``` + +### Processing `FieldRef` instead of `DataType` for user defined functions + +In order to support metadata handling and extension types, user defined functions are +now switching to traits which use `FieldRef` rather than a `DataType` and nullability. +This gives a single interface to both of these parameters and additionally allows +access to metadata fields, which can be used for extension types. + +To upgrade structs which implement `ScalarUDFImpl`, if you have implemented +`return_type_from_args` you need instead to implement `return_field_from_args`. +If your functions do not need to handle metadata, this should be straightforward +repackaging of the output data into a `FieldRef`. The name you specify on the +field is not important. It will be overwritten during planning. `ReturnInfo` +has been removed, so you will need to remove all references to it. + +`ScalarFunctionArgs` now contains a field called `arg_fields`. You can use this +to access the metadata associated with the columnar values during invocation. + +To upgrade user defined aggregate functions, there is now a function +`return_field` that will allow you to specify both metadata and nullability of +your function. You are not required to implement this if you do not need to +handle metadata. + +The largest change to aggregate functions happens in the accumulator arguments. +Both the `AccumulatorArgs` and `StateFieldsArgs` now contain `FieldRef` rather +than `DataType`. + +To upgrade window functions, `ExpressionArgs` now contains input fields instead +of input data types. When setting these fields, the name of the field is +not important since this gets overwritten during the planning stage. All you +should need to do is wrap your existing data types in fields with nullability +set depending on your use case. + +### Physical Expression return `Field` + +To support the changes to user defined functions processing metadata, the +`PhysicalExpr` trait, which now must specify a return `Field` based on the input +schema. To upgrade structs which implement `PhysicalExpr` you need to implement +the `return_field` function. There are numerous examples in the `physical-expr` +crate. + +### `FileFormat::supports_filters_pushdown` replaced with `FileSource::try_pushdown_filters` + +To support more general filter pushdown, the `FileFormat::supports_filters_pushdown` was replaced with +`FileSource::try_pushdown_filters`. +If you implemented a custom `FileFormat` that uses a custom `FileSource` you will need to implement +`FileSource::try_pushdown_filters`. +See `ParquetSource::try_pushdown_filters` for an example of how to implement this. + +`FileFormat::supports_filters_pushdown` has been removed. + +### `ParquetExec`, `AvroExec`, `CsvExec`, `JsonExec` Removed + +`ParquetExec`, `AvroExec`, `CsvExec`, and `JsonExec` were deprecated in +DataFusion 46 and are removed in DataFusion 48. This is sooner than the normal +process described in the [API Deprecation Guidelines] because all the tests +cover the new `DataSourceExec` rather than the older structures. As we evolve +`DataSource`, the old structures began to show signs of "bit rotting" (not +working but no one knows due to lack of test coverage). + +[api deprecation guidelines]: https://datafusion.apache.org/contributor-guide/api-health.html#deprecation-guidelines + +### `PartitionedFile` added as an argument to the `FileOpener` trait + +This is necessary to properly fix filter pushdown for filters that combine partition +columns and file columns (e.g. `day = username['dob']`). + +If you implemented a custom `FileOpener` you will need to add the `PartitionedFile` argument +but are not required to use it in any way. diff --git a/docs/source/library-user-guide/upgrading/48.0.1.md b/docs/source/library-user-guide/upgrading/48.0.1.md new file mode 100644 index 0000000000000..5dfb9e1e3d0b1 --- /dev/null +++ b/docs/source/library-user-guide/upgrading/48.0.1.md @@ -0,0 +1,39 @@ + + +# Upgrade Guides + +## DataFusion 48.0.1 + +### `datafusion.execution.collect_statistics` now defaults to `true` + +The default value of the `datafusion.execution.collect_statistics` configuration +setting is now true. This change impacts users that use that value directly and relied +on its default value being `false`. + +This change also restores the default behavior of `ListingTable` to its previous. If you use it directly +you can maintain the current behavior by overriding the default value in your code. + +```rust +# /* comment to avoid running +ListingOptions::new(Arc::new(ParquetFormat::default())) + .with_collect_stat(false) + // other options +# */ +``` diff --git a/docs/source/library-user-guide/upgrading/49.0.0.md b/docs/source/library-user-guide/upgrading/49.0.0.md new file mode 100644 index 0000000000000..92dee8135590a --- /dev/null +++ b/docs/source/library-user-guide/upgrading/49.0.0.md @@ -0,0 +1,222 @@ + + +# Upgrade Guides + +## DataFusion 49.0.0 + +### `MSRV` updated to 1.85.1 + +The Minimum Supported Rust Version (MSRV) has been updated to [`1.85.1`]. See +[#16728] for details. + +[`1.85.1`]: https://releases.rs/docs/1.85.1/ +[#16728]: https://github.com/apache/datafusion/pull/16728 + +### `DataFusionError` variants are now `Box`ed + +To reduce the size of `DataFusionError`, several variants that were previously stored inline are now `Box`ed. This reduces the size of `Result` and thus stack usage and async state machine size. Please see [#16652] for more details. + +The following variants of `DataFusionError` are now boxed: + +- `ArrowError` +- `SQL` +- `SchemaError` + +This is a breaking change. Code that constructs or matches on these variants will need to be updated. + +For example, to create a `SchemaError`, instead of: + +```rust +# /* comment to avoid running +use datafusion_common::{DataFusionError, SchemaError}; +DataFusionError::SchemaError( + SchemaError::DuplicateUnqualifiedField { name: "foo".to_string() }, + Box::new(None) +) +# */ +``` + +You now need to `Box` the inner error: + +```rust +# /* comment to avoid running +use datafusion_common::{DataFusionError, SchemaError}; +DataFusionError::SchemaError( + Box::new(SchemaError::DuplicateUnqualifiedField { name: "foo".to_string() }), + Box::new(None) +) +# */ +``` + +[#16652]: https://github.com/apache/datafusion/issues/16652 + +### Metadata on Arrow Types is now represented by `FieldMetadata` + +Metadata from the Arrow `Field` is now stored using the `FieldMetadata` +structure. In prior versions it was stored as both a `HashMap` +and a `BTreeMap`. `FieldMetadata` is a easier to work with and +is more efficient. + +To create `FieldMetadata` from a `Field`: + +```rust +# /* comment to avoid running + let metadata = FieldMetadata::from(&field); +# */ +``` + +To add metadata to a `Field`, use the `add_to_field` method: + +```rust +# /* comment to avoid running +let updated_field = metadata.add_to_field(field); +# */ +``` + +See [#16317] for details. + +[#16317]: https://github.com/apache/datafusion/pull/16317 + +### New `datafusion.execution.spill_compression` configuration option + +DataFusion 49.0.0 adds support for compressing spill files when data is written to disk during spilling query execution. A new configuration option `datafusion.execution.spill_compression` controls the compression codec used. + +**Configuration:** + +- **Key**: `datafusion.execution.spill_compression` +- **Default**: `uncompressed` +- **Valid values**: `uncompressed`, `lz4_frame`, `zstd` + +**Usage:** + +```rust +# /* comment to avoid running +use datafusion::prelude::*; +use datafusion_common::config::SpillCompression; + +let config = SessionConfig::default() + .with_spill_compression(SpillCompression::Zstd); +let ctx = SessionContext::new_with_config(config); +# */ +``` + +Or via SQL: + +```sql +SET datafusion.execution.spill_compression = 'zstd'; +``` + +For more details about this configuration option, including performance trade-offs between different compression codecs, see the [Configuration Settings](../../user-guide/configs) documentation. + +### Deprecated `map_varchar_to_utf8view` configuration option + +See [issue #16290](https://github.com/apache/datafusion/pull/16290) for more information +The old configuration + +```text +datafusion.sql_parser.map_varchar_to_utf8view +``` + +is now **deprecated** in favor of the unified option below.\ +If you previously used this to control only `VARCHAR`→`Utf8View` mapping, please migrate to `map_string_types_to_utf8view`. + +--- + +### New `map_string_types_to_utf8view` configuration option + +To unify **all** SQL string types (`CHAR`, `VARCHAR`, `TEXT`, `STRING`) to Arrow’s zero‑copy `Utf8View`, DataFusion 49.0.0 introduces: + +- **Key**: `datafusion.sql_parser.map_string_types_to_utf8view` +- **Default**: `true` + +**Description:** + +- When **true** (default), **all** SQL string types are mapped to `Utf8View`, avoiding full‑copy UTF‑8 allocations and improving performance. +- When **false**, DataFusion falls back to the legacy `Utf8` mapping for **all** string types. + +#### Examples + +```rust +# /* comment to avoid running +// Disable Utf8View mapping for all SQL string types +let opts = datafusion::sql::planner::ParserOptions::new() + .with_map_string_types_to_utf8view(false); + +// Verify the setting is applied +assert!(!opts.map_string_types_to_utf8view); +# */ +``` + +--- + +```sql +-- Disable Utf8View mapping globally +SET datafusion.sql_parser.map_string_types_to_utf8view = false; + +-- Now VARCHAR, CHAR, TEXT, STRING all use Utf8 rather than Utf8View +CREATE TABLE my_table (a VARCHAR, b TEXT, c STRING); +DESCRIBE my_table; +``` + +### Deprecating `SchemaAdapterFactory` and `SchemaAdapter` + +We are moving away from converting data (using `SchemaAdapter`) to converting the expressions themselves (which is more efficient and flexible). + +See [issue #16800](https://github.com/apache/datafusion/issues/16800) for more information +The first place this change has taken place is in predicate pushdown for Parquet. +By default if you do not use a custom `SchemaAdapterFactory` we will use expression conversion instead. +If you do set a custom `SchemaAdapterFactory` we will continue to use it but emit a warning about that code path being deprecated. + +To resolve this you need to implement a custom `PhysicalExprAdapterFactory` and use that instead of a `SchemaAdapterFactory`. +See the [default values](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/custom_data_source/default_column_values.rs) for an example of how to do this. +Opting into the new APIs will set you up for future changes since we plan to expand use of `PhysicalExprAdapterFactory` to other areas of DataFusion. + +See [#16800] for details. + +[#16800]: https://github.com/apache/datafusion/issues/16800 + +### `TableParquetOptions` Updated + +The `TableParquetOptions` struct has a new `crypto` field to specify encryption +options for Parquet files. The `ParquetEncryptionOptions` implements `Default` +so you can upgrade your existing code like this: + +```rust +# /* comment to avoid running +TableParquetOptions { + global, + column_specific_options, + key_value_metadata, +} +# */ +``` + +To this: + +```rust +# /* comment to avoid running +TableParquetOptions { + global, + column_specific_options, + key_value_metadata, + crypto: Default::default(), // New crypto field +} +# */ +``` diff --git a/docs/source/library-user-guide/upgrading/50.0.0.md b/docs/source/library-user-guide/upgrading/50.0.0.md new file mode 100644 index 0000000000000..d8155dab58962 --- /dev/null +++ b/docs/source/library-user-guide/upgrading/50.0.0.md @@ -0,0 +1,330 @@ + + +# Upgrade Guides + +## DataFusion 50.0.0 + +### ListingTable automatically detects Hive Partitioned tables + +DataFusion 50.0.0 automatically infers Hive partitions when using the `ListingTableFactory` and `CREATE EXTERNAL TABLE`. Previously, +when creating a `ListingTable`, datasets that use Hive partitioning (e.g. +`/table_root/column1=value1/column2=value2/data.parquet`) would not have the Hive columns reflected in +the table's schema or data. The previous behavior can be +restored by setting the `datafusion.execution.listing_table_factory_infer_partitions` configuration option to `false`. +See [issue #17049] for more details. + +[issue #17049]: https://github.com/apache/datafusion/issues/17049 + +### `MSRV` updated to 1.86.0 + +The Minimum Supported Rust Version (MSRV) has been updated to [`1.86.0`]. +See [#17230] for details. + +[`1.86.0`]: https://releases.rs/docs/1.86.0/ +[#17230]: https://github.com/apache/datafusion/pull/17230 + +### `ScalarUDFImpl`, `AggregateUDFImpl` and `WindowUDFImpl` traits now require `PartialEq`, `Eq`, and `Hash` traits + +To address error-proneness of `ScalarUDFImpl::equals`, `AggregateUDFImpl::equals`and +`WindowUDFImpl::equals` methods and to make it easy to implement function equality correctly, +the `equals` and `hash_value` methods have been removed from `ScalarUDFImpl`, `AggregateUDFImpl` +and `WindowUDFImpl` traits. They are replaced the requirement to implement the `PartialEq`, `Eq`, +and `Hash` traits on any type implementing `ScalarUDFImpl`, `AggregateUDFImpl` or `WindowUDFImpl`. +Please see [issue #16677] for more details. + +Most of the scalar functions are stateless and have a `signature` field. These can be migrated +using regular expressions + +- search for `\#\[derive\(Debug\)\](\n *(pub )?struct \w+ \{\n *signature\: Signature\,\n *\})`, +- replace with `#[derive(Debug, PartialEq, Eq, Hash)]$1`, +- review all the changes and make sure only function structs were changed. + +[issue #16677]: https://github.com/apache/datafusion/issues/16677 + +### `AsyncScalarUDFImpl::invoke_async_with_args` returns `ColumnarValue` + +In order to enable single value optimizations and be consistent with other +user defined function APIs, the `AsyncScalarUDFImpl::invoke_async_with_args` method now +returns a `ColumnarValue` instead of a `ArrayRef`. + +To upgrade, change the return type of your implementation + +```rust +# /* comment to avoid running +impl AsyncScalarUDFImpl for AskLLM { + async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + _option: &ConfigOptions, + ) -> Result { + .. + return array_ref; // old code + } +} +# */ +``` + +To return a `ColumnarValue` + +```rust +# /* comment to avoid running +impl AsyncScalarUDFImpl for AskLLM { + async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + _option: &ConfigOptions, + ) -> Result { + .. + return ColumnarValue::from(array_ref); // new code + } +} +# */ +``` + +See [#16896](https://github.com/apache/datafusion/issues/16896) for more details. + +### `ProjectionExpr` changed from type alias to struct + +`ProjectionExpr` has been changed from a type alias to a struct with named fields to improve code clarity and maintainability. + +**Before:** + +```rust,ignore +pub type ProjectionExpr = (Arc, String); +``` + +**After:** + +```rust,ignore +#[derive(Debug, Clone)] +pub struct ProjectionExpr { + pub expr: Arc, + pub alias: String, +} +``` + +To upgrade your code: + +- Replace tuple construction `(expr, alias)` with `ProjectionExpr::new(expr, alias)` or `ProjectionExpr { expr, alias }` +- Replace tuple field access `.0` and `.1` with `.expr` and `.alias` +- Update pattern matching from `(expr, alias)` to `ProjectionExpr { expr, alias }` + +This mainly impacts use of `ProjectionExec`. + +This change was done in [#17398] + +[#17398]: https://github.com/apache/datafusion/pull/17398 + +### `SessionState`, `SessionConfig`, and `OptimizerConfig` returns `&Arc` instead of `&ConfigOptions` + +To provide broader access to `ConfigOptions` and reduce required clones, some +APIs have been changed to return a `&Arc` instead of a +`&ConfigOptions`. This allows sharing the same `ConfigOptions` across multiple +threads without needing to clone the entire `ConfigOptions` structure unless it +is modified. + +Most users will not be impacted by this change since the Rust compiler typically +automatically dereference the `Arc` when needed. However, in some cases you may +have to change your code to explicitly call `as_ref()` for example, from + +```rust +# /* comment to avoid running +let optimizer_config: &ConfigOptions = state.options(); +# */ +``` + +To + +```rust +# /* comment to avoid running +let optimizer_config: &ConfigOptions = state.options().as_ref(); +# */ +``` + +See PR [#16970](https://github.com/apache/datafusion/pull/16970) + +### API Change to `AsyncScalarUDFImpl::invoke_async_with_args` + +The `invoke_async_with_args` method of the `AsyncScalarUDFImpl` trait has been +updated to remove the `_option: &ConfigOptions` parameter to simplify the API +now that the `ConfigOptions` can be accessed through the `ScalarFunctionArgs` +parameter. + +You can change your code like this + +```rust +# /* comment to avoid running +impl AsyncScalarUDFImpl for AskLLM { + async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + _option: &ConfigOptions, + ) -> Result { + .. + } + ... +} +# */ +``` + +To this: + +```rust +# /* comment to avoid running + +impl AsyncScalarUDFImpl for AskLLM { + async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + ) -> Result { + let options = &args.config_options; + .. + } + ... +} +# */ +``` + +### Schema Rewriter Module Moved to New Crate + +The `schema_rewriter` module and its associated symbols have been moved from `datafusion_physical_expr` to a new crate `datafusion_physical_expr_adapter`. This affects the following symbols: + +- `DefaultPhysicalExprAdapter` +- `DefaultPhysicalExprAdapterFactory` +- `PhysicalExprAdapter` +- `PhysicalExprAdapterFactory` + +To upgrade, change your imports to: + +```rust +use datafusion_physical_expr_adapter::{ + DefaultPhysicalExprAdapter, DefaultPhysicalExprAdapterFactory, + PhysicalExprAdapter, PhysicalExprAdapterFactory +}; +``` + +### Upgrade to arrow `56.0.0` and parquet `56.0.0` + +This version of DataFusion upgrades the underlying Apache Arrow implementation +to version `56.0.0`. See the [release notes](https://github.com/apache/arrow-rs/releases/tag/56.0.0) +for more details. + +### Added `ExecutionPlan::reset_state` + +In order to fix a bug in DataFusion `49.0.0` where dynamic filters (currently only generated in the presence of a query such as `ORDER BY ... LIMIT ...`) +produced incorrect results in recursive queries, a new method `reset_state` has been added to the `ExecutionPlan` trait. + +Any `ExecutionPlan` that needs to maintain internal state or references to other nodes in the execution plan tree should implement this method to reset that state. +See [#17028] for more details and an example implementation for `SortExec`. + +[#17028]: https://github.com/apache/datafusion/pull/17028 + +### Nested Loop Join input sort order cannot be preserved + +The Nested Loop Join operator has been rewritten from scratch to improve performance and memory efficiency. From the micro-benchmarks: this change introduces up to 5X speed-up and uses only 1% memory in extreme cases compared to the previous implementation. + +However, the new implementation cannot preserve input sort order like the old version could. This is a fundamental design trade-off that prioritizes performance and memory efficiency over sort order preservation. + +See [#16996] for details. + +[#16996]: https://github.com/apache/datafusion/pull/16996 + +### Add `as_any()` method to `LazyBatchGenerator` + +To help with protobuf serialization, the `as_any()` method has been added to the `LazyBatchGenerator` trait. This means you will need to add `as_any()` to your implementation of `LazyBatchGenerator`: + +```rust +# /* comment to avoid running + +impl LazyBatchGenerator for MyBatchGenerator { + fn as_any(&self) -> &dyn Any { + self + } + + ... +} + +# */ +``` + +See [#17200](https://github.com/apache/datafusion/pull/17200) for details. + +### Refactored `DataSource::try_swapping_with_projection` + +We refactored `DataSource::try_swapping_with_projection` to simplify the method and minimize leakage across the ExecutionPlan <-> DataSource abstraction layer. +Reimplementation for any custom `DataSource` should be relatively straightforward, see [#17395] for more details. + +[#17395]: https://github.com/apache/datafusion/pull/17395/ + +### `FileOpenFuture` now uses `DataFusionError` instead of `ArrowError` + +The `FileOpenFuture` type alias has been updated to use `DataFusionError` instead of `ArrowError` for its error type. This change affects the `FileOpener` trait and any implementations that work with file streaming operations. + +**Before:** + +```rust,ignore +pub type FileOpenFuture = BoxFuture<'static, Result>>>; +``` + +**After:** + +```rust,ignore +pub type FileOpenFuture = BoxFuture<'static, Result>>>; +``` + +If you have custom implementations of `FileOpener` or work directly with `FileOpenFuture`, you'll need to update your error handling to use `DataFusionError` instead of `ArrowError`. The `FileStreamState` enum's `Open` variant has also been updated accordingly. See [#17397] for more details. + +[#17397]: https://github.com/apache/datafusion/pull/17397 + +### FFI user defined aggregate function signature change + +The Foreign Function Interface (FFI) signature for user defined aggregate functions +has been updated to call `return_field` instead of `return_type` on the underlying +aggregate function. This is to support metadata handling with these aggregate functions. +This change should be transparent to most users. If you have written unit tests to call +`return_type` directly, you may need to change them to calling `return_field` instead. + +This update is a breaking change to the FFI API. The current best practice when using the +FFI crate is to ensure that all libraries that are interacting are using the same +underlying Rust version. Issue [#17374] has been opened to discuss stabilization of +this interface so that these libraries can be used across different DataFusion versions. + +See [#17407] for details. + +[#17407]: https://github.com/apache/datafusion/pull/17407 +[#17374]: https://github.com/apache/datafusion/issues/17374 + +### Added `PhysicalExpr::is_volatile_node` + +We added a method to `PhysicalExpr` to mark a `PhysicalExpr` as volatile: + +```rust,ignore +impl PhysicalExpr for MyRandomExpr { + fn is_volatile_node(&self) -> bool { + true + } +} +``` + +We've shipped this with a default value of `false` to minimize breakage but we highly recommend that implementers of `PhysicalExpr` opt into a behavior, even if it is returning `false`. + +You can see more discussion and example implementations in [#17351]. + +[#17351]: https://github.com/apache/datafusion/pull/17351 diff --git a/docs/source/library-user-guide/upgrading/51.0.0.md b/docs/source/library-user-guide/upgrading/51.0.0.md new file mode 100644 index 0000000000000..c3acfe15c493f --- /dev/null +++ b/docs/source/library-user-guide/upgrading/51.0.0.md @@ -0,0 +1,272 @@ + + +# Upgrade Guides + +## DataFusion 51.0.0 + +### `arrow` / `parquet` updated to 57.0.0 + +### Upgrade to arrow `57.0.0` and parquet `57.0.0` + +This version of DataFusion upgrades the underlying Apache Arrow implementation +to version `57.0.0`, including several dependent crates such as `prost`, +`tonic`, `pyo3`, and `substrait`. . See the [release +notes](https://github.com/apache/arrow-rs/releases/tag/57.0.0) for more details. + +### `MSRV` updated to 1.88.0 + +The Minimum Supported Rust Version (MSRV) has been updated to [`1.88.0`]. + +[`1.88.0`]: https://releases.rs/docs/1.88.0/ + +### `FunctionRegistry` exposes two additional methods + +`FunctionRegistry` exposes two additional methods `udafs` and `udwfs` which expose set of registered user defined aggregation and window function names. To upgrade implement methods returning set of registered function names: + +```diff +impl FunctionRegistry for FunctionRegistryImpl { + fn udfs(&self) -> HashSet { + self.scalar_functions.keys().cloned().collect() + } ++ fn udafs(&self) -> HashSet { ++ self.aggregate_functions.keys().cloned().collect() ++ } ++ ++ fn udwfs(&self) -> HashSet { ++ self.window_functions.keys().cloned().collect() ++ } +} +``` + +### `datafusion-proto` use `TaskContext` rather than `SessionContext` in physical plan serde methods + +There have been changes in the public API methods of `datafusion-proto` which handle physical plan serde. + +Methods like `physical_plan_from_bytes`, `parse_physical_expr` and similar, expect `TaskContext` instead of `SessionContext` + +```diff +- let plan2 = physical_plan_from_bytes(&bytes, &ctx)?; ++ let plan2 = physical_plan_from_bytes(&bytes, &ctx.task_ctx())?; +``` + +as `TaskContext` contains `RuntimeEnv` methods such as `try_into_physical_plan` will not have explicit `RuntimeEnv` parameter. + +```diff +let result_exec_plan: Arc = proto +- .try_into_physical_plan(&ctx, runtime.deref(), &composed_codec) ++. .try_into_physical_plan(&ctx.task_ctx(), &composed_codec) +``` + +`PhysicalExtensionCodec::try_decode()` expects `TaskContext` instead of `FunctionRegistry`: + +```diff +pub trait PhysicalExtensionCodec { + fn try_decode( + &self, + buf: &[u8], + inputs: &[Arc], +- registry: &dyn FunctionRegistry, ++ ctx: &TaskContext, + ) -> Result>; +``` + +See [issue #17601] for more details. + +[issue #17601]: https://github.com/apache/datafusion/issues/17601 + +### `SessionState`'s `sql_to_statement` method takes `Dialect` rather than a `str` + +The `dialect` parameter of `sql_to_statement` method defined in `datafusion::execution::session_state::SessionState` +has changed from `&str` to `&Dialect`. +`Dialect` is an enum defined in the `datafusion-common` +crate under the `config` module that provides type safety +and better validation for SQL dialect selection + +### Reorganization of `ListingTable` into `datafusion-catalog-listing` crate + +There has been a long standing request to remove features such as `ListingTable` +from the `datafusion` crate to support faster build times. The structs +`ListingOptions`, `ListingTable`, and `ListingTableConfig` are now available +within the `datafusion-catalog-listing` crate. These are re-exported in +the `datafusion` crate, so this should be a minimal impact to existing users. + +See [issue #14462] and [issue #17713] for more details. + +[issue #14462]: https://github.com/apache/datafusion/issues/14462 +[issue #17713]: https://github.com/apache/datafusion/issues/17713 + +### Reorganization of `ArrowSource` into `datafusion-datasource-arrow` crate + +To support [issue #17713] the `ArrowSource` code has been removed from +the `datafusion` core crate into it's own crate, `datafusion-datasource-arrow`. +This follows the pattern for the AVRO, CSV, JSON, and Parquet data sources. +Users may need to update their paths to account for these changes. + +See [issue #17713] for more details. + +### `FileScanConfig::projection` renamed to `FileScanConfig::projection_exprs` + +The `projection` field in `FileScanConfig` has been renamed to `projection_exprs` and its type has changed from `Option>` to `Option`. This change enables more powerful projection pushdown capabilities by supporting arbitrary physical expressions rather than just column indices. + +**Impact on direct field access:** + +If you directly access the `projection` field: + +```rust,ignore +let config: FileScanConfig = ...; +let projection = config.projection; +``` + +You should update to: + +```rust,ignore +let config: FileScanConfig = ...; +let projection_exprs = config.projection_exprs; +``` + +**Impact on builders:** + +The `FileScanConfigBuilder::with_projection()` method has been deprecated in favor of `with_projection_indices()`: + +```diff +let config = FileScanConfigBuilder::new(url, file_source) +- .with_projection(Some(vec![0, 2, 3])) ++ .with_projection_indices(Some(vec![0, 2, 3])) + .build(); +``` + +Note: `with_projection()` still works but is deprecated and will be removed in a future release. + +**What is `ProjectionExprs`?** + +`ProjectionExprs` is a new type that represents a list of physical expressions for projection. While it can be constructed from column indices (which is what `with_projection_indices` does internally), it also supports arbitrary physical expressions, enabling advanced features like expression evaluation during scanning. + +You can access column indices from `ProjectionExprs` using its methods if needed: + +```rust,ignore +let projection_exprs: ProjectionExprs = ...; +// Get the column indices if the projection only contains simple column references +let indices = projection_exprs.column_indices(); +``` + +### `DESCRIBE query` support + +`DESCRIBE query` was previously an alias for `EXPLAIN query`, which outputs the +_execution plan_ of the query. With this release, `DESCRIBE query` now outputs +the computed _schema_ of the query, consistent with the behavior of `DESCRIBE table_name`. + +### `datafusion.execution.time_zone` default configuration changed + +The default value for `datafusion.execution.time_zone` previously was a string value of `+00:00` (GMT/Zulu time). +This was changed to be an `Option` with a default of `None`. If you want to change the timezone back +to the previous value you can execute the sql: + +```sql +SET +TIMEZONE = '+00:00'; +``` + +This change was made to better support using the default timezone in scalar UDF functions such as +`now`, `current_date`, `current_time`, and `to_timestamp` among others. + +### Introduction of `TableSchema` and changes to `FileSource::with_schema()` method + +A new `TableSchema` struct has been introduced in the `datafusion-datasource` crate to better manage table schemas with partition columns. This struct helps distinguish between: + +- **File schema**: The schema of actual data files on disk +- **Partition columns**: Columns derived from directory structure (e.g., Hive-style partitioning) +- **Table schema**: The complete schema combining both file and partition columns + +As part of this change, the `FileSource::with_schema()` method signature has changed from accepting a `SchemaRef` to accepting a `TableSchema`. + +**Who is affected:** + +- Users who have implemented custom `FileSource` implementations will need to update their code +- Users who only use built-in file sources (Parquet, CSV, JSON, AVRO, Arrow) are not affected + +**Migration guide for custom `FileSource` implementations:** + +```diff + use datafusion_datasource::file::FileSource; +-use arrow::datatypes::SchemaRef; ++use datafusion_datasource::TableSchema; + + impl FileSource for MyCustomSource { +- fn with_schema(&self, schema: SchemaRef) -> Arc { ++ fn with_schema(&self, schema: TableSchema) -> Arc { + Arc::new(Self { +- schema: Some(schema), ++ // Use schema.file_schema() to get the file schema without partition columns ++ schema: Some(Arc::clone(schema.file_schema())), + ..self.clone() + }) + } + } +``` + +For implementations that need access to partition columns: + +```rust,ignore +fn with_schema(&self, schema: TableSchema) -> Arc { + Arc::new(Self { + file_schema: Arc::clone(schema.file_schema()), + partition_cols: schema.table_partition_cols().clone(), + table_schema: Arc::clone(schema.table_schema()), + ..self.clone() + }) +} +``` + +**Note**: Most `FileSource` implementations only need to store the file schema (without partition columns), as shown in the first example. The second pattern of storing all three schema components is typically only needed for advanced use cases where you need access to different schema representations for different operations (e.g., ParquetSource uses the file schema for building pruning predicates but needs the table schema for filter pushdown logic). + +**Using `TableSchema` directly:** + +If you're constructing a `FileScanConfig` or working with table schemas and partition columns, you can now use `TableSchema`: + +```rust +use datafusion_datasource::TableSchema; +use arrow::datatypes::{Schema, Field, DataType}; +use std::sync::Arc; + +// Create a TableSchema with partition columns +let file_schema = Arc::new(Schema::new(vec![ + Field::new("user_id", DataType::Int64, false), + Field::new("amount", DataType::Float64, false), +])); + +let partition_cols = vec![ + Arc::new(Field::new("date", DataType::Utf8, false)), + Arc::new(Field::new("region", DataType::Utf8, false)), +]; + +let table_schema = TableSchema::new(file_schema, partition_cols); + +// Access different schema representations +let file_schema_ref = table_schema.file_schema(); // Schema without partition columns +let full_schema = table_schema.table_schema(); // Complete schema with partition columns +let partition_cols_ref = table_schema.table_partition_cols(); // Just the partition columns +``` + +### `AggregateUDFImpl::is_ordered_set_aggregate` has been renamed to `AggregateUDFImpl::supports_within_group_clause` + +This method has been renamed to better reflect the actual impact it has for aggregate UDF implementations. +The accompanying `AggregateUDF::is_ordered_set_aggregate` has also been renamed to `AggregateUDF::supports_within_group_clause`. +No functionality has been changed with regards to this method; it still refers only to permitting use of `WITHIN GROUP` +SQL syntax for the aggregate function. diff --git a/docs/source/library-user-guide/upgrading/52.0.0.md b/docs/source/library-user-guide/upgrading/52.0.0.md new file mode 100644 index 0000000000000..4c659b6118fe4 --- /dev/null +++ b/docs/source/library-user-guide/upgrading/52.0.0.md @@ -0,0 +1,669 @@ + + +# Upgrade Guides + +## DataFusion 52.0.0 + +### Changes to DFSchema API + +To permit more efficient planning, several methods on `DFSchema` have been +changed to return references to the underlying [`&FieldRef`] rather than +[`&Field`]. This allows planners to more cheaply copy the references via +`Arc::clone` rather than cloning the entire `Field` structure. + +You may need to change code to use `Arc::clone` instead of `.as_ref().clone()` +directly on the `Field`. For example: + +```diff +- let field = df_schema.field("my_column").as_ref().clone(); ++ let field = Arc::clone(df_schema.field("my_column")); +``` + +### ListingTableProvider now caches `LIST` commands + +In prior versions, `ListingTableProvider` would issue `LIST` commands to +the underlying object store each time it needed to list files for a query. +To improve performance, `ListingTableProvider` now caches the results of +`LIST` commands for the lifetime of the `ListingTableProvider` instance or +until a cache entry expires. + +Note that by default the cache has no expiration time, so if files are added or removed +from the underlying object store, the `ListingTableProvider` will not see +those changes until the `ListingTableProvider` instance is dropped and recreated. + +You can configure the maximum cache size and cache entry expiration time via configuration options: + +- `datafusion.runtime.list_files_cache_limit` - Limits the size of the cache in bytes +- `datafusion.runtime.list_files_cache_ttl` - Limits the TTL (time-to-live) of an entry in seconds + +Detailed configuration information can be found in the [DataFusion Runtime +Configuration](https://datafusion.apache.org/user-guide/configs.html#runtime-configuration-settings) user's guide. + +Caching can be disabled by setting the limit to 0: + +```sql +SET datafusion.runtime.list_files_cache_limit TO "0K"; +``` + +Note that the internal API has changed to use a trait `ListFilesCache` instead of a type alias. + +### `newlines_in_values` moved from `FileScanConfig` to `CsvOptions` + +The CSV-specific `newlines_in_values` configuration option has been moved from `FileScanConfig` to `CsvOptions`, as it only applies to CSV file parsing. + +**Who is affected:** + +- Users who set `newlines_in_values` via `FileScanConfigBuilder::with_newlines_in_values()` + +**Migration guide:** + +Set `newlines_in_values` in `CsvOptions` instead of on `FileScanConfigBuilder`: + +**Before:** + +```rust,ignore +let source = Arc::new(CsvSource::new(file_schema.clone())); +let config = FileScanConfigBuilder::new(object_store_url, source) + .with_newlines_in_values(true) + .build(); +``` + +**After:** + +```rust,ignore +let options = CsvOptions { + newlines_in_values: Some(true), + ..Default::default() +}; +let source = Arc::new(CsvSource::new(file_schema.clone()) + .with_csv_options(options)); +let config = FileScanConfigBuilder::new(object_store_url, source) + .build(); +``` + +### Removal of `pyarrow` feature + +The `pyarrow` feature flag has been removed. This feature has been migrated to +the `datafusion-python` repository since version `44.0.0`. + +### Refactoring of `FileSource` constructors and `FileScanConfigBuilder` to accept schemas upfront + +The way schemas are passed to file sources and scan configurations has been significantly refactored. File sources now require the schema (including partition columns) to be provided at construction time, and `FileScanConfigBuilder` no longer takes a separate schema parameter. + +**Who is affected:** + +- Users who create `FileScanConfig` or file sources (`ParquetSource`, `CsvSource`, `JsonSource`, `AvroSource`) directly +- Users who implement custom `FileFormat` implementations + +**Key changes:** + +1. **FileSource constructors now require TableSchema**: All built-in file sources now take the schema in their constructor: + + ```diff + - let source = ParquetSource::default(); + + let source = ParquetSource::new(table_schema); + ``` + +2. **FileScanConfigBuilder no longer takes schema as a parameter**: The schema is now passed via the FileSource: + + ```diff + - FileScanConfigBuilder::new(url, schema, source) + + FileScanConfigBuilder::new(url, source) + ``` + +3. **Partition columns are now part of TableSchema**: The `with_table_partition_cols()` method has been removed from `FileScanConfigBuilder`. Partition columns are now passed as part of the `TableSchema` to the FileSource constructor: + + ```diff + + let table_schema = TableSchema::new( + + file_schema, + + vec![Arc::new(Field::new("date", DataType::Utf8, false))], + + ); + + let source = ParquetSource::new(table_schema); + let config = FileScanConfigBuilder::new(url, source) + - .with_table_partition_cols(vec![Field::new("date", DataType::Utf8, false)]) + .with_file(partitioned_file) + .build(); + ``` + +4. **FileFormat::file_source() now takes TableSchema parameter**: Custom `FileFormat` implementations must be updated: + ```diff + impl FileFormat for MyFileFormat { + - fn file_source(&self) -> Arc { + + fn file_source(&self, table_schema: TableSchema) -> Arc { + - Arc::new(MyFileSource::default()) + + Arc::new(MyFileSource::new(table_schema)) + } + } + ``` + +**Migration examples:** + +For Parquet files: + +```diff +- let source = Arc::new(ParquetSource::default()); +- let config = FileScanConfigBuilder::new(url, schema, source) ++ let table_schema = TableSchema::new(schema, vec![]); ++ let source = Arc::new(ParquetSource::new(table_schema)); ++ let config = FileScanConfigBuilder::new(url, source) + .with_file(partitioned_file) + .build(); +``` + +For CSV files with partition columns: + +```diff +- let source = Arc::new(CsvSource::new(true, b',', b'"')); +- let config = FileScanConfigBuilder::new(url, file_schema, source) +- .with_table_partition_cols(vec![Field::new("year", DataType::Int32, false)]) ++ let options = CsvOptions { ++ has_header: Some(true), ++ delimiter: b',', ++ quote: b'"', ++ ..Default::default() ++ }; ++ let table_schema = TableSchema::new( ++ file_schema, ++ vec![Arc::new(Field::new("year", DataType::Int32, false))], ++ ); ++ let source = Arc::new(CsvSource::new(table_schema).with_csv_options(options)); ++ let config = FileScanConfigBuilder::new(url, source) + .build(); +``` + +### Adaptive filter representation in Parquet filter pushdown + +As of Arrow 57.1.0, DataFusion uses a new adaptive filter strategy when +evaluating pushed down filters for Parquet files. This new strategy improves +performance for certain types of queries where the results of filtering are +more efficiently represented with a bitmask rather than a selection. +See [arrow-rs #5523] for more details. + +This change only applies to the built-in Parquet data source with filter-pushdown enabled ( +which is [not yet the default behavior]). + +You can disable the new behavior by setting the +`datafusion.execution.parquet.force_filter_selections` [configuration setting] to true. + +```sql +> set datafusion.execution.parquet.force_filter_selections = true; +``` + +[arrow-rs #5523]: https://github.com/apache/arrow-rs/issues/5523 +[configuration setting]: https://datafusion.apache.org/user-guide/configs.html +[not yet the default behavior]: https://github.com/apache/datafusion/issues/3463 + +### Statistics handling moved from `FileSource` to `FileScanConfig` + +Statistics are now managed directly by `FileScanConfig` instead of being delegated to `FileSource` implementations. This simplifies the `FileSource` trait and provides more consistent statistics handling across all file formats. + +**Who is affected:** + +- Users who have implemented custom `FileSource` implementations + +**Breaking changes:** + +Two methods have been removed from the `FileSource` trait: + +- `with_statistics(&self, statistics: Statistics) -> Arc` +- `statistics(&self) -> Result` + +**Migration guide:** + +If you have a custom `FileSource` implementation, you need to: + +1. Remove the `with_statistics` method implementation +2. Remove the `statistics` method implementation +3. Remove any internal state that was storing statistics + +**Before:** + +```rust,ignore +#[derive(Clone)] +struct MyCustomSource { + table_schema: TableSchema, + projected_statistics: Option, + // other fields... +} + +impl FileSource for MyCustomSource { + fn with_statistics(&self, statistics: Statistics) -> Arc { + Arc::new(Self { + table_schema: self.table_schema.clone(), + projected_statistics: Some(statistics), + // other fields... + }) + } + + fn statistics(&self) -> Result { + Ok(self.projected_statistics.clone().unwrap_or_else(|| + Statistics::new_unknown(self.table_schema.file_schema()) + )) + } + + // other methods... +} +``` + +**After:** + +```rust,ignore +#[derive(Clone)] +struct MyCustomSource { + table_schema: TableSchema, + // projected_statistics field removed + // other fields... +} + +impl FileSource for MyCustomSource { + // with_statistics method removed + // statistics method removed + + // other methods... +} +``` + +**Accessing statistics:** + +Statistics are now accessed through `FileScanConfig` instead of `FileSource`: + +```diff +- let stats = config.file_source.statistics()?; ++ let stats = config.statistics(); +``` + +Note that `FileScanConfig::statistics()` automatically marks statistics as inexact when filters are present, ensuring correctness when filters are pushed down. + +### Partition column handling moved out of `PhysicalExprAdapter` + +Partition column replacement is now a separate preprocessing step performed before expression rewriting via `PhysicalExprAdapter`. This change provides better separation of concerns and makes the adapter more focused on schema differences rather than partition value substitution. + +**Who is affected:** + +- Users who have custom implementations of `PhysicalExprAdapterFactory` that handle partition columns +- Users who directly use the `FilePruner` API + +**Breaking changes:** + +1. `FilePruner::try_new()` signature changed: the `partition_fields` parameter has been removed since partition column handling is now done separately +2. Partition column replacement must now be done via `replace_columns_with_literals()` before expressions are passed to the adapter + +**Migration guide:** + +If you have code that creates a `FilePruner` with partition fields: + +**Before:** + +```rust,ignore +use datafusion_pruning::FilePruner; + +let pruner = FilePruner::try_new( + predicate, + file_schema, + partition_fields, // This parameter is removed + file_stats, +)?; +``` + +**After:** + +```rust,ignore +use datafusion_pruning::FilePruner; + +// Partition fields are no longer needed +let pruner = FilePruner::try_new( + predicate, + file_schema, + file_stats, +)?; +``` + +If you have custom code that relies on `PhysicalExprAdapter` to handle partition columns, you must now call `replace_columns_with_literals()` separately: + +**Before:** + +```rust,ignore +// Adapter handled partition column replacement internally +let adapted_expr = adapter.rewrite(expr)?; +``` + +**After:** + +```rust,ignore +use datafusion_physical_expr_adapter::replace_columns_with_literals; + +// Replace partition columns first +let expr_with_literals = replace_columns_with_literals(expr, &partition_values)?; +// Then apply the adapter +let adapted_expr = adapter.rewrite(expr_with_literals)?; +``` + +### `build_row_filter` signature simplified + +The `build_row_filter` function in `datafusion-datasource-parquet` has been simplified to take a single schema parameter instead of two. +The expectation is now that the filter has been adapted to the physical file schema (the arrow representation of the parquet file's schema) before being passed to this function +using a `PhysicalExprAdapter` for example. + +**Who is affected:** + +- Users who call `build_row_filter` directly + +**Breaking changes:** + +The function signature changed from: + +```rust,ignore +pub fn build_row_filter( + expr: &Arc, + physical_file_schema: &SchemaRef, + predicate_file_schema: &SchemaRef, // removed + metadata: &ParquetMetaData, + reorder_predicates: bool, + file_metrics: &ParquetFileMetrics, +) -> Result> +``` + +To: + +```rust,ignore +pub fn build_row_filter( + expr: &Arc, + file_schema: &SchemaRef, + metadata: &ParquetMetaData, + reorder_predicates: bool, + file_metrics: &ParquetFileMetrics, +) -> Result> +``` + +**Migration guide:** + +Remove the duplicate schema parameter from your call: + +```diff +- build_row_filter(&predicate, &file_schema, &file_schema, metadata, reorder, metrics) ++ build_row_filter(&predicate, &file_schema, metadata, reorder, metrics) +``` + +### Planner now requires explicit opt-in for WITHIN GROUP syntax + +The SQL planner now enforces the aggregate UDF contract more strictly: the +`WITHIN GROUP (ORDER BY ...)` syntax is accepted only if the aggregate UDAF +explicitly advertises support by returning `true` from +`AggregateUDFImpl::supports_within_group_clause()`. + +Previously the planner forwarded a `WITHIN GROUP` clause to order-sensitive +aggregates even when they did not implement ordered-set semantics, which could +cause queries such as `SUM(x) WITHIN GROUP (ORDER BY x)` to plan successfully. +This behavior was too permissive and has been changed to match PostgreSQL and +the documented semantics. + +Migration: If your UDAF intentionally implements ordered-set semantics and +wants to accept the `WITHIN GROUP` SQL syntax, update your implementation to +return `true` from `supports_within_group_clause()` and handle the ordering +semantics in your accumulator implementation. If your UDAF is merely +order-sensitive (but not an ordered-set aggregate), do not advertise +`supports_within_group_clause()` and clients should use alternative function +signatures (for example, explicit ordering as a function argument) instead. + +### `AggregateUDFImpl::supports_null_handling_clause` now defaults to `false` + +This method specifies whether an aggregate function allows `IGNORE NULLS`/`RESPECT NULLS` +during SQL parsing, with the implication it respects these configs during computation. + +Most DataFusion aggregate functions silently ignored this syntax in prior versions +as they did not make use of it and it was permitted by default. We change this so +only the few functions which do respect this clause (e.g. `array_agg`, `first_value`, +`last_value`) need to implement it. + +Custom user defined aggregate functions will also error if this syntax is used, +unless they explicitly declare support by overriding the method. + +For example, SQL parsing will now fail for queries such as this: + +```sql +SELECT median(c1) IGNORE NULLS FROM table +``` + +Instead of silently succeeding. + +### API change for `CacheAccessor` trait + +The remove API no longer requires a mutable instance + +### FFI crate updates + +Many of the structs in the `datafusion-ffi` crate have been updated to allow easier +conversion to the underlying trait types they represent. This simplifies some code +paths, but also provides an additional improvement in cases where library code goes +through a round trip via the foreign function interface. + +To update your code, suppose you have a `FFI_SchemaProvider` called `ffi_provider` +and you wish to use this as a `SchemaProvider`. In the old approach you would do +something like: + +```rust,ignore + let foreign_provider: ForeignSchemaProvider = ffi_provider.into(); + let foreign_provider = Arc::new(foreign_provider) as Arc; +``` + +This code should now be written as: + +```rust,ignore + let foreign_provider: Arc = ffi_provider.into(); + let foreign_provider = foreign_provider as Arc; +``` + +For the case of user defined functions, the updates are similar but you +may need to change the way you call the creation of the `ScalarUDF`. +Aggregate and window functions follow the same pattern. + +Previously you may write: + +```rust,ignore + let foreign_udf: ForeignScalarUDF = ffi_udf.try_into()?; + let foreign_udf: ScalarUDF = foreign_udf.into(); +``` + +Instead this should now be: + +```rust,ignore + let foreign_udf: Arc = ffi_udf.into(); + let foreign_udf = ScalarUDF::new_from_shared_impl(foreign_udf); +``` + +When creating any of the following structs, we now require the user to +provide a `TaskContextProvider` and optionally a `LogicalExtensionCodec`: + +- `FFI_CatalogListProvider` +- `FFI_CatalogProvider` +- `FFI_SchemaProvider` +- `FFI_TableProvider` +- `FFI_TableFunction` + +Each of these structs has a `new()` and a `new_with_ffi_codec()` method for +instantiation. For example, when you previously would write + +```rust,ignore + let table = Arc::new(MyTableProvider::new()); + let ffi_table = FFI_TableProvider::new(table, None); +``` + +Now you will need to provide a `TaskContextProvider`. The most common +implementation of this trait is `SessionContext`. + +```rust,ignore + let ctx = Arc::new(SessionContext::default()); + let table = Arc::new(MyTableProvider::new()); + let ffi_table = FFI_TableProvider::new(table, None, ctx, None); +``` + +The alternative function to create these structures may be more convenient +if you are doing many of these operations. A `FFI_LogicalExtensionCodec` will +store the `TaskContextProvider` as well. + +```rust,ignore + let codec = Arc::new(DefaultLogicalExtensionCodec {}); + let ctx = Arc::new(SessionContext::default()); + let ffi_codec = FFI_LogicalExtensionCodec::new(codec, None, ctx); + let table = Arc::new(MyTableProvider::new()); + let ffi_table = FFI_TableProvider::new_with_ffi_codec(table, None, ffi_codec); +``` + +Additional information about the usage of the `TaskContextProvider` can be +found in the crate README. + +Additionally, the FFI structure for Scalar UDF's no longer contains a +`return_type` call. This code was not used since the `ForeignScalarUDF` +struct implements the `return_field_from_args` instead. + +### Projection handling moved from FileScanConfig to FileSource + +Projection handling has been moved from `FileScanConfig` into `FileSource` implementations. This enables format-specific projection pushdown (e.g., Parquet can push down struct field access, Vortex can push down computed expressions into un-decoded data). + +**Who is affected:** + +- Users who have implemented custom `FileSource` implementations +- Users who use `FileScanConfigBuilder::with_projection_indices` directly + +**Breaking changes:** + +1. **`FileSource::with_projection` replaced with `try_pushdown_projection`:** + + The `with_projection(&self, config: &FileScanConfig) -> Arc` method has been removed and replaced with `try_pushdown_projection(&self, projection: &ProjectionExprs) -> Result>>`. + +2. **`FileScanConfig.projection_exprs` field removed:** + + Projections are now stored in the `FileSource` directly, not in `FileScanConfig`. + Various public helper methods that access projection information have been removed from `FileScanConfig`. + +3. **`FileScanConfigBuilder::with_projection_indices` now returns `Result`:** + + This method can now fail if the projection pushdown fails. + +4. **`FileSource::create_file_opener` now returns `Result>`:** + + Previously returned `Arc` directly. + Any `FileSource` implementation that may fail to create a `FileOpener` should now return an appropriate error. + +5. **`DataSource::try_swapping_with_projection` signature changed:** + + Parameter changed from `&[ProjectionExpr]` to `&ProjectionExprs`. + +**Migration guide:** + +If you have a custom `FileSource` implementation: + +**Before:** + +```rust,ignore +impl FileSource for MyCustomSource { + fn with_projection(&self, config: &FileScanConfig) -> Arc { + // Apply projection from config + Arc::new(Self { /* ... */ }) + } + + fn create_file_opener( + &self, + object_store: Arc, + base_config: &FileScanConfig, + partition: usize, + ) -> Arc { + Arc::new(MyOpener { /* ... */ }) + } +} +``` + +**After:** + +```rust,ignore +impl FileSource for MyCustomSource { + fn try_pushdown_projection( + &self, + projection: &ProjectionExprs, + ) -> Result>> { + // Return None if projection cannot be pushed down + // Return Some(new_source) with projection applied if it can + Ok(Some(Arc::new(Self { + projection: Some(projection.clone()), + /* ... */ + }))) + } + + fn projection(&self) -> Option<&ProjectionExprs> { + self.projection.as_ref() + } + + fn create_file_opener( + &self, + object_store: Arc, + base_config: &FileScanConfig, + partition: usize, + ) -> Result> { + Ok(Arc::new(MyOpener { /* ... */ })) + } +} +``` + +We recommend you look at [#18627](https://github.com/apache/datafusion/pull/18627) +that introduced these changes for more examples for how this was handled for the various built in file sources. + +We have added [`SplitProjection`](https://docs.rs/datafusion-datasource/latest/datafusion_datasource/projection/struct.SplitProjection.html) and [`ProjectionOpener`](https://docs.rs/datafusion-datasource/latest/datafusion_datasource/projection/struct.ProjectionOpener.html) helpers to make it easier to handle projections in your `FileSource` implementations. + +For file sources that can only handle simple column selections (not computed expressions), use the `SplitProjection` and `ProjectionOpener` helpers to split the projection into pushdownable and non-pushdownable parts: + +```rust,ignore +use datafusion_datasource::projection::{SplitProjection, ProjectionOpener}; + +// In try_pushdown_projection: +let split = SplitProjection::new(projection, self.table_schema())?; +// Use split.file_projection() for what to push down to the file format +// The ProjectionOpener wrapper will handle the rest +``` + +**For `FileScanConfigBuilder` users:** + +```diff +let config = FileScanConfigBuilder::new(url, source) +- .with_projection_indices(Some(vec![0, 2, 3])) ++ .with_projection_indices(Some(vec![0, 2, 3]))? + .build(); +``` + +### `SchemaAdapter` and `SchemaAdapterFactory` completely removed + +Following the deprecation announced in [DataFusion 49.0.0](49.0.0.md#deprecating-schemaadapterfactory-and-schemaadapter), `SchemaAdapterFactory` has been fully removed from Parquet scanning. This applies to both: + +The following symbols have been deprecated and will be removed in the next release: + +- `SchemaAdapter` trait +- `SchemaAdapterFactory` trait +- `SchemaMapper` trait +- `SchemaMapping` struct +- `DefaultSchemaAdapterFactory` struct + +These types were previously used to adapt record batch schemas during file reading. +This functionality has been replaced by `PhysicalExprAdapterFactory`, which rewrites expressions at planning time rather than transforming batches at runtime. +If you were using a custom `SchemaAdapterFactory` for schema adaptation (e.g., default column values, type coercion), you should now implement `PhysicalExprAdapterFactory` instead. +See the [default column values example](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/custom_data_source/default_column_values.rs) for how to implement a custom `PhysicalExprAdapterFactory`. + +**Migration guide:** + +If you implemented a custom `SchemaAdapterFactory`, migrate to `PhysicalExprAdapterFactory`. +See the [default column values example](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/custom_data_source/default_column_values.rs) for a complete implementation. diff --git a/docs/source/library-user-guide/upgrading/53.0.0.md b/docs/source/library-user-guide/upgrading/53.0.0.md new file mode 100644 index 0000000000000..06c917b2ab925 --- /dev/null +++ b/docs/source/library-user-guide/upgrading/53.0.0.md @@ -0,0 +1,400 @@ + + +# Upgrade Guides + +## DataFusion 53.0.0 + +**Note:** DataFusion `53.0.0` has not been released yet. The information provided +*in this section pertains to features and changes that have already been merged +*to the main branch and are awaiting release in this version. See [#19692] for +\*more details. + +[#19692]: https://github.com/apache/datafusion/issues/19692 + +### `PlannerContext` outer query schema API now uses a stack + +`PlannerContext` no longer stores a single `outer_query_schema`. It now tracks a +stack of outer relation schemas so nested subqueries can access non-adjacent +outer relations. + +**Before:** + +```rust,ignore +let old_outer_query_schema = + planner_context.set_outer_query_schema(Some(input_schema.clone().into())); +let sub_plan = self.query_to_plan(subquery, planner_context)?; +planner_context.set_outer_query_schema(old_outer_query_schema); +``` + +**After:** + +```rust,ignore +planner_context.append_outer_query_schema(input_schema.clone().into()); +let sub_plan = self.query_to_plan(subquery, planner_context)?; +planner_context.pop_outer_query_schema(); +``` + +### `FileSinkConfig` adds `file_output_mode` + +`FileSinkConfig` now includes a `file_output_mode: FileOutputMode` field to control +single-file vs directory output behavior. Any code constructing `FileSinkConfig` via struct +literals must initialize this field. + +The `FileOutputMode` enum has three variants: + +- `Automatic` (default): Infer output mode from the URL (extension/trailing `/` heuristic) +- `SingleFile`: Write to a single file at the exact output path +- `Directory`: Write to a directory with generated filenames + +**Before:** + +```rust,ignore +FileSinkConfig { + // ... + file_extension: "parquet".into(), +} +``` + +**After:** + +```rust,ignore +use datafusion_datasource::file_sink_config::FileOutputMode; + +FileSinkConfig { + // ... + file_extension: "parquet".into(), + file_output_mode: FileOutputMode::Automatic, +} +``` + +### `SimplifyInfo` trait removed, `SimplifyContext` now uses builder-style API + +The `SimplifyInfo` trait has been removed and replaced with the concrete `SimplifyContext` struct. This simplifies the expression simplification API and removes the need for trait objects. + +**Who is affected:** + +- Users who implemented custom `SimplifyInfo` implementations +- Users who implemented `ScalarUDFImpl::simplify()` for custom scalar functions +- Users who directly use `SimplifyContext` or `ExprSimplifier` + +**Breaking changes:** + +1. The `SimplifyInfo` trait has been removed entirely +2. `SimplifyContext` no longer takes `&ExecutionProps` - it now uses a builder-style API with direct fields +3. `ScalarUDFImpl::simplify()` now takes `&SimplifyContext` instead of `&dyn SimplifyInfo` +4. Time-dependent function simplification (e.g., `now()`) is now optional - if `query_execution_start_time` is `None`, these functions won't be simplified + +**Migration guide:** + +If you implemented a custom `SimplifyInfo`: + +**Before:** + +```rust,ignore +impl SimplifyInfo for MySimplifyInfo { + fn is_boolean_type(&self, expr: &Expr) -> Result { ... } + fn nullable(&self, expr: &Expr) -> Result { ... } + fn execution_props(&self) -> &ExecutionProps { ... } + fn get_data_type(&self, expr: &Expr) -> Result { ... } +} +``` + +**After:** + +Use `SimplifyContext` directly with the builder-style API: + +```rust,ignore +let context = SimplifyContext::default() + .with_schema(schema) + .with_config_options(config_options) + .with_query_execution_start_time(Some(Utc::now())); // or use .with_current_time() +``` + +If you implemented `ScalarUDFImpl::simplify()`: + +**Before:** + +```rust,ignore +fn simplify( + &self, + args: Vec, + info: &dyn SimplifyInfo, +) -> Result { + let now_ts = info.execution_props().query_execution_start_time; + // ... +} +``` + +**After:** + +```rust,ignore +fn simplify( + &self, + args: Vec, + info: &SimplifyContext, +) -> Result { + // query_execution_start_time is now Option> + // Return Original if time is not set (simplification skipped) + let Some(now_ts) = info.query_execution_start_time() else { + return Ok(ExprSimplifyResult::Original(args)); + }; + // ... +} +``` + +If you created `SimplifyContext` from `ExecutionProps`: + +**Before:** + +```rust,ignore +let props = ExecutionProps::new(); +let context = SimplifyContext::new(&props).with_schema(schema); +``` + +**After:** + +```rust,ignore +let context = SimplifyContext::default() + .with_schema(schema) + .with_config_options(config_options) + .with_current_time(); // Sets query_execution_start_time to Utc::now() +``` + +See [`SimplifyContext` documentation](https://docs.rs/datafusion-expr/latest/datafusion_expr/simplify/struct.SimplifyContext.html) for more details. + +### Struct Casting Now Requires Field Name Overlap + +DataFusion's struct casting mechanism previously allowed casting between structs with differing field names if the field counts matched. This "positional fallback" behavior could silently misalign fields and cause data corruption. + +**Breaking Change:** + +Starting with DataFusion 53.0.0, struct casts now require **at least one overlapping field name** between the source and target structs. Casts without field name overlap are rejected at plan time with a clear error message. + +**Who is affected:** + +- Applications that cast between structs with no overlapping field names +- Queries that rely on positional struct field mapping (e.g., casting `struct(x, y)` to `struct(a, b)` based solely on position) +- Code that constructs or transforms struct columns programmatically + +**Migration guide:** + +If you encounter an error like: + +```text +Cannot cast struct with 2 fields to 2 fields because there is no field name overlap +``` + +You must explicitly rename or map fields to ensure at least one field name matches. Here are common patterns: + +**Example 1: Source and target field names already match (Name-based casting)** + +**Success case (field names align):** + +```sql +-- source_col has schema: STRUCT +-- Casting to the same field names succeeds (no-op or type validation only) +SELECT CAST(source_col AS STRUCT) FROM table1; +``` + +**Example 2: Source and target field names differ (Migration scenario)** + +**What fails now (no field name overlap):** + +```sql +-- source_col has schema: STRUCT +-- This FAILS because there is no field name overlap: +-- ❌ SELECT CAST(source_col AS STRUCT) FROM table1; +-- Error: Cannot cast struct with 2 fields to 2 fields because there is no field name overlap +``` + +**Migration options (must align names):** + +**Option A: Use struct constructor for explicit field mapping** + +```sql +-- source_col has schema: STRUCT +-- Use STRUCT_CONSTRUCT with explicit field names +SELECT STRUCT_CONSTRUCT( + 'x', source_col.a, + 'y', source_col.b +) AS renamed_struct FROM table1; +``` + +**Option B: Rename in the cast target to match source names** + +```sql +-- source_col has schema: STRUCT +-- Cast to target with matching field names +SELECT CAST(source_col AS STRUCT) FROM table1; +``` + +**Example 3: Using struct constructors in Rust API** + +If you need to map fields programmatically, build the target struct explicitly: + +```rust,ignore +// Build the target struct with explicit field names +let target_struct_type = DataType::Struct(vec![ + FieldRef::new("x", DataType::Int32), + FieldRef::new("y", DataType::Utf8), +]); + +// Use struct constructors rather than casting for field mapping +// This makes the field mapping explicit and unambiguous +// Use struct builders or row constructors that preserve your mapping logic +``` + +**Why this change:** + +1. **Safety:** Field names are now the primary contract for struct compatibility +2. **Explicitness:** Prevents silent data misalignment caused by positional assumptions +3. **Consistency:** Matches DuckDB's behavior and aligns with other SQL engines that enforce name-based matching +4. **Debuggability:** Errors now appear at plan time rather than as silent data corruption + +See [Issue #19841](https://github.com/apache/datafusion/issues/19841) and [PR #19955](https://github.com/apache/datafusion/pull/19955) for more details. + +### `FilterExec` builder methods deprecated + +The following methods on `FilterExec` have been deprecated in favor of using `FilterExecBuilder`: + +- `with_projection()` +- `with_batch_size()` + +**Who is affected:** + +- Users who create `FilterExec` instances and use these methods to configure them + +**Migration guide:** + +Use `FilterExecBuilder` instead of chaining method calls on `FilterExec`: + +**Before:** + +```rust,ignore +let filter = FilterExec::try_new(predicate, input)? + .with_projection(Some(vec![0, 2]))? + .with_batch_size(8192)?; +``` + +**After:** + +```rust,ignore +let filter = FilterExecBuilder::new(predicate, input) + .with_projection(Some(vec![0, 2])) + .with_batch_size(8192) + .build()?; +``` + +The builder pattern is more efficient as it computes properties once during `build()` rather than recomputing them for each method call. + +Note: `with_default_selectivity()` is not deprecated as it simply updates a field value and does not require the overhead of the builder pattern. + +### Protobuf conversion trait added + +A new trait, `PhysicalProtoConverterExtension`, has been added to the `datafusion-proto` +crate. This is used for controlling the process of conversion of physical plans and +expressions to and from their protobuf equivalents. The methods for conversion now +require an additional parameter. + +The primary APIs for interacting with this crate have not been modified, so most users +should not need to make any changes. If you do require this trait, you can use the +`DefaultPhysicalProtoConverter` implementation. + +For example, to convert a sort expression protobuf node you can make the following +updates: + +**Before:** + +```rust,ignore +let sort_expr = parse_physical_sort_expr( + sort_proto, + ctx, + input_schema, + codec, +); +``` + +**After:** + +```rust,ignore +let converter = DefaultPhysicalProtoConverter {}; +let sort_expr = parse_physical_sort_expr( + sort_proto, + ctx, + input_schema, + codec, + &converter +); +``` + +Similarly to convert from a physical sort expression into a protobuf node: + +**Before:** + +```rust,ignore +let sort_proto = serialize_physical_sort_expr( + sort_expr, + codec, +); +``` + +**After:** + +```rust,ignore +let converter = DefaultPhysicalProtoConverter {}; +let sort_proto = serialize_physical_sort_expr( + sort_expr, + codec, + &converter, +); +``` + +### `generate_series` and `range` table functions changed + +The `generate_series` and `range` table functions now return an empty set when the interval is invalid, instead of an error. +This behavior is consistent with systems like PostgreSQL. + +Before: + +```sql +> select * from generate_series(0, -1); +Error during planning: Start is bigger than end, but increment is positive: Cannot generate infinite series + +> select * from range(0, -1); +Error during planning: Start is bigger than end, but increment is positive: Cannot generate infinite series +``` + +Now: + +```sql +> select * from generate_series(0, -1); ++-------+ +| value | ++-------+ ++-------+ +0 row(s) fetched. + +> select * from range(0, -1); ++-------+ +| value | ++-------+ ++-------+ +0 row(s) fetched. +``` diff --git a/docs/source/library-user-guide/upgrading/index.rst b/docs/source/library-user-guide/upgrading/index.rst new file mode 100644 index 0000000000000..16bb33b7592ae --- /dev/null +++ b/docs/source/library-user-guide/upgrading/index.rst @@ -0,0 +1,32 @@ +.. Licensed to the Apache Software Foundation (ASF) under one +.. or more contributor license agreements. See the NOTICE file +.. distributed with this work for additional information +.. regarding copyright ownership. The ASF licenses this file +.. to you under the Apache License, Version 2.0 (the +.. "License"); you may not use this file except in compliance +.. with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, +.. software distributed under the License is distributed on an +.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +.. KIND, either express or implied. See the License for the +.. specific language governing permissions and limitations +.. under the License. + +Upgrade Guides +============== + +.. toctree:: + :maxdepth: 1 + + DataFusion 53.0.0 <53.0.0> + DataFusion 52.0.0 <52.0.0> + DataFusion 51.0.0 <51.0.0> + DataFusion 50.0.0 <50.0.0> + DataFusion 49.0.0 <49.0.0> + DataFusion 48.0.1 <48.0.1> + DataFusion 48.0.0 <48.0.0> + DataFusion 47.0.0 <47.0.0> + DataFusion 46.0.0 <46.0.0> diff --git a/docs/source/user-guide/arrow-introduction.md b/docs/source/user-guide/arrow-introduction.md index 89662a0c29c5d..5a225782adfdb 100644 --- a/docs/source/user-guide/arrow-introduction.md +++ b/docs/source/user-guide/arrow-introduction.md @@ -220,14 +220,15 @@ When working with Arrow and RecordBatches, watch out for these common issues: - [Schema](https://docs.rs/arrow-schema/latest/arrow_schema/struct.Schema.html) - Describes the structure of a RecordBatch (column names and types) [apache arrow]: https://arrow.apache.org/docs/index.html +[arrow-rs]: https://github.com/apache/arrow-rs [`arc`]: https://doc.rust-lang.org/std/sync/struct.Arc.html [`arrayref`]: https://docs.rs/arrow-array/latest/arrow_array/array/type.ArrayRef.html [`cast`]: https://docs.rs/arrow/latest/arrow/compute/fn.cast.html [`field`]: https://docs.rs/arrow-schema/latest/arrow_schema/struct.Field.html [`schema`]: https://docs.rs/arrow-schema/latest/arrow_schema/struct.Schema.html [`datatype`]: https://docs.rs/arrow-schema/latest/arrow_schema/enum.DataType.html -[`int32array`]: https://docs.rs/arrow-array/latest/arrow_array/array/struct.Int32Array.html -[`stringarray`]: https://docs.rs/arrow-array/latest/arrow_array/array/struct.StringArray.html +[`int32array`]: https://docs.rs/arrow/latest/arrow/array/type.Int32Array.html +[`stringarray`]: https://docs.rs/arrow/latest/arrow/array/type.StringArray.html [`int32`]: https://docs.rs/arrow-schema/latest/arrow_schema/enum.DataType.html#variant.Int32 [`int64`]: https://docs.rs/arrow-schema/latest/arrow_schema/enum.DataType.html#variant.Int64 [extension points]: ../library-user-guide/extensions.md @@ -241,8 +242,8 @@ When working with Arrow and RecordBatches, watch out for these common issues: [`.show()`]: https://docs.rs/datafusion/latest/datafusion/dataframe/struct.DataFrame.html#method.show [`memtable`]: https://docs.rs/datafusion/latest/datafusion/datasource/struct.MemTable.html [`sessioncontext`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html -[`csvreadoptions`]: https://docs.rs/datafusion/latest/datafusion/execution/options/struct.CsvReadOptions.html -[`parquetreadoptions`]: https://docs.rs/datafusion/latest/datafusion/execution/options/struct.ParquetReadOptions.html +[`csvreadoptions`]: https://docs.rs/datafusion/latest/datafusion/datasource/file_format/options/struct.CsvReadOptions.html +[`parquetreadoptions`]: https://docs.rs/datafusion/latest/datafusion/datasource/file_format/options/struct.ParquetReadOptions.html [`recordbatch`]: https://docs.rs/arrow-array/latest/arrow_array/struct.RecordBatch.html [`read_csv`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html#method.read_csv [`read_parquet`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html#method.read_parquet diff --git a/docs/source/user-guide/concepts-readings-events.md b/docs/source/user-guide/concepts-readings-events.md index ad444ef91c474..3946ca7b16f63 100644 --- a/docs/source/user-guide/concepts-readings-events.md +++ b/docs/source/user-guide/concepts-readings-events.md @@ -21,7 +21,7 @@ ## 🧭 Background Concepts -- **2024-06-13**: [2024 ACM SIGMOD International Conference on Management of Data: Apache Arrow DataFusion: A Fast, Embeddable, Modular Analytic Query Engine](https://dl.acm.org/doi/10.1145/3626246.3653368) - [Download](http://andrew.nerdnetworks.org/other/SIGMOD-2024-lamb.pdf), [Talk](https://youtu.be/-DpKcPfnNms), [Slides](https://docs.google.com/presentation/d/1gqcxSNLGVwaqN0_yJtCbNm19-w5pqPuktII5_EDA6_k/edit#slide=id.p), [Recording ](https://youtu.be/-DpKcPfnNms) +- **2024-06-13**: [2024 ACM SIGMOD International Conference on Management of Data: Apache Arrow DataFusion: A Fast, Embeddable, Modular Analytic Query Engine](https://dl.acm.org/doi/10.1145/3626246.3653368) - [Download](https://andrew.nerdnetworks.org/pdf/SIGMOD-2024-lamb.pdf), [Talk](https://youtu.be/-DpKcPfnNms), [Slides](https://docs.google.com/presentation/d/1gqcxSNLGVwaqN0_yJtCbNm19-w5pqPuktII5_EDA6_k/edit#slide=id.p), [Recording ](https://youtu.be/-DpKcPfnNms) - **2024-06-07**: [Video: SIGMOD 2024 Practice: Apache Arrow DataFusion A Fast, Embeddable, Modular Analytic Query Engine](https://www.youtube.com/watch?v=-DpKcPfnNms&t=5s) - [Slides](https://docs.google.com/presentation/d/1gqcxSNLGVwaqN0_yJtCbNm19-w5pqPuktII5_EDA6_k/edit#slide=id.p) @@ -37,6 +37,34 @@ This is a list of DataFusion related blog posts, articles, and other resources. Please open a PR to add any new resources you create or find +- **2026-01-12** [Blog: Extending SQL in DataFusion: from ->> to TABLESAMPLE](https://datafusion.apache.org/blog/2026/01/12/extending-sql) + +- **2025-12-15** [Blog: Optimizing Repartitions in DataFusion: How I Went From Database Noob to Core Contribution](https://datafusion.apache.org/blog/2025/12/15/avoid-consecutive-repartitions) + +- **2025-09-21** [Blog: Implementing User Defined Types and Custom Metadata in DataFusion](https://datafusion.apache.org/blog/2025/09/21/custom-types-using-metadata) + +- **2025-09-10** [Blog: Dynamic Filters: Passing Information Between Operators During Execution for 25x Faster Queries](https://datafusion.apache.org/blog/2025/09/10/dynamic-filters) + +- **2025-08-15** [Blog: Using External Indexes, Metadata Stores, Catalogs and Caches to Accelerate Queries on Apache Parquet](https://datafusion.apache.org/blog/2025/08/15/external-parquet-indexes) + +- **2025-07-14** [Blog: Embedding User-Defined Indexes in Apache Parquet Files](https://datafusion.apache.org/blog/2025/07/14/user-defined-parquet-indexes) + +- **2025-06-30** [Blog: Using Rust async for Query Execution and Cancelling Long-Running Queries](https://datafusion.apache.org/blog/2025/06/30/cancellation) + +- **2025-06-15** [Blog: Optimizing SQL (and DataFrames) in DataFusion, Part 1: Query Optimization Overview](https://datafusion.apache.org/blog/2025/06/15/optimizing-sql-dataframes-part-one) + +- **2025-06-15** [Blog: Optimizing SQL (and DataFrames) in DataFusion, Part 2: Optimizers in Apache DataFusion](https://datafusion.apache.org/blog/2025/06/15/optimizing-sql-dataframes-part-two) + +- **2025-04-19** [Blog: User defined Window Functions in DataFusion](https://datafusion.apache.org/blog/2025/04/19/user-defined-window-functions) + +- **2025-04-10** [Blog: tpchgen-rs World's fastest open source TPC-H data generator, written in Rust](https://datafusion.apache.org/blog/2025/04/10/fastest-tpch-generator) + +- **2025-03-11** [Blog: Using Ordering for Better Plans in Apache DataFusion](https://datafusion.apache.org/blog/2025/03/11/ordering-analysis) + +- **2024-05-07** [Blog: Announcing Apache Arrow DataFusion is now Apache DataFusion](https://datafusion.apache.org/blog/2024/05/07/datafusion-tlp) + +- **2024-03-06** [Blog: Announcing Apache Arrow DataFusion Comet](https://datafusion.apache.org/blog/2024/03/06/comet-donation) + - **2025-03-21** [Blog: Efficient Filter Pushdown in Parquet](https://datafusion.apache.org/blog/2025/03/21/parquet-pushdown/) - **2025-03-20** [Blog: Parquet Pruning in DataFusion: Read Only What Matters](https://datafusion.apache.org/blog/2025/03/20/parquet-pruning/) @@ -59,16 +87,14 @@ This is a list of DataFusion related blog posts, articles, and other resources. - **2024-10-29** [Video: MiDAS Seminar Fall 2024 on "Apache DataFusion" by Andrew Lamb](https://www.youtube.com/watch?v=CpnxuBwHbUc) -- **2024-10-27** [Blog: Caching in DataFusion: Don't read twice](https://blog.haoxp.xyz/posts/caching-datafusion) +- **2024-10-27** [Blog: Caching in DataFusion: Don't read twice](https://blog.xiangpeng.systems/posts/caching-datafusion/) -- **2024-10-24** [Blog: Parquet pruning in DataFusion: Read no more than you need](https://blog.haoxp.xyz/posts/parquet-to-arrow/) +- **2024-10-24** [Blog: Parquet pruning in DataFusion: Read no more than you need](https://blog.xiangpeng.systems/posts/parquet-to-arrow/) - **2024-09-13** [Blog: Using StringView / German Style Strings to make Queries Faster: Part 2 - String Operations](https://www.influxdata.com/blog/faster-queries-with-stringview-part-two-influxdb/) | [Reposted on DataFusion Blog](https://datafusion.apache.org/blog/2024/09/13/string-view-german-style-strings-part-2/) - **2024-09-13** [Blog: Using StringView / German Style Strings to Make Queries Faster: Part 1- Reading Parquet](https://www.influxdata.com/blog/faster-queries-with-stringview-part-one-influxdb/) | [Reposted on Datafusion Blog](https://datafusion.apache.org/blog/2024/09/13/string-view-german-style-strings-part-1/) -- **2024-10-16** [Blog: Candle Image Segmentation](https://www.letsql.com/posts/candle-image-segmentation/) - - **2024-09-23 → 2024-12-02** [Talks: Carnegie Mellon University: Database Building Blocks Seminar Series - Fall 2024](https://db.cs.cmu.edu/seminar2024/) - **2024-11-12** [Video: Building InfluxDB 3.0 with the FDAP Stack: Apache Flight, DataFusion, Arrow and Parquet (Paul Dix)](https://www.youtube.com/watch?v=AGS4GNGDK_4) diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index f0ee0cbbc4e55..e48f0a7c92276 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -93,7 +93,7 @@ The following configuration settings are available: | datafusion.execution.parquet.bloom_filter_on_read | true | (reading) Use any available bloom filters when reading parquet files | | datafusion.execution.parquet.max_predicate_cache_size | NULL | (reading) The maximum predicate cache size, in bytes. When `pushdown_filters` is enabled, sets the maximum memory used to cache the results of predicate evaluation between filter evaluation and output generation. Decreasing this value will reduce memory usage, but may increase IO and CPU usage. None means use the default parquet reader setting. 0 means no caching. | | datafusion.execution.parquet.data_pagesize_limit | 1048576 | (writing) Sets best effort maximum size of data page in bytes | -| datafusion.execution.parquet.write_batch_size | 1024 | (writing) Sets write_batch_size in bytes | +| datafusion.execution.parquet.write_batch_size | 1024 | (writing) Sets write_batch_size in rows | | datafusion.execution.parquet.writer_version | 1.0 | (writing) Sets parquet writer version valid values are "1.0" and "2.0" | | datafusion.execution.parquet.skip_arrow_metadata | false | (writing) Skip encoding the embedded arrow metadata in the KV_meta This is analogous to the `ArrowWriterOptions::with_skip_arrow_metadata`. Refer to | | datafusion.execution.parquet.compression | zstd(3) | (writing) Sets default parquet compression codec. Valid values are: uncompressed, snappy, gzip(level), brotli(level), lz4, zstd(level), and lz4_raw. These values are not case sensitive. If NULL, uses default parquet writer setting Note that this default setting is not the same as the default parquet writer setting. | @@ -101,7 +101,7 @@ The following configuration settings are available: | datafusion.execution.parquet.dictionary_page_size_limit | 1048576 | (writing) Sets best effort maximum dictionary page size, in bytes | | datafusion.execution.parquet.statistics_enabled | page | (writing) Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting | | datafusion.execution.parquet.max_row_group_size | 1048576 | (writing) Target maximum number of rows in each row group (defaults to 1M rows). Writing larger row groups requires more memory to write, but can get better compression and be faster to read. | -| datafusion.execution.parquet.created_by | datafusion version 52.0.0 | (writing) Sets "created by" property | +| datafusion.execution.parquet.created_by | datafusion version 52.1.0 | (writing) Sets "created by" property | | datafusion.execution.parquet.column_index_truncate_length | 64 | (writing) Sets column index truncate length | | datafusion.execution.parquet.statistics_truncate_length | 64 | (writing) Sets statistics truncate length. If NULL, uses default parquet writer setting | | datafusion.execution.parquet.data_page_row_count_limit | 20000 | (writing) Sets best effort maximum number of rows in data page | @@ -165,6 +165,7 @@ The following configuration settings are available: | datafusion.optimizer.prefer_existing_union | false | When set to true, the optimizer will not attempt to convert Union to Interleave | | datafusion.optimizer.expand_views_at_output | false | When set to true, if the returned type is a view type then the output will be coerced to a non-view. Coerces `Utf8View` to `LargeUtf8`, and `BinaryView` to `LargeBinary`. | | datafusion.optimizer.enable_sort_pushdown | true | Enable sort pushdown optimization. When enabled, attempts to push sort requirements down to data sources that can natively handle them (e.g., by reversing file/row group read order). Returns **inexact ordering**: Sort operator is kept for correctness, but optimized input enables early termination for TopK queries (ORDER BY ... LIMIT N), providing significant speedup. Memory: No additional overhead (only changes read order). Future: Will add option to detect perfectly sorted data and eliminate Sort completely. Default: true | +| datafusion.optimizer.enable_leaf_expression_pushdown | true | When set to true, the optimizer will extract leaf expressions (such as `get_field`) from filter/sort/join nodes into projections closer to the leaf table scans, and push those projections down towards the leaf nodes. | | datafusion.explain.logical_plan_only | false | When set to true, the explain statement will only print logical plans | | datafusion.explain.physical_plan_only | false | When set to true, the explain statement will only print physical plans | | datafusion.explain.show_statistics | false | When set to true, the explain statement will print operator statistics for physical plans | diff --git a/docs/source/user-guide/crate-configuration.md b/docs/source/user-guide/crate-configuration.md index 83a46b50c004f..14827a8c2c721 100644 --- a/docs/source/user-guide/crate-configuration.md +++ b/docs/source/user-guide/crate-configuration.md @@ -24,6 +24,7 @@ your Rust project. The [Configuration Settings] section lists options that control additional aspects DataFusion's runtime behavior. [configuration settings]: configs.md +[support for adding dependencies]: https://doc.rust-lang.org/cargo/reference/specifying-dependencies.html#specifying-dependencies ## Using the nightly DataFusion builds @@ -155,7 +156,7 @@ By default, Datafusion returns errors as a plain text message. You can enable mo such as backtraces by enabling the `backtrace` feature to your `Cargo.toml` file like this: ```toml -datafusion = { version = "31.0.0", features = ["backtrace"]} +datafusion = { version = "52.0.0", features = ["backtrace"]} ``` Set environment [variables](https://doc.rust-lang.org/std/backtrace/index.html#environment-variables) diff --git a/docs/source/user-guide/example-usage.md b/docs/source/user-guide/example-usage.md index 6108315f398aa..46006c62241db 100644 --- a/docs/source/user-guide/example-usage.md +++ b/docs/source/user-guide/example-usage.md @@ -29,7 +29,7 @@ Find latest available Datafusion version on [DataFusion's crates.io] page. Add the dependency to your `Cargo.toml` file: ```toml -datafusion = "latest_version" +datafusion = "52.0.0" tokio = { version = "1.0", features = ["rt-multi-thread"] } ``` @@ -103,8 +103,8 @@ exported by DataFusion, for example: use datafusion::arrow::datatypes::Schema; ``` -For example, [DataFusion `25.0.0` dependencies] require `arrow` -`39.0.0`. If instead you used `arrow` `40.0.0` in your project you may +For example, [DataFusion `26.0.0` dependencies] require `arrow` +`40.0.0`. If instead you used `arrow` `41.0.0` in your project you may see errors such as: ```text diff --git a/docs/source/user-guide/explain-usage.md b/docs/source/user-guide/explain-usage.md index 5a1184539c034..c047659e9940d 100644 --- a/docs/source/user-guide/explain-usage.md +++ b/docs/source/user-guide/explain-usage.md @@ -226,8 +226,10 @@ Again, reading from bottom up: When predicate pushdown is enabled, `DataSourceExec` with `ParquetSource` gains the following metrics: - `page_index_rows_pruned`: number of rows evaluated by page index filters. The metric reports both how many rows were considered in total and how many matched (were not pruned). +- `page_index_pages_pruned`: number of pages evaluated by page index filters. The metric reports both how many pages were considered in total and how many matched (were not pruned). - `row_groups_pruned_bloom_filter`: number of row groups evaluated by Bloom Filters, reporting both total checked groups and groups that matched. - `row_groups_pruned_statistics`: number of row groups evaluated by row-group statistics (min/max), reporting both total checked groups and groups that matched. +- `limit_pruned_row_groups`: number of row groups pruned by the limit. - `pushdown_rows_matched`: rows that were tested by any of the above filters, and passed all of them. - `pushdown_rows_pruned`: rows that were tested by any of the above filters, and did not pass at least one of them. - `predicate_evaluation_errors`: number of times evaluating the filter expression failed (expected to be zero in normal operation) diff --git a/docs/source/user-guide/sql/data_types.md b/docs/source/user-guide/sql/data_types.md index 02edb6371ce3e..502193df41a64 100644 --- a/docs/source/user-guide/sql/data_types.md +++ b/docs/source/user-guide/sql/data_types.md @@ -25,6 +25,11 @@ execution. The SQL types from are mapped to [Arrow data types](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) according to the following table. This mapping occurs when defining the schema in a `CREATE EXTERNAL TABLE` command or when performing a SQL `CAST` operation. +For background on extension types and custom metadata, see the +[Implementing User Defined Types and Custom Metadata in DataFusion] blog. + +[implementing user defined types and custom metadata in datafusion]: https://datafusion.apache.org/blog/2025/09/21/custom-types-using-metadata + You can see the corresponding Arrow type for any SQL expression using the `arrow_typeof` function. For example: @@ -64,27 +69,32 @@ select arrow_cast(now(), 'Timestamp(Second, None)') as "now()"; | SQL DataType | Arrow DataType | | ------------ | -------------- | -| `CHAR` | `Utf8` | -| `VARCHAR` | `Utf8` | -| `TEXT` | `Utf8` | -| `STRING` | `Utf8` | +| `CHAR` | `Utf8View` | +| `VARCHAR` | `Utf8View` | +| `TEXT` | `Utf8View` | +| `STRING` | `Utf8View` | + +By default, string types are mapped to `Utf8View`. This can be configured using the `datafusion.sql_parser.map_string_types_to_utf8view` setting. When set to `false`, string types are mapped to `Utf8` instead. ## Numeric Types -| SQL DataType | Arrow DataType | -| ------------------------------------ | :----------------------------- | -| `TINYINT` | `Int8` | -| `SMALLINT` | `Int16` | -| `INT` or `INTEGER` | `Int32` | -| `BIGINT` | `Int64` | -| `TINYINT UNSIGNED` | `UInt8` | -| `SMALLINT UNSIGNED` | `UInt16` | -| `INT UNSIGNED` or `INTEGER UNSIGNED` | `UInt32` | -| `BIGINT UNSIGNED` | `UInt64` | -| `FLOAT` | `Float32` | -| `REAL` | `Float32` | -| `DOUBLE` | `Float64` | -| `DECIMAL(precision, scale)` | `Decimal128(precision, scale)` | +| SQL DataType | Arrow DataType | +| ------------------------------------------------ | :----------------------------- | +| `TINYINT` | `Int8` | +| `SMALLINT` | `Int16` | +| `INT` or `INTEGER` | `Int32` | +| `BIGINT` | `Int64` | +| `TINYINT UNSIGNED` | `UInt8` | +| `SMALLINT UNSIGNED` | `UInt16` | +| `INT UNSIGNED` or `INTEGER UNSIGNED` | `UInt32` | +| `BIGINT UNSIGNED` | `UInt64` | +| `FLOAT` | `Float32` | +| `REAL` | `Float32` | +| `DOUBLE` | `Float64` | +| `DECIMAL(precision, scale)` where precision ≤ 38 | `Decimal128(precision, scale)` | +| `DECIMAL(precision, scale)` where precision > 38 | `Decimal256(precision, scale)` | + +The maximum supported precision for `DECIMAL` types is 76. ## Date/Time Types @@ -126,42 +136,3 @@ You can create binary literals using a hex string literal such as | `ENUM` | _Not yet supported_ | | `SET` | _Not yet supported_ | | `DATETIME` | _Not yet supported_ | - -## Supported Arrow Types - -The following types are supported by the `arrow_typeof` function: - -| Arrow Type | -| ----------------------------------------------------------- | -| `Null` | -| `Boolean` | -| `Int8` | -| `Int16` | -| `Int32` | -| `Int64` | -| `UInt8` | -| `UInt16` | -| `UInt32` | -| `UInt64` | -| `Float16` | -| `Float32` | -| `Float64` | -| `Utf8` | -| `LargeUtf8` | -| `Binary` | -| `Timestamp(Second, None)` | -| `Timestamp(Millisecond, None)` | -| `Timestamp(Microsecond, None)` | -| `Timestamp(Nanosecond, None)` | -| `Time32` | -| `Time64` | -| `Duration(Second)` | -| `Duration(Millisecond)` | -| `Duration(Microsecond)` | -| `Duration(Nanosecond)` | -| `Interval(YearMonth)` | -| `Interval(DayTime)` | -| `Interval(MonthDayNano)` | -| `FixedSizeBinary()` (e.g. `FixedSizeBinary(16)`) | -| `Decimal128(, )` e.g. `Decimal128(3, 10)` | -| `Decimal256(, )` e.g. `Decimal256(3, 10)` | diff --git a/docs/source/user-guide/sql/format_options.md b/docs/source/user-guide/sql/format_options.md index c04a6b5d52ca5..338508031413c 100644 --- a/docs/source/user-guide/sql/format_options.md +++ b/docs/source/user-guide/sql/format_options.md @@ -153,7 +153,7 @@ The following options are available when reading or writing Parquet files. If an | DATA_PAGESIZE_LIMIT | No | Sets best effort maximum size of data page in bytes. | `'data_pagesize_limit'` | 1048576 | | DATA_PAGE_ROW_COUNT_LIMIT | No | Sets best effort maximum number of rows in data page. | `'data_page_row_count_limit'` | 20000 | | DICTIONARY_PAGE_SIZE_LIMIT | No | Sets best effort maximum dictionary page size, in bytes. | `'dictionary_page_size_limit'` | 1048576 | -| WRITE_BATCH_SIZE | No | Sets write_batch_size in bytes. | `'write_batch_size'` | 1024 | +| WRITE_BATCH_SIZE | No | Sets write_batch_size in rows. | `'write_batch_size'` | 1024 | | WRITER_VERSION | No | Sets the Parquet writer version (`1.0` or `2.0`). | `'writer_version'` | 1.0 | | SKIP_ARROW_METADATA | No | If true, skips writing Arrow schema information into the Parquet file metadata. | `'skip_arrow_metadata'` | false | | CREATED_BY | No | Sets the "created by" string in the Parquet file metadata. | `'created_by'` | datafusion version X.Y.Z | diff --git a/docs/source/user-guide/sql/index.rst b/docs/source/user-guide/sql/index.rst index a13d40334b639..f1fef45f705a8 100644 --- a/docs/source/user-guide/sql/index.rst +++ b/docs/source/user-guide/sql/index.rst @@ -22,6 +22,7 @@ SQL Reference :maxdepth: 2 data_types + struct_coercion select subqueries ddl diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 4079802d9e630..cfd8e68e11d9c 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1891,7 +1891,7 @@ split_part(str, delimiter, pos) - **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. - **delimiter**: String or character to split on. -- **pos**: Position of the part to return. +- **pos**: Position of the part to return (counting from 1). Negative values count backward from the end of the string. #### Example @@ -2068,17 +2068,17 @@ to_hex(int) ### `translate` -Translates characters in a string to specified translation characters. +Performs character-wise substitution based on a mapping. ```sql -translate(str, chars, translation) +translate(str, from, to) ``` #### Arguments - **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **chars**: Characters to translate. -- **translation**: Translation characters. Translation characters replace only characters at the same position in the **chars** string. +- **from**: The characters to be replaced. +- **to**: The characters to replace them with. Each character in **from** that is found in **str** is replaced by the character at the same index in **to**. Any characters in **from** that don't have a corresponding character in **to** are removed. If a character appears more than once in **from**, the first occurrence determines the mapping. #### Example @@ -2175,7 +2175,7 @@ encode(expression, format) #### Arguments - **expression**: Expression containing string or binary data -- **format**: Supported formats are: `base64`, `hex` +- **format**: Supported formats are: `base64`, `base64pad`, `hex` **Related functions**: @@ -2387,6 +2387,7 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo - [date_trunc](#date_trunc) - [datepart](#datepart) - [datetrunc](#datetrunc) +- [extract](#extract) - [from_unixtime](#from_unixtime) - [make_date](#make_date) - [make_time](#make_time) @@ -2519,6 +2520,7 @@ date_part(part, expression) - **part**: Part of the date to return. The following date parts are supported: - year + - isoyear (ISO 8601 week-numbering year) - quarter (emits value in inclusive range [1, 4] based on which quartile of the year the date is in) - month - week (week of the year) @@ -2531,7 +2533,7 @@ date_part(part, expression) - nanosecond - dow (day of the week where Sunday is 0) - doy (day of the year) - - epoch (seconds since Unix epoch) + - epoch (seconds since Unix epoch for timestamps/dates, total seconds for intervals) - isodow (day of the week where Monday is 0) - **expression**: Time expression to operate on. Can be a constant, column, or function. @@ -2545,6 +2547,7 @@ extract(field FROM source) #### Aliases - datepart +- extract ### `date_trunc` @@ -2593,6 +2596,10 @@ _Alias of [date_part](#date_part)._ _Alias of [date_trunc](#date_trunc)._ +### `extract` + +_Alias of [date_part](#date_part)._ + ### `from_unixtime` Converts an integer to RFC3339 timestamp format (`YYYY-MM-DDT00:00:00.000000000Z`). Integers and unsigned integers are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`) return the corresponding timestamp. @@ -2879,7 +2886,8 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000`) in the session time zone. Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') -if no [Chrono formats] are provided. Strings that parse without a time zone are treated as if they are in the +if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. +Strings that parse without a time zone are treated as if they are in the session time zone, or UTC if no session time zone is set. Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). @@ -2927,7 +2935,8 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000`) in the session time zone. Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') -if no [Chrono formats] are provided. Strings that parse without a time zone are treated as if they are in the +if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. +Strings that parse without a time zone are treated as if they are in the session time zone, or UTC if no session time zone is set. Integers, unsigned integers, and doubles are interpreted as microseconds since the unix epoch (`1970-01-01T00:00:00Z`). @@ -2970,7 +2979,8 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000`) in the session time zone. Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') -if no [Chrono formats] are provided. Strings that parse without a time zone are treated as if they are in the +if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. +Strings that parse without a time zone are treated as if they are in the session time zone, or UTC if no session time zone is set. Integers, unsigned integers, and doubles are interpreted as milliseconds since the unix epoch (`1970-01-01T00:00:00Z`). @@ -3013,7 +3023,8 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000000`) in the session time zone. Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') -if no [Chrono formats] are provided. Strings that parse without a time zone are treated as if they are in the +if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. +Strings that parse without a time zone are treated as if they are in the session time zone. Integers, unsigned integers, and doubles are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`). The session time zone can be set using the statement `SET TIMEZONE = 'desired time zone'`. @@ -3055,7 +3066,8 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo Converts a value to a timestamp (`YYYY-MM-DDT00:00:00`) in the session time zone. Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') -if no [Chrono formats] are provided. Strings that parse without a time zone are treated as if they are in the +if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. +Strings that parse without a time zone are treated as if they are in the session time zone, or UTC if no session time zone is set. Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). @@ -3096,7 +3108,11 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo ### `to_unixtime` -Converts a value to seconds since the unix epoch (`1970-01-01T00:00:00`). Supports strings, dates, timestamps, integer, unsigned integer, and float types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. Integers, unsigned integers, and floats are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00`). +Converts a value to seconds since the unix epoch (`1970-01-01T00:00:00`). +Supports strings, dates, timestamps, integer, unsigned integer, and float types as input. +Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') +if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. +Integers, unsigned integers, and floats are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00`). ```sql to_unixtime(expression[, ..., format_n]) @@ -4221,7 +4237,7 @@ array_to_string(array, delimiter[, null_string]) ### `array_union` -Returns an array of elements that are present in both arrays (all elements from both arrays) with out duplicates. +Returns an array of elements that are present in both arrays (all elements from both arrays) without duplicates. ```sql array_union(array1, array2) diff --git a/docs/source/user-guide/sql/struct_coercion.md b/docs/source/user-guide/sql/struct_coercion.md new file mode 100644 index 0000000000000..d2a32fcee2650 --- /dev/null +++ b/docs/source/user-guide/sql/struct_coercion.md @@ -0,0 +1,354 @@ + + +# Struct Type Coercion and Field Mapping + +DataFusion uses **name-based field mapping** when coercing struct types across different operations. This document explains how struct coercion works, when it applies, and how to handle NULL fields. + +## Overview: Name-Based vs Positional Mapping + +When combining structs from different sources (e.g., in UNION, array construction, or JOINs), DataFusion matches struct fields by **name** rather than by **position**. This provides more robust and predictable behavior compared to positional matching. + +### Example: Field Reordering is Handled Transparently + +```sql +-- These two structs have the same fields in different order +SELECT [{a: 1, b: 2}, {b: 3, a: 4}]; + +-- Result: Field names matched, values unified +-- [{"a": 1, "b": 2}, {"a": 4, "b": 3}] +``` + +## Coercion Paths Using Name-Based Matching + +The following query operations use name-based field mapping for struct coercion: + +### 1. Array Literal Construction + +When creating array literals with struct elements that have different field orders: + +```sql +-- Structs with reordered fields in array literal +SELECT [{x: 1, y: 2}, {y: 3, x: 4}]; + +-- Unified type: List(Struct("x": Int32, "y": Int32)) +-- Values: [{"x": 1, "y": 2}, {"x": 4, "y": 3}] +``` + +**When it applies:** + +- Array literals with struct elements: `[{...}, {...}]` +- Nested arrays with structs: `[[{x: 1}, {x: 2}]]` + +### 2. Array Construction from Columns + +When constructing arrays from table columns with different struct schemas: + +```sql +CREATE TABLE t_left (s struct(x int, y int)) AS VALUES ({x: 1, y: 2}); +CREATE TABLE t_right (s struct(y int, x int)) AS VALUES ({y: 3, x: 4}); + +-- Dynamically constructs unified array schema +SELECT [t_left.s, t_right.s] FROM t_left JOIN t_right; + +-- Result: [{"x": 1, "y": 2}, {"x": 4, "y": 3}] +``` + +**When it applies:** + +- Array construction with column references: `[col1, col2]` +- Array construction in joins with matching field names + +### 3. UNION Operations + +When combining query results with different struct field orders: + +```sql +SELECT {a: 1, b: 2} as s +UNION ALL +SELECT {b: 3, a: 4} as s; + +-- Result: {"a": 1, "b": 2} and {"a": 4, "b": 3} +``` + +**When it applies:** + +- UNION ALL with structs: field names matched across branches +- UNION (deduplicated) with structs + +### 4. Common Table Expressions (CTEs) + +When multiple CTEs produce structs with different field orders that are combined: + +```sql +WITH + t1 AS (SELECT {a: 1, b: 2} as s), + t2 AS (SELECT {b: 3, a: 4} as s) +SELECT s FROM t1 +UNION ALL +SELECT s FROM t2; + +-- Result: Field names matched across CTEs +``` + +### 5. VALUES Clauses + +When creating tables or temporary results with struct values in different field orders: + +```sql +CREATE TABLE t AS VALUES ({a: 1, b: 2}), ({b: 3, a: 4}); + +-- Table schema unified: struct(a: int, b: int) +-- Values: {a: 1, b: 2} and {a: 4, b: 3} +``` + +### 6. JOIN Operations + +When joining tables where the JOIN condition involves structs with different field orders: + +```sql +CREATE TABLE orders (customer struct(name varchar, id int)); +CREATE TABLE customers (info struct(id int, name varchar)); + +-- Join matches struct fields by name +SELECT * FROM orders +JOIN customers ON orders.customer = customers.info; +``` + +### 7. Aggregate Functions + +When collecting structs with different field orders using aggregate functions like `array_agg`: + +```sql +SELECT array_agg(s) FROM ( + SELECT {x: 1, y: 2} as s + UNION ALL + SELECT {y: 3, x: 4} as s +) t +GROUP BY category; + +-- Result: Array of structs with unified field order +``` + +### 8. Window Functions + +When using window functions with struct expressions having different field orders: + +```sql +SELECT + id, + row_number() over (partition by s order by id) as rn +FROM ( + SELECT {category: 1, value: 10} as s, 1 as id + UNION ALL + SELECT {value: 20, category: 1} as s, 2 as id +); + +-- Fields matched by name in PARTITION BY clause +``` + +## NULL Handling for Missing Fields + +When structs have different field sets, missing fields are filled with **NULL** values during coercion. + +### Example: Partial Field Overlap + +```sql +-- Struct in first position has fields: a, b +-- Struct in second position has fields: b, c +-- Unified schema includes all fields: a, b, c + +SELECT [ + CAST({a: 1, b: 2} AS STRUCT(a INT, b INT, c INT)), + CAST({b: 3, c: 4} AS STRUCT(a INT, b INT, c INT)) +]; + +-- Result: +-- [ +-- {"a": 1, "b": 2, "c": NULL}, +-- {"a": NULL, "b": 3, "c": 4} +-- ] +``` + +### Limitations + +**Field count must match exactly.** If structs have different numbers of fields and their field names don't completely overlap, the query will fail: + +```sql +-- This fails because field sets don't match: +-- t_left has {x, y} but t_right has {x, y, z} +SELECT [t_left.s, t_right.s] FROM t_left JOIN t_right; +-- Error: Cannot coerce struct with mismatched field counts +``` + +**Workaround: Use explicit CAST** + +To handle partial field overlap, explicitly cast structs to a unified schema: + +```sql +SELECT [ + CAST(t_left.s AS STRUCT(x INT, y INT, z INT)), + CAST(t_right.s AS STRUCT(x INT, y INT, z INT)) +] FROM t_left JOIN t_right; +``` + +## Migration Guide: From Positional to Name-Based Matching + +If you have existing code that relied on **positional** struct field matching, you may need to update it. + +### Example: Query That Changes Behavior + +**Old behavior (positional):** + +```sql +-- These would have been positionally mapped (left-to-right) +SELECT [{x: 1, y: 2}, {y: 3, x: 4}]; +-- Old result (positional): [{"x": 1, "y": 2}, {"y": 3, "x": 4}] +``` + +**New behavior (name-based):** + +```sql +-- Now uses name-based matching +SELECT [{x: 1, y: 2}, {y: 3, x: 4}]; +-- New result (by name): [{"x": 1, "y": 2}, {"x": 4, "y": 3}] +``` + +### Migration Steps + +1. **Review struct operations** - Look for queries that combine structs from different sources +2. **Check field names** - Verify that field names match as expected (not positions) +3. **Test with new coercion** - Run queries and verify the results match your expectations +4. **Handle field reordering** - If you need specific field orders, use explicit CAST operations + +### Using Explicit CAST for Compatibility + +If you need precise control over struct field order and types, use explicit `CAST`: + +```sql +-- Guarantee specific field order and types +SELECT CAST({b: 3, a: 4} AS STRUCT(a INT, b INT)); +-- Result: {"a": 4, "b": 3} +``` + +## Best Practices + +### 1. Be Explicit with Schema Definitions + +When joining or combining structs, define target schemas explicitly: + +```sql +-- Good: explicit schema definition +SELECT CAST(data AS STRUCT(id INT, name VARCHAR, active BOOLEAN)) +FROM external_source; +``` + +### 2. Use Named Struct Constructors + +Prefer named struct constructors for clarity: + +```sql +-- Good: field names are explicit +SELECT named_struct('id', 1, 'name', 'Alice', 'active', true); + +-- Or using struct literal syntax +SELECT {id: 1, name: 'Alice', active: true}; +``` + +### 3. Test Field Mappings + +Always verify that field mappings work as expected: + +```sql +-- Use arrow_typeof to verify unified schema +SELECT arrow_typeof([{x: 1, y: 2}, {y: 3, x: 4}]); +-- Result: List(Struct("x": Int32, "y": Int32)) +``` + +### 4. Handle Partial Field Overlap Explicitly + +When combining structs with partial field overlap, use explicit CAST: + +```sql +-- Instead of relying on implicit coercion +SELECT [ + CAST(left_struct AS STRUCT(x INT, y INT, z INT)), + CAST(right_struct AS STRUCT(x INT, y INT, z INT)) +]; +``` + +### 5. Document Struct Schemas + +In complex queries, document the expected struct schemas: + +```sql +-- Expected schema: {customer_id: INT, name: VARCHAR, age: INT} +SELECT { + customer_id: c.id, + name: c.name, + age: c.age +} as customer_info +FROM customers c; +``` + +## Error Messages and Troubleshooting + +### "Cannot coerce struct with different field counts" + +**Cause:** Trying to combine structs with different numbers of fields. + +**Solution:** + +```sql +-- Use explicit CAST to handle missing fields +SELECT [ + CAST(struct1 AS STRUCT(a INT, b INT, c INT)), + CAST(struct2 AS STRUCT(a INT, b INT, c INT)) +]; +``` + +### "Field X not found in struct" + +**Cause:** Referencing a field name that doesn't exist in the struct. + +**Solution:** + +```sql +-- Verify field names match exactly (case-sensitive) +SELECT s['field_name'] FROM my_table; -- Use bracket notation for access +-- Or use get_field function +SELECT get_field(s, 'field_name') FROM my_table; +``` + +### Unexpected NULL values after coercion + +**Cause:** Struct coercion added NULL for missing fields. + +**Solution:** Check that all structs have the required fields, or explicitly handle NULLs: + +```sql +SELECT COALESCE(s['field'], default_value) FROM my_table; +``` + +## Related Functions + +- `arrow_typeof()` - Returns the Arrow type of an expression +- `struct()` / `named_struct()` - Creates struct values +- `get_field()` - Extracts field values from structs +- `CAST()` - Explicitly casts structs to specific schemas diff --git a/test-utils/src/data_gen.rs b/test-utils/src/data_gen.rs index 2228010b28dd1..bb8fdad5a0f89 100644 --- a/test-utils/src/data_gen.rs +++ b/test-utils/src/data_gen.rs @@ -129,7 +129,7 @@ impl BatchBuilder { } } - #[allow(clippy::too_many_arguments)] + #[expect(clippy::too_many_arguments)] fn append_row( &mut self, rng: &mut StdRng,