From b1d4559fea7cd4ebadc733f746057c95189715a0 Mon Sep 17 00:00:00 2001 From: usamoi Date: Thu, 2 Jan 2025 16:01:27 +0800 Subject: [PATCH] refactor: move pgvecto.rs base to this repo Signed-off-by: usamoi --- .github/workflows/psql.yml | 6 +- .github/workflows/release.yml | 5 +- Cargo.lock | 169 +- Cargo.toml | 54 +- build.rs | 4 + crates/algorithm/Cargo.toml | 7 + crates/algorithm/src/lib.rs | 41 + crates/always_equal/Cargo.toml | 7 + crates/always_equal/src/lib.rs | 30 + crates/distance/Cargo.toml | 7 + crates/distance/src/lib.rs | 79 + crates/rabitq/Cargo.toml | 11 + crates/rabitq/src/binary.rs | 123 + crates/rabitq/src/block.rs | 172 ++ crates/rabitq/src/lib.rs | 5 + .../rabitq/src/utils.rs | 0 crates/random_orthogonal_matrix/Cargo.toml | 15 + crates/random_orthogonal_matrix/src/lib.rs | 55 + crates/simd/Cargo.toml | 18 + crates/simd/build.rs | 12 + crates/simd/cshim.c | 223 ++ crates/simd/src/aligned.rs | 4 + crates/simd/src/bit.rs | 399 +++ crates/simd/src/emulate.rs | 208 ++ crates/simd/src/f16.rs | 1185 +++++++++ crates/simd/src/f32.rs | 2277 +++++++++++++++++ crates/simd/src/fast_scan/mod.rs | 496 ++++ crates/simd/src/lib.rs | 107 + crates/simd/src/packed_u4.rs | 18 + crates/simd/src/quantize.rs | 291 +++ crates/simd/src/u8.rs | 343 +++ crates/simd_macros/Cargo.toml | 18 + crates/simd_macros/src/lib.rs | 206 ++ crates/simd_macros/src/target.rs | 83 + crates/vector/Cargo.toml | 13 + crates/vector/src/bvect.rs | 274 ++ crates/vector/src/lib.rs | 48 + {src/types => crates/vector/src}/scalar8.rs | 13 +- crates/vector/src/svect.rs | 445 ++++ crates/vector/src/vect.rs | 216 ++ rustfmt.toml | 1 + scripts/README.md | 9 +- src/bin/pgrx_embed.rs | 1 - src/datatype/binary_scalar8.rs | 4 +- src/datatype/functions_scalar8.rs | 12 +- src/datatype/memory_pgvector_halfvec.rs | 3 +- src/datatype/memory_pgvector_vector.rs | 3 +- src/datatype/memory_scalar8.rs | 4 +- src/datatype/operators_pgvector_halfvec.rs | 5 +- src/datatype/operators_pgvector_vector.rs | 5 +- src/datatype/operators_scalar8.rs | 4 +- src/datatype/text_scalar8.rs | 2 +- src/datatype/typmod.rs | 2 - src/lib.rs | 8 +- src/postgres.rs | 158 +- src/projection.rs | 72 +- src/types/mod.rs | 1 - src/utils/k_means.rs | 194 +- src/utils/mod.rs | 1 - src/utils/parallelism.rs | 5 +- src/vchordrq/algorithm/build.rs | 15 +- src/vchordrq/algorithm/insert.rs | 23 +- src/vchordrq/algorithm/mod.rs | 59 - src/vchordrq/algorithm/prewarm.rs | 2 +- src/vchordrq/algorithm/rabitq.rs | 131 +- src/vchordrq/algorithm/scan.rs | 22 +- src/vchordrq/algorithm/tuples.rs | 59 +- src/vchordrq/algorithm/vacuum.rs | 12 +- src/vchordrq/algorithm/vectors.rs | 9 +- src/vchordrq/gucs/executing.rs | 6 +- src/vchordrq/index/am.rs | 36 +- src/vchordrq/index/am_options.rs | 10 +- src/vchordrq/index/am_scan.rs | 12 +- src/vchordrq/index/functions.rs | 6 +- src/vchordrq/index/utils.rs | 10 +- src/vchordrq/types.rs | 10 +- src/vchordrqfscan/algorithm/build.rs | 33 +- src/vchordrqfscan/algorithm/insert.rs | 27 +- src/vchordrqfscan/algorithm/prewarm.rs | 4 +- src/vchordrqfscan/algorithm/rabitq.rs | 171 +- src/vchordrqfscan/algorithm/scan.rs | 21 +- src/vchordrqfscan/algorithm/tuples.rs | 8 +- src/vchordrqfscan/algorithm/vacuum.rs | 18 +- src/vchordrqfscan/gucs/executing.rs | 6 +- src/vchordrqfscan/index/am.rs | 26 +- src/vchordrqfscan/index/am_options.rs | 8 +- src/vchordrqfscan/index/am_scan.rs | 10 +- src/vchordrqfscan/index/functions.rs | 4 +- src/vchordrqfscan/index/utils.rs | 10 +- src/vchordrqfscan/types.rs | 19 +- tests/logic/reindex.slt | 2 +- tools/package.sh | 5 +- tools/schema.sh | 8 - 93 files changed, 7995 insertions(+), 988 deletions(-) create mode 100644 build.rs create mode 100644 crates/algorithm/Cargo.toml create mode 100644 crates/algorithm/src/lib.rs create mode 100644 crates/always_equal/Cargo.toml create mode 100644 crates/always_equal/src/lib.rs create mode 100644 crates/distance/Cargo.toml create mode 100644 crates/distance/src/lib.rs create mode 100644 crates/rabitq/Cargo.toml create mode 100644 crates/rabitq/src/binary.rs create mode 100644 crates/rabitq/src/block.rs create mode 100644 crates/rabitq/src/lib.rs rename src/utils/infinite_byte_chunks.rs => crates/rabitq/src/utils.rs (100%) create mode 100644 crates/random_orthogonal_matrix/Cargo.toml create mode 100644 crates/random_orthogonal_matrix/src/lib.rs create mode 100644 crates/simd/Cargo.toml create mode 100644 crates/simd/build.rs create mode 100644 crates/simd/cshim.c create mode 100644 crates/simd/src/aligned.rs create mode 100644 crates/simd/src/bit.rs create mode 100644 crates/simd/src/emulate.rs create mode 100644 crates/simd/src/f16.rs create mode 100644 crates/simd/src/f32.rs create mode 100644 crates/simd/src/fast_scan/mod.rs create mode 100644 crates/simd/src/lib.rs create mode 100644 crates/simd/src/packed_u4.rs create mode 100644 crates/simd/src/quantize.rs create mode 100644 crates/simd/src/u8.rs create mode 100644 crates/simd_macros/Cargo.toml create mode 100644 crates/simd_macros/src/lib.rs create mode 100644 crates/simd_macros/src/target.rs create mode 100644 crates/vector/Cargo.toml create mode 100644 crates/vector/src/bvect.rs create mode 100644 crates/vector/src/lib.rs rename {src/types => crates/vector/src}/scalar8.rs (93%) create mode 100644 crates/vector/src/svect.rs create mode 100644 crates/vector/src/vect.rs create mode 100644 rustfmt.toml delete mode 100644 src/types/mod.rs diff --git a/.github/workflows/psql.yml b/.github/workflows/psql.yml index 17f4bac..a631189 100644 --- a/.github/workflows/psql.yml +++ b/.github/workflows/psql.yml @@ -34,7 +34,6 @@ env: CARGO_TERM_COLOR: always RUST_BACKTRACE: 1 RUSTFLAGS: "-Dwarnings" - CARGO_PROFILE_OPT_BUILD_OVERRIDE_DEBUG: true jobs: test: @@ -71,10 +70,9 @@ jobs: env: SEMVER: "0.0.0" VERSION: ${{ matrix.version }} - PROFILE: "opt" run: | - docker run --rm -v .:/workspace $CACHE_ENVS $PGRX_IMAGE cargo build --lib --features pg${{ matrix.version }} --profile $PROFILE - docker run --rm -v .:/workspace $CACHE_ENVS $PGRX_IMAGE ./tools/schema.sh --features pg${{ matrix.version }} --profile $PROFILE + docker run --rm -v .:/workspace $CACHE_ENVS $PGRX_IMAGE cargo build --lib --features pg${{ matrix.version }} --release + docker run --rm -v .:/workspace $CACHE_ENVS $PGRX_IMAGE ./tools/schema.sh --features pg${{ matrix.version }} --release ./tools/package.sh docker build -t vchord:pg${{ matrix.version }} --build-arg PG_VERSION=${{ matrix.version }} -f ./docker/Dockerfile . diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index d7b1758..9608824 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -66,11 +66,10 @@ jobs: - name: Build env: VERSION: ${{ matrix.version }} - PROFILE: "release" GH_TOKEN: ${{ github.token }} run: | - docker run --rm -v .:/workspace $CACHE_ENVS $PGRX_IMAGE cargo build --lib --features pg$VERSION --profile $PROFILE - docker run --rm -v .:/workspace $CACHE_ENVS -e SEMVER=${SEMVER} $PGRX_IMAGE ./tools/schema.sh --features pg$VERSION --profile $PROFILE + docker run --rm -v .:/workspace $CACHE_ENVS $PGRX_IMAGE cargo build --lib --features pg$VERSION --release + docker run --rm -v .:/workspace $CACHE_ENVS -e SEMVER=${SEMVER} $PGRX_IMAGE ./tools/schema.sh --features pg$VERSION --release ./tools/package.sh ls ./build gh release upload --clobber $SEMVER ./build/vchord-pg${VERSION}_${SEMVER}_${PLATFORM}.deb diff --git a/Cargo.lock b/Cargo.lock index 9439996..5d17e1a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -22,6 +22,14 @@ dependencies = [ "memchr", ] +[[package]] +name = "algorithm" +version = "0.0.0" + +[[package]] +name = "always_equal" +version = "0.0.0" + [[package]] name = "annotate-snippets" version = "0.9.2" @@ -63,32 +71,6 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" -[[package]] -name = "base" -version = "0.0.0" -source = "git+https://github.com/tensorchord/pgvecto.rs.git?rev=9d87afd75ca3dd6819da2a0a38d9fefdfb5b1c74#9d87afd75ca3dd6819da2a0a38d9fefdfb5b1c74" -dependencies = [ - "base_macros", - "cc", - "half 2.4.1", - "libc", - "rand", - "serde", - "thiserror 2.0.9", - "toml", - "validator", -] - -[[package]] -name = "base_macros" -version = "0.0.0" -source = "git+https://github.com/tensorchord/pgvecto.rs.git?rev=9d87afd75ca3dd6819da2a0a38d9fefdfb5b1c74#9d87afd75ca3dd6819da2a0a38d9fefdfb5b1c74" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.91", -] - [[package]] name = "bindgen" version = "0.70.1" @@ -105,7 +87,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.91", + "syn 2.0.94", ] [[package]] @@ -178,9 +160,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.5" +version = "1.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c31a0499c1dc64f458ad13872de75c0eb7e3fdb0e67964610c914b034fc5956e" +checksum = "8d6dbb628b8f8555f86d0323c2eb39e3ec81901f4b83e091db8a6a76d316a333" dependencies = [ "shlex", ] @@ -282,7 +264,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn 2.0.91", + "syn 2.0.94", ] [[package]] @@ -293,7 +275,7 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" dependencies = [ "darling_core", "quote", - "syn 2.0.91", + "syn 2.0.94", ] [[package]] @@ -304,9 +286,13 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.91", + "syn 2.0.94", ] +[[package]] +name = "distance" +version = "0.0.0" + [[package]] name = "either" version = "1.13.0" @@ -330,7 +316,7 @@ checksum = "f282cfdfe92516eb26c2af8589c274c7c17681f5ecc03c18255fe741c6aa64eb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.91", + "syn 2.0.94", ] [[package]] @@ -389,9 +375,9 @@ dependencies = [ [[package]] name = "glob" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "half" @@ -402,12 +388,10 @@ checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403" [[package]] name = "half" version = "2.4.1" -source = "git+https://github.com/tensorchord/half-rs.git#5b7fedc636c0eb1624763a40840c5cbf54cffd02" +source = "git+https://github.com/tensorchord/half-rs.git#2d8a66092bee436aebb26ea7ac47d11150cda31d" dependencies = [ "cfg-if", "crunchy", - "rand", - "rand_distr", "rkyv", "serde", ] @@ -576,7 +560,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.91", + "syn 2.0.94", ] [[package]] @@ -734,7 +718,7 @@ checksum = "254a5372af8fc138e36684761d3c0cdb758a4410e938babcff1c860ce14ddbfc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.91", + "syn 2.0.94", ] [[package]] @@ -893,7 +877,7 @@ dependencies = [ "proc-macro2", "quote", "shlex", - "syn 2.0.91", + "syn 2.0.94", "walkdir", ] @@ -916,7 +900,7 @@ dependencies = [ "pgrx-sql-entity-graph", "proc-macro2", "quote", - "syn 2.0.91", + "syn 2.0.94", ] [[package]] @@ -963,7 +947,7 @@ dependencies = [ "petgraph", "proc-macro2", "quote", - "syn 2.0.91", + "syn 2.0.94", "thiserror 1.0.69", "unescape", ] @@ -996,7 +980,7 @@ dependencies = [ "proc-macro-error-attr2", "proc-macro2", "quote", - "syn 2.0.91", + "syn 2.0.94", ] [[package]] @@ -1030,13 +1014,21 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.37" +version = "1.0.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +checksum = "0e4dccaaaf89514f546c693ddc140f729f958c247918a13380cccc6078391acc" dependencies = [ "proc-macro2", ] +[[package]] +name = "rabitq" +version = "0.0.0" +dependencies = [ + "distance", + "simd", +] + [[package]] name = "radium" version = "0.7.0" @@ -1083,6 +1075,16 @@ dependencies = [ "rand", ] +[[package]] +name = "random_orthogonal_matrix" +version = "0.0.0" +dependencies = [ + "nalgebra", + "rand", + "rand_chacha", + "rand_distr", +] + [[package]] name = "rawpointer" version = "0.2.1" @@ -1241,9 +1243,9 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.216" +version = "1.0.217" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b9781016e935a97e8beecf0c933758c97a5520d32930e460142b4cd80c6338e" +checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70" dependencies = [ "serde_derive", ] @@ -1260,13 +1262,13 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.216" +version = "1.0.217" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46f859dbbf73865c6627ed570e78961cd3ac92407a2d117204c49232485da55e" +checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.91", + "syn 2.0.94", ] [[package]] @@ -1309,6 +1311,26 @@ dependencies = [ "wide", ] +[[package]] +name = "simd" +version = "0.0.0" +dependencies = [ + "cc", + "half 2.4.1", + "rand", + "serde", + "simd_macros", +] + +[[package]] +name = "simd_macros" +version = "0.0.0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.94", +] + [[package]] name = "simdutf8" version = "0.1.5" @@ -1371,9 +1393,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.91" +version = "2.0.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d53cbcb5a243bd33b7858b1d7f4aca2153490815872d86d955d6ea29f743c035" +checksum = "987bc0be1cdea8b10216bd06e2ca407d40b9543468fafd3ddfb02f36e77f71f3" dependencies = [ "proc-macro2", "quote", @@ -1388,7 +1410,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn 2.0.91", + "syn 2.0.94", ] [[package]] @@ -1423,7 +1445,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.91", + "syn 2.0.94", ] [[package]] @@ -1434,7 +1456,7 @@ checksum = "7b50fa271071aae2e6ee85f842e2e28ba8cd2c5fb67f11fcb1fd70b276f9e7d4" dependencies = [ "proc-macro2", "quote", - "syn 2.0.91", + "syn 2.0.94", ] [[package]] @@ -1591,28 +1613,41 @@ dependencies = [ "proc-macro-error2", "proc-macro2", "quote", - "syn 2.0.91", + "syn 2.0.94", ] [[package]] name = "vchord" version = "0.0.0" dependencies = [ - "base", + "algorithm", + "always_equal", + "distance", "half 2.4.1", "log", - "nalgebra", "paste", "pgrx", "pgrx-catalog", + "rabitq", "rand", - "rand_chacha", - "rand_distr", + "random_orthogonal_matrix", "rayon", "rkyv", "serde", + "simd", "toml", "validator", + "vector", +] + +[[package]] +name = "vector" +version = "0.0.0" +dependencies = [ + "distance", + "half 2.4.1", + "serde", + "simd", ] [[package]] @@ -1762,9 +1797,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.20" +version = "0.6.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b" +checksum = "e6f5bb5257f2407a5425c6e749bfd9692192a73e70a6060516ac04f889087d68" dependencies = [ "memchr", ] @@ -1819,7 +1854,7 @@ checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" dependencies = [ "proc-macro2", "quote", - "syn 2.0.91", + "syn 2.0.94", "synstructure", ] @@ -1841,7 +1876,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.91", + "syn 2.0.94", ] [[package]] @@ -1861,7 +1896,7 @@ checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" dependencies = [ "proc-macro2", "quote", - "syn 2.0.91", + "syn 2.0.94", "synstructure", ] @@ -1884,5 +1919,5 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.91", + "syn 2.0.94", ] diff --git a/Cargo.toml b/Cargo.toml index 2697ad1..13272c2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "vchord" -version = "0.0.0" -edition = "2021" +version.workspace = true +edition.workspace = true [lib] name = "vchord" @@ -20,24 +20,25 @@ pg16 = ["pgrx/pg16", "pgrx-catalog/pg16"] pg17 = ["pgrx/pg17", "pgrx-catalog/pg17"] [dependencies] -base = { git = "https://github.com/tensorchord/pgvecto.rs.git", rev = "9d87afd75ca3dd6819da2a0a38d9fefdfb5b1c74" } - -# lock algebra version forever so that the QR decomposition never changes for same input -nalgebra = "=0.33.0" +algorithm = { path = "./crates/algorithm" } +always_equal = { path = "./crates/always_equal" } +distance = { path = "./crates/distance" } +rabitq = { path = "./crates/rabitq" } +random_orthogonal_matrix = { path = "./crates/random_orthogonal_matrix" } +simd = { path = "./crates/simd" } +vector = { path = "./crates/vector" } # lock rkyv version forever so that data is always compatible rkyv = { version = "=0.7.45", features = ["validation"] } -half = { version = "2.4.1", features = ["rkyv"] } +half.workspace = true log = "0.4.22" paste = "1" pgrx = { version = "=0.12.9", default-features = false, features = ["cshim"] } pgrx-catalog = "0.1.0" -rand = "0.8.5" -rand_chacha = "0.3.1" -rand_distr = "0.4.3" +rand.workspace = true rayon = "1.10.0" -serde = "1" +serde.workspace = true toml = "0.8.19" validator = { version = "0.19.0", features = ["derive"] } @@ -45,21 +46,30 @@ validator = { version = "0.19.0", features = ["derive"] } half = { git = "https://github.com/tensorchord/half-rs.git" } [lints] -rust.fuzzy_provenance_casts = "deny" -rust.unexpected_cfgs = { level = "warn", check-cfg = [ - 'cfg(feature, values("pg12"))', - 'cfg(pgrx_embed)', -] } +workspace = true + +[workspace] +resolver = "2" +members = ["crates/*"] + +[workspace.package] +version = "0.0.0" +edition = "2021" + +[workspace.dependencies] +half = { version = "2.4.1", features = ["rkyv", "serde"] } +rand = "0.8.5" +serde = "1" + +[workspace.lints] +clippy.identity_op = "allow" +clippy.int_plus_one = "allow" +clippy.needless_range_loop = "allow" +clippy.nonminimal_bool = "allow" rust.unsafe_op_in_unsafe_fn = "deny" rust.unused_lifetimes = "warn" rust.unused_qualifications = "warn" -[profile.opt] -debug-assertions = false -inherits = "dev" -opt-level = 3 -overflow-checks = false - [profile.release] codegen-units = 1 debug = true diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..1fadc37 --- /dev/null +++ b/build.rs @@ -0,0 +1,4 @@ +fn main() { + println!(r#"cargo::rustc-check-cfg=cfg(pgrx_embed)"#); + println!(r#"cargo::rustc-check-cfg=cfg(feature, values("pg12"))"#); +} diff --git a/crates/algorithm/Cargo.toml b/crates/algorithm/Cargo.toml new file mode 100644 index 0000000..56da6ce --- /dev/null +++ b/crates/algorithm/Cargo.toml @@ -0,0 +1,7 @@ +[package] +name = "algorithm" +version.workspace = true +edition.workspace = true + +[lints] +workspace = true diff --git a/crates/algorithm/src/lib.rs b/crates/algorithm/src/lib.rs new file mode 100644 index 0000000..8bee016 --- /dev/null +++ b/crates/algorithm/src/lib.rs @@ -0,0 +1,41 @@ +use std::ops::{Deref, DerefMut}; + +#[repr(C, align(8))] +pub struct Opaque { + pub next: u32, + pub skip: u32, +} + +#[allow(clippy::len_without_is_empty)] +pub trait Page: Sized { + fn get_opaque(&self) -> &Opaque; + fn get_opaque_mut(&mut self) -> &mut Opaque; + fn len(&self) -> u16; + fn get(&self, i: u16) -> Option<&[u8]>; + fn get_mut(&mut self, i: u16) -> Option<&mut [u8]>; + fn alloc(&mut self, data: &[u8]) -> Option; + fn free(&mut self, i: u16); + fn reconstruct(&mut self, removes: &[u16]); + fn freespace(&self) -> u16; +} + +pub trait PageGuard { + fn id(&self) -> u32; +} + +pub trait RelationRead { + type Page: Page; + type ReadGuard<'a>: PageGuard + Deref + where + Self: 'a; + fn read(&self, id: u32) -> Self::ReadGuard<'_>; +} + +pub trait RelationWrite: RelationRead { + type WriteGuard<'a>: PageGuard + DerefMut + where + Self: 'a; + fn write(&self, id: u32, tracking_freespace: bool) -> Self::WriteGuard<'_>; + fn extend(&self, tracking_freespace: bool) -> Self::WriteGuard<'_>; + fn search(&self, freespace: usize) -> Option>; +} diff --git a/crates/always_equal/Cargo.toml b/crates/always_equal/Cargo.toml new file mode 100644 index 0000000..923fa8e --- /dev/null +++ b/crates/always_equal/Cargo.toml @@ -0,0 +1,7 @@ +[package] +name = "always_equal" +version.workspace = true +edition.workspace = true + +[lints] +workspace = true diff --git a/crates/always_equal/src/lib.rs b/crates/always_equal/src/lib.rs new file mode 100644 index 0000000..8cb2cc1 --- /dev/null +++ b/crates/always_equal/src/lib.rs @@ -0,0 +1,30 @@ +use std::cmp::Ordering; +use std::hash::Hash; + +#[derive(Debug, Clone, Copy, Default)] +#[repr(transparent)] +pub struct AlwaysEqual(pub T); + +impl PartialEq for AlwaysEqual { + fn eq(&self, _: &Self) -> bool { + true + } +} + +impl Eq for AlwaysEqual {} + +impl PartialOrd for AlwaysEqual { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for AlwaysEqual { + fn cmp(&self, _: &Self) -> Ordering { + Ordering::Equal + } +} + +impl Hash for AlwaysEqual { + fn hash(&self, _: &mut H) {} +} diff --git a/crates/distance/Cargo.toml b/crates/distance/Cargo.toml new file mode 100644 index 0000000..6c57819 --- /dev/null +++ b/crates/distance/Cargo.toml @@ -0,0 +1,7 @@ +[package] +name = "distance" +version.workspace = true +edition.workspace = true + +[lints] +workspace = true diff --git a/crates/distance/src/lib.rs b/crates/distance/src/lib.rs new file mode 100644 index 0000000..5f21aac --- /dev/null +++ b/crates/distance/src/lib.rs @@ -0,0 +1,79 @@ +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct Distance(i32); + +impl Distance { + pub const ZERO: Self = Distance::from_f32(0.0f32); + pub const INFINITY: Self = Distance::from_f32(f32::INFINITY); + pub const NEG_INFINITY: Self = Distance::from_f32(f32::NEG_INFINITY); + + #[inline(always)] + pub const fn from_f32(value: f32) -> Self { + let bits = value.to_bits() as i32; + let mask = ((bits >> 31) as u32) >> 1; + let res = bits ^ (mask as i32); + Self(res) + } + + #[inline(always)] + pub const fn to_f32(self) -> f32 { + let bits = self.0; + let mask = ((bits >> 31) as u32) >> 1; + let res = bits ^ (mask as i32); + f32::from_bits(res as u32) + } + + #[inline(always)] + pub const fn to_i32(self) -> i32 { + self.0 + } +} + +impl From for Distance { + #[inline(always)] + fn from(value: f32) -> Self { + Distance::from_f32(value) + } +} + +impl From for f32 { + #[inline(always)] + fn from(value: Distance) -> Self { + Distance::to_f32(value) + } +} + +#[test] +fn distance_conversions() { + assert_eq!(Distance::from(0.0f32), Distance::ZERO); + assert_eq!(Distance::from(f32::INFINITY), Distance::INFINITY); + assert_eq!(Distance::from(f32::NEG_INFINITY), Distance::NEG_INFINITY); + for i in -100..100 { + let val = (i as f32) * 0.1; + assert_eq!(f32::from(Distance::from(val)).to_bits(), val.to_bits()); + } + assert_eq!( + f32::from(Distance::from(0.0f32)).to_bits(), + 0.0f32.to_bits() + ); + assert_eq!( + f32::from(Distance::from(-0.0f32)).to_bits(), + (-0.0f32).to_bits() + ); + assert_eq!( + f32::from(Distance::from(f32::NAN)).to_bits(), + f32::NAN.to_bits() + ); + assert_eq!( + f32::from(Distance::from(-f32::NAN)).to_bits(), + (-f32::NAN).to_bits() + ); + assert_eq!( + f32::from(Distance::from(f32::INFINITY)).to_bits(), + f32::INFINITY.to_bits() + ); + assert_eq!( + f32::from(Distance::from(-f32::INFINITY)).to_bits(), + (-f32::INFINITY).to_bits() + ); +} diff --git a/crates/rabitq/Cargo.toml b/crates/rabitq/Cargo.toml new file mode 100644 index 0000000..5ca300c --- /dev/null +++ b/crates/rabitq/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "rabitq" +version.workspace = true +edition.workspace = true + +[dependencies] +distance = { path = "../distance" } +simd = { path = "../simd" } + +[lints] +workspace = true diff --git a/crates/rabitq/src/binary.rs b/crates/rabitq/src/binary.rs new file mode 100644 index 0000000..8916ffd --- /dev/null +++ b/crates/rabitq/src/binary.rs @@ -0,0 +1,123 @@ +use distance::Distance; +use simd::Floating; + +#[derive(Debug, Clone)] +pub struct Code { + pub dis_u_2: f32, + pub factor_ppc: f32, + pub factor_ip: f32, + pub factor_err: f32, + pub signs: Vec, +} + +impl Code { + pub fn t(&self) -> Vec { + use crate::utils::InfiniteByteChunks; + let mut result = Vec::new(); + for x in InfiniteByteChunks::<_, 64>::new(self.signs.iter().copied()) + .take(self.signs.len().div_ceil(64)) + { + let mut r = 0_u64; + for i in 0..64 { + r |= (x[i] as u64) << i; + } + result.push(r); + } + result + } +} + +pub fn code(dims: u32, vector: &[f32]) -> Code { + let sum_of_abs_x = f32::reduce_sum_of_abs_x(vector); + let sum_of_x_2 = f32::reduce_sum_of_x2(vector); + let dis_u = sum_of_x_2.sqrt(); + let x0 = sum_of_abs_x / (sum_of_x_2 * (dims as f32)).sqrt(); + let x_x0 = dis_u / x0; + let fac_norm = (dims as f32).sqrt(); + let max_x1 = 1.0f32 / (dims as f32 - 1.0).sqrt(); + let factor_err = 2.0f32 * max_x1 * (x_x0 * x_x0 - dis_u * dis_u).sqrt(); + let factor_ip = -2.0f32 / fac_norm * x_x0; + let cnt_pos = vector + .iter() + .map(|x| x.is_sign_positive() as i32) + .sum::(); + let cnt_neg = vector + .iter() + .map(|x| x.is_sign_negative() as i32) + .sum::(); + let factor_ppc = factor_ip * (cnt_pos - cnt_neg) as f32; + let mut signs = Vec::new(); + for i in 0..dims { + signs.push(vector[i as usize].is_sign_positive() as u8); + } + Code { + dis_u_2: sum_of_x_2, + factor_ppc, + factor_ip, + factor_err, + signs, + } +} + +pub type Lut = (f32, f32, f32, f32, (Vec, Vec, Vec, Vec)); + +pub fn preprocess(vector: &[f32]) -> Lut { + let dis_v_2 = f32::reduce_sum_of_x2(vector); + let (k, b, qvector) = simd::quantize::quantize(vector, 15.0); + let qvector_sum = if vector.len() <= 4369 { + simd::u8::reduce_sum_of_x_as_u16(&qvector) as f32 + } else { + simd::u8::reduce_sum_of_x(&qvector) as f32 + }; + (dis_v_2, b, k, qvector_sum, binarize(&qvector)) +} + +pub fn process_lowerbound_l2( + _: u32, + lut: &Lut, + (dis_u_2, factor_ppc, factor_ip, factor_err, t): (f32, f32, f32, f32, &[u64]), + epsilon: f32, +) -> Distance { + let &(dis_v_2, b, k, qvector_sum, ref s) = lut; + let value = asymmetric_binary_dot_product(t, s) as u16; + let rough = + dis_u_2 + dis_v_2 + b * factor_ppc + ((2.0 * value as f32) - qvector_sum) * factor_ip * k; + let err = factor_err * dis_v_2.sqrt(); + Distance::from_f32(rough - epsilon * err) +} + +pub fn process_lowerbound_dot( + _: u32, + lut: &Lut, + (_, factor_ppc, factor_ip, factor_err, t): (f32, f32, f32, f32, &[u64]), + epsilon: f32, +) -> Distance { + let &(dis_v_2, b, k, qvector_sum, ref s) = lut; + let value = asymmetric_binary_dot_product(t, s) as u16; + let rough = 0.5 * b * factor_ppc + 0.5 * ((2.0 * value as f32) - qvector_sum) * factor_ip * k; + let err = 0.5 * factor_err * dis_v_2.sqrt(); + Distance::from_f32(rough - epsilon * err) +} + +fn binarize(vector: &[u8]) -> (Vec, Vec, Vec, Vec) { + let n = vector.len(); + let mut t0 = vec![0u64; n.div_ceil(64)]; + let mut t1 = vec![0u64; n.div_ceil(64)]; + let mut t2 = vec![0u64; n.div_ceil(64)]; + let mut t3 = vec![0u64; n.div_ceil(64)]; + for i in 0..n { + t0[i / 64] |= (((vector[i] >> 0) & 1) as u64) << (i % 64); + t1[i / 64] |= (((vector[i] >> 1) & 1) as u64) << (i % 64); + t2[i / 64] |= (((vector[i] >> 2) & 1) as u64) << (i % 64); + t3[i / 64] |= (((vector[i] >> 3) & 1) as u64) << (i % 64); + } + (t0, t1, t2, t3) +} + +fn asymmetric_binary_dot_product(x: &[u64], y: &(Vec, Vec, Vec, Vec)) -> u32 { + let t0 = simd::bit::sum_of_and(x, &y.0); + let t1 = simd::bit::sum_of_and(x, &y.1); + let t2 = simd::bit::sum_of_and(x, &y.2); + let t3 = simd::bit::sum_of_and(x, &y.3); + (t0 << 0) + (t1 << 1) + (t2 << 2) + (t3 << 3) +} diff --git a/crates/rabitq/src/block.rs b/crates/rabitq/src/block.rs new file mode 100644 index 0000000..0a8dab7 --- /dev/null +++ b/crates/rabitq/src/block.rs @@ -0,0 +1,172 @@ +use distance::Distance; +use simd::Floating; + +#[derive(Debug, Clone)] +pub struct Code { + pub dis_u_2: f32, + pub factor_ppc: f32, + pub factor_ip: f32, + pub factor_err: f32, + pub signs: Vec, +} + +pub fn code(dims: u32, vector: &[f32]) -> Code { + let sum_of_abs_x = f32::reduce_sum_of_abs_x(vector); + let sum_of_x_2 = f32::reduce_sum_of_x2(vector); + let dis_u = sum_of_x_2.sqrt(); + let x0 = sum_of_abs_x / (sum_of_x_2 * (dims as f32)).sqrt(); + let x_x0 = dis_u / x0; + let fac_norm = (dims as f32).sqrt(); + let max_x1 = 1.0f32 / (dims as f32 - 1.0).sqrt(); + let factor_err = 2.0f32 * max_x1 * (x_x0 * x_x0 - dis_u * dis_u).sqrt(); + let factor_ip = -2.0f32 / fac_norm * x_x0; + let cnt_pos = vector + .iter() + .map(|x| x.is_sign_positive() as i32) + .sum::(); + let cnt_neg = vector + .iter() + .map(|x| x.is_sign_negative() as i32) + .sum::(); + let factor_ppc = factor_ip * (cnt_pos - cnt_neg) as f32; + let mut signs = Vec::new(); + for i in 0..dims { + signs.push(vector[i as usize].is_sign_positive() as u8); + } + Code { + dis_u_2: sum_of_x_2, + factor_ppc, + factor_ip, + factor_err, + signs, + } +} + +pub fn dummy_code(dims: u32) -> Code { + Code { + dis_u_2: 0.0, + factor_ppc: 0.0, + factor_ip: 0.0, + factor_err: 0.0, + signs: vec![0; dims as _], + } +} + +pub struct PackedCodes { + pub dis_u_2: [f32; 32], + pub factor_ppc: [f32; 32], + pub factor_ip: [f32; 32], + pub factor_err: [f32; 32], + pub t: Vec, +} + +pub fn pack_codes(dims: u32, codes: [Code; 32]) -> PackedCodes { + use crate::utils::InfiniteByteChunks; + PackedCodes { + dis_u_2: std::array::from_fn(|i| codes[i].dis_u_2), + factor_ppc: std::array::from_fn(|i| codes[i].factor_ppc), + factor_ip: std::array::from_fn(|i| codes[i].factor_ip), + factor_err: std::array::from_fn(|i| codes[i].factor_err), + t: { + let signs = codes.map(|code| { + InfiniteByteChunks::new(code.signs.into_iter()) + .map(|[b0, b1, b2, b3]| b0 | b1 << 1 | b2 << 2 | b3 << 3) + .take(dims.div_ceil(4) as usize) + .collect::>() + }); + simd::fast_scan::pack(dims.div_ceil(4), signs).collect() + }, + } +} + +pub fn fscan_preprocess(vector: &[f32]) -> (f32, f32, f32, f32, Vec) { + let dis_v_2 = f32::reduce_sum_of_x2(vector); + let (k, b, qvector) = simd::quantize::quantize(vector, 15.0); + let qvector_sum = if vector.len() <= 4369 { + simd::u8::reduce_sum_of_x_as_u16(&qvector) as f32 + } else { + simd::u8::reduce_sum_of_x(&qvector) as f32 + }; + (dis_v_2, b, k, qvector_sum, compress(qvector)) +} + +pub fn fscan_process_lowerbound_l2( + dims: u32, + lut: &(f32, f32, f32, f32, Vec), + (dis_u_2, factor_ppc, factor_ip, factor_err, t): ( + &[f32; 32], + &[f32; 32], + &[f32; 32], + &[f32; 32], + &[u8], + ), + epsilon: f32, +) -> [Distance; 32] { + let &(dis_v_2, b, k, qvector_sum, ref s) = lut; + let r = simd::fast_scan::fast_scan(dims.div_ceil(4), t, s); + std::array::from_fn(|i| { + let rough = dis_u_2[i] + + dis_v_2 + + b * factor_ppc[i] + + ((2.0 * r[i] as f32) - qvector_sum) * factor_ip[i] * k; + let err = factor_err[i] * dis_v_2.sqrt(); + Distance::from_f32(rough - epsilon * err) + }) +} + +pub fn fscan_process_lowerbound_dot( + dims: u32, + lut: &(f32, f32, f32, f32, Vec), + (_, factor_ppc, factor_ip, factor_err, t): ( + &[f32; 32], + &[f32; 32], + &[f32; 32], + &[f32; 32], + &[u8], + ), + epsilon: f32, +) -> [Distance; 32] { + let &(dis_v_2, b, k, qvector_sum, ref s) = lut; + let r = simd::fast_scan::fast_scan(dims.div_ceil(4), t, s); + std::array::from_fn(|i| { + let rough = + 0.5 * b * factor_ppc[i] + 0.5 * ((2.0 * r[i] as f32) - qvector_sum) * factor_ip[i] * k; + let err = 0.5 * factor_err[i] * dis_v_2.sqrt(); + Distance::from_f32(rough - epsilon * err) + }) +} + +fn compress(mut qvector: Vec) -> Vec { + let dims = qvector.len() as u32; + let width = dims.div_ceil(4); + qvector.resize(qvector.len().next_multiple_of(4), 0); + let mut t = vec![0u8; width as usize * 16]; + for i in 0..width as usize { + unsafe { + // this hint is used to skip bound checks + std::hint::assert_unchecked(4 * i + 3 < qvector.len()); + std::hint::assert_unchecked(16 * i + 15 < t.len()); + } + let t0 = qvector[4 * i + 0]; + let t1 = qvector[4 * i + 1]; + let t2 = qvector[4 * i + 2]; + let t3 = qvector[4 * i + 3]; + t[16 * i + 0b0000] = 0; + t[16 * i + 0b0001] = t0; + t[16 * i + 0b0010] = t1; + t[16 * i + 0b0011] = t1 + t0; + t[16 * i + 0b0100] = t2; + t[16 * i + 0b0101] = t2 + t0; + t[16 * i + 0b0110] = t2 + t1; + t[16 * i + 0b0111] = t2 + t1 + t0; + t[16 * i + 0b1000] = t3; + t[16 * i + 0b1001] = t3 + t0; + t[16 * i + 0b1010] = t3 + t1; + t[16 * i + 0b1011] = t3 + t1 + t0; + t[16 * i + 0b1100] = t3 + t2; + t[16 * i + 0b1101] = t3 + t2 + t0; + t[16 * i + 0b1110] = t3 + t2 + t1; + t[16 * i + 0b1111] = t3 + t2 + t1 + t0; + } + t +} diff --git a/crates/rabitq/src/lib.rs b/crates/rabitq/src/lib.rs new file mode 100644 index 0000000..0796b7a --- /dev/null +++ b/crates/rabitq/src/lib.rs @@ -0,0 +1,5 @@ +#![allow(clippy::type_complexity)] + +pub mod binary; +pub mod block; +mod utils; diff --git a/src/utils/infinite_byte_chunks.rs b/crates/rabitq/src/utils.rs similarity index 100% rename from src/utils/infinite_byte_chunks.rs rename to crates/rabitq/src/utils.rs diff --git a/crates/random_orthogonal_matrix/Cargo.toml b/crates/random_orthogonal_matrix/Cargo.toml new file mode 100644 index 0000000..16c3a39 --- /dev/null +++ b/crates/random_orthogonal_matrix/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "random_orthogonal_matrix" +version.workspace = true +edition.workspace = true + +[dependencies] +# lock algebra version forever so that the QR decomposition never changes for same input +nalgebra = "=0.33.0" + +rand.workspace = true +rand_chacha = "0.3.1" +rand_distr = "0.4.3" + +[lints] +workspace = true diff --git a/crates/random_orthogonal_matrix/src/lib.rs b/crates/random_orthogonal_matrix/src/lib.rs new file mode 100644 index 0000000..e271150 --- /dev/null +++ b/crates/random_orthogonal_matrix/src/lib.rs @@ -0,0 +1,55 @@ +use nalgebra::DMatrix; + +#[ignore] +#[test] +fn check_full_rank_matrix() { + let parallelism = std::thread::available_parallelism().unwrap().get(); + std::thread::scope(|scope| { + let mut threads = vec![]; + for remainder in 0..parallelism { + threads.push(scope.spawn(move || { + for n in (0..=60000).filter(|x| x % parallelism == remainder) { + let matrix = random_full_rank_matrix(n); + assert!(matrix.is_invertible()); + } + })); + } + for thread in threads { + thread.join().unwrap(); + } + }); +} + +fn random_full_rank_matrix(n: usize) -> DMatrix { + use rand::{Rng, SeedableRng}; + use rand_chacha::ChaCha12Rng; + use rand_distr::StandardNormal; + let mut rng = ChaCha12Rng::from_seed([7; 32]); + DMatrix::from_fn(n, n, |_, _| rng.sample(StandardNormal)) +} + +#[test] +fn check_random_orthogonal_matrix() { + assert_eq!(random_orthogonal_matrix(2), vec![ + vec![-0.5424608, -0.8400813], + vec![0.8400813, -0.54246056] + ]); + assert_eq!(random_orthogonal_matrix(3), vec![ + vec![-0.5309615, -0.69094884, -0.49058124], + vec![0.8222731, -0.56002235, -0.10120347], + vec![0.20481002, 0.45712686, -0.86549866] + ]); +} + +pub fn random_orthogonal_matrix(n: usize) -> Vec> { + use nalgebra::QR; + let matrix = random_full_rank_matrix(n); + // QR decomposition is unique if the matrix is full rank + let qr = QR::new(matrix); + let q = qr.q(); + let mut projection = Vec::new(); + for row in q.row_iter() { + projection.push(row.iter().copied().collect::>()); + } + projection +} diff --git a/crates/simd/Cargo.toml b/crates/simd/Cargo.toml new file mode 100644 index 0000000..0905fb1 --- /dev/null +++ b/crates/simd/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "simd" +version.workspace = true +edition.workspace = true + +[dependencies] +half.workspace = true +serde.workspace = true +simd_macros = { path = "../simd_macros" } + +[dev-dependencies] +rand.workspace = true + +[build-dependencies] +cc = "1.2.6" + +[lints] +workspace = true diff --git a/crates/simd/build.rs b/crates/simd/build.rs new file mode 100644 index 0000000..22ebf93 --- /dev/null +++ b/crates/simd/build.rs @@ -0,0 +1,12 @@ +fn main() { + println!("cargo::rerun-if-changed=cshim.c"); + cc::Build::new() + .compiler("clang") + .file("cshim.c") + .opt_level(3) + .flag("-fassociative-math") + .flag("-ffp-contract=fast") + .flag("-freciprocal-math") + .flag("-fno-signed-zeros") + .compile("base_cshim"); +} diff --git a/crates/simd/cshim.c b/crates/simd/cshim.c new file mode 100644 index 0000000..1374bbf --- /dev/null +++ b/crates/simd/cshim.c @@ -0,0 +1,223 @@ +#if !(__clang_major__ >= 16) +#error "clang version must be >= 16" +#endif + +#include +#include + +#ifdef __aarch64__ + +#include +#include + +__attribute__((target("v8.3a,fp16"))) float +fp16_reduce_sum_of_xy_v8_3a_fp16_unroll(__fp16 *__restrict a, + __fp16 *__restrict b, size_t n) { + float16x8_t xy_0 = vdupq_n_f16(0.0); + float16x8_t xy_1 = vdupq_n_f16(0.0); + float16x8_t xy_2 = vdupq_n_f16(0.0); + float16x8_t xy_3 = vdupq_n_f16(0.0); + while (n >= 32) { + float16x8_t x_0 = vld1q_f16(a + 0); + float16x8_t x_1 = vld1q_f16(a + 8); + float16x8_t x_2 = vld1q_f16(a + 16); + float16x8_t x_3 = vld1q_f16(a + 24); + float16x8_t y_0 = vld1q_f16(b + 0); + float16x8_t y_1 = vld1q_f16(b + 8); + float16x8_t y_2 = vld1q_f16(b + 16); + float16x8_t y_3 = vld1q_f16(b + 24); + a += 32; + b += 32; + n -= 32; + xy_0 = vfmaq_f16(xy_0, x_0, y_0); + xy_1 = vfmaq_f16(xy_1, x_1, y_1); + xy_2 = vfmaq_f16(xy_2, x_2, y_2); + xy_3 = vfmaq_f16(xy_3, x_3, y_3); + } + if (n > 0) { + __fp16 A[32] = {}; + __fp16 B[32] = {}; + for (size_t i = 0; i < n; i += 1) { + A[i] = a[i]; + B[i] = b[i]; + } + float16x8_t x_0 = vld1q_f16(A + 0); + float16x8_t x_1 = vld1q_f16(A + 8); + float16x8_t x_2 = vld1q_f16(A + 16); + float16x8_t x_3 = vld1q_f16(A + 24); + float16x8_t y_0 = vld1q_f16(B + 0); + float16x8_t y_1 = vld1q_f16(B + 8); + float16x8_t y_2 = vld1q_f16(B + 16); + float16x8_t y_3 = vld1q_f16(B + 24); + xy_0 = vfmaq_f16(xy_0, x_0, y_0); + xy_1 = vfmaq_f16(xy_1, x_1, y_1); + xy_2 = vfmaq_f16(xy_2, x_2, y_2); + xy_3 = vfmaq_f16(xy_3, x_3, y_3); + } + float16x8_t xy = vaddq_f16(vaddq_f16(xy_0, xy_1), vaddq_f16(xy_2, xy_3)); + return vgetq_lane_f16(xy, 0) + vgetq_lane_f16(xy, 1) + vgetq_lane_f16(xy, 2) + + vgetq_lane_f16(xy, 3) + vgetq_lane_f16(xy, 4) + vgetq_lane_f16(xy, 5) + + vgetq_lane_f16(xy, 6) + vgetq_lane_f16(xy, 7); +} + +__attribute__((target("v8.3a,sve"))) float +fp16_reduce_sum_of_xy_v8_3a_sve(__fp16 *__restrict a, __fp16 *__restrict b, + size_t n) { + svfloat16_t xy = svdup_f16(0.0); + for (size_t i = 0; i < n; i += svcnth()) { + svbool_t mask = svwhilelt_b16(i, n); + svfloat16_t x = svld1_f16(mask, a + i); + svfloat16_t y = svld1_f16(mask, b + i); + xy = svmla_f16_x(mask, xy, x, y); + } + return svaddv_f16(svptrue_b16(), xy); +} + +__attribute__((target("v8.3a,fp16"))) float +fp16_reduce_sum_of_d2_v8_3a_fp16_unroll(__fp16 *__restrict a, + __fp16 *__restrict b, size_t n) { + float16x8_t d2_0 = vdupq_n_f16(0.0); + float16x8_t d2_1 = vdupq_n_f16(0.0); + float16x8_t d2_2 = vdupq_n_f16(0.0); + float16x8_t d2_3 = vdupq_n_f16(0.0); + while (n >= 32) { + float16x8_t x_0 = vld1q_f16(a + 0); + float16x8_t x_1 = vld1q_f16(a + 8); + float16x8_t x_2 = vld1q_f16(a + 16); + float16x8_t x_3 = vld1q_f16(a + 24); + float16x8_t y_0 = vld1q_f16(b + 0); + float16x8_t y_1 = vld1q_f16(b + 8); + float16x8_t y_2 = vld1q_f16(b + 16); + float16x8_t y_3 = vld1q_f16(b + 24); + a += 32; + b += 32; + n -= 32; + float16x8_t d_0 = vsubq_f16(x_0, y_0); + float16x8_t d_1 = vsubq_f16(x_1, y_1); + float16x8_t d_2 = vsubq_f16(x_2, y_2); + float16x8_t d_3 = vsubq_f16(x_3, y_3); + d2_0 = vfmaq_f16(d2_0, d_0, d_0); + d2_1 = vfmaq_f16(d2_1, d_1, d_1); + d2_2 = vfmaq_f16(d2_2, d_2, d_2); + d2_3 = vfmaq_f16(d2_3, d_3, d_3); + } + if (n > 0) { + __fp16 A[32] = {}; + __fp16 B[32] = {}; + for (size_t i = 0; i < n; i += 1) { + A[i] = a[i]; + B[i] = b[i]; + } + float16x8_t x_0 = vld1q_f16(A + 0); + float16x8_t x_1 = vld1q_f16(A + 8); + float16x8_t x_2 = vld1q_f16(A + 16); + float16x8_t x_3 = vld1q_f16(A + 24); + float16x8_t y_0 = vld1q_f16(B + 0); + float16x8_t y_1 = vld1q_f16(B + 8); + float16x8_t y_2 = vld1q_f16(B + 16); + float16x8_t y_3 = vld1q_f16(B + 24); + float16x8_t d_0 = vsubq_f16(x_0, y_0); + float16x8_t d_1 = vsubq_f16(x_1, y_1); + float16x8_t d_2 = vsubq_f16(x_2, y_2); + float16x8_t d_3 = vsubq_f16(x_3, y_3); + d2_0 = vfmaq_f16(d2_0, d_0, d_0); + d2_1 = vfmaq_f16(d2_1, d_1, d_1); + d2_2 = vfmaq_f16(d2_2, d_2, d_2); + d2_3 = vfmaq_f16(d2_3, d_3, d_3); + } + float16x8_t d2 = vaddq_f16(vaddq_f16(d2_0, d2_1), vaddq_f16(d2_2, d2_3)); + return vgetq_lane_f16(d2, 0) + vgetq_lane_f16(d2, 1) + vgetq_lane_f16(d2, 2) + + vgetq_lane_f16(d2, 3) + vgetq_lane_f16(d2, 4) + vgetq_lane_f16(d2, 5) + + vgetq_lane_f16(d2, 6) + vgetq_lane_f16(d2, 7); +} + +__attribute__((target("v8.3a,sve"))) float +fp16_reduce_sum_of_d2_v8_3a_sve(__fp16 *__restrict a, __fp16 *__restrict b, + size_t n) { + svfloat16_t d2 = svdup_f16(0.0); + for (size_t i = 0; i < n; i += svcnth()) { + svbool_t mask = svwhilelt_b16(i, n); + svfloat16_t x = svld1_f16(mask, a + i); + svfloat16_t y = svld1_f16(mask, b + i); + svfloat16_t d = svsub_f16_x(mask, x, y); + d2 = svmla_f16_x(mask, d2, d, d); + } + return svaddv_f16(svptrue_b16(), d2); +} + +__attribute__((target("v8.3a,sve"))) float +fp32_reduce_sum_of_x_v8_3a_sve(float *__restrict this, size_t n) { + svfloat32_t sum = svdup_f32(0.0); + for (size_t i = 0; i < n; i += svcntw()) { + svbool_t mask = svwhilelt_b32(i, n); + svfloat32_t x = svld1_f32(mask, this + i); + sum = svadd_f32_x(mask, sum, x); + } + return svaddv_f32(svptrue_b32(), sum); +} + +__attribute__((target("v8.3a,sve"))) float +fp32_reduce_sum_of_abs_x_v8_3a_sve(float *__restrict this, size_t n) { + svfloat32_t sum = svdup_f32(0.0); + for (size_t i = 0; i < n; i += svcntw()) { + svbool_t mask = svwhilelt_b32(i, n); + svfloat32_t x = svld1_f32(mask, this + i); + sum = svadd_f32_x(mask, sum, svabs_f32_x(mask, x)); + } + return svaddv_f32(svptrue_b32(), sum); +} + +__attribute__((target("v8.3a,sve"))) float +fp32_reduce_sum_of_x2_v8_3a_sve(float *__restrict this, size_t n) { + svfloat32_t sum = svdup_f32(0.0); + for (size_t i = 0; i < n; i += svcntw()) { + svbool_t mask = svwhilelt_b32(i, n); + svfloat32_t x = svld1_f32(mask, this + i); + sum = svmla_f32_x(mask, sum, x, x); + } + return svaddv_f32(svptrue_b32(), sum); +} + +__attribute__((target("v8.3a,sve"))) void +fp32_reduce_min_max_of_x_v8_3a_sve(float *__restrict this, size_t n, + float *out_min, float *out_max) { + svfloat32_t min = svdup_f32(1.0 / 0.0); + svfloat32_t max = svdup_f32(-1.0 / 0.0); + for (size_t i = 0; i < n; i += svcntw()) { + svbool_t mask = svwhilelt_b32(i, n); + svfloat32_t x = svld1_f32(mask, this + i); + min = svmin_f32_x(mask, min, x); + max = svmax_f32_x(mask, max, x); + } + *out_min = svminv_f32(svptrue_b32(), min); + *out_max = svmaxv_f32(svptrue_b32(), max); +} + +__attribute__((target("v8.3a,sve"))) float +fp32_reduce_sum_of_xy_v8_3a_sve(float *__restrict lhs, float *__restrict rhs, + size_t n) { + svfloat32_t sum = svdup_f32(0.0); + for (size_t i = 0; i < n; i += svcntw()) { + svbool_t mask = svwhilelt_b32(i, n); + svfloat32_t x = svld1_f32(mask, lhs + i); + svfloat32_t y = svld1_f32(mask, rhs + i); + sum = svmla_f32_x(mask, sum, x, y); + } + return svaddv_f32(svptrue_b32(), sum); +} + +__attribute__((target("v8.3a,sve"))) float +fp32_reduce_sum_of_d2_v8_3a_sve(float *__restrict lhs, float *__restrict rhs, + size_t n) { + svfloat32_t sum = svdup_f32(0.0); + for (size_t i = 0; i < n; i += svcntw()) { + svbool_t mask = svwhilelt_b32(i, n); + svfloat32_t x = svld1_f32(mask, lhs + i); + svfloat32_t y = svld1_f32(mask, rhs + i); + svfloat32_t d = svsub_f32_x(mask, x, y); + sum = svmla_f32_x(mask, sum, d, d); + } + return svaddv_f32(svptrue_b32(), sum); +} + +#endif diff --git a/crates/simd/src/aligned.rs b/crates/simd/src/aligned.rs new file mode 100644 index 0000000..18db9d8 --- /dev/null +++ b/crates/simd/src/aligned.rs @@ -0,0 +1,4 @@ +#[allow(dead_code)] +#[derive(Debug, Clone, Copy)] +#[repr(C, align(16))] +pub struct Aligned16(pub T); diff --git a/crates/simd/src/bit.rs b/crates/simd/src/bit.rs new file mode 100644 index 0000000..dc417bd --- /dev/null +++ b/crates/simd/src/bit.rs @@ -0,0 +1,399 @@ +#[inline(always)] +pub fn sum_of_and(lhs: &[u64], rhs: &[u64]) -> u32 { + sum_of_and::sum_of_and(lhs, rhs) +} + +mod sum_of_and { + // FIXME: add manually-implemented SIMD version for AVX512 and AVX2 + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v4")] + #[target_feature(enable = "avx512vpopcntdq")] + fn sum_of_and_v4_avx512vpopcntdq(lhs: &[u64], rhs: &[u64]) -> u32 { + assert!(lhs.len() == rhs.len()); + unsafe { + use std::arch::x86_64::*; + let mut and = _mm512_setzero_si512(); + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut n = lhs.len(); + while n >= 8 { + let x = _mm512_loadu_si512(a.cast()); + let y = _mm512_loadu_si512(b.cast()); + a = a.add(8); + b = b.add(8); + n -= 8; + and = _mm512_add_epi64(and, _mm512_popcnt_epi64(_mm512_and_si512(x, y))); + } + if n > 0 { + let mask = _bzhi_u32(0xff, n as u32) as u8; + let x = _mm512_maskz_loadu_epi64(mask, a.cast()); + let y = _mm512_maskz_loadu_epi64(mask, b.cast()); + and = _mm512_add_epi64(and, _mm512_popcnt_epi64(_mm512_and_si512(x, y))); + } + _mm512_reduce_add_epi64(and) as u32 + } + } + + #[cfg(all(target_arch = "x86_64", test, not(miri)))] + #[test] + fn sum_of_and_v4_avx512vpopcntdq_test() { + if !crate::is_cpu_detected!("v4") || !crate::is_feature_detected!("avx512vpopcntdq") { + println!("test {} ... skipped (v4:avx512vpopcntdq)", module_path!()); + return; + } + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let lhs = (0..126).map(|_| rand::random::()).collect::>(); + let rhs = (0..126).map(|_| rand::random::()).collect::>(); + let specialized = unsafe { sum_of_and_v4_avx512vpopcntdq(&lhs, &rhs) }; + let fallback = sum_of_and_fallback(&lhs, &rhs); + assert_eq!(specialized, fallback); + } + } + + #[crate::multiversion(@"v4:avx512vpopcntdq", "v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn sum_of_and(lhs: &[u64], rhs: &[u64]) -> u32 { + assert_eq!(lhs.len(), rhs.len()); + let n = lhs.len(); + let mut and = 0; + for i in 0..n { + and += (lhs[i] & rhs[i]).count_ones(); + } + and + } +} + +#[inline(always)] +pub fn sum_of_or(lhs: &[u64], rhs: &[u64]) -> u32 { + sum_of_or::sum_of_or(lhs, rhs) +} + +mod sum_of_or { + // FIXME: add manually-implemented SIMD version for AVX512 and AVX2 + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v4")] + #[target_feature(enable = "avx512vpopcntdq")] + fn sum_of_or_v4_avx512vpopcntdq(lhs: &[u64], rhs: &[u64]) -> u32 { + assert!(lhs.len() == rhs.len()); + unsafe { + use std::arch::x86_64::*; + let mut or = _mm512_setzero_si512(); + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut n = lhs.len(); + while n >= 8 { + let x = _mm512_loadu_si512(a.cast()); + let y = _mm512_loadu_si512(b.cast()); + a = a.add(8); + b = b.add(8); + n -= 8; + or = _mm512_add_epi64(or, _mm512_popcnt_epi64(_mm512_or_si512(x, y))); + } + if n > 0 { + let mask = _bzhi_u32(0xff, n as u32) as u8; + let x = _mm512_maskz_loadu_epi64(mask, a.cast()); + let y = _mm512_maskz_loadu_epi64(mask, b.cast()); + or = _mm512_add_epi64(or, _mm512_popcnt_epi64(_mm512_or_si512(x, y))); + } + _mm512_reduce_add_epi64(or) as u32 + } + } + + #[cfg(all(target_arch = "x86_64", test, not(miri)))] + #[test] + fn sum_of_or_v4_avx512vpopcntdq_test() { + if !crate::is_cpu_detected!("v4") || !crate::is_feature_detected!("avx512vpopcntdq") { + println!("test {} ... skipped (v4:avx512vpopcntdq)", module_path!()); + return; + } + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let lhs = (0..126).map(|_| rand::random::()).collect::>(); + let rhs = (0..126).map(|_| rand::random::()).collect::>(); + let specialized = unsafe { sum_of_or_v4_avx512vpopcntdq(&lhs, &rhs) }; + let fallback = sum_of_or_fallback(&lhs, &rhs); + assert_eq!(specialized, fallback); + } + } + + #[crate::multiversion(@"v4:avx512vpopcntdq", "v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn sum_of_or(lhs: &[u64], rhs: &[u64]) -> u32 { + assert_eq!(lhs.len(), rhs.len()); + let n = lhs.len(); + let mut or = 0; + for i in 0..n { + or += (lhs[i] | rhs[i]).count_ones(); + } + or + } +} + +#[inline(always)] +pub fn sum_of_xor(lhs: &[u64], rhs: &[u64]) -> u32 { + sum_of_xor::sum_of_xor(lhs, rhs) +} + +mod sum_of_xor { + // FIXME: add manually-implemented SIMD version for AVX512 and AVX2 + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v4")] + #[target_feature(enable = "avx512vpopcntdq")] + fn sum_of_xor_v4_avx512vpopcntdq(lhs: &[u64], rhs: &[u64]) -> u32 { + assert!(lhs.len() == rhs.len()); + unsafe { + use std::arch::x86_64::*; + let mut xor = _mm512_setzero_si512(); + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut n = lhs.len(); + while n >= 8 { + let x = _mm512_loadu_si512(a.cast()); + let y = _mm512_loadu_si512(b.cast()); + a = a.add(8); + b = b.add(8); + n -= 8; + xor = _mm512_add_epi64(xor, _mm512_popcnt_epi64(_mm512_xor_si512(x, y))); + } + if n > 0 { + let mask = _bzhi_u32(0xff, n as u32) as u8; + let x = _mm512_maskz_loadu_epi64(mask, a.cast()); + let y = _mm512_maskz_loadu_epi64(mask, b.cast()); + xor = _mm512_add_epi64(xor, _mm512_popcnt_epi64(_mm512_xor_si512(x, y))); + } + _mm512_reduce_add_epi64(xor) as u32 + } + } + + #[cfg(all(target_arch = "x86_64", test, not(miri)))] + #[test] + fn sum_of_xor_v4_avx512vpopcntdq_test() { + if !crate::is_cpu_detected!("v4") || !crate::is_feature_detected!("avx512vpopcntdq") { + println!("test {} ... skipped (v4:avx512vpopcntdq)", module_path!()); + return; + } + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let lhs = (0..126).map(|_| rand::random::()).collect::>(); + let rhs = (0..126).map(|_| rand::random::()).collect::>(); + let specialized = unsafe { sum_of_xor_v4_avx512vpopcntdq(&lhs, &rhs) }; + let fallback = sum_of_xor_fallback(&lhs, &rhs); + assert_eq!(specialized, fallback); + } + } + + #[crate::multiversion(@"v4:avx512vpopcntdq", "v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn sum_of_xor(lhs: &[u64], rhs: &[u64]) -> u32 { + assert_eq!(lhs.len(), rhs.len()); + let n = lhs.len(); + let mut xor = 0; + for i in 0..n { + xor += (lhs[i] ^ rhs[i]).count_ones(); + } + xor + } +} + +#[inline(always)] +pub fn sum_of_and_or(lhs: &[u64], rhs: &[u64]) -> (u32, u32) { + sum_of_and_or::sum_of_and_or(lhs, rhs) +} + +mod sum_of_and_or { + // FIXME: add manually-implemented SIMD version for AVX512 and AVX2 + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v4")] + #[target_feature(enable = "avx512vpopcntdq")] + fn sum_of_and_or_v4_avx512vpopcntdq(lhs: &[u64], rhs: &[u64]) -> (u32, u32) { + assert!(lhs.len() == rhs.len()); + unsafe { + use std::arch::x86_64::*; + let mut and = _mm512_setzero_si512(); + let mut or = _mm512_setzero_si512(); + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut n = lhs.len(); + while n >= 8 { + let x = _mm512_loadu_si512(a.cast()); + let y = _mm512_loadu_si512(b.cast()); + a = a.add(8); + b = b.add(8); + n -= 8; + and = _mm512_add_epi64(and, _mm512_popcnt_epi64(_mm512_and_si512(x, y))); + or = _mm512_add_epi64(or, _mm512_popcnt_epi64(_mm512_or_si512(x, y))); + } + if n > 0 { + let mask = _bzhi_u32(0xff, n as u32) as u8; + let x = _mm512_maskz_loadu_epi64(mask, a.cast()); + let y = _mm512_maskz_loadu_epi64(mask, b.cast()); + and = _mm512_add_epi64(and, _mm512_popcnt_epi64(_mm512_and_si512(x, y))); + or = _mm512_add_epi64(or, _mm512_popcnt_epi64(_mm512_or_si512(x, y))); + } + ( + _mm512_reduce_add_epi64(and) as u32, + _mm512_reduce_add_epi64(or) as u32, + ) + } + } + + #[cfg(all(target_arch = "x86_64", test, not(miri)))] + #[test] + fn sum_of_xor_v4_avx512vpopcntdq_test() { + if !crate::is_cpu_detected!("v4") || !crate::is_feature_detected!("avx512vpopcntdq") { + println!("test {} ... skipped (v4:avx512vpopcntdq)", module_path!()); + return; + } + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let lhs = (0..126).map(|_| rand::random::()).collect::>(); + let rhs = (0..126).map(|_| rand::random::()).collect::>(); + let specialized = unsafe { sum_of_and_or_v4_avx512vpopcntdq(&lhs, &rhs) }; + let fallback = sum_of_and_or_fallback(&lhs, &rhs); + assert_eq!(specialized, fallback); + } + } + + #[crate::multiversion(@"v4:avx512vpopcntdq", "v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn sum_of_and_or(lhs: &[u64], rhs: &[u64]) -> (u32, u32) { + assert_eq!(lhs.len(), rhs.len()); + let n = lhs.len(); + let mut and = 0; + let mut or = 0; + for i in 0..n { + and += (lhs[i] & rhs[i]).count_ones(); + or += (lhs[i] | rhs[i]).count_ones(); + } + (and, or) + } +} + +#[inline(always)] +pub fn sum_of_x(this: &[u64]) -> u32 { + sum_of_x::sum_of_x(this) +} + +mod sum_of_x { + // FIXME: add manually-implemented SIMD version for AVX512 and AVX2 + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v4")] + #[target_feature(enable = "avx512vpopcntdq")] + fn sum_of_x_v4_avx512vpopcntdq(this: &[u64]) -> u32 { + unsafe { + use std::arch::x86_64::*; + let mut and = _mm512_setzero_si512(); + let mut a = this.as_ptr(); + let mut n = this.len(); + while n >= 8 { + let x = _mm512_loadu_si512(a.cast()); + a = a.add(8); + n -= 8; + and = _mm512_add_epi64(and, _mm512_popcnt_epi64(x)); + } + if n > 0 { + let mask = _bzhi_u32(0xff, n as u32) as u8; + let x = _mm512_maskz_loadu_epi64(mask, a.cast()); + and = _mm512_add_epi64(and, _mm512_popcnt_epi64(x)); + } + _mm512_reduce_add_epi64(and) as u32 + } + } + + #[cfg(all(target_arch = "x86_64", test, not(miri)))] + #[test] + fn sum_of_x_v4_avx512vpopcntdq_test() { + if !crate::is_cpu_detected!("v4") || !crate::is_feature_detected!("avx512vpopcntdq") { + println!("test {} ... skipped (v4:avx512vpopcntdq)", module_path!()); + return; + } + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let this = (0..126).map(|_| rand::random::()).collect::>(); + let specialized = unsafe { sum_of_x_v4_avx512vpopcntdq(&this) }; + let fallback = sum_of_x_fallback(&this); + assert_eq!(specialized, fallback); + } + } + + #[crate::multiversion(@"v4:avx512vpopcntdq", "v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn sum_of_x(this: &[u64]) -> u32 { + let n = this.len(); + let mut and = 0; + for i in 0..n { + and += this[i].count_ones(); + } + and + } +} + +#[inline(always)] +pub fn vector_and(lhs: &[u64], rhs: &[u64]) -> Vec { + vector_and::vector_and(lhs, rhs) +} + +mod vector_and { + #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_and(lhs: &[u64], rhs: &[u64]) -> Vec { + assert_eq!(lhs.len(), rhs.len()); + let n = lhs.len(); + let mut r = Vec::::with_capacity(n); + for i in 0..n { + unsafe { + r.as_mut_ptr().add(i).write(lhs[i] & rhs[i]); + } + } + unsafe { + r.set_len(n); + } + r + } +} + +#[inline(always)] +pub fn vector_or(lhs: &[u64], rhs: &[u64]) -> Vec { + vector_or::vector_or(lhs, rhs) +} + +mod vector_or { + #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_or(lhs: &[u64], rhs: &[u64]) -> Vec { + assert_eq!(lhs.len(), rhs.len()); + let n = lhs.len(); + let mut r = Vec::::with_capacity(n); + for i in 0..n { + unsafe { + r.as_mut_ptr().add(i).write(lhs[i] | rhs[i]); + } + } + unsafe { + r.set_len(n); + } + r + } +} + +#[inline(always)] +pub fn vector_xor(lhs: &[u64], rhs: &[u64]) -> Vec { + vector_xor::vector_xor(lhs, rhs) +} + +mod vector_xor { + #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_xor(lhs: &[u64], rhs: &[u64]) -> Vec { + assert_eq!(lhs.len(), rhs.len()); + let n = lhs.len(); + let mut r = Vec::::with_capacity(n); + for i in 0..n { + unsafe { + r.as_mut_ptr().add(i).write(lhs[i] ^ rhs[i]); + } + } + unsafe { + r.set_len(n); + } + r + } +} diff --git a/crates/simd/src/emulate.rs b/crates/simd/src/emulate.rs new file mode 100644 index 0000000..520b6a1 --- /dev/null +++ b/crates/simd/src/emulate.rs @@ -0,0 +1,208 @@ +// VP2INTERSECT emulation. +// Díez-Cañas, G. (2021). Faster-Than-Native Alternatives for x86 VP2INTERSECT +// Instructions. arXiv preprint arXiv:2112.06342. +#[inline] +#[cfg(target_arch = "x86_64")] +#[crate::target_cpu(enable = "v4")] +pub fn emulate_mm512_2intersect_epi32( + a: std::arch::x86_64::__m512i, + b: std::arch::x86_64::__m512i, +) -> (std::arch::x86_64::__mmask16, std::arch::x86_64::__mmask16) { + unsafe { + use std::arch::x86_64::*; + + let a1 = _mm512_alignr_epi32(a, a, 4); + let a2 = _mm512_alignr_epi32(a, a, 8); + let a3 = _mm512_alignr_epi32(a, a, 12); + let b1 = _mm512_shuffle_epi32(b, _MM_PERM_ADCB); + let b2 = _mm512_shuffle_epi32(b, _MM_PERM_BADC); + let b3 = _mm512_shuffle_epi32(b, _MM_PERM_CBAD); + let m00 = _mm512_cmpeq_epi32_mask(a, b); + let m01 = _mm512_cmpeq_epi32_mask(a, b1); + let m02 = _mm512_cmpeq_epi32_mask(a, b2); + let m03 = _mm512_cmpeq_epi32_mask(a, b3); + let m10 = _mm512_cmpeq_epi32_mask(a1, b); + let m11 = _mm512_cmpeq_epi32_mask(a1, b1); + let m12 = _mm512_cmpeq_epi32_mask(a1, b2); + let m13 = _mm512_cmpeq_epi32_mask(a1, b3); + let m20 = _mm512_cmpeq_epi32_mask(a2, b); + let m21 = _mm512_cmpeq_epi32_mask(a2, b1); + let m22 = _mm512_cmpeq_epi32_mask(a2, b2); + let m23 = _mm512_cmpeq_epi32_mask(a2, b3); + let m30 = _mm512_cmpeq_epi32_mask(a3, b); + let m31 = _mm512_cmpeq_epi32_mask(a3, b1); + let m32 = _mm512_cmpeq_epi32_mask(a3, b2); + let m33 = _mm512_cmpeq_epi32_mask(a3, b3); + + let m0 = m00 | m10 | m20 | m30; + let m1 = m01 | m11 | m21 | m31; + let m2 = m02 | m12 | m22 | m32; + let m3 = m03 | m13 | m23 | m33; + + let res_a = m00 + | m01 + | m02 + | m03 + | (m10 | m11 | m12 | m13).rotate_left(4) + | (m20 | m21 | m22 | m23).rotate_left(8) + | (m30 | m31 | m32 | m33).rotate_right(4); + + let res_b = m0 + | ((0x7777 & m1) << 1) + | ((m1 >> 3) & 0x1111) + | ((0x3333 & m2) << 2) + | ((m2 >> 2) & 0x3333) + | ((0x1111 & m3) << 3) + | ((m3 >> 1) & 0x7777); + (res_a, res_b) + } +} + +#[inline] +#[cfg(target_arch = "x86_64")] +#[crate::target_cpu(enable = "v3")] +pub fn emulate_mm256_reduce_add_ps(mut x: std::arch::x86_64::__m256) -> f32 { + unsafe { + use std::arch::x86_64::*; + x = _mm256_add_ps(x, _mm256_permute2f128_ps(x, x, 1)); + x = _mm256_hadd_ps(x, x); + x = _mm256_hadd_ps(x, x); + _mm256_cvtss_f32(x) + } +} + +#[inline] +#[cfg(target_arch = "x86_64")] +#[crate::target_cpu(enable = "v2")] +pub fn emulate_mm_reduce_add_ps(mut x: std::arch::x86_64::__m128) -> f32 { + unsafe { + use std::arch::x86_64::*; + x = _mm_hadd_ps(x, x); + x = _mm_hadd_ps(x, x); + _mm_cvtss_f32(x) + } +} + +#[inline] +#[cfg(target_arch = "x86_64")] +#[crate::target_cpu(enable = "v4")] +pub fn emulate_mm512_reduce_add_epi16(x: std::arch::x86_64::__m512i) -> i16 { + unsafe { + use std::arch::x86_64::*; + _mm256_reduce_add_epi16(_mm512_castsi512_si256(x)) + + _mm256_reduce_add_epi16(_mm512_extracti32x8_epi32(x, 1)) + } +} + +#[inline] +#[cfg(target_arch = "x86_64")] +#[crate::target_cpu(enable = "v3")] +pub fn emulate_mm256_reduce_add_epi16(mut x: std::arch::x86_64::__m256i) -> i16 { + unsafe { + use std::arch::x86_64::*; + x = _mm256_add_epi16(x, _mm256_permute2f128_si256(x, x, 1)); + x = _mm256_hadd_epi16(x, x); + x = _mm256_hadd_epi16(x, x); + let x = _mm256_cvtsi256_si32(x); + (x as i16) + ((x >> 16) as i16) + } +} + +#[inline] +#[cfg(target_arch = "x86_64")] +#[crate::target_cpu(enable = "v2")] +pub fn emulate_mm_reduce_add_epi16(mut x: std::arch::x86_64::__m128i) -> i16 { + unsafe { + use std::arch::x86_64::*; + x = _mm_hadd_epi16(x, x); + x = _mm_hadd_epi16(x, x); + let x = _mm_cvtsi128_si32(x); + (x as i16) + ((x >> 16) as i16) + } +} + +#[inline] +#[cfg(target_arch = "x86_64")] +#[crate::target_cpu(enable = "v3")] +pub fn emulate_mm256_reduce_add_epi32(mut x: std::arch::x86_64::__m256i) -> i32 { + unsafe { + use std::arch::x86_64::*; + x = _mm256_add_epi32(x, _mm256_permute2f128_si256(x, x, 1)); + x = _mm256_hadd_epi32(x, x); + x = _mm256_hadd_epi32(x, x); + _mm256_cvtsi256_si32(x) + } +} + +#[expect(dead_code)] +#[inline] +#[cfg(target_arch = "x86_64")] +#[crate::target_cpu(enable = "v2")] +pub fn emulate_mm_reduce_add_epi32(mut x: std::arch::x86_64::__m128i) -> i32 { + unsafe { + use std::arch::x86_64::*; + x = _mm_hadd_epi32(x, x); + x = _mm_hadd_epi32(x, x); + _mm_cvtsi128_si32(x) + } +} + +#[inline] +#[cfg(target_arch = "x86_64")] +#[crate::target_cpu(enable = "v3")] +pub fn emulate_mm256_reduce_min_ps(x: std::arch::x86_64::__m256) -> f32 { + use crate::aligned::Aligned16; + unsafe { + use std::arch::x86_64::*; + let lo = _mm256_castps256_ps128(x); + let hi = _mm256_extractf128_ps(x, 1); + let min = _mm_min_ps(lo, hi); + let mut x = Aligned16([0.0f32; 4]); + _mm_store_ps(x.0.as_mut_ptr(), min); + f32::min(f32::min(x.0[0], x.0[1]), f32::min(x.0[2], x.0[3])) + } +} + +#[inline] +#[cfg(target_arch = "x86_64")] +#[crate::target_cpu(enable = "v2")] +pub fn emulate_mm_reduce_min_ps(x: std::arch::x86_64::__m128) -> f32 { + use crate::aligned::Aligned16; + unsafe { + use std::arch::x86_64::*; + let min = x; + let mut x = Aligned16([0.0f32; 4]); + _mm_store_ps(x.0.as_mut_ptr(), min); + f32::min(f32::min(x.0[0], x.0[1]), f32::min(x.0[2], x.0[3])) + } +} + +#[inline] +#[cfg(target_arch = "x86_64")] +#[crate::target_cpu(enable = "v3")] +pub fn emulate_mm256_reduce_max_ps(x: std::arch::x86_64::__m256) -> f32 { + use crate::aligned::Aligned16; + unsafe { + use std::arch::x86_64::*; + let lo = _mm256_castps256_ps128(x); + let hi = _mm256_extractf128_ps(x, 1); + let max = _mm_max_ps(lo, hi); + let mut x = Aligned16([0.0f32; 4]); + _mm_store_ps(x.0.as_mut_ptr(), max); + f32::max(f32::max(x.0[0], x.0[1]), f32::max(x.0[2], x.0[3])) + } +} + +#[inline] +#[cfg(target_arch = "x86_64")] +#[crate::target_cpu(enable = "v2")] +pub fn emulate_mm_reduce_max_ps(x: std::arch::x86_64::__m128) -> f32 { + use crate::aligned::Aligned16; + unsafe { + use std::arch::x86_64::*; + let max = x; + let mut x = Aligned16([0.0f32; 4]); + _mm_store_ps(x.0.as_mut_ptr(), max); + f32::max(f32::max(x.0[0], x.0[1]), f32::max(x.0[2], x.0[3])) + } +} diff --git a/crates/simd/src/f16.rs b/crates/simd/src/f16.rs new file mode 100644 index 0000000..53bddb7 --- /dev/null +++ b/crates/simd/src/f16.rs @@ -0,0 +1,1185 @@ +use crate::{Floating, f32}; +use half::f16; + +impl Floating for f16 { + #[inline(always)] + fn zero() -> Self { + f16::ZERO + } + + #[inline(always)] + fn infinity() -> Self { + f16::INFINITY + } + + #[inline(always)] + fn mask(self, m: bool) -> Self { + f16::from_bits(self.to_bits() & (m as u16).wrapping_neg()) + } + + #[inline(always)] + fn scalar_neg(this: Self) -> Self { + -this + } + + #[inline(always)] + fn scalar_add(lhs: Self, rhs: Self) -> Self { + lhs + rhs + } + + #[inline(always)] + fn scalar_sub(lhs: Self, rhs: Self) -> Self { + lhs - rhs + } + + #[inline(always)] + fn scalar_mul(lhs: Self, rhs: Self) -> Self { + lhs * rhs + } + + #[inline(always)] + fn reduce_or_of_is_zero_x(this: &[f16]) -> bool { + reduce_or_of_is_zero_x::reduce_or_of_is_zero_x(this) + } + + #[inline(always)] + fn reduce_sum_of_x(this: &[f16]) -> f32 { + reduce_sum_of_x::reduce_sum_of_x(this) + } + + #[inline(always)] + fn reduce_sum_of_abs_x(this: &[f16]) -> f32 { + reduce_sum_of_abs_x::reduce_sum_of_abs_x(this) + } + + #[inline(always)] + fn reduce_sum_of_x2(this: &[f16]) -> f32 { + reduce_sum_of_x2::reduce_sum_of_x2(this) + } + + #[inline(always)] + fn reduce_min_max_of_x(this: &[f16]) -> (f32, f32) { + reduce_min_max_of_x::reduce_min_max_of_x(this) + } + + #[inline(always)] + fn reduce_sum_of_xy(lhs: &[Self], rhs: &[Self]) -> f32 { + reduce_sum_of_xy::reduce_sum_of_xy(lhs, rhs) + } + + #[inline(always)] + fn reduce_sum_of_d2(lhs: &[f16], rhs: &[f16]) -> f32 { + reduce_sum_of_d2::reduce_sum_of_d2(lhs, rhs) + } + + #[inline(always)] + fn reduce_sum_of_xy_sparse(lidx: &[u32], lval: &[f16], ridx: &[u32], rval: &[f16]) -> f32 { + reduce_sum_of_xy_sparse::reduce_sum_of_xy_sparse(lidx, lval, ridx, rval) + } + + #[inline(always)] + fn reduce_sum_of_d2_sparse(lidx: &[u32], lval: &[f16], ridx: &[u32], rval: &[f16]) -> f32 { + reduce_sum_of_d2_sparse::reduce_sum_of_d2_sparse(lidx, lval, ridx, rval) + } + + #[inline(always)] + fn vector_from_f32(this: &[f32]) -> Vec { + vector_from_f32::vector_from_f32(this) + } + + #[inline(always)] + fn vector_to_f32(this: &[Self]) -> Vec { + vector_to_f32::vector_to_f32(this) + } + + #[inline(always)] + fn vector_add(lhs: &[Self], rhs: &[Self]) -> Vec { + vector_add::vector_add(lhs, rhs) + } + + #[inline(always)] + fn vector_add_inplace(lhs: &mut [Self], rhs: &[Self]) { + vector_add_inplace::vector_add_inplace(lhs, rhs) + } + + #[inline(always)] + fn vector_sub(lhs: &[Self], rhs: &[Self]) -> Vec { + vector_sub::vector_sub(lhs, rhs) + } + + #[inline(always)] + fn vector_mul(lhs: &[Self], rhs: &[Self]) -> Vec { + vector_mul::vector_mul(lhs, rhs) + } + + #[inline(always)] + fn vector_mul_scalar(lhs: &[Self], rhs: f32) -> Vec { + vector_mul_scalar::vector_mul_scalar(lhs, rhs) + } + + #[inline(always)] + fn vector_mul_scalar_inplace(lhs: &mut [Self], rhs: f32) { + vector_mul_scalar_inplace::vector_mul_scalar_inplace(lhs, rhs) + } + + #[inline(always)] + fn vector_to_f32_borrowed(this: &[Self]) -> impl AsRef<[f32]> { + Self::vector_to_f32(this) + } +} + +mod reduce_or_of_is_zero_x { + // FIXME: add manually-implemented SIMD version + + use half::f16; + + #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn reduce_or_of_is_zero_x(this: &[f16]) -> bool { + for &x in this { + if x == f16::ZERO { + return true; + } + } + false + } +} + +mod reduce_sum_of_x { + // FIXME: add manually-implemented SIMD version + + use half::f16; + + #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn reduce_sum_of_x(this: &[f16]) -> f32 { + let n = this.len(); + let mut x = 0.0f32; + for i in 0..n { + x += this[i].to_f32(); + } + x + } +} + +mod reduce_sum_of_abs_x { + // FIXME: add manually-implemented SIMD version + + use half::f16; + + #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn reduce_sum_of_abs_x(this: &[f16]) -> f32 { + let n = this.len(); + let mut x = 0.0f32; + for i in 0..n { + x += this[i].to_f32().abs(); + } + x + } +} + +mod reduce_sum_of_x2 { + // FIXME: add manually-implemented SIMD version + + use half::f16; + + #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn reduce_sum_of_x2(this: &[f16]) -> f32 { + let n = this.len(); + let mut x2 = 0.0f32; + for i in 0..n { + x2 += this[i].to_f32() * this[i].to_f32(); + } + x2 + } +} + +mod reduce_min_max_of_x { + // FIXME: add manually-implemented SIMD version + + use half::f16; + + #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn reduce_min_max_of_x(this: &[f16]) -> (f32, f32) { + let mut min = f32::INFINITY; + let mut max = f32::NEG_INFINITY; + let n = this.len(); + for i in 0..n { + min = min.min(this[i].to_f32()); + max = max.max(this[i].to_f32()); + } + (min, max) + } +} + +mod reduce_sum_of_xy { + use half::f16; + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v4")] + #[target_feature(enable = "avx512fp16")] + pub fn reduce_sum_of_xy_v4_avx512fp16(lhs: &[f16], rhs: &[f16]) -> f32 { + assert!(lhs.len() == rhs.len()); + unsafe { + use std::arch::x86_64::*; + let mut n = lhs.len(); + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut xy = _mm512_setzero_ph(); + while n >= 32 { + let x = _mm512_loadu_ph(a.cast()); + let y = _mm512_loadu_ph(b.cast()); + a = a.add(32); + b = b.add(32); + n -= 32; + xy = _mm512_fmadd_ph(x, y, xy); + } + if n > 0 { + let mask = _bzhi_u32(0xffffffff, n as u32); + let x = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a.cast())); + let y = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b.cast())); + xy = _mm512_fmadd_ph(x, y, xy); + } + _mm512_reduce_add_ph(xy) as f32 + } + } + + #[cfg(all(target_arch = "x86_64", test, not(miri)))] + #[test] + fn reduce_sum_of_xy_v4_avx512fp16_test() { + use rand::Rng; + const EPSILON: f32 = 2.0; + if !crate::is_cpu_detected!("v4") || !crate::is_feature_detected!("avx512fp16") { + println!("test {} ... skipped (v4:avx512fp16)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let lhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + let rhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_xy_v4_avx512fp16(lhs, rhs) }; + let fallback = reduce_sum_of_xy_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v4")] + pub fn reduce_sum_of_xy_v4(lhs: &[f16], rhs: &[f16]) -> f32 { + assert!(lhs.len() == rhs.len()); + unsafe { + use std::arch::x86_64::*; + let mut n = lhs.len(); + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut xy = _mm512_setzero_ps(); + while n >= 16 { + let x = _mm512_cvtph_ps(_mm256_loadu_epi16(a.cast())); + let y = _mm512_cvtph_ps(_mm256_loadu_epi16(b.cast())); + a = a.add(16); + b = b.add(16); + n -= 16; + xy = _mm512_fmadd_ps(x, y, xy); + } + if n > 0 { + let mask = _bzhi_u32(0xffff, n as u32) as u16; + let x = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, a.cast())); + let y = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, b.cast())); + xy = _mm512_fmadd_ps(x, y, xy); + } + _mm512_reduce_add_ps(xy) + } + } + + #[cfg(all(target_arch = "x86_64", test, not(miri)))] + #[test] + fn reduce_sum_of_xy_v4_test() { + use rand::Rng; + const EPSILON: f32 = 2.0; + if !crate::is_cpu_detected!("v4") { + println!("test {} ... skipped (v4)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let lhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + let rhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + let specialized = unsafe { reduce_sum_of_xy_v4(&lhs, &rhs) }; + let fallback = reduce_sum_of_xy_fallback(&lhs, &rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v3")] + pub fn reduce_sum_of_xy_v3(lhs: &[f16], rhs: &[f16]) -> f32 { + use crate::emulate::emulate_mm256_reduce_add_ps; + assert!(lhs.len() == rhs.len()); + unsafe { + use std::arch::x86_64::*; + let mut n = lhs.len(); + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut xy = _mm256_setzero_ps(); + while n >= 8 { + let x = _mm256_cvtph_ps(_mm_loadu_si128(a.cast())); + let y = _mm256_cvtph_ps(_mm_loadu_si128(b.cast())); + a = a.add(8); + b = b.add(8); + n -= 8; + xy = _mm256_fmadd_ps(x, y, xy); + } + let mut xy = emulate_mm256_reduce_add_ps(xy); + while n > 0 { + let x = a.read().to_f32(); + let y = b.read().to_f32(); + a = a.add(1); + b = b.add(1); + n -= 1; + xy += x * y; + } + xy + } + } + + #[cfg(all(target_arch = "x86_64", test, not(miri)))] + #[test] + fn reduce_sum_of_xy_v3_test() { + use rand::Rng; + const EPSILON: f32 = 2.0; + if !crate::is_cpu_detected!("v3") { + println!("test {} ... skipped (v3)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let lhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + let rhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_xy_v3(lhs, rhs) }; + let fallback = reduce_sum_of_xy_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v2")] + #[target_feature(enable = "f16c")] + #[target_feature(enable = "fma")] + pub fn reduce_sum_of_xy_v2_f16c_fma(lhs: &[f16], rhs: &[f16]) -> f32 { + use crate::emulate::emulate_mm_reduce_add_ps; + assert!(lhs.len() == rhs.len()); + unsafe { + use std::arch::x86_64::*; + let mut n = lhs.len(); + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut xy = _mm_setzero_ps(); + while n >= 4 { + let x = _mm_cvtph_ps(_mm_loadu_si128(a.cast())); + let y = _mm_cvtph_ps(_mm_loadu_si128(b.cast())); + a = a.add(4); + b = b.add(4); + n -= 4; + xy = _mm_fmadd_ps(x, y, xy); + } + let mut xy = emulate_mm_reduce_add_ps(xy); + while n > 0 { + let x = a.read().to_f32(); + let y = b.read().to_f32(); + a = a.add(1); + b = b.add(1); + n -= 1; + xy += x * y; + } + xy + } + } + + #[cfg(all(target_arch = "x86_64", test, not(miri)))] + #[test] + fn reduce_sum_of_xy_v2_f16c_fma_test() { + use rand::Rng; + const EPSILON: f32 = 2.0; + if !crate::is_cpu_detected!("v2") + || !crate::is_feature_detected!("f16c") + || !crate::is_feature_detected!("fma") + { + println!("test {} ... skipped (v2:f16c:fma)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let lhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + let rhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_xy_v2_f16c_fma(lhs, rhs) }; + let fallback = reduce_sum_of_xy_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::target_cpu(enable = "v8.3a")] + #[target_feature(enable = "fp16")] + pub fn reduce_sum_of_xy_v8_3a_fp16(lhs: &[f16], rhs: &[f16]) -> f32 { + assert!(lhs.len() == rhs.len()); + unsafe { + extern "C" { + fn fp16_reduce_sum_of_xy_v8_3a_fp16_unroll( + a: *const (), + b: *const (), + n: usize, + ) -> f32; + } + fp16_reduce_sum_of_xy_v8_3a_fp16_unroll( + lhs.as_ptr().cast(), + rhs.as_ptr().cast(), + lhs.len(), + ) + } + } + + #[cfg(all(target_arch = "aarch64", test, not(miri)))] + #[test] + fn reduce_sum_of_xy_v8_3a_fp16_test() { + use rand::Rng; + const EPSILON: f32 = 2.0; + if !crate::is_cpu_detected!("v8.3a") || !crate::is_feature_detected!("fp16") { + println!("test {} ... skipped (v8.3a:fp16)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let lhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + let rhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_xy_v8_3a_fp16(lhs, rhs) }; + let fallback = reduce_sum_of_xy_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + // temporarily disables this for uncertain precision + #[expect(dead_code)] + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::target_cpu(enable = "v8.3a")] + #[target_feature(enable = "sve")] + pub fn reduce_sum_of_xy_v8_3a_sve(lhs: &[f16], rhs: &[f16]) -> f32 { + assert!(lhs.len() == rhs.len()); + unsafe { + extern "C" { + fn fp16_reduce_sum_of_xy_v8_3a_sve(a: *const (), b: *const (), n: usize) -> f32; + } + fp16_reduce_sum_of_xy_v8_3a_sve(lhs.as_ptr().cast(), rhs.as_ptr().cast(), lhs.len()) + } + } + + #[cfg(all(target_arch = "aarch64", test, not(miri)))] + #[test] + fn reduce_sum_of_xy_v8_3a_sve_test() { + use rand::Rng; + const EPSILON: f32 = 2.0; + if !crate::is_cpu_detected!("v8.3a") || !crate::is_feature_detected!("sve") { + println!("test {} ... skipped (v8.3a:sve)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let lhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + let rhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_xy_v8_3a_sve(lhs, rhs) }; + let fallback = reduce_sum_of_xy_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[crate::multiversion(@"v4:avx512fp16", @"v4", @"v3", @"v2:f16c:fma", @"v8.3a:fp16")] + pub fn reduce_sum_of_xy(lhs: &[f16], rhs: &[f16]) -> f32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut xy = 0.0f32; + for i in 0..n { + xy += lhs[i].to_f32() * rhs[i].to_f32(); + } + xy + } +} + +mod reduce_sum_of_d2 { + use half::f16; + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v4")] + #[target_feature(enable = "avx512fp16")] + pub fn reduce_sum_of_d2_v4_avx512fp16(lhs: &[f16], rhs: &[f16]) -> f32 { + assert!(lhs.len() == rhs.len()); + unsafe { + use std::arch::x86_64::*; + let mut n = lhs.len() as u32; + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut d2 = _mm512_setzero_ph(); + while n >= 32 { + let x = _mm512_loadu_ph(a.cast()); + let y = _mm512_loadu_ph(b.cast()); + a = a.add(32); + b = b.add(32); + n -= 32; + let d = _mm512_sub_ph(x, y); + d2 = _mm512_fmadd_ph(d, d, d2); + } + if n > 0 { + let mask = _bzhi_u32(0xffffffff, n); + let x = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a.cast())); + let y = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b.cast())); + let d = _mm512_sub_ph(x, y); + d2 = _mm512_fmadd_ph(d, d, d2); + } + _mm512_reduce_add_ph(d2) as f32 + } + } + + #[cfg(all(target_arch = "x86_64", test, not(miri)))] + #[test] + fn reduce_sum_of_d2_v4_avx512fp16_test() { + use rand::Rng; + const EPSILON: f32 = 6.0; + if !crate::is_cpu_detected!("v4") || !crate::is_feature_detected!("avx512fp16") { + println!("test {} ... skipped (v4:avx512fp16)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let lhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + let rhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_d2_v4_avx512fp16(lhs, rhs) }; + let fallback = reduce_sum_of_d2_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v4")] + pub fn reduce_sum_of_d2_v4(lhs: &[f16], rhs: &[f16]) -> f32 { + assert!(lhs.len() == rhs.len()); + unsafe { + use std::arch::x86_64::*; + let mut n = lhs.len() as u32; + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut d2 = _mm512_setzero_ps(); + while n >= 16 { + let x = _mm512_cvtph_ps(_mm256_loadu_epi16(a.cast())); + let y = _mm512_cvtph_ps(_mm256_loadu_epi16(b.cast())); + a = a.add(16); + b = b.add(16); + n -= 16; + let d = _mm512_sub_ps(x, y); + d2 = _mm512_fmadd_ps(d, d, d2); + } + if n > 0 { + let mask = _bzhi_u32(0xffff, n) as u16; + let x = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, a.cast())); + let y = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, b.cast())); + let d = _mm512_sub_ps(x, y); + d2 = _mm512_fmadd_ps(d, d, d2); + } + _mm512_reduce_add_ps(d2) + } + } + + #[cfg(all(target_arch = "x86_64", test, not(miri)))] + #[test] + fn reduce_sum_of_d2_v4_test() { + use rand::Rng; + const EPSILON: f32 = 2.0; + if !crate::is_cpu_detected!("v4") { + println!("test {} ... skipped (v4)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let lhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + let rhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_d2_v4(lhs, rhs) }; + let fallback = reduce_sum_of_d2_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v3")] + pub fn reduce_sum_of_d2_v3(lhs: &[f16], rhs: &[f16]) -> f32 { + use crate::emulate::emulate_mm256_reduce_add_ps; + assert!(lhs.len() == rhs.len()); + unsafe { + use std::arch::x86_64::*; + let mut n = lhs.len() as u32; + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut d2 = _mm256_setzero_ps(); + while n >= 8 { + let x = _mm256_cvtph_ps(_mm_loadu_si128(a.cast())); + let y = _mm256_cvtph_ps(_mm_loadu_si128(b.cast())); + a = a.add(8); + b = b.add(8); + n -= 8; + let d = _mm256_sub_ps(x, y); + d2 = _mm256_fmadd_ps(d, d, d2); + } + let mut d2 = emulate_mm256_reduce_add_ps(d2); + while n > 0 { + let x = a.read().to_f32(); + let y = b.read().to_f32(); + a = a.add(1); + b = b.add(1); + n -= 1; + let d = x - y; + d2 += d * d; + } + d2 + } + } + + #[cfg(all(target_arch = "x86_64", test, not(miri)))] + #[test] + fn reduce_sum_of_d2_v3_test() { + use rand::Rng; + const EPSILON: f32 = 2.0; + if !crate::is_cpu_detected!("v3") { + println!("test {} ... skipped (v3)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let lhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + let rhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_d2_v3(lhs, rhs) }; + let fallback = reduce_sum_of_d2_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v2")] + #[target_feature(enable = "f16c")] + #[target_feature(enable = "fma")] + pub fn reduce_sum_of_d2_v2_f16c_fma(lhs: &[f16], rhs: &[f16]) -> f32 { + use crate::emulate::emulate_mm_reduce_add_ps; + assert!(lhs.len() == rhs.len()); + unsafe { + use std::arch::x86_64::*; + let mut n = lhs.len() as u32; + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut d2 = _mm_setzero_ps(); + while n >= 4 { + let x = _mm_cvtph_ps(_mm_loadu_si128(a.cast())); + let y = _mm_cvtph_ps(_mm_loadu_si128(b.cast())); + a = a.add(4); + b = b.add(4); + n -= 4; + let d = _mm_sub_ps(x, y); + d2 = _mm_fmadd_ps(d, d, d2); + } + let mut d2 = emulate_mm_reduce_add_ps(d2); + while n > 0 { + let x = a.read().to_f32(); + let y = b.read().to_f32(); + a = a.add(1); + b = b.add(1); + n -= 1; + let d = x - y; + d2 += d * d; + } + d2 + } + } + + #[cfg(all(target_arch = "x86_64", test, not(miri)))] + #[test] + fn reduce_sum_of_d2_v2_f16c_fma_test() { + use rand::Rng; + const EPSILON: f32 = 2.0; + if !crate::is_cpu_detected!("v2") + || !crate::is_feature_detected!("f16c") + || !crate::is_feature_detected!("fma") + { + println!("test {} ... skipped (v2:f16c:fma)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let lhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + let rhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_d2_v2_f16c_fma(lhs, rhs) }; + let fallback = reduce_sum_of_d2_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::target_cpu(enable = "v8.3a")] + #[target_feature(enable = "fp16")] + pub fn reduce_sum_of_d2_v8_3a_fp16(lhs: &[f16], rhs: &[f16]) -> f32 { + assert!(lhs.len() == rhs.len()); + unsafe { + extern "C" { + fn fp16_reduce_sum_of_d2_v8_3a_fp16_unroll( + a: *const (), + b: *const (), + n: usize, + ) -> f32; + } + fp16_reduce_sum_of_d2_v8_3a_fp16_unroll( + lhs.as_ptr().cast(), + rhs.as_ptr().cast(), + lhs.len(), + ) + } + } + + #[cfg(all(target_arch = "aarch64", test, not(miri)))] + #[test] + fn reduce_sum_of_d2_v8_3a_fp16_test() { + use rand::Rng; + const EPSILON: f32 = 6.0; + if !crate::is_cpu_detected!("v8.3a") || !crate::is_feature_detected!("fp16") { + println!("test {} ... skipped (v8.3a:fp16)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let lhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + let rhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_d2_v8_3a_fp16(lhs, rhs) }; + let fallback = reduce_sum_of_d2_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + // temporarily disables this for uncertain precision + #[expect(dead_code)] + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::target_cpu(enable = "v8.3a")] + #[target_feature(enable = "sve")] + pub fn reduce_sum_of_d2_v8_3a_sve(lhs: &[f16], rhs: &[f16]) -> f32 { + assert!(lhs.len() == rhs.len()); + unsafe { + extern "C" { + fn fp16_reduce_sum_of_d2_v8_3a_sve(a: *const (), b: *const (), n: usize) -> f32; + } + fp16_reduce_sum_of_d2_v8_3a_sve(lhs.as_ptr().cast(), rhs.as_ptr().cast(), lhs.len()) + } + } + + #[cfg(all(target_arch = "aarch64", test, not(miri)))] + #[test] + fn reduce_sum_of_d2_v8_3a_sve_test() { + use rand::Rng; + const EPSILON: f32 = 6.0; + if !crate::is_cpu_detected!("v8.3a") || !crate::is_feature_detected!("sve") { + println!("test {} ... skipped (v8.3a:sve)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let lhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + let rhs = (0..n) + .map(|_| f16::from_f32(rng.gen_range(-1.0..=1.0))) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_d2_v8_3a_sve(lhs, rhs) }; + let fallback = reduce_sum_of_d2_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[crate::multiversion(@"v4:avx512fp16", @"v4", @"v3", @"v2:f16c:fma", @"v8.3a:fp16")] + pub fn reduce_sum_of_d2(lhs: &[f16], rhs: &[f16]) -> f32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut d2 = 0.0; + for i in 0..n { + let d = lhs[i].to_f32() - rhs[i].to_f32(); + d2 += d * d; + } + d2 + } +} + +mod reduce_sum_of_xy_sparse { + // There is no manually-implemented SIMD version. + // Add it if `svecf16` is supported. + + use half::f16; + + #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn reduce_sum_of_xy_sparse(lidx: &[u32], lval: &[f16], ridx: &[u32], rval: &[f16]) -> f32 { + use std::cmp::Ordering; + assert_eq!(lidx.len(), lval.len()); + assert_eq!(ridx.len(), rval.len()); + let (mut lp, ln) = (0, lidx.len()); + let (mut rp, rn) = (0, ridx.len()); + let mut xy = 0.0f32; + while lp < ln && rp < rn { + match Ord::cmp(&lidx[lp], &ridx[rp]) { + Ordering::Equal => { + xy += lval[lp].to_f32() * rval[rp].to_f32(); + lp += 1; + rp += 1; + } + Ordering::Less => { + lp += 1; + } + Ordering::Greater => { + rp += 1; + } + } + } + xy + } +} + +mod reduce_sum_of_d2_sparse { + // There is no manually-implemented SIMD version. + // Add it if `svecf16` is supported. + + use half::f16; + + #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn reduce_sum_of_d2_sparse(lidx: &[u32], lval: &[f16], ridx: &[u32], rval: &[f16]) -> f32 { + use std::cmp::Ordering; + assert_eq!(lidx.len(), lval.len()); + assert_eq!(ridx.len(), rval.len()); + let (mut lp, ln) = (0, lidx.len()); + let (mut rp, rn) = (0, ridx.len()); + let mut d2 = 0.0f32; + while lp < ln && rp < rn { + match Ord::cmp(&lidx[lp], &ridx[rp]) { + Ordering::Equal => { + let d = lval[lp].to_f32() - rval[rp].to_f32(); + d2 += d * d; + lp += 1; + rp += 1; + } + Ordering::Less => { + d2 += lval[lp].to_f32() * lval[lp].to_f32(); + lp += 1; + } + Ordering::Greater => { + d2 += rval[rp].to_f32() * rval[rp].to_f32(); + rp += 1; + } + } + } + for i in lp..ln { + d2 += lval[i].to_f32() * lval[i].to_f32(); + } + for i in rp..rn { + d2 += rval[i].to_f32() * rval[i].to_f32(); + } + d2 + } +} + +mod vector_add { + use half::f16; + + #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_add(lhs: &[f16], rhs: &[f16]) -> Vec { + assert_eq!(lhs.len(), rhs.len()); + let n = lhs.len(); + let mut r = Vec::::with_capacity(n); + for i in 0..n { + unsafe { + r.as_mut_ptr().add(i).write(lhs[i] + rhs[i]); + } + } + unsafe { + r.set_len(n); + } + r + } +} + +mod vector_add_inplace { + use half::f16; + + #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_add_inplace(lhs: &mut [f16], rhs: &[f16]) { + assert_eq!(lhs.len(), rhs.len()); + let n = lhs.len(); + for i in 0..n { + lhs[i] += rhs[i]; + } + } +} + +mod vector_sub { + use half::f16; + + #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_sub(lhs: &[f16], rhs: &[f16]) -> Vec { + assert_eq!(lhs.len(), rhs.len()); + let n = lhs.len(); + let mut r = Vec::::with_capacity(n); + for i in 0..n { + unsafe { + r.as_mut_ptr().add(i).write(lhs[i] - rhs[i]); + } + } + unsafe { + r.set_len(n); + } + r + } +} + +mod vector_mul { + use half::f16; + + #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_mul(lhs: &[f16], rhs: &[f16]) -> Vec { + assert_eq!(lhs.len(), rhs.len()); + let n = lhs.len(); + let mut r = Vec::::with_capacity(n); + for i in 0..n { + unsafe { + r.as_mut_ptr().add(i).write(lhs[i] * rhs[i]); + } + } + unsafe { + r.set_len(n); + } + r + } +} + +mod vector_mul_scalar { + use half::f16; + + #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_mul_scalar(lhs: &[f16], rhs: f32) -> Vec { + let rhs = f16::from_f32(rhs); + let n = lhs.len(); + let mut r = Vec::::with_capacity(n); + for i in 0..n { + unsafe { + r.as_mut_ptr().add(i).write(lhs[i] * rhs); + } + } + unsafe { + r.set_len(n); + } + r + } +} + +mod vector_mul_scalar_inplace { + use half::f16; + + #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_mul_scalar_inplace(lhs: &mut [f16], rhs: f32) { + let rhs = f16::from_f32(rhs); + let n = lhs.len(); + for i in 0..n { + lhs[i] *= rhs; + } + } +} + +mod vector_abs_inplace { + use half::f16; + + #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_abs_inplace(this: &mut [f16]) { + let n = this.len(); + for i in 0..n { + this[i] = f16::from_f32(this[i].to_f32().abs()); + } + } +} + +mod vector_from_f32 { + use half::f16; + + #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_from_f32(this: &[f32]) -> Vec { + let n = this.len(); + let mut r = Vec::::with_capacity(n); + for i in 0..n { + unsafe { + r.as_mut_ptr().add(i).write(f16::from_f32(this[i])); + } + } + unsafe { + r.set_len(n); + } + r + } +} + +mod vector_to_f32 { + use half::f16; + + #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_to_f32(this: &[f16]) -> Vec { + let n = this.len(); + let mut r = Vec::::with_capacity(n); + for i in 0..n { + unsafe { + r.as_mut_ptr().add(i).write(this[i].to_f32()); + } + } + unsafe { + r.set_len(n); + } + r + } +} diff --git a/crates/simd/src/f32.rs b/crates/simd/src/f32.rs new file mode 100644 index 0000000..555da5c --- /dev/null +++ b/crates/simd/src/f32.rs @@ -0,0 +1,2277 @@ +use crate::Floating; + +impl Floating for f32 { + #[inline(always)] + fn zero() -> Self { + 0.0f32 + } + + #[inline(always)] + fn infinity() -> Self { + f32::INFINITY + } + + #[inline(always)] + fn mask(self, m: bool) -> Self { + f32::from_bits(self.to_bits() & (m as u32).wrapping_neg()) + } + + #[inline(always)] + fn scalar_neg(this: Self) -> Self { + -this + } + + #[inline(always)] + fn scalar_add(lhs: Self, rhs: Self) -> Self { + lhs + rhs + } + + #[inline(always)] + fn scalar_sub(lhs: Self, rhs: Self) -> Self { + lhs - rhs + } + + #[inline(always)] + fn scalar_mul(lhs: Self, rhs: Self) -> Self { + lhs * rhs + } + + #[inline(always)] + fn reduce_or_of_is_zero_x(this: &[f32]) -> bool { + reduce_or_of_is_zero_x::reduce_or_of_is_zero_x(this) + } + + #[inline(always)] + fn reduce_sum_of_x(this: &[f32]) -> f32 { + reduce_sum_of_x::reduce_sum_of_x(this) + } + + #[inline(always)] + fn reduce_sum_of_abs_x(this: &[f32]) -> f32 { + reduce_sum_of_abs_x::reduce_sum_of_abs_x(this) + } + + #[inline(always)] + fn reduce_sum_of_x2(this: &[f32]) -> f32 { + reduce_sum_of_x2::reduce_sum_of_x2(this) + } + + #[inline(always)] + fn reduce_min_max_of_x(this: &[f32]) -> (f32, f32) { + reduce_min_max_of_x::reduce_min_max_of_x(this) + } + + #[inline(always)] + fn reduce_sum_of_xy(lhs: &[Self], rhs: &[Self]) -> f32 { + reduce_sum_of_xy::reduce_sum_of_xy(lhs, rhs) + } + + #[inline(always)] + fn reduce_sum_of_d2(lhs: &[Self], rhs: &[Self]) -> f32 { + reduce_sum_of_d2::reduce_sum_of_d2(lhs, rhs) + } + + #[inline(always)] + fn reduce_sum_of_xy_sparse(lidx: &[u32], lval: &[f32], ridx: &[u32], rval: &[f32]) -> f32 { + reduce_sum_of_xy_sparse::reduce_sum_of_xy_sparse(lidx, lval, ridx, rval) + } + + #[inline(always)] + fn reduce_sum_of_d2_sparse(lidx: &[u32], lval: &[f32], ridx: &[u32], rval: &[f32]) -> f32 { + reduce_sum_of_d2_sparse::reduce_sum_of_d2_sparse(lidx, lval, ridx, rval) + } + + fn vector_add(lhs: &[f32], rhs: &[f32]) -> Vec { + vector_add::vector_add(lhs, rhs) + } + + fn vector_add_inplace(lhs: &mut [f32], rhs: &[f32]) { + vector_add_inplace::vector_add_inplace(lhs, rhs) + } + + fn vector_sub(lhs: &[f32], rhs: &[f32]) -> Vec { + vector_sub::vector_sub(lhs, rhs) + } + + fn vector_mul(lhs: &[f32], rhs: &[f32]) -> Vec { + vector_mul::vector_mul(lhs, rhs) + } + + fn vector_mul_scalar(lhs: &[f32], rhs: f32) -> Vec { + vector_mul_scalar::vector_mul_scalar(lhs, rhs) + } + + fn vector_mul_scalar_inplace(lhs: &mut [f32], rhs: f32) { + vector_mul_scalar_inplace::vector_mul_scalar_inplace(lhs, rhs); + } + + fn vector_from_f32(this: &[f32]) -> Vec { + this.to_vec() + } + + fn vector_to_f32(this: &[f32]) -> Vec { + this.to_vec() + } + + fn vector_to_f32_borrowed(this: &[Self]) -> impl AsRef<[f32]> { + this + } +} + +mod reduce_or_of_is_zero_x { + // FIXME: add manually-implemented SIMD version + + #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn reduce_or_of_is_zero_x(this: &[f32]) -> bool { + for &x in this { + if x == 0.0f32 { + return true; + } + } + false + } +} + +mod reduce_sum_of_x { + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v4")] + fn reduce_sum_of_x_v4(this: &[f32]) -> f32 { + unsafe { + use std::arch::x86_64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut sum = _mm512_setzero_ps(); + while n >= 16 { + let x = _mm512_loadu_ps(a); + a = a.add(16); + n -= 16; + sum = _mm512_add_ps(x, sum); + } + if n > 0 { + let mask = _bzhi_u32(0xffff, n as u32) as u16; + let x = _mm512_maskz_loadu_ps(mask, a); + sum = _mm512_add_ps(x, sum); + } + _mm512_reduce_add_ps(sum) + } + } + + #[cfg(all(target_arch = "x86_64", test, not(miri)))] + #[test] + fn reduce_sum_of_x_v4_test() { + use rand::Rng; + const EPSILON: f32 = 0.008; + if !crate::is_cpu_detected!("v4") { + println!("test {} ... skipped (v4)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x_v4(this) }; + let fallback = reduce_sum_of_x_fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v3")] + fn reduce_sum_of_x_v3(this: &[f32]) -> f32 { + use crate::emulate::emulate_mm256_reduce_add_ps; + unsafe { + use std::arch::x86_64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut sum = _mm256_setzero_ps(); + while n >= 8 { + let x = _mm256_loadu_ps(a); + a = a.add(8); + n -= 8; + sum = _mm256_add_ps(x, sum); + } + if n >= 4 { + let x = _mm256_zextps128_ps256(_mm_loadu_ps(a)); + a = a.add(4); + n -= 4; + sum = _mm256_add_ps(x, sum); + } + let mut sum = emulate_mm256_reduce_add_ps(sum); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + a = a.add(1); + n -= 1; + sum += x; + } + sum + } + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + fn reduce_sum_of_x_v3_test() { + use rand::Rng; + const EPSILON: f32 = 0.008; + if !crate::is_cpu_detected!("v3") { + println!("test {} ... skipped (v3)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x_v3(this) }; + let fallback = reduce_sum_of_x_fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v2")] + fn reduce_sum_of_x_v2(this: &[f32]) -> f32 { + use crate::emulate::emulate_mm_reduce_add_ps; + unsafe { + use std::arch::x86_64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut sum = _mm_setzero_ps(); + while n >= 4 { + let x = _mm_loadu_ps(a); + a = a.add(4); + n -= 4; + sum = _mm_add_ps(x, sum); + } + let mut sum = emulate_mm_reduce_add_ps(sum); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + a = a.add(1); + n -= 1; + sum += x; + } + sum + } + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + fn reduce_sum_of_x_v2_test() { + use rand::Rng; + const EPSILON: f32 = 0.008; + if !crate::is_cpu_detected!("v2") { + println!("test {} ... skipped (v2)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x_v2(this) }; + let fallback = reduce_sum_of_x_fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::target_cpu(enable = "v8.3a")] + fn reduce_sum_of_x_v8_3a(this: &[f32]) -> f32 { + unsafe { + use std::arch::aarch64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut sum = vdupq_n_f32(0.0); + while n >= 4 { + let x = vld1q_f32(a); + a = a.add(4); + n -= 4; + sum = vaddq_f32(x, sum); + } + let mut sum = vaddvq_f32(sum); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + a = a.add(1); + n -= 1; + sum += x; + } + sum + } + } + + #[cfg(all(target_arch = "aarch64", test, not(miri)))] + #[test] + fn reduce_sum_of_x_v8_3a_test() { + use rand::Rng; + const EPSILON: f32 = 0.008; + if !crate::is_cpu_detected!("v8.3a") { + println!("test {} ... skipped (v8_3a)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x_v8_3a(this) }; + let fallback = reduce_sum_of_x_fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::target_cpu(enable = "v8.3a")] + #[target_feature(enable = "sve")] + fn reduce_sum_of_x_v8_3a_sve(this: &[f32]) -> f32 { + unsafe { + extern "C" { + fn fp32_reduce_sum_of_x_v8_3a_sve(this: *const f32, n: usize) -> f32; + } + fp32_reduce_sum_of_x_v8_3a_sve(this.as_ptr(), this.len()) + } + } + + #[cfg(all(target_arch = "aarch64", test, not(miri)))] + #[test] + fn reduce_sum_of_x_v8_3a_sve_test() { + use rand::Rng; + const EPSILON: f32 = 0.008; + if !crate::is_cpu_detected!("v8.3a") || !crate::is_feature_detected!("sve") { + println!("test {} ... skipped (v8_3a:sve)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x_v8_3a_sve(this) }; + let fallback = reduce_sum_of_x_fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[crate::multiversion(@"v4", @"v3", @"v2", @"v8.3a:sve", @"v8.3a")] + pub fn reduce_sum_of_x(this: &[f32]) -> f32 { + let n = this.len(); + let mut sum = 0.0f32; + for i in 0..n { + sum += this[i]; + } + sum + } +} + +mod reduce_sum_of_abs_x { + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v4")] + fn reduce_sum_of_abs_x_v4(this: &[f32]) -> f32 { + unsafe { + use std::arch::x86_64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut sum = _mm512_setzero_ps(); + while n >= 16 { + let x = _mm512_loadu_ps(a); + let abs_x = _mm512_abs_ps(x); + a = a.add(16); + n -= 16; + sum = _mm512_add_ps(abs_x, sum); + } + if n > 0 { + let mask = _bzhi_u32(0xffff, n as u32) as u16; + let x = _mm512_maskz_loadu_ps(mask, a); + let abs_x = _mm512_abs_ps(x); + sum = _mm512_add_ps(abs_x, sum); + } + _mm512_reduce_add_ps(sum) + } + } + + #[cfg(all(target_arch = "x86_64", test, not(miri)))] + #[test] + fn reduce_sum_of_abs_x_v4_test() { + use rand::Rng; + const EPSILON: f32 = 0.008; + if !crate::is_cpu_detected!("v4") { + println!("test {} ... skipped (v4)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_abs_x_v4(this) }; + let fallback = reduce_sum_of_abs_x_fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v3")] + fn reduce_sum_of_abs_x_v3(this: &[f32]) -> f32 { + use crate::emulate::emulate_mm256_reduce_add_ps; + unsafe { + use std::arch::x86_64::*; + let abs = _mm256_castsi256_ps(_mm256_srli_epi32(_mm256_set1_epi32(-1), 1)); + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut sum = _mm256_setzero_ps(); + while n >= 8 { + let x = _mm256_loadu_ps(a); + let abs_x = _mm256_and_ps(abs, x); + a = a.add(8); + n -= 8; + sum = _mm256_add_ps(abs_x, sum); + } + if n >= 4 { + let x = _mm256_zextps128_ps256(_mm_loadu_ps(a)); + let abs_x = _mm256_and_ps(abs, x); + a = a.add(4); + n -= 4; + sum = _mm256_add_ps(abs_x, sum); + } + let mut sum = emulate_mm256_reduce_add_ps(sum); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + let abs_x = x.abs(); + a = a.add(1); + n -= 1; + sum += abs_x; + } + sum + } + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + fn reduce_sum_of_abs_x_v3_test() { + use rand::Rng; + const EPSILON: f32 = 0.008; + if !crate::is_cpu_detected!("v3") { + println!("test {} ... skipped (v3)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_abs_x_v3(this) }; + let fallback = reduce_sum_of_abs_x_fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v2")] + fn reduce_sum_of_abs_x_v2(this: &[f32]) -> f32 { + use crate::emulate::emulate_mm_reduce_add_ps; + unsafe { + use std::arch::x86_64::*; + let abs = _mm_castsi128_ps(_mm_srli_epi32(_mm_set1_epi32(-1), 1)); + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut sum = _mm_setzero_ps(); + while n >= 4 { + let x = _mm_loadu_ps(a); + let abs_x = _mm_and_ps(abs, x); + a = a.add(4); + n -= 4; + sum = _mm_add_ps(abs_x, sum); + } + let mut sum = emulate_mm_reduce_add_ps(sum); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + let abs_x = x.abs(); + a = a.add(1); + n -= 1; + sum += abs_x; + } + sum + } + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + fn reduce_sum_of_abs_x_v2_test() { + use rand::Rng; + const EPSILON: f32 = 0.008; + if !crate::is_cpu_detected!("v2") { + println!("test {} ... skipped (v2)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_abs_x_v2(this) }; + let fallback = reduce_sum_of_abs_x_fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::target_cpu(enable = "v8.3a")] + fn reduce_sum_of_abs_x_v8_3a(this: &[f32]) -> f32 { + unsafe { + use std::arch::aarch64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut sum = vdupq_n_f32(0.0); + while n >= 4 { + let x = vld1q_f32(a); + let abs_x = vabsq_f32(x); + a = a.add(4); + n -= 4; + sum = vaddq_f32(abs_x, sum); + } + let mut sum = vaddvq_f32(sum); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + let abs_x = x.abs(); + a = a.add(1); + n -= 1; + sum += abs_x; + } + sum + } + } + + #[cfg(all(target_arch = "aarch64", test, not(miri)))] + #[test] + fn reduce_sum_of_abs_x_v8_3a_test() { + use rand::Rng; + const EPSILON: f32 = 0.008; + if !crate::is_cpu_detected!("v8.3a") { + println!("test {} ... skipped (v8.3a)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_abs_x_v8_3a(this) }; + let fallback = reduce_sum_of_abs_x_fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::target_cpu(enable = "v8.3a")] + #[target_feature(enable = "sve")] + fn reduce_sum_of_abs_x_v8_3a_sve(this: &[f32]) -> f32 { + unsafe { + extern "C" { + fn fp32_reduce_sum_of_abs_x_v8_3a_sve(this: *const f32, n: usize) -> f32; + } + fp32_reduce_sum_of_abs_x_v8_3a_sve(this.as_ptr(), this.len()) + } + } + + #[cfg(all(target_arch = "aarch64", test, not(miri)))] + #[test] + fn reduce_sum_of_abs_x_v8_3a_sve_test() { + use rand::Rng; + const EPSILON: f32 = 0.008; + if !crate::is_cpu_detected!("v8.3a") || !crate::is_feature_detected!("sve") { + println!("test {} ... skipped (v8.3a:sve)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_abs_x_v8_3a_sve(this) }; + let fallback = reduce_sum_of_abs_x_fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[crate::multiversion(@"v4", @"v3", @"v2", @"v8.3a:sve", @"v8.3a")] + pub fn reduce_sum_of_abs_x(this: &[f32]) -> f32 { + let n = this.len(); + let mut sum = 0.0f32; + for i in 0..n { + sum += this[i].abs(); + } + sum + } +} + +mod reduce_sum_of_x2 { + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v4")] + fn reduce_sum_of_x2_v4(this: &[f32]) -> f32 { + unsafe { + use std::arch::x86_64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut x2 = _mm512_setzero_ps(); + while n >= 16 { + let x = _mm512_loadu_ps(a); + a = a.add(16); + n -= 16; + x2 = _mm512_fmadd_ps(x, x, x2); + } + if n > 0 { + let mask = _bzhi_u32(0xffff, n as u32) as u16; + let x = _mm512_maskz_loadu_ps(mask, a); + x2 = _mm512_fmadd_ps(x, x, x2); + } + _mm512_reduce_add_ps(x2) + } + } + + #[cfg(all(target_arch = "x86_64", test, not(miri)))] + #[test] + fn reduce_sum_of_x2_v4_test() { + use rand::Rng; + const EPSILON: f32 = 0.006; + if !crate::is_cpu_detected!("v4") { + println!("test {} ... skipped (v4)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x2_v4(this) }; + let fallback = reduce_sum_of_x2_fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v3")] + fn reduce_sum_of_x2_v3(this: &[f32]) -> f32 { + use crate::emulate::emulate_mm256_reduce_add_ps; + unsafe { + use std::arch::x86_64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut x2 = _mm256_setzero_ps(); + while n >= 8 { + let x = _mm256_loadu_ps(a); + a = a.add(8); + n -= 8; + x2 = _mm256_fmadd_ps(x, x, x2); + } + if n >= 4 { + let x = _mm256_zextps128_ps256(_mm_loadu_ps(a)); + a = a.add(4); + n -= 4; + x2 = _mm256_fmadd_ps(x, x, x2); + } + let mut x2 = emulate_mm256_reduce_add_ps(x2); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + a = a.add(1); + n -= 1; + x2 += x * x; + } + x2 + } + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + fn reduce_sum_of_x2_v3_test() { + use rand::Rng; + const EPSILON: f32 = 0.006; + if !crate::is_cpu_detected!("v3") { + println!("test {} ... skipped (v3)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x2_v3(this) }; + let fallback = reduce_sum_of_x2_fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v2")] + #[target_feature(enable = "fma")] + fn reduce_sum_of_x2_v2_fma(this: &[f32]) -> f32 { + use crate::emulate::emulate_mm_reduce_add_ps; + unsafe { + use std::arch::x86_64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut x2 = _mm_setzero_ps(); + while n >= 4 { + let x = _mm_loadu_ps(a); + a = a.add(4); + n -= 4; + x2 = _mm_fmadd_ps(x, x, x2); + } + let mut x2 = emulate_mm_reduce_add_ps(x2); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + a = a.add(1); + n -= 1; + x2 += x * x; + } + x2 + } + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + fn reduce_sum_of_x2_v2_fma_test() { + use rand::Rng; + const EPSILON: f32 = 0.006; + if !crate::is_cpu_detected!("v2") || !crate::is_feature_detected!("fma") { + println!("test {} ... skipped (v2:fma)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x2_v2_fma(this) }; + let fallback = reduce_sum_of_x2_fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::target_cpu(enable = "v8.3a")] + fn reduce_sum_of_x2_v8_3a(this: &[f32]) -> f32 { + unsafe { + use std::arch::aarch64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut x2 = vdupq_n_f32(0.0); + while n >= 4 { + let x = vld1q_f32(a); + a = a.add(4); + n -= 4; + x2 = vfmaq_f32(x2, x, x); + } + let mut x2 = vaddvq_f32(x2); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + a = a.add(1); + n -= 1; + x2 += x * x; + } + x2 + } + } + + #[cfg(all(target_arch = "aarch64", test, not(miri)))] + #[test] + fn reduce_sum_of_x2_v8_3a_test() { + use rand::Rng; + const EPSILON: f32 = 0.006; + if !crate::is_cpu_detected!("v8.3a") { + println!("test {} ... skipped (v8.3a)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x2_v8_3a(this) }; + let fallback = reduce_sum_of_x2_fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::target_cpu(enable = "v8.3a")] + #[target_feature(enable = "sve")] + fn reduce_sum_of_x2_v8_3a_sve(this: &[f32]) -> f32 { + unsafe { + extern "C" { + fn fp32_reduce_sum_of_x2_v8_3a_sve(this: *const f32, n: usize) -> f32; + } + fp32_reduce_sum_of_x2_v8_3a_sve(this.as_ptr(), this.len()) + } + } + + #[cfg(all(target_arch = "aarch64", test, not(miri)))] + #[test] + fn reduce_sum_of_x2_v8_3a_sve_test() { + use rand::Rng; + const EPSILON: f32 = 0.006; + if !crate::is_cpu_detected!("v8.3a") || !crate::is_feature_detected!("sve") { + println!("test {} ... skipped (v8.3a:sve)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x2_v8_3a_sve(this) }; + let fallback = reduce_sum_of_x2_fallback(this); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[crate::multiversion(@"v4", @"v3", @"v2:fma", @"v8.3a:sve", @"v8.3a")] + pub fn reduce_sum_of_x2(this: &[f32]) -> f32 { + let n = this.len(); + let mut x2 = 0.0f32; + for i in 0..n { + x2 += this[i] * this[i]; + } + x2 + } +} + +mod reduce_min_max_of_x { + // Semanctics of `f32::min` is different from `_mm256_min_ps`, + // which may lead to issues... + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v4")] + fn reduce_min_max_of_x_v4(this: &[f32]) -> (f32, f32) { + unsafe { + use std::arch::x86_64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut min = _mm512_set1_ps(f32::INFINITY); + let mut max = _mm512_set1_ps(f32::NEG_INFINITY); + while n >= 16 { + let x = _mm512_loadu_ps(a); + a = a.add(16); + n -= 16; + min = _mm512_min_ps(x, min); + max = _mm512_max_ps(x, max); + } + if n > 0 { + let mask = _bzhi_u32(0xffff, n as u32) as u16; + let x = _mm512_maskz_loadu_ps(mask, a); + min = _mm512_mask_min_ps(min, mask, x, min); + max = _mm512_mask_max_ps(max, mask, x, max); + } + let min = _mm512_reduce_min_ps(min); + let max = _mm512_reduce_max_ps(max); + (min, max) + } + } + + #[cfg(all(target_arch = "x86_64", test, not(miri)))] + #[test] + fn reduce_min_max_of_x_v4_test() { + use rand::Rng; + if !crate::is_cpu_detected!("v4") { + println!("test {} ... skipped (v4)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 200; + let x = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 50..200 { + let x = &x[..z]; + let specialized = unsafe { reduce_min_max_of_x_v4(x) }; + let fallback = reduce_min_max_of_x_fallback(x); + assert_eq!(specialized.0, fallback.0); + assert_eq!(specialized.1, fallback.1); + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v3")] + fn reduce_min_max_of_x_v3(this: &[f32]) -> (f32, f32) { + use crate::emulate::emulate_mm256_reduce_max_ps; + use crate::emulate::emulate_mm256_reduce_min_ps; + unsafe { + use std::arch::x86_64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut min = _mm256_set1_ps(f32::INFINITY); + let mut max = _mm256_set1_ps(f32::NEG_INFINITY); + while n >= 8 { + let x = _mm256_loadu_ps(a); + a = a.add(8); + n -= 8; + min = _mm256_min_ps(x, min); + max = _mm256_max_ps(x, max); + } + let mut min = emulate_mm256_reduce_min_ps(min); + let mut max = emulate_mm256_reduce_max_ps(max); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + a = a.add(1); + n -= 1; + min = x.min(min); + max = x.max(max); + } + (min, max) + } + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + fn reduce_min_max_of_x_v3_test() { + use rand::Rng; + if !crate::is_cpu_detected!("v3") { + println!("test {} ... skipped (v3)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 200; + let x = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 50..200 { + let x = &x[..z]; + let specialized = unsafe { reduce_min_max_of_x_v3(x) }; + let fallback = reduce_min_max_of_x_fallback(x); + assert_eq!(specialized.0, fallback.0,); + assert_eq!(specialized.1, fallback.1,); + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v2")] + fn reduce_min_max_of_x_v2(this: &[f32]) -> (f32, f32) { + use crate::emulate::emulate_mm_reduce_max_ps; + use crate::emulate::emulate_mm_reduce_min_ps; + unsafe { + use std::arch::x86_64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut min = _mm_set1_ps(f32::INFINITY); + let mut max = _mm_set1_ps(f32::NEG_INFINITY); + while n >= 4 { + let x = _mm_loadu_ps(a); + a = a.add(4); + n -= 4; + min = _mm_min_ps(x, min); + max = _mm_max_ps(x, max); + } + let mut min = emulate_mm_reduce_min_ps(min); + let mut max = emulate_mm_reduce_max_ps(max); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + a = a.add(1); + n -= 1; + min = x.min(min); + max = x.max(max); + } + (min, max) + } + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + fn reduce_min_max_of_x_v2_test() { + use rand::Rng; + if !crate::is_cpu_detected!("v2") { + println!("test {} ... skipped (v2)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 200; + let x = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 50..200 { + let x = &x[..z]; + let specialized = unsafe { reduce_min_max_of_x_v2(x) }; + let fallback = reduce_min_max_of_x_fallback(x); + assert_eq!(specialized.0, fallback.0,); + assert_eq!(specialized.1, fallback.1,); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::target_cpu(enable = "v8.3a")] + fn reduce_min_max_of_x_v8_3a(this: &[f32]) -> (f32, f32) { + unsafe { + use std::arch::aarch64::*; + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut min = vdupq_n_f32(f32::INFINITY); + let mut max = vdupq_n_f32(f32::NEG_INFINITY); + while n >= 4 { + let x = vld1q_f32(a); + a = a.add(4); + n -= 4; + min = vminq_f32(x, min); + max = vmaxq_f32(x, max); + } + let mut min = vminvq_f32(min); + let mut max = vmaxvq_f32(max); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + a = a.add(1); + n -= 1; + min = x.min(min); + max = x.max(max); + } + (min, max) + } + } + + #[cfg(all(target_arch = "aarch64", test, not(miri)))] + #[test] + fn reduce_min_max_of_x_v8_3a_test() { + use rand::Rng; + if !crate::is_cpu_detected!("v8.3a") { + println!("test {} ... skipped (v8.3a)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 200; + let x = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 50..200 { + let x = &x[..z]; + let specialized = unsafe { reduce_min_max_of_x_v8_3a(x) }; + let fallback = reduce_min_max_of_x_fallback(x); + assert_eq!(specialized.0, fallback.0,); + assert_eq!(specialized.1, fallback.1,); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::target_cpu(enable = "v8.3a")] + #[target_feature(enable = "sve")] + fn reduce_min_max_of_x_v8_3a_sve(this: &[f32]) -> (f32, f32) { + let mut min = f32::INFINITY; + let mut max = -f32::INFINITY; + unsafe { + extern "C" { + fn fp32_reduce_min_max_of_x_v8_3a_sve( + this: *const f32, + n: usize, + out_min: &mut f32, + out_max: &mut f32, + ); + } + fp32_reduce_min_max_of_x_v8_3a_sve(this.as_ptr(), this.len(), &mut min, &mut max); + } + (min, max) + } + + #[cfg(all(target_arch = "aarch64", test, not(miri)))] + #[test] + fn reduce_min_max_of_x_v8_3a_sve_test() { + use rand::Rng; + if !crate::is_cpu_detected!("v8.3a") || !crate::is_feature_detected!("sve") { + println!("test {} ... skipped (v8.3a:sve)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 200; + let x = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 50..200 { + let x = &x[..z]; + let specialized = unsafe { reduce_min_max_of_x_v8_3a_sve(x) }; + let fallback = reduce_min_max_of_x_fallback(x); + assert_eq!(specialized.0, fallback.0,); + assert_eq!(specialized.1, fallback.1,); + } + } + } + + #[crate::multiversion(@"v4", @"v3", @"v2", @"v8.3a:sve", @"v8.3a")] + pub fn reduce_min_max_of_x(this: &[f32]) -> (f32, f32) { + let mut min = f32::INFINITY; + let mut max = f32::NEG_INFINITY; + let n = this.len(); + for i in 0..n { + min = min.min(this[i]); + max = max.max(this[i]); + } + (min, max) + } +} + +mod reduce_sum_of_xy { + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v4")] + fn reduce_sum_of_xy_v4(lhs: &[f32], rhs: &[f32]) -> f32 { + assert!(lhs.len() == rhs.len()); + unsafe { + use std::arch::x86_64::*; + let mut n = lhs.len(); + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut xy = _mm512_setzero_ps(); + while n >= 16 { + let x = _mm512_loadu_ps(a); + let y = _mm512_loadu_ps(b); + a = a.add(16); + b = b.add(16); + n -= 16; + xy = _mm512_fmadd_ps(x, y, xy); + } + if n > 0 { + let mask = _bzhi_u32(0xffff, n as u32) as u16; + let x = _mm512_maskz_loadu_ps(mask, a); + let y = _mm512_maskz_loadu_ps(mask, b); + xy = _mm512_fmadd_ps(x, y, xy); + } + _mm512_reduce_add_ps(xy) + } + } + + #[cfg(all(target_arch = "x86_64", test, not(miri)))] + #[test] + fn reduce_sum_of_xy_v4_test() { + use rand::Rng; + const EPSILON: f32 = 0.004; + if !crate::is_cpu_detected!("v4") { + println!("test {} ... skipped (v4)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let lhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + let rhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_xy_v4(lhs, rhs) }; + let fallback = reduce_sum_of_xy_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v3")] + fn reduce_sum_of_xy_v3(lhs: &[f32], rhs: &[f32]) -> f32 { + use crate::emulate::emulate_mm256_reduce_add_ps; + assert!(lhs.len() == rhs.len()); + unsafe { + use std::arch::x86_64::*; + let mut n = lhs.len(); + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut xy = _mm256_setzero_ps(); + while n >= 8 { + let x = _mm256_loadu_ps(a); + let y = _mm256_loadu_ps(b); + a = a.add(8); + b = b.add(8); + n -= 8; + xy = _mm256_fmadd_ps(x, y, xy); + } + if n >= 4 { + let x = _mm256_zextps128_ps256(_mm_loadu_ps(a)); + let y = _mm256_zextps128_ps256(_mm_loadu_ps(b)); + a = a.add(4); + b = b.add(4); + n -= 4; + xy = _mm256_fmadd_ps(x, y, xy); + } + let mut xy = emulate_mm256_reduce_add_ps(xy); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + let y = b.read(); + a = a.add(1); + b = b.add(1); + n -= 1; + xy += x * y; + } + xy + } + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + fn reduce_sum_of_xy_v3_test() { + use rand::Rng; + const EPSILON: f32 = 0.004; + if !crate::is_cpu_detected!("v3") { + println!("test {} ... skipped (v3)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let lhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + let rhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_xy_v3(lhs, rhs) }; + let fallback = reduce_sum_of_xy_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v2")] + #[target_feature(enable = "fma")] + fn reduce_sum_of_xy_v2_fma(lhs: &[f32], rhs: &[f32]) -> f32 { + use crate::emulate::emulate_mm_reduce_add_ps; + assert!(lhs.len() == rhs.len()); + unsafe { + use std::arch::x86_64::*; + let mut n = lhs.len(); + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut xy = _mm_setzero_ps(); + while n >= 4 { + let x = _mm_loadu_ps(a); + let y = _mm_loadu_ps(b); + a = a.add(4); + b = b.add(4); + n -= 4; + xy = _mm_fmadd_ps(x, y, xy); + } + let mut xy = emulate_mm_reduce_add_ps(xy); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + let y = b.read(); + a = a.add(1); + b = b.add(1); + n -= 1; + xy += x * y; + } + xy + } + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + fn reduce_sum_of_xy_v2_fma_test() { + use rand::Rng; + const EPSILON: f32 = 0.004; + if !crate::is_cpu_detected!("v2") || !crate::is_feature_detected!("fma") { + println!("test {} ... skipped (v2:fma)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let lhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + let rhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_xy_v2_fma(lhs, rhs) }; + let fallback = reduce_sum_of_xy_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::target_cpu(enable = "v8.3a")] + fn reduce_sum_of_xy_v8_3a(lhs: &[f32], rhs: &[f32]) -> f32 { + assert!(lhs.len() == rhs.len()); + unsafe { + use std::arch::aarch64::*; + let mut n = lhs.len(); + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut xy = vdupq_n_f32(0.0); + while n >= 4 { + let x = vld1q_f32(a); + let y = vld1q_f32(b); + a = a.add(4); + b = b.add(4); + n -= 4; + xy = vfmaq_f32(xy, x, y); + } + let mut xy = vaddvq_f32(xy); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + let y = b.read(); + a = a.add(1); + b = b.add(1); + n -= 1; + xy += x * y; + } + xy + } + } + + #[cfg(all(target_arch = "aarch64", test, not(miri)))] + #[test] + fn reduce_sum_of_xy_v8_3a_test() { + use rand::Rng; + const EPSILON: f32 = 0.004; + if !crate::is_cpu_detected!("v8.3a") { + println!("test {} ... skipped (v8.3a)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let lhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + let rhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_xy_v8_3a(lhs, rhs) }; + let fallback = reduce_sum_of_xy_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::target_cpu(enable = "v8.3a")] + #[target_feature(enable = "sve")] + fn reduce_sum_of_xy_v8_3a_sve(lhs: &[f32], rhs: &[f32]) -> f32 { + assert!(lhs.len() == rhs.len()); + unsafe { + extern "C" { + fn fp32_reduce_sum_of_xy_v8_3a_sve(a: *const f32, b: *const f32, n: usize) -> f32; + } + fp32_reduce_sum_of_xy_v8_3a_sve(lhs.as_ptr(), rhs.as_ptr(), lhs.len()) + } + } + + #[cfg(all(target_arch = "aarch64", test, not(miri)))] + #[test] + fn reduce_sum_of_xy_v8_3a_sve_test() { + use rand::Rng; + const EPSILON: f32 = 0.004; + if !crate::is_cpu_detected!("v8.3a") || !crate::is_feature_detected!("sve") { + println!("test {} ... skipped (v8.3a:sve)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let lhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + let rhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_xy_v8_3a_sve(lhs, rhs) }; + let fallback = reduce_sum_of_xy_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[crate::multiversion(@"v4", @"v3", @"v2:fma", @"v8.3a:sve", @"v8.3a")] + pub fn reduce_sum_of_xy(lhs: &[f32], rhs: &[f32]) -> f32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut xy = 0.0f32; + for i in 0..n { + xy += lhs[i] * rhs[i]; + } + xy + } +} + +mod reduce_sum_of_d2 { + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v4")] + fn reduce_sum_of_d2_v4(lhs: &[f32], rhs: &[f32]) -> f32 { + assert!(lhs.len() == rhs.len()); + unsafe { + use std::arch::x86_64::*; + let mut n = lhs.len() as u32; + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut d2 = _mm512_setzero_ps(); + while n >= 16 { + let x = _mm512_loadu_ps(a); + let y = _mm512_loadu_ps(b); + a = a.add(16); + b = b.add(16); + n -= 16; + let d = _mm512_sub_ps(x, y); + d2 = _mm512_fmadd_ps(d, d, d2); + } + if n > 0 { + let mask = _bzhi_u32(0xffff, n) as u16; + let x = _mm512_maskz_loadu_ps(mask, a); + let y = _mm512_maskz_loadu_ps(mask, b); + let d = _mm512_sub_ps(x, y); + d2 = _mm512_fmadd_ps(d, d, d2); + } + _mm512_reduce_add_ps(d2) + } + } + + #[cfg(all(target_arch = "x86_64", test, not(miri)))] + #[test] + fn reduce_sum_of_d2_v4_test() { + use rand::Rng; + const EPSILON: f32 = 0.02; + if !crate::is_cpu_detected!("v4") { + println!("test {} ... skipped (v4)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let lhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + let rhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_d2_v4(lhs, rhs) }; + let fallback = reduce_sum_of_d2_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v3")] + fn reduce_sum_of_d2_v3(lhs: &[f32], rhs: &[f32]) -> f32 { + use crate::emulate::emulate_mm256_reduce_add_ps; + assert!(lhs.len() == rhs.len()); + unsafe { + use std::arch::x86_64::*; + let mut n = lhs.len(); + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut d2 = _mm256_setzero_ps(); + while n >= 8 { + let x = _mm256_loadu_ps(a); + let y = _mm256_loadu_ps(b); + a = a.add(8); + b = b.add(8); + n -= 8; + let d = _mm256_sub_ps(x, y); + d2 = _mm256_fmadd_ps(d, d, d2); + } + if n >= 4 { + let x = _mm256_zextps128_ps256(_mm_loadu_ps(a)); + let y = _mm256_zextps128_ps256(_mm_loadu_ps(b)); + a = a.add(4); + b = b.add(4); + n -= 4; + let d = _mm256_sub_ps(x, y); + d2 = _mm256_fmadd_ps(d, d, d2); + } + let mut d2 = emulate_mm256_reduce_add_ps(d2); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + let y = b.read(); + a = a.add(1); + b = b.add(1); + n -= 1; + let d = x - y; + d2 += d * d; + } + d2 + } + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + fn reduce_sum_of_d2_v3_test() { + use rand::Rng; + const EPSILON: f32 = 0.02; + if !crate::is_cpu_detected!("v3") { + println!("test {} ... skipped (v3)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let lhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + let rhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_d2_v3(lhs, rhs) }; + let fallback = reduce_sum_of_d2_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v2")] + #[target_feature(enable = "fma")] + fn reduce_sum_of_d2_v2_fma(lhs: &[f32], rhs: &[f32]) -> f32 { + use crate::emulate::emulate_mm_reduce_add_ps; + assert!(lhs.len() == rhs.len()); + unsafe { + use std::arch::x86_64::*; + let mut n = lhs.len(); + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut d2 = _mm_setzero_ps(); + while n >= 4 { + let x = _mm_loadu_ps(a); + let y = _mm_loadu_ps(b); + a = a.add(4); + b = b.add(4); + n -= 4; + let d = _mm_sub_ps(x, y); + d2 = _mm_fmadd_ps(d, d, d2); + } + let mut d2 = emulate_mm_reduce_add_ps(d2); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + let y = b.read(); + a = a.add(1); + b = b.add(1); + n -= 1; + let d = x - y; + d2 += d * d; + } + d2 + } + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + fn reduce_sum_of_d2_v2_fma_test() { + use rand::Rng; + const EPSILON: f32 = 0.02; + if !crate::is_cpu_detected!("v2") || !crate::is_feature_detected!("fma") { + println!("test {} ... skipped (v2:fma)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let lhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + let rhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_d2_v2_fma(lhs, rhs) }; + let fallback = reduce_sum_of_d2_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::target_cpu(enable = "v8.3a")] + fn reduce_sum_of_d2_v8_3a(lhs: &[f32], rhs: &[f32]) -> f32 { + assert!(lhs.len() == rhs.len()); + unsafe { + use std::arch::aarch64::*; + let mut n = lhs.len(); + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut d2 = vdupq_n_f32(0.0); + while n >= 4 { + let x = vld1q_f32(a); + let y = vld1q_f32(b); + a = a.add(4); + b = b.add(4); + n -= 4; + let d = vsubq_f32(x, y); + d2 = vfmaq_f32(d2, d, d); + } + let mut d2 = vaddvq_f32(d2); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + let y = b.read(); + a = a.add(1); + b = b.add(1); + n -= 1; + let d = x - y; + d2 += d * d; + } + d2 + } + } + + #[cfg(all(target_arch = "aarch64", test, not(miri)))] + #[test] + fn reduce_sum_of_d2_v8_3a_test() { + use rand::Rng; + const EPSILON: f32 = 0.02; + if !crate::is_cpu_detected!("v8.3a") { + println!("test {} ... skipped (v8.3a)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let lhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + let rhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_d2_v8_3a(lhs, rhs) }; + let fallback = reduce_sum_of_d2_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::target_cpu(enable = "v8.3a")] + #[target_feature(enable = "sve")] + fn reduce_sum_of_d2_v8_3a_sve(lhs: &[f32], rhs: &[f32]) -> f32 { + assert!(lhs.len() == rhs.len()); + unsafe { + extern "C" { + fn fp32_reduce_sum_of_d2_v8_3a_sve(a: *const f32, b: *const f32, n: usize) -> f32; + } + fp32_reduce_sum_of_d2_v8_3a_sve(lhs.as_ptr(), rhs.as_ptr(), lhs.len()) + } + } + + #[cfg(all(target_arch = "aarch64", test, not(miri)))] + #[test] + fn reduce_sum_of_d2_v8_3a_sve_test() { + use rand::Rng; + const EPSILON: f32 = 0.02; + if !crate::is_cpu_detected!("v8.3a") || !crate::is_feature_detected!("sve") { + println!("test {} ... skipped (v8.3a:sve)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let lhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + let rhs = (0..n) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + for z in 3984..4016 { + let lhs = &lhs[..z]; + let rhs = &rhs[..z]; + let specialized = unsafe { reduce_sum_of_d2_v8_3a_sve(lhs, rhs) }; + let fallback = reduce_sum_of_d2_fallback(lhs, rhs); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + } + + #[crate::multiversion(@"v4", @"v3", @"v2:fma", @"v8.3a:sve", @"v8.3a")] + pub fn reduce_sum_of_d2(lhs: &[f32], rhs: &[f32]) -> f32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut d2 = 0.0f32; + for i in 0..n { + let d = lhs[i] - rhs[i]; + d2 += d * d; + } + d2 + } +} + +mod reduce_sum_of_xy_sparse { + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v4")] + fn reduce_sum_of_xy_sparse_v4(li: &[u32], lv: &[f32], ri: &[u32], rv: &[f32]) -> f32 { + use crate::emulate::emulate_mm512_2intersect_epi32; + assert_eq!(li.len(), lv.len()); + assert_eq!(ri.len(), rv.len()); + let (mut lp, ln) = (0, li.len()); + let (mut rp, rn) = (0, ri.len()); + let (li, lv) = (li.as_ptr(), lv.as_ptr()); + let (ri, rv) = (ri.as_ptr(), rv.as_ptr()); + unsafe { + use std::arch::x86_64::*; + let mut xy = _mm512_setzero_ps(); + while lp + 16 <= ln && rp + 16 <= rn { + let lx = _mm512_loadu_epi32(li.add(lp).cast()); + let rx = _mm512_loadu_epi32(ri.add(rp).cast()); + let (lk, rk) = emulate_mm512_2intersect_epi32(lx, rx); + let lv = _mm512_maskz_compress_ps(lk, _mm512_loadu_ps(lv.add(lp))); + let rv = _mm512_maskz_compress_ps(rk, _mm512_loadu_ps(rv.add(rp))); + xy = _mm512_fmadd_ps(lv, rv, xy); + let lt = li.add(lp + 16 - 1).read(); + let rt = ri.add(rp + 16 - 1).read(); + lp += (lt <= rt) as usize * 16; + rp += (lt >= rt) as usize * 16; + } + while lp < ln && rp < rn { + let lw = 16.min(ln - lp); + let rw = 16.min(rn - rp); + let lm = _bzhi_u32(0xffff, lw as _) as u16; + let rm = _bzhi_u32(0xffff, rw as _) as u16; + let lx = _mm512_mask_loadu_epi32(_mm512_set1_epi32(-1), lm, li.add(lp).cast()); + let rx = _mm512_mask_loadu_epi32(_mm512_set1_epi32(-1), rm, ri.add(rp).cast()); + let (lk, rk) = emulate_mm512_2intersect_epi32(lx, rx); + let lv = _mm512_maskz_compress_ps(lk, _mm512_maskz_loadu_ps(lm, lv.add(lp))); + let rv = _mm512_maskz_compress_ps(rk, _mm512_maskz_loadu_ps(rm, rv.add(rp))); + xy = _mm512_fmadd_ps(lv, rv, xy); + let lt = li.add(lp + lw - 1).read(); + let rt = ri.add(rp + rw - 1).read(); + lp += (lt <= rt) as usize * lw; + rp += (lt >= rt) as usize * rw; + } + _mm512_reduce_add_ps(xy) + } + } + + #[cfg(all(target_arch = "x86_64", test, not(miri)))] + #[test] + fn reduce_sum_of_xy_sparse_v4_test() { + use rand::Rng; + const EPSILON: f32 = 0.000001; + if !crate::is_cpu_detected!("v4") { + println!("test {} ... skipped (v4)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + pub fn sample_u32_sorted( + rng: &mut (impl Rng + ?Sized), + length: u32, + amount: u32, + ) -> Vec { + let mut x = match rand::seq::index::sample(rng, length as usize, amount as usize) { + rand::seq::index::IndexVec::U32(x) => x, + _ => unreachable!(), + }; + x.sort(); + x + } + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let lm = 300; + let lidx = sample_u32_sorted(&mut rng, 10000, lm); + let lval = (0..lm) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + let rm = 350; + let ridx = sample_u32_sorted(&mut rng, 10000, rm); + let rval = (0..rm) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + let specialized = unsafe { reduce_sum_of_xy_sparse_v4(&lidx, &lval, &ridx, &rval) }; + let fallback = reduce_sum_of_xy_sparse_fallback(&lidx, &lval, &ridx, &rval); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + + #[crate::multiversion(@"v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn reduce_sum_of_xy_sparse(lidx: &[u32], lval: &[f32], ridx: &[u32], rval: &[f32]) -> f32 { + use std::cmp::Ordering; + assert_eq!(lidx.len(), lval.len()); + assert_eq!(ridx.len(), rval.len()); + let (mut lp, ln) = (0, lidx.len()); + let (mut rp, rn) = (0, ridx.len()); + let mut xy = 0.0f32; + while lp < ln && rp < rn { + match Ord::cmp(&lidx[lp], &ridx[rp]) { + Ordering::Equal => { + xy += lval[lp] * rval[rp]; + lp += 1; + rp += 1; + } + Ordering::Less => { + lp += 1; + } + Ordering::Greater => { + rp += 1; + } + } + } + xy + } +} + +mod reduce_sum_of_d2_sparse { + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v4")] + fn reduce_sum_of_d2_sparse_v4(li: &[u32], lv: &[f32], ri: &[u32], rv: &[f32]) -> f32 { + use crate::emulate::emulate_mm512_2intersect_epi32; + assert_eq!(li.len(), lv.len()); + assert_eq!(ri.len(), rv.len()); + let (mut lp, ln) = (0, li.len()); + let (mut rp, rn) = (0, ri.len()); + let (li, lv) = (li.as_ptr(), lv.as_ptr()); + let (ri, rv) = (ri.as_ptr(), rv.as_ptr()); + unsafe { + use std::arch::x86_64::*; + let mut d2 = _mm512_setzero_ps(); + while lp + 16 <= ln && rp + 16 <= rn { + let lx = _mm512_loadu_epi32(li.add(lp).cast()); + let rx = _mm512_loadu_epi32(ri.add(rp).cast()); + let (lk, rk) = emulate_mm512_2intersect_epi32(lx, rx); + let lv = _mm512_maskz_compress_ps(lk, _mm512_loadu_ps(lv.add(lp))); + let rv = _mm512_maskz_compress_ps(rk, _mm512_loadu_ps(rv.add(rp))); + let d = _mm512_sub_ps(lv, rv); + d2 = _mm512_fmadd_ps(d, d, d2); + d2 = _mm512_sub_ps(d2, _mm512_mul_ps(lv, lv)); + d2 = _mm512_sub_ps(d2, _mm512_mul_ps(rv, rv)); + let lt = li.add(lp + 16 - 1).read(); + let rt = ri.add(rp + 16 - 1).read(); + lp += (lt <= rt) as usize * 16; + rp += (lt >= rt) as usize * 16; + } + while lp < ln && rp < rn { + let lw = 16.min(ln - lp); + let rw = 16.min(rn - rp); + let lm = _bzhi_u32(0xffff, lw as _) as u16; + let rm = _bzhi_u32(0xffff, rw as _) as u16; + let lx = _mm512_mask_loadu_epi32(_mm512_set1_epi32(-1), lm, li.add(lp).cast()); + let rx = _mm512_mask_loadu_epi32(_mm512_set1_epi32(-1), rm, ri.add(rp).cast()); + let (lk, rk) = emulate_mm512_2intersect_epi32(lx, rx); + let lv = _mm512_maskz_compress_ps(lk, _mm512_maskz_loadu_ps(lm, lv.add(lp))); + let rv = _mm512_maskz_compress_ps(rk, _mm512_maskz_loadu_ps(rm, rv.add(rp))); + let d = _mm512_sub_ps(lv, rv); + d2 = _mm512_fmadd_ps(d, d, d2); + d2 = _mm512_sub_ps(d2, _mm512_mul_ps(lv, lv)); + d2 = _mm512_sub_ps(d2, _mm512_mul_ps(rv, rv)); + let lt = li.add(lp + lw - 1).read(); + let rt = ri.add(rp + rw - 1).read(); + lp += (lt <= rt) as usize * lw; + rp += (lt >= rt) as usize * rw; + } + { + let mut lp = 0; + while lp + 16 <= ln { + let d = _mm512_loadu_ps(lv.add(lp)); + d2 = _mm512_fmadd_ps(d, d, d2); + lp += 16; + } + if lp < ln { + let lw = ln - lp; + let lm = _bzhi_u32(0xffff, lw as _) as u16; + let d = _mm512_maskz_loadu_ps(lm, lv.add(lp)); + d2 = _mm512_fmadd_ps(d, d, d2); + } + } + { + let mut rp = 0; + while rp + 16 <= rn { + let d = _mm512_loadu_ps(rv.add(rp)); + d2 = _mm512_fmadd_ps(d, d, d2); + rp += 16; + } + if rp < rn { + let rw = rn - rp; + let rm = _bzhi_u32(0xffff, rw as _) as u16; + let d = _mm512_maskz_loadu_ps(rm, rv.add(rp)); + d2 = _mm512_fmadd_ps(d, d, d2); + } + } + _mm512_reduce_add_ps(d2) + } + } + + #[cfg(all(target_arch = "x86_64", test, not(miri)))] + #[test] + fn reduce_sum_of_d2_sparse_v4_test() { + use rand::Rng; + const EPSILON: f32 = 0.0004; + if !crate::is_cpu_detected!("v4") { + println!("test {} ... skipped (v4)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + pub fn sample_u32_sorted( + rng: &mut (impl Rng + ?Sized), + length: u32, + amount: u32, + ) -> Vec { + let mut x = match rand::seq::index::sample(rng, length as usize, amount as usize) { + rand::seq::index::IndexVec::U32(x) => x, + _ => unreachable!(), + }; + x.sort(); + x + } + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let lm = 300; + let lidx = sample_u32_sorted(&mut rng, 10000, lm); + let lval = (0..lm) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + let rm = 350; + let ridx = sample_u32_sorted(&mut rng, 10000, rm); + let rval = (0..rm) + .map(|_| rng.gen_range(-1.0..=1.0)) + .collect::>(); + let specialized = unsafe { reduce_sum_of_d2_sparse_v4(&lidx, &lval, &ridx, &rval) }; + let fallback = reduce_sum_of_d2_sparse_fallback(&lidx, &lval, &ridx, &rval); + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } + } + + #[crate::multiversion(@"v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn reduce_sum_of_d2_sparse(lidx: &[u32], lval: &[f32], ridx: &[u32], rval: &[f32]) -> f32 { + use std::cmp::Ordering; + assert_eq!(lidx.len(), lval.len()); + assert_eq!(ridx.len(), rval.len()); + let (mut lp, ln) = (0, lidx.len()); + let (mut rp, rn) = (0, ridx.len()); + let mut d2 = 0.0f32; + while lp < ln && rp < rn { + match Ord::cmp(&lidx[lp], &ridx[rp]) { + Ordering::Equal => { + let d = lval[lp] - rval[rp]; + d2 += d * d; + lp += 1; + rp += 1; + } + Ordering::Less => { + d2 += lval[lp] * lval[lp]; + lp += 1; + } + Ordering::Greater => { + d2 += rval[rp] * rval[rp]; + rp += 1; + } + } + } + for i in lp..ln { + d2 += lval[i] * lval[i]; + } + for i in rp..rn { + d2 += rval[i] * rval[i]; + } + d2 + } +} + +mod vector_add { + #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_add(lhs: &[f32], rhs: &[f32]) -> Vec { + assert_eq!(lhs.len(), rhs.len()); + let n = lhs.len(); + let mut r = Vec::::with_capacity(n); + for i in 0..n { + unsafe { + r.as_mut_ptr().add(i).write(lhs[i] + rhs[i]); + } + } + unsafe { + r.set_len(n); + } + r + } +} + +mod vector_add_inplace { + #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_add_inplace(lhs: &mut [f32], rhs: &[f32]) { + assert_eq!(lhs.len(), rhs.len()); + let n = lhs.len(); + for i in 0..n { + lhs[i] += rhs[i]; + } + } +} + +mod vector_sub { + #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_sub(lhs: &[f32], rhs: &[f32]) -> Vec { + assert_eq!(lhs.len(), rhs.len()); + let n = lhs.len(); + let mut r = Vec::::with_capacity(n); + for i in 0..n { + unsafe { + r.as_mut_ptr().add(i).write(lhs[i] - rhs[i]); + } + } + unsafe { + r.set_len(n); + } + r + } +} + +mod vector_mul { + #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_mul(lhs: &[f32], rhs: &[f32]) -> Vec { + assert_eq!(lhs.len(), rhs.len()); + let n = lhs.len(); + let mut r = Vec::::with_capacity(n); + for i in 0..n { + unsafe { + r.as_mut_ptr().add(i).write(lhs[i] * rhs[i]); + } + } + unsafe { + r.set_len(n); + } + r + } +} + +mod vector_mul_scalar { + #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_mul_scalar(lhs: &[f32], rhs: f32) -> Vec { + let n = lhs.len(); + let mut r = Vec::::with_capacity(n); + for i in 0..n { + unsafe { + r.as_mut_ptr().add(i).write(lhs[i] * rhs); + } + } + unsafe { + r.set_len(n); + } + r + } +} + +mod vector_mul_scalar_inplace { + #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_mul_scalar_inplace(lhs: &mut [f32], rhs: f32) { + let n = lhs.len(); + for i in 0..n { + lhs[i] *= rhs; + } + } +} + +mod vector_abs_inplace { + #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn vector_abs_inplace(this: &mut [f32]) { + let n = this.len(); + for i in 0..n { + this[i] = this[i].abs(); + } + } +} diff --git a/crates/simd/src/fast_scan/mod.rs b/crates/simd/src/fast_scan/mod.rs new file mode 100644 index 0000000..c2579af --- /dev/null +++ b/crates/simd/src/fast_scan/mod.rs @@ -0,0 +1,496 @@ +/* + +## codes layout for 4-bit quantizer + +group i = | vector i | (total bytes = width/2) + +byte: | 0 | 1 | 2 | ... | width/2 - 1 | +bits 0..3: | code 0 | code 2 | code 4 | ... | code width-2 | +bits 4..7: | code 1 | code 3 | code 5 | ... | code width-1 | + +## packed_codes layout for 4-bit quantizer + +group i = | vector 32i | vector 32i+1 | vector 32i+2 | ... | vector 32i+31 | (total bytes = width * 16) + +byte | 0 | 1 | 2 | ... | 14 | 15 | +bits 0..3 | code 0,vector 0 | code 0,vector 8 | code 0,vector 1 | ... | code 0,vector 14 | code 0,vector 15 | +bits 4..7 | code 0,vector 16 | code 0,vector 24 | code 0,vector 17 | ... | code 0,vector 30 | code 0,vector 31 | + +byte | 16 | 17 | 18 | ... | 30 | 31 | +bits 0..3 | code 1,vector 0 | code 1,vector 8 | code 1,vector 1 | ... | code 1,vector 14 | code 1,vector 15 | +bits 4..7 | code 1,vector 16 | code 1,vector 24 | code 1,vector 17 | ... | code 1,vector 30 | code 1,vector 31 | + +byte | 32 | 33 | 34 | ... | 46 | 47 | +bits 0..3 | code 2,vector 0 | code 2,vector 8 | code 2,vector 1 | ... | code 2,vector 14 | code 2,vector 15 | +bits 4..7 | code 2,vector 16 | code 2,vector 24 | code 2,vector 17 | ... | code 2,vector 30 | code 2,vector 31 | + +... + +byte | width*32-32 | width*32-31 | ... | width*32-1 | +bits 0..3 | code (width-1),vector 0 | code (width-1),vector 8 | ... | code (width-1),vector 15 | +bits 4..7 | code (width-1),vector 16 | code (width-1),vector 24 | ... | code (width-1),vector 31 | + +*/ + +pub fn pack(width: u32, r: [Vec; 32]) -> impl Iterator { + (0..width as usize).flat_map(move |i| { + [ + r[0][i] | (r[16][i] << 4), + r[8][i] | (r[24][i] << 4), + r[1][i] | (r[17][i] << 4), + r[9][i] | (r[25][i] << 4), + r[2][i] | (r[18][i] << 4), + r[10][i] | (r[26][i] << 4), + r[3][i] | (r[19][i] << 4), + r[11][i] | (r[27][i] << 4), + r[4][i] | (r[20][i] << 4), + r[12][i] | (r[28][i] << 4), + r[5][i] | (r[21][i] << 4), + r[13][i] | (r[29][i] << 4), + r[6][i] | (r[22][i] << 4), + r[14][i] | (r[30][i] << 4), + r[7][i] | (r[23][i] << 4), + r[15][i] | (r[31][i] << 4), + ] + .into_iter() + }) +} + +#[allow(clippy::module_inception)] +mod fast_scan { + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v4")] + fn fast_scan_v4(width: u32, codes: &[u8], lut: &[u8]) -> [u16; 32] { + // bounds checking is not enforced by compiler, so check it manually + assert_eq!(codes.len(), width as usize * 16); + assert_eq!(lut.len(), width as usize * 16); + + unsafe { + use std::arch::x86_64::*; + + #[inline] + #[crate::target_cpu(enable = "v4")] + fn combine2x2(x0x1: __m256i, y0y1: __m256i) -> __m256i { + unsafe { + let x1y0 = _mm256_permute2f128_si256(x0x1, y0y1, 0x21); + let x0y1 = _mm256_blend_epi32(x0x1, y0y1, 0xf0); + _mm256_add_epi16(x1y0, x0y1) + } + } + + #[inline] + #[crate::target_cpu(enable = "v4")] + fn combine4x2(x0x1x2x3: __m512i, y0y1y2y3: __m512i) -> __m256i { + unsafe { + let x0x1 = _mm512_castsi512_si256(x0x1x2x3); + let x2x3 = _mm512_extracti64x4_epi64(x0x1x2x3, 1); + let y0y1 = _mm512_castsi512_si256(y0y1y2y3); + let y2y3 = _mm512_extracti64x4_epi64(y0y1y2y3, 1); + let x01y01 = combine2x2(x0x1, y0y1); + let x23y23 = combine2x2(x2x3, y2y3); + _mm256_add_epi16(x01y01, x23y23) + } + } + + let mut accu_0 = _mm512_setzero_si512(); + let mut accu_1 = _mm512_setzero_si512(); + let mut accu_2 = _mm512_setzero_si512(); + let mut accu_3 = _mm512_setzero_si512(); + + let mut i = 0_usize; + while i + 4 <= width as usize { + let c = _mm512_loadu_si512(codes.as_ptr().add(i * 16).cast()); + + let mask = _mm512_set1_epi8(0xf); + let clo = _mm512_and_si512(c, mask); + let chi = _mm512_and_si512(_mm512_srli_epi16(c, 4), mask); + + let lut = _mm512_loadu_si512(lut.as_ptr().add(i * 16).cast()); + let res_lo = _mm512_shuffle_epi8(lut, clo); + accu_0 = _mm512_add_epi16(accu_0, res_lo); + accu_1 = _mm512_add_epi16(accu_1, _mm512_srli_epi16(res_lo, 8)); + let res_hi = _mm512_shuffle_epi8(lut, chi); + accu_2 = _mm512_add_epi16(accu_2, res_hi); + accu_3 = _mm512_add_epi16(accu_3, _mm512_srli_epi16(res_hi, 8)); + + i += 4; + } + if i + 2 <= width as usize { + let c = _mm256_loadu_si256(codes.as_ptr().add(i * 16).cast()); + + let mask = _mm256_set1_epi8(0xf); + let clo = _mm256_and_si256(c, mask); + let chi = _mm256_and_si256(_mm256_srli_epi16(c, 4), mask); + + let lut = _mm256_loadu_si256(lut.as_ptr().add(i * 16).cast()); + let res_lo = _mm512_zextsi256_si512(_mm256_shuffle_epi8(lut, clo)); + accu_0 = _mm512_add_epi16(accu_0, res_lo); + accu_1 = _mm512_add_epi16(accu_1, _mm512_srli_epi16(res_lo, 8)); + let res_hi = _mm512_zextsi256_si512(_mm256_shuffle_epi8(lut, chi)); + accu_2 = _mm512_add_epi16(accu_2, res_hi); + accu_3 = _mm512_add_epi16(accu_3, _mm512_srli_epi16(res_hi, 8)); + + i += 2; + } + if i < width as usize { + let c = _mm_loadu_si128(codes.as_ptr().add(i * 16).cast()); + + let mask = _mm_set1_epi8(0xf); + let clo = _mm_and_si128(c, mask); + let chi = _mm_and_si128(_mm_srli_epi16(c, 4), mask); + + let lut = _mm_loadu_si128(lut.as_ptr().add(i * 16).cast()); + let res_lo = _mm512_zextsi128_si512(_mm_shuffle_epi8(lut, clo)); + accu_0 = _mm512_add_epi16(accu_0, res_lo); + accu_1 = _mm512_add_epi16(accu_1, _mm512_srli_epi16(res_lo, 8)); + let res_hi = _mm512_zextsi128_si512(_mm_shuffle_epi8(lut, chi)); + accu_2 = _mm512_add_epi16(accu_2, res_hi); + accu_3 = _mm512_add_epi16(accu_3, _mm512_srli_epi16(res_hi, 8)); + + i += 1; + } + debug_assert_eq!(i, width as usize); + + let mut result = [0_u16; 32]; + + accu_0 = _mm512_sub_epi16(accu_0, _mm512_slli_epi16(accu_1, 8)); + _mm256_storeu_si256( + result.as_mut_ptr().add(0).cast(), + combine4x2(accu_0, accu_1), + ); + + accu_2 = _mm512_sub_epi16(accu_2, _mm512_slli_epi16(accu_3, 8)); + _mm256_storeu_si256( + result.as_mut_ptr().add(16).cast(), + combine4x2(accu_2, accu_3), + ); + + result + } + } + + #[cfg(all(target_arch = "x86_64", test, not(miri)))] + #[test] + fn fast_scan_v4_test() { + if !crate::is_cpu_detected!("v4") { + println!("test {} ... skipped (v4)", module_path!()); + return; + } + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + for width in 90..110 { + let codes = (0..16 * width).map(|_| rand::random()).collect::>(); + let lut = (0..16 * width).map(|_| rand::random()).collect::>(); + unsafe { + assert_eq!( + fast_scan_v4(width, &codes, &lut), + fast_scan_fallback(width, &codes, &lut) + ); + } + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v3")] + fn fast_scan_v3(width: u32, codes: &[u8], lut: &[u8]) -> [u16; 32] { + // bounds checking is not enforced by compiler, so check it manually + assert_eq!(codes.len(), width as usize * 16); + assert_eq!(lut.len(), width as usize * 16); + + unsafe { + use std::arch::x86_64::*; + + #[inline] + #[crate::target_cpu(enable = "v3")] + fn combine2x2(x0x1: __m256i, y0y1: __m256i) -> __m256i { + unsafe { + let x1y0 = _mm256_permute2f128_si256(x0x1, y0y1, 0x21); + let x0y1 = _mm256_blend_epi32(x0x1, y0y1, 0xf0); + _mm256_add_epi16(x1y0, x0y1) + } + } + + let mut accu_0 = _mm256_setzero_si256(); + let mut accu_1 = _mm256_setzero_si256(); + let mut accu_2 = _mm256_setzero_si256(); + let mut accu_3 = _mm256_setzero_si256(); + + let mut i = 0_usize; + while i + 2 <= width as usize { + let c = _mm256_loadu_si256(codes.as_ptr().add(i * 16).cast()); + + let mask = _mm256_set1_epi8(0xf); + let clo = _mm256_and_si256(c, mask); + let chi = _mm256_and_si256(_mm256_srli_epi16(c, 4), mask); + + let lut = _mm256_loadu_si256(lut.as_ptr().add(i * 16).cast()); + let res_lo = _mm256_shuffle_epi8(lut, clo); + accu_0 = _mm256_add_epi16(accu_0, res_lo); + accu_1 = _mm256_add_epi16(accu_1, _mm256_srli_epi16(res_lo, 8)); + let res_hi = _mm256_shuffle_epi8(lut, chi); + accu_2 = _mm256_add_epi16(accu_2, res_hi); + accu_3 = _mm256_add_epi16(accu_3, _mm256_srli_epi16(res_hi, 8)); + + i += 2; + } + if i < width as usize { + let c = _mm_loadu_si128(codes.as_ptr().add(i * 16).cast()); + + let mask = _mm_set1_epi8(0xf); + let clo = _mm_and_si128(c, mask); + let chi = _mm_and_si128(_mm_srli_epi16(c, 4), mask); + + let lut = _mm_loadu_si128(lut.as_ptr().add(i * 16).cast()); + let res_lo = _mm256_zextsi128_si256(_mm_shuffle_epi8(lut, clo)); + accu_0 = _mm256_add_epi16(accu_0, res_lo); + accu_1 = _mm256_add_epi16(accu_1, _mm256_srli_epi16(res_lo, 8)); + let res_hi = _mm256_zextsi128_si256(_mm_shuffle_epi8(lut, chi)); + accu_2 = _mm256_add_epi16(accu_2, res_hi); + accu_3 = _mm256_add_epi16(accu_3, _mm256_srli_epi16(res_hi, 8)); + + i += 1; + } + debug_assert_eq!(i, width as usize); + + let mut result = [0_u16; 32]; + + accu_0 = _mm256_sub_epi16(accu_0, _mm256_slli_epi16(accu_1, 8)); + _mm256_storeu_si256( + result.as_mut_ptr().add(0).cast(), + combine2x2(accu_0, accu_1), + ); + + accu_2 = _mm256_sub_epi16(accu_2, _mm256_slli_epi16(accu_3, 8)); + _mm256_storeu_si256( + result.as_mut_ptr().add(16).cast(), + combine2x2(accu_2, accu_3), + ); + + result + } + } + + #[cfg(all(target_arch = "x86_64", test, not(miri)))] + #[test] + fn fast_scan_v3_test() { + if !crate::is_cpu_detected!("v3") { + println!("test {} ... skipped (v3)", module_path!()); + return; + } + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + for width in 90..110 { + let codes = (0..16 * width).map(|_| rand::random()).collect::>(); + let lut = (0..16 * width).map(|_| rand::random()).collect::>(); + unsafe { + assert_eq!( + fast_scan_v3(width, &codes, &lut), + fast_scan_fallback(width, &codes, &lut) + ); + } + } + } + } + + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v2")] + fn fast_scan_v2(width: u32, codes: &[u8], lut: &[u8]) -> [u16; 32] { + // bounds checking is not enforced by compiler, so check it manually + assert_eq!(codes.len(), width as usize * 16); + assert_eq!(lut.len(), width as usize * 16); + + unsafe { + use std::arch::x86_64::*; + + let mut accu_0 = _mm_setzero_si128(); + let mut accu_1 = _mm_setzero_si128(); + let mut accu_2 = _mm_setzero_si128(); + let mut accu_3 = _mm_setzero_si128(); + + let mut i = 0_usize; + while i < width as usize { + let c = _mm_loadu_si128(codes.as_ptr().add(i * 16).cast()); + + let mask = _mm_set1_epi8(0xf); + let clo = _mm_and_si128(c, mask); + let chi = _mm_and_si128(_mm_srli_epi16(c, 4), mask); + + let lut = _mm_loadu_si128(lut.as_ptr().add(i * 16).cast()); + let res_lo = _mm_shuffle_epi8(lut, clo); + accu_0 = _mm_add_epi16(accu_0, res_lo); + accu_1 = _mm_add_epi16(accu_1, _mm_srli_epi16(res_lo, 8)); + let res_hi = _mm_shuffle_epi8(lut, chi); + accu_2 = _mm_add_epi16(accu_2, res_hi); + accu_3 = _mm_add_epi16(accu_3, _mm_srli_epi16(res_hi, 8)); + + i += 1; + } + debug_assert_eq!(i, width as usize); + + let mut result = [0_u16; 32]; + + accu_0 = _mm_sub_epi16(accu_0, _mm_slli_epi16(accu_1, 8)); + _mm_storeu_si128(result.as_mut_ptr().add(0).cast(), accu_0); + _mm_storeu_si128(result.as_mut_ptr().add(8).cast(), accu_1); + + accu_2 = _mm_sub_epi16(accu_2, _mm_slli_epi16(accu_3, 8)); + _mm_storeu_si128(result.as_mut_ptr().add(16).cast(), accu_2); + _mm_storeu_si128(result.as_mut_ptr().add(24).cast(), accu_3); + + result + } + } + + #[cfg(all(target_arch = "x86_64", test, not(miri)))] + #[test] + fn fast_scan_v2_test() { + if !crate::is_cpu_detected!("v2") { + println!("test {} ... skipped (v2)", module_path!()); + return; + } + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + for width in 90..110 { + let codes = (0..16 * width).map(|_| rand::random()).collect::>(); + let lut = (0..16 * width).map(|_| rand::random()).collect::>(); + unsafe { + assert_eq!( + fast_scan_v2(width, &codes, &lut), + fast_scan_fallback(width, &codes, &lut) + ); + } + } + } + } + + #[cfg(target_arch = "aarch64")] + #[crate::target_cpu(enable = "v8.3a")] + fn fast_scan_v8_3a(width: u32, codes: &[u8], lut: &[u8]) -> [u16; 32] { + // bounds checking is not enforced by compiler, so check it manually + assert_eq!(codes.len(), width as usize * 16); + assert_eq!(lut.len(), width as usize * 16); + + unsafe { + use std::arch::aarch64::*; + + let mut accu_0 = vdupq_n_u16(0); + let mut accu_1 = vdupq_n_u16(0); + let mut accu_2 = vdupq_n_u16(0); + let mut accu_3 = vdupq_n_u16(0); + + let mut i = 0_usize; + while i < width as usize { + let c = vld1q_u8(codes.as_ptr().add(i * 16).cast()); + + let mask = vdupq_n_u8(0xf); + let clo = vandq_u8(c, mask); + let chi = vandq_u8(vshrq_n_u8(c, 4), mask); + + let lut = vld1q_u8(lut.as_ptr().add(i * 16).cast()); + let res_lo = vreinterpretq_u16_u8(vqtbl1q_u8(lut, clo)); + accu_0 = vaddq_u16(accu_0, res_lo); + accu_1 = vaddq_u16(accu_1, vshrq_n_u16(res_lo, 8)); + let res_hi = vreinterpretq_u16_u8(vqtbl1q_u8(lut, chi)); + accu_2 = vaddq_u16(accu_2, res_hi); + accu_3 = vaddq_u16(accu_3, vshrq_n_u16(res_hi, 8)); + + i += 1; + } + debug_assert_eq!(i, width as usize); + + let mut result = [0_u16; 32]; + + accu_0 = vsubq_u16(accu_0, vshlq_n_u16(accu_1, 8)); + vst1q_u16(result.as_mut_ptr().add(0).cast(), accu_0); + vst1q_u16(result.as_mut_ptr().add(8).cast(), accu_1); + + accu_2 = vsubq_u16(accu_2, vshlq_n_u16(accu_3, 8)); + vst1q_u16(result.as_mut_ptr().add(16).cast(), accu_2); + vst1q_u16(result.as_mut_ptr().add(24).cast(), accu_3); + + result + } + } + + #[cfg(all(target_arch = "aarch64", test, not(miri)))] + #[test] + fn fast_scan_v8_3a_test() { + if !crate::is_cpu_detected!("v8.3a") { + println!("test {} ... skipped (v8.3a)", module_path!()); + return; + } + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + for width in 90..110 { + let codes = (0..16 * width).map(|_| rand::random()).collect::>(); + let lut = (0..16 * width).map(|_| rand::random()).collect::>(); + unsafe { + assert_eq!( + fast_scan_v8_3a(width, &codes, &lut), + fast_scan_fallback(width, &codes, &lut) + ); + } + } + } + } + + #[crate::multiversion(@"v4", @"v3", @"v2", @"v8.3a")] + pub fn fast_scan(width: u32, codes: &[u8], lut: &[u8]) -> [u16; 32] { + let width = width as usize; + + assert_eq!(codes.len(), width * 16); + assert_eq!(lut.len(), width * 16); + + use std::array::from_fn; + use std::ops::BitAnd; + + fn load(slice: &[T]) -> [T; N] { + from_fn(|i| slice[i]) + } + fn unary(op: impl Fn(T) -> U, a: [T; N]) -> [U; N] { + from_fn(|i| op(a[i])) + } + fn binary(op: impl Fn(T, T) -> T, a: [T; N], b: [T; N]) -> [T; N] { + from_fn(|i| op(a[i], b[i])) + } + fn shuffle(a: [T; N], b: [u8; N]) -> [T; N] { + from_fn(|i| a[b[i] as usize]) + } + fn cast(x: [u8; 16]) -> [u16; 8] { + from_fn(|i| u16::from_le_bytes([x[i << 1 | 0], x[i << 1 | 1]])) + } + fn setr(x: [[T; 8]; 4]) -> [T; 32] { + from_fn(|i| x[i >> 3][i & 7]) + } + + let mut a_0 = [0u16; 8]; + let mut a_1 = [0u16; 8]; + let mut a_2 = [0u16; 8]; + let mut a_3 = [0u16; 8]; + + for i in 0..width { + let c = load(&codes[16 * i..]); + + let mask = [0xfu8; 16]; + let clo = binary(u8::bitand, c, mask); + let chi = binary(u8::bitand, unary(|x| x >> 4, c), mask); + + let lut = load(&lut[16 * i..]); + let res_lo = cast(shuffle(lut, clo)); + a_0 = binary(u16::wrapping_add, a_0, res_lo); + a_1 = binary(u16::wrapping_add, a_1, unary(|x| x >> 8, res_lo)); + let res_hi = cast(shuffle(lut, chi)); + a_2 = binary(u16::wrapping_add, a_2, res_hi); + a_3 = binary(u16::wrapping_add, a_3, unary(|x| x >> 8, res_hi)); + } + + a_0 = binary(u16::wrapping_sub, a_0, unary(|x| x.wrapping_shl(8), a_1)); + a_2 = binary(u16::wrapping_sub, a_2, unary(|x| x.wrapping_shl(8), a_3)); + + setr([a_0, a_1, a_2, a_3]) + } +} + +#[inline(always)] +pub fn fast_scan(width: u32, codes: &[u8], lut: &[u8]) -> [u16; 32] { + fast_scan::fast_scan(width, codes, lut) +} diff --git a/crates/simd/src/lib.rs b/crates/simd/src/lib.rs new file mode 100644 index 0000000..aec5c52 --- /dev/null +++ b/crates/simd/src/lib.rs @@ -0,0 +1,107 @@ +#![feature(target_feature_11)] +#![feature(avx512_target_feature)] +#![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512))] +#![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512_f16))] + +mod aligned; +mod emulate; +mod f16; +mod f32; + +pub mod bit; +pub mod fast_scan; +pub mod packed_u4; +pub mod quantize; +pub mod u8; + +pub trait Floating: + Copy + + Send + + Sync + + std::fmt::Debug + + serde::Serialize + + for<'a> serde::Deserialize<'a> + + Default + + 'static + + PartialEq + + PartialOrd +{ + fn zero() -> Self; + fn infinity() -> Self; + fn mask(self, m: bool) -> Self; + + fn scalar_neg(this: Self) -> Self; + fn scalar_add(lhs: Self, rhs: Self) -> Self; + fn scalar_sub(lhs: Self, rhs: Self) -> Self; + fn scalar_mul(lhs: Self, rhs: Self) -> Self; + + fn reduce_or_of_is_zero_x(this: &[Self]) -> bool; + fn reduce_sum_of_x(this: &[Self]) -> f32; + fn reduce_sum_of_abs_x(this: &[Self]) -> f32; + fn reduce_sum_of_x2(this: &[Self]) -> f32; + fn reduce_min_max_of_x(this: &[Self]) -> (f32, f32); + fn reduce_sum_of_xy(lhs: &[Self], rhs: &[Self]) -> f32; + fn reduce_sum_of_d2(lhs: &[Self], rhs: &[Self]) -> f32; + fn reduce_sum_of_xy_sparse(lidx: &[u32], lval: &[Self], ridx: &[u32], rval: &[Self]) -> f32; + fn reduce_sum_of_d2_sparse(lidx: &[u32], lval: &[Self], ridx: &[u32], rval: &[Self]) -> f32; + + fn vector_from_f32(this: &[f32]) -> Vec; + fn vector_to_f32(this: &[Self]) -> Vec; + fn vector_to_f32_borrowed(this: &[Self]) -> impl AsRef<[f32]>; + fn vector_add(lhs: &[Self], rhs: &[Self]) -> Vec; + fn vector_add_inplace(lhs: &mut [Self], rhs: &[Self]); + fn vector_sub(lhs: &[Self], rhs: &[Self]) -> Vec; + fn vector_mul(lhs: &[Self], rhs: &[Self]) -> Vec; + fn vector_mul_scalar(lhs: &[Self], rhs: f32) -> Vec; + fn vector_mul_scalar_inplace(lhs: &mut [Self], rhs: f32); +} + +mod internal { + #[cfg(target_arch = "x86_64")] + simd_macros::define_is_cpu_detected!("x86_64"); + + #[cfg(target_arch = "aarch64")] + simd_macros::define_is_cpu_detected!("aarch64"); + + #[cfg(target_arch = "riscv64")] + simd_macros::define_is_cpu_detected!("riscv64"); + + #[cfg(target_arch = "x86_64")] + #[allow(unused_imports)] + pub use is_x86_64_cpu_detected; + + #[cfg(target_arch = "aarch64")] + #[allow(unused_imports)] + pub use is_aarch64_cpu_detected; + + #[cfg(target_arch = "riscv64")] + #[allow(unused_imports)] + pub use is_riscv64_cpu_detected; +} + +pub use simd_macros::multiversion; +pub use simd_macros::target_cpu; + +#[cfg(target_arch = "x86_64")] +#[allow(unused_imports)] +pub use std::arch::is_x86_feature_detected as is_feature_detected; + +#[cfg(target_arch = "aarch64")] +#[allow(unused_imports)] +pub use std::arch::is_aarch64_feature_detected as is_feature_detected; + +#[cfg(target_arch = "riscv64")] +#[allow(unused_imports)] +pub use std::arch::is_riscv_feature_detected as is_feature_detected; + +#[cfg(target_arch = "x86_64")] +#[allow(unused_imports)] +pub use internal::is_x86_64_cpu_detected as is_cpu_detected; + +#[cfg(target_arch = "aarch64")] +#[allow(unused_imports)] +pub use internal::is_aarch64_cpu_detected as is_cpu_detected; + +#[cfg(target_arch = "riscv64")] +#[allow(unused_imports)] +pub use internal::is_riscv64_cpu_detected as is_cpu_detected; diff --git a/crates/simd/src/packed_u4.rs b/crates/simd/src/packed_u4.rs new file mode 100644 index 0000000..d68333c --- /dev/null +++ b/crates/simd/src/packed_u4.rs @@ -0,0 +1,18 @@ +pub fn reduce_sum_of_xy(s: &[u8], t: &[u8]) -> u32 { + reduce_sum_of_xy::reduce_sum_of_xy(s, t) +} + +mod reduce_sum_of_xy { + #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn reduce_sum_of_xy(s: &[u8], t: &[u8]) -> u32 { + assert_eq!(s.len(), t.len()); + let n = s.len(); + let mut result = 0; + for i in 0..n { + let (s, t) = (s[i], t[i]); + result += ((s & 15) as u32) * ((t & 15) as u32); + result += ((s >> 4) as u32) * ((t >> 4) as u32); + } + result + } +} diff --git a/crates/simd/src/quantize.rs b/crates/simd/src/quantize.rs new file mode 100644 index 0000000..baaf594 --- /dev/null +++ b/crates/simd/src/quantize.rs @@ -0,0 +1,291 @@ +mod mul_add_round { + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v4")] + fn mul_add_round_v4(this: &[f32], k: f32, b: f32) -> Vec { + let n = this.len(); + let mut r = Vec::::with_capacity(n); + unsafe { + use std::arch::x86_64::*; + let lk = _mm512_set1_ps(k); + let lb = _mm512_set1_ps(b); + let mut n = n; + let mut a = this.as_ptr(); + let mut r = r.as_mut_ptr(); + while n >= 16 { + let x = _mm512_loadu_ps(a); + let v = + _mm512_fmadd_round_ps(x, lk, lb, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + let v = _mm512_cvtps_epi32(v); + let vfl = _mm512_cvtepi32_epi8(v); + _mm_storeu_si128(r.cast(), vfl); + n -= 16; + a = a.add(16); + r = r.add(16); + } + if n > 0 { + let mask = _bzhi_u32(0xffff, n as u32) as u16; + let x = _mm512_maskz_loadu_ps(mask, a); + let v = + _mm512_fmadd_round_ps(x, lk, lb, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + let v = _mm512_cvtps_epi32(v); + let vfl = _mm512_cvtepi32_epi8(v); + _mm_mask_storeu_epi8(r.cast(), mask, vfl); + } + } + unsafe { + r.set_len(n); + } + r + } + + #[cfg(all(target_arch = "x86_64", test, not(miri)))] + #[test] + fn mul_add_round_v4_test() { + if !crate::is_cpu_detected!("v4") { + println!("test {} ... skipped (v4)", module_path!()); + return; + } + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4010; + let x = (0..n).map(|_| rand::random::<_>()).collect::>(); + for z in 3990..4010 { + let x = &x[..z]; + let k = 20.0; + let b = 20.0; + let specialized = unsafe { mul_add_round_v4(x, k, b) }; + let fallback = mul_add_round_fallback(x, k, b); + assert_eq!(specialized, fallback); + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v3")] + fn mul_add_round_v3(this: &[f32], k: f32, b: f32) -> Vec { + let n = this.len(); + let mut r = Vec::::with_capacity(n); + unsafe { + use std::arch::x86_64::*; + let cons = _mm256_setr_epi8( + 0, 4, 8, 12, -1, -1, -1, -1, // 0..8 + -1, -1, -1, -1, -1, -1, -1, -1, // 8..15 + 0, 4, 8, 12, -1, -1, -1, -1, // 16..24 + -1, -1, -1, -1, -1, -1, -1, -1, // 24..32 + ); + let lk = _mm256_set1_ps(k); + let lb = _mm256_set1_ps(b); + let mut n = n; + let mut a = this.as_ptr(); + let mut r = r.as_mut_ptr(); + while n >= 8 { + let x = _mm256_loadu_ps(a); + let v = _mm256_fmadd_ps(x, lk, lb); + let v = _mm256_cvtps_epi32(_mm256_round_ps(v, 0x00)); + let vs = _mm256_shuffle_epi8(v, cons); + let vlo = _mm256_extract_epi32::<0>(vs) as u32; + let vhi = _mm256_extract_epi32::<4>(vs) as u32; + let vfl = vlo as u64 | ((vhi as u64) << 32); + r.cast::().write_unaligned(vfl); + n -= 8; + a = a.add(8); + r = r.add(8); + } + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + let v = x.mul_add(k, b).round_ties_even() as u8; + r.write(v); + n -= 1; + a = a.add(1); + r = r.add(1); + } + } + unsafe { + r.set_len(n); + } + r + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + fn mul_add_round_v3_test() { + if !crate::is_cpu_detected!("v3") { + println!("test {} ... skipped (v3)", module_path!()); + return; + } + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4010; + let x = (0..n).map(|_| rand::random::<_>()).collect::>(); + for z in 3990..4010 { + let x = &x[..z]; + let k = 20.0; + let b = 20.0; + let specialized = unsafe { mul_add_round_v3(x, k, b) }; + let fallback = mul_add_round_fallback(x, k, b); + assert_eq!(specialized, fallback); + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v2")] + #[target_feature(enable = "fma")] + fn mul_add_round_v2_fma(this: &[f32], k: f32, b: f32) -> Vec { + let n = this.len(); + let mut r = Vec::::with_capacity(n); + unsafe { + use std::arch::x86_64::*; + let cons = _mm_setr_epi8( + 0, 4, 8, 12, -1, -1, -1, -1, // 0..8 + -1, -1, -1, -1, -1, -1, -1, -1, // 8..15 + ); + let lk = _mm_set1_ps(k); + let lb = _mm_set1_ps(b); + let mut n = n; + let mut a = this.as_ptr(); + let mut r = r.as_mut_ptr(); + while n >= 4 { + let x = _mm_loadu_ps(a); + let v = _mm_fmadd_ps(x, lk, lb); + let v = _mm_cvtps_epi32(_mm_round_ps(v, 0x00)); + let vs = _mm_shuffle_epi8(v, cons); + let vfl = _mm_extract_epi32::<0>(vs) as u32; + r.cast::().write_unaligned(vfl); + n -= 4; + a = a.add(4); + r = r.add(4); + } + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + let v = x.mul_add(k, b).round_ties_even() as u8; + r.write(v); + n -= 1; + a = a.add(1); + r = r.add(1); + } + } + unsafe { + r.set_len(n); + } + r + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + fn mul_add_round_v2_fma_test() { + if !crate::is_cpu_detected!("v2") || !crate::is_feature_detected!("fma") { + println!("test {} ... skipped (v2:fma)", module_path!()); + return; + } + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4010; + let x = (0..n).map(|_| rand::random::<_>()).collect::>(); + for z in 3990..4010 { + let x = &x[..z]; + let k = 20.0; + let b = 20.0; + let specialized = unsafe { mul_add_round_v2_fma(x, k, b) }; + let fallback = mul_add_round_fallback(x, k, b); + assert_eq!(specialized, fallback); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::target_cpu(enable = "v8.3a")] + fn mul_add_round_v8_3a(this: &[f32], k: f32, b: f32) -> Vec { + let n = this.len(); + let mut r = Vec::::with_capacity(n); + unsafe { + use std::arch::aarch64::*; + let cons = vld1q_u8( + [ + 0, 4, 8, 12, 0xff, 0xff, 0xff, 0xff, // 0..8 + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, // 8..15 + ] + .as_ptr(), + ); + let lk = vdupq_n_f32(k); + let lb = vdupq_n_f32(b); + let mut n = n; + let mut a = this.as_ptr(); + let mut r = r.as_mut_ptr(); + while n >= 4 { + let x = vld1q_f32(a); + let v = vfmaq_f32(lb, x, lk); + let v = vcvtnq_u32_f32(v); + let vs = vqtbl1q_u8(vreinterpretq_u8_u32(v), cons); + let vfl = vgetq_lane_u32::<0>(vreinterpretq_u32_u8(vs)); + r.cast::().write_unaligned(vfl); + n -= 4; + a = a.add(4); + r = r.add(4); + } + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + let v = x.mul_add(k, b).round_ties_even() as u8; + r.write(v); + n -= 1; + a = a.add(1); + r = r.add(1); + } + } + unsafe { + r.set_len(n); + } + r + } + + #[cfg(all(target_arch = "aarch64", test, not(miri)))] + #[test] + fn mul_add_round_v8_3a_test() { + if !crate::is_cpu_detected!("v8.3a") { + println!("test {} ... skipped (v8.3a)", module_path!()); + return; + } + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4010; + let x = (0..n).map(|_| rand::random::<_>()).collect::>(); + for z in 3990..4010 { + let x = &x[..z]; + let k = 20.0; + let b = 20.0; + let specialized = unsafe { mul_add_round_v8_3a(x, k, b) }; + let fallback = mul_add_round_fallback(x, k, b); + assert_eq!(specialized, fallback); + } + } + } + + #[crate::multiversion(@"v4", @"v3", @"v2:fma", @"v8.3a")] + pub fn mul_add_round(this: &[f32], k: f32, b: f32) -> Vec { + let n = this.len(); + let mut r = Vec::::with_capacity(n); + for i in 0..n { + let x = this[i]; + let v = x.mul_add(k, b).round_ties_even() as u8; + unsafe { + r.as_mut_ptr().add(i).write(v); + } + } + unsafe { + r.set_len(n); + } + r + } +} + +#[inline(always)] +pub fn quantize(lut: &[f32], n: f32) -> (f32, f32, Vec) { + use crate::Floating; + let (min, max) = f32::reduce_min_max_of_x(lut); + let k = 0.0f32.max((max - min) / n); + let b = min; + (k, b, mul_add_round::mul_add_round(lut, 1.0 / k, -b / k)) +} diff --git a/crates/simd/src/u8.rs b/crates/simd/src/u8.rs new file mode 100644 index 0000000..bc84533 --- /dev/null +++ b/crates/simd/src/u8.rs @@ -0,0 +1,343 @@ +mod reduce_sum_of_xy { + #[crate::multiversion("v4", "v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn reduce_sum_of_xy(s: &[u8], t: &[u8]) -> u32 { + assert_eq!(s.len(), t.len()); + let n = s.len(); + let mut result = 0; + for i in 0..n { + result += (s[i] as u32) * (t[i] as u32); + } + result + } +} + +#[inline(always)] +pub fn reduce_sum_of_xy(s: &[u8], t: &[u8]) -> u32 { + reduce_sum_of_xy::reduce_sum_of_xy(s, t) +} + +mod reduce_sum_of_x_as_u16 { + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v4")] + fn reduce_sum_of_x_as_u16_v4(this: &[u8]) -> u16 { + use crate::emulate::emulate_mm512_reduce_add_epi16; + unsafe { + use std::arch::x86_64::*; + let us = _mm512_set1_epi16(255); + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut sum = _mm512_setzero_si512(); + while n >= 32 { + let x = _mm256_loadu_si256(a.cast()); + a = a.add(32); + n -= 32; + sum = _mm512_add_epi16(_mm512_and_si512(us, _mm512_cvtepi8_epi16(x)), sum); + } + if n > 0 { + let mask = _bzhi_u32(0xffffffff, n as u32); + let x = _mm256_maskz_loadu_epi8(mask, a.cast()); + sum = _mm512_add_epi16(_mm512_and_si512(us, _mm512_cvtepi8_epi16(x)), sum); + } + emulate_mm512_reduce_add_epi16(sum) as u16 + } + } + + #[cfg(all(target_arch = "x86_64", test, not(miri)))] + #[test] + fn reduce_sum_of_x_as_u16_v4_test() { + use rand::Rng; + if !crate::is_cpu_detected!("v4") { + println!("test {} ... skipped (v4)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n).map(|_| rng.gen_range(0..16)).collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x_as_u16_v4(this) }; + let fallback = reduce_sum_of_x_as_u16_fallback(this); + assert_eq!(specialized, fallback); + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v3")] + fn reduce_sum_of_x_as_u16_v3(this: &[u8]) -> u16 { + use crate::emulate::emulate_mm256_reduce_add_epi16; + unsafe { + use std::arch::x86_64::*; + let us = _mm256_set1_epi16(255); + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut sum = _mm256_setzero_si256(); + while n >= 16 { + let x = _mm_loadu_si128(a.cast()); + a = a.add(16); + n -= 16; + sum = _mm256_add_epi16(_mm256_and_si256(us, _mm256_cvtepi8_epi16(x)), sum); + } + let mut sum = emulate_mm256_reduce_add_epi16(sum) as u16; + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + a = a.add(1); + n -= 1; + sum += x as u16; + } + sum + } + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + fn reduce_sum_of_x_as_u16_v3_test() { + use rand::Rng; + if !crate::is_cpu_detected!("v3") { + println!("test {} ... skipped (v3)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n).map(|_| rng.gen_range(0..16)).collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x_as_u16_v3(this) }; + let fallback = reduce_sum_of_x_as_u16_fallback(this); + assert_eq!(specialized, fallback); + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v2")] + fn reduce_sum_of_x_as_u16_v2(this: &[u8]) -> u16 { + use crate::emulate::emulate_mm_reduce_add_epi16; + unsafe { + use std::arch::x86_64::*; + let us = _mm_set1_epi16(255); + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut sum = _mm_setzero_si128(); + while n >= 8 { + let x = _mm_loadu_si64(a.cast()); + a = a.add(8); + n -= 8; + sum = _mm_add_epi16(_mm_and_si128(us, _mm_cvtepi8_epi16(x)), sum); + } + let mut sum = emulate_mm_reduce_add_epi16(sum) as u16; + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + a = a.add(1); + n -= 1; + sum += x as u16; + } + sum + } + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + fn reduce_sum_of_x_as_u16_v2_test() { + use rand::Rng; + if !crate::is_cpu_detected!("v2") { + println!("test {} ... skipped (v2)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n).map(|_| rng.gen_range(0..16)).collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x_as_u16_v2(this) }; + let fallback = reduce_sum_of_x_as_u16_fallback(this); + assert_eq!(specialized, fallback); + } + } + } + + #[inline] + #[cfg(target_arch = "aarch64")] + #[crate::target_cpu(enable = "v8.3a")] + fn reduce_sum_of_x_as_u16_v8_3a(this: &[u8]) -> u16 { + unsafe { + use std::arch::aarch64::*; + let us = vdupq_n_u16(255); + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut sum = vdupq_n_u16(0); + while n >= 8 { + let x = vld1_u8(a); + a = a.add(8); + n -= 8; + sum = vaddq_u16(vandq_u16(us, vmovl_u8(x)), sum); + } + let mut sum = vaddvq_u16(sum); + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + a = a.add(1); + n -= 1; + sum += x as u16; + } + sum + } + } + + #[cfg(all(target_arch = "aarch64", test, not(miri)))] + #[test] + fn reduce_sum_of_x_as_u16_v8_3a_test() { + use rand::Rng; + if !crate::is_cpu_detected!("v8.3a") { + println!("test {} ... skipped (v8.3a)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n).map(|_| rng.gen_range(0..16)).collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x_as_u16_v8_3a(this) }; + let fallback = reduce_sum_of_x_as_u16_fallback(this); + assert_eq!(specialized, fallback); + } + } + } + + #[crate::multiversion(@"v4", @"v3", @"v2", @"v8.3a")] + pub fn reduce_sum_of_x_as_u16(this: &[u8]) -> u16 { + let n = this.len(); + let mut sum = 0; + for i in 0..n { + sum += this[i] as u16; + } + sum + } +} + +#[inline(always)] +pub fn reduce_sum_of_x_as_u16(vector: &[u8]) -> u16 { + reduce_sum_of_x_as_u16::reduce_sum_of_x_as_u16(vector) +} + +mod reduce_sum_of_x { + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v4")] + fn reduce_sum_of_x_v4(this: &[u8]) -> u32 { + unsafe { + use std::arch::x86_64::*; + let us = _mm512_set1_epi32(255); + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut sum = _mm512_setzero_si512(); + while n >= 16 { + let x = _mm_loadu_epi8(a.cast()); + a = a.add(16); + n -= 16; + sum = _mm512_add_epi32(_mm512_and_si512(us, _mm512_cvtepi8_epi32(x)), sum); + } + if n > 0 { + let mask = _bzhi_u32(0xffff, n as u32) as u16; + let x = _mm_maskz_loadu_epi8(mask, a.cast()); + sum = _mm512_add_epi32(_mm512_and_si512(us, _mm512_cvtepi8_epi32(x)), sum); + } + _mm512_reduce_add_epi32(sum) as u32 + } + } + + #[cfg(all(target_arch = "x86_64", test, not(miri)))] + #[test] + fn reduce_sum_of_x_v4_test() { + use rand::Rng; + if !crate::is_cpu_detected!("v4") { + println!("test {} ... skipped (v4)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n).map(|_| rng.gen_range(0..16)).collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x_v4(this) }; + let fallback = reduce_sum_of_x_fallback(this); + assert_eq!(specialized, fallback); + } + } + } + + #[inline] + #[cfg(target_arch = "x86_64")] + #[crate::target_cpu(enable = "v3")] + fn reduce_sum_of_x_v3(this: &[u8]) -> u32 { + use crate::emulate::emulate_mm256_reduce_add_epi32; + unsafe { + use std::arch::x86_64::*; + let us = _mm256_set1_epi32(255); + let mut n = this.len(); + let mut a = this.as_ptr(); + let mut sum = _mm256_setzero_si256(); + while n >= 8 { + let x = _mm_loadl_epi64(a.cast()); + a = a.add(8); + n -= 8; + sum = _mm256_add_epi32(_mm256_and_si256(us, _mm256_cvtepi8_epi32(x)), sum); + } + let mut sum = emulate_mm256_reduce_add_epi32(sum) as u32; + // this hint is used to disable loop unrolling + while std::hint::black_box(n) > 0 { + let x = a.read(); + a = a.add(1); + n -= 1; + sum += x as u32; + } + sum + } + } + + #[cfg(all(target_arch = "x86_64", test))] + #[test] + fn reduce_sum_of_x_as_u16_v3_test() { + use rand::Rng; + if !crate::is_cpu_detected!("v3") { + println!("test {} ... skipped (v3)", module_path!()); + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { + let n = 4016; + let this = (0..n).map(|_| rng.gen_range(0..16)).collect::>(); + for z in 3984..4016 { + let this = &this[..z]; + let specialized = unsafe { reduce_sum_of_x_v3(this) }; + let fallback = reduce_sum_of_x_fallback(this); + assert_eq!(specialized, fallback); + } + } + } + + #[crate::multiversion(@"v4", @"v3", "v2", "v8.3a:sve", "v8.3a")] + pub fn reduce_sum_of_x(this: &[u8]) -> u32 { + let n = this.len(); + let mut sum = 0; + for i in 0..n { + sum += this[i] as u32; + } + sum + } +} + +#[inline(always)] +pub fn reduce_sum_of_x(vector: &[u8]) -> u32 { + reduce_sum_of_x::reduce_sum_of_x(vector) +} diff --git a/crates/simd_macros/Cargo.toml b/crates/simd_macros/Cargo.toml new file mode 100644 index 0000000..799db57 --- /dev/null +++ b/crates/simd_macros/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "simd_macros" +version.workspace = true +edition.workspace = true + +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = { version = "1.0.79", features = ["proc-macro"] } +quote = "1.0.35" +syn = { version = "2.0.53", default-features = false, features = [ + "clone-impls", + "full", + "parsing", + "printing", + "proc-macro", +] } diff --git a/crates/simd_macros/src/lib.rs b/crates/simd_macros/src/lib.rs new file mode 100644 index 0000000..fe45564 --- /dev/null +++ b/crates/simd_macros/src/lib.rs @@ -0,0 +1,206 @@ +mod target; + +struct MultiversionVersion { + target: String, + import: bool, +} + +impl syn::parse::Parse for MultiversionVersion { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let lookahead1 = input.lookahead1(); + if lookahead1.peek(syn::Token![@]) { + let _: syn::Token![@] = input.parse()?; + let target: syn::LitStr = input.parse()?; + Ok(Self { + target: target.value(), + import: true, + }) + } else { + let target: syn::LitStr = input.parse()?; + Ok(Self { + target: target.value(), + import: false, + }) + } + } +} + +struct Multiversion { + versions: syn::punctuated::Punctuated, +} + +impl syn::parse::Parse for Multiversion { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + Ok(Multiversion { + versions: syn::punctuated::Punctuated::parse_terminated(input)?, + }) + } +} + +#[proc_macro_attribute] +pub fn multiversion( + attr: proc_macro::TokenStream, + item: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + let attr = syn::parse_macro_input!(attr as Multiversion); + let item_fn = syn::parse::(item).expect("not a function item"); + let syn::ItemFn { + attrs, + vis, + sig, + block, + } = item_fn; + let name = sig.ident.to_string(); + if sig.constness.is_some() { + panic!("const functions are not supported"); + } + if sig.asyncness.is_some() { + panic!("async functions are not supported"); + } + let generics_params = sig.generics.params.clone(); + for generic_param in generics_params.iter() { + if !matches!(generic_param, syn::GenericParam::Lifetime(_)) { + panic!("generic parameters are not supported"); + } + } + let generics_where = sig.generics.where_clause.clone(); + let inputs = sig.inputs.clone(); + let arguments = { + let mut list = vec![]; + for x in sig.inputs.iter() { + if let syn::FnArg::Typed(y) = x { + if let syn::Pat::Ident(ident) = *y.pat.clone() { + list.push(ident); + } else { + panic!("patterns on parameters are not supported") + } + } else { + panic!("receiver parameters are not supported") + } + } + list + }; + if sig.variadic.is_some() { + panic!("variadic parameters are not supported"); + } + let output = sig.output.clone(); + let mut versions = quote::quote! {}; + let mut branches = quote::quote! {}; + for version in attr.versions { + let target = version.target.clone(); + let name = syn::Ident::new( + &format!("{name}_{}", target.replace(":", "_").replace(".", "_")), + proc_macro2::Span::mixed_site(), + ); + let s = target.split(":").collect::>(); + let target_cpu = target::TARGET_CPUS + .iter() + .find(|target_cpu| target_cpu.target_cpu == s[0]) + .expect("unknown target_cpu"); + let additional_target_features = s[1..].to_vec(); + let target_arch = target_cpu.target_arch; + let target_cpu = target_cpu.target_cpu; + if !version.import { + versions.extend(quote::quote! { + #[inline] + #[cfg(any(target_arch = #target_arch))] + #[crate::target_cpu(enable = #target_cpu)] + #(#[target_feature(enable = #additional_target_features)])* + fn #name < #generics_params > (#inputs) #output #generics_where { #block } + }); + } + branches.extend(quote::quote! { + #[cfg(target_arch = #target_arch)] + if crate::is_cpu_detected!(#target_cpu) #(&& crate::is_feature_detected!(#additional_target_features))* { + let _multiversion_internal: unsafe fn(#inputs) #output = #name; + CACHE.store(_multiversion_internal as *mut (), core::sync::atomic::Ordering::Relaxed); + return unsafe { _multiversion_internal(#(#arguments,)*) }; + } + }); + } + let fallback_name = + syn::Ident::new(&format!("{name}_fallback"), proc_macro2::Span::mixed_site()); + quote::quote! { + #versions + fn #fallback_name < #generics_params > (#inputs) #output #generics_where { #block } + #[inline(always)] + #(#attrs)* #vis #sig { + static CACHE: core::sync::atomic::AtomicPtr<()> = core::sync::atomic::AtomicPtr::new(core::ptr::null_mut()); + let cache = CACHE.load(core::sync::atomic::Ordering::Relaxed); + if !cache.is_null() { + let f = unsafe { core::mem::transmute::<*mut (), unsafe fn(#inputs) #output>(cache as _) }; + return unsafe { f(#(#arguments,)*) }; + } + #branches + let _multiversion_internal: unsafe fn(#inputs) #output = #fallback_name; + CACHE.store(_multiversion_internal as *mut (), core::sync::atomic::Ordering::Relaxed); + unsafe { _multiversion_internal(#(#arguments,)*) } + } + } + .into() +} + +struct TargetCpu { + enable: String, +} + +impl syn::parse::Parse for TargetCpu { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let _: syn::Ident = input.parse()?; + let _: syn::Token![=] = input.parse()?; + let enable: syn::LitStr = input.parse()?; + Ok(Self { + enable: enable.value(), + }) + } +} + +#[proc_macro_attribute] +pub fn target_cpu( + attr: proc_macro::TokenStream, + item: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + let attr = syn::parse_macro_input!(attr as TargetCpu); + let mut result = quote::quote! {}; + for s in attr.enable.split(',') { + let target_cpu = target::TARGET_CPUS + .iter() + .find(|target_cpu| target_cpu.target_cpu == s) + .expect("unknown target_cpu"); + let target_features = target_cpu.target_features; + result.extend(quote::quote!( + #(#[target_feature(enable = #target_features)])* + )); + } + result.extend(proc_macro2::TokenStream::from(item)); + result.into() +} + +#[proc_macro] +pub fn define_is_cpu_detected(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let target_arch = syn::parse_macro_input!(input as syn::LitStr).value(); + let mut arms = quote::quote! {}; + for target_cpu in target::TARGET_CPUS { + if target_cpu.target_arch != target_arch { + continue; + } + let target_features = target_cpu.target_features; + let target_cpu = target_cpu.target_cpu; + arms.extend(quote::quote! { + (#target_cpu) => { + true #(&& $crate::is_feature_detected!(#target_features))* + }; + }); + } + let name = syn::Ident::new( + &format!("is_{target_arch}_cpu_detected"), + proc_macro2::Span::mixed_site(), + ); + quote::quote! { + #[macro_export] + macro_rules! #name { + #arms + } + } + .into() +} diff --git a/crates/simd_macros/src/target.rs b/crates/simd_macros/src/target.rs new file mode 100644 index 0000000..c258401 --- /dev/null +++ b/crates/simd_macros/src/target.rs @@ -0,0 +1,83 @@ +pub struct TargetCpu { + pub target_cpu: &'static str, + pub target_arch: &'static str, + pub target_features: &'static [&'static str], +} + +pub const TARGET_CPUS: &[TargetCpu] = &[ + TargetCpu { + target_cpu: "v4", + target_arch: "x86_64", + target_features: &[ + "avx", + "avx2", + "avx512bw", + "avx512cd", + "avx512dq", + "avx512f", + "avx512vl", + "bmi1", + "bmi2", + "cmpxchg16b", + "f16c", + "fma", + "fxsr", + "lzcnt", + "movbe", + "popcnt", + "sse", + "sse2", + "sse3", + "sse4.1", + "sse4.2", + "ssse3", + "xsave", + ], + }, + TargetCpu { + target_cpu: "v3", + target_arch: "x86_64", + target_features: &[ + "avx", + "avx2", + "bmi1", + "bmi2", + "cmpxchg16b", + "f16c", + "fma", + "fxsr", + "lzcnt", + "movbe", + "popcnt", + "sse", + "sse2", + "sse3", + "sse4.1", + "sse4.2", + "ssse3", + "xsave", + ], + }, + TargetCpu { + target_cpu: "v2", + target_arch: "x86_64", + target_features: &[ + "cmpxchg16b", + "fxsr", + "popcnt", + "sse", + "sse2", + "sse3", + "sse4.1", + "sse4.2", + "ssse3", + ], + }, + TargetCpu { + target_cpu: "v8.3a", + target_arch: "aarch64", + target_features: &[ + "crc", "dpb", "fcma", "jsconv", "lse", "neon", "paca", "pacg", "rcpc", "rdm", + ], + }, +]; diff --git a/crates/vector/Cargo.toml b/crates/vector/Cargo.toml new file mode 100644 index 0000000..b910d36 --- /dev/null +++ b/crates/vector/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "vector" +version.workspace = true +edition.workspace = true + +[dependencies] +distance = { path = "../distance" } +half.workspace = true +serde.workspace = true +simd = { path = "../simd" } + +[lints] +workspace = true diff --git a/crates/vector/src/bvect.rs b/crates/vector/src/bvect.rs new file mode 100644 index 0000000..fe80164 --- /dev/null +++ b/crates/vector/src/bvect.rs @@ -0,0 +1,274 @@ +use crate::{VectorBorrowed, VectorOwned}; +use distance::Distance; +use serde::{Deserialize, Serialize}; +use std::ops::{Bound, RangeBounds}; + +pub const BVECTOR_WIDTH: u32 = u64::BITS; + +// When using binary vector, please ensure that the padding bits are always zero. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BVectOwned { + dims: u32, + data: Vec, +} + +impl BVectOwned { + #[inline(always)] + pub fn new(dims: u32, data: Vec) -> Self { + Self::new_checked(dims, data).expect("invalid data") + } + + #[inline(always)] + pub fn new_checked(dims: u32, data: Vec) -> Option { + if !(1..=65535).contains(&dims) { + return None; + } + if data.len() != dims.div_ceil(BVECTOR_WIDTH) as usize { + return None; + } + if dims % BVECTOR_WIDTH != 0 && data[data.len() - 1] >> (dims % BVECTOR_WIDTH) != 0 { + return None; + } + unsafe { Some(Self::new_unchecked(dims, data)) } + } + + /// # Safety + /// + /// * `dims` must be in `1..=65535`. + /// * `data` must be of the correct length. + /// * The padding bits must be zero. + #[inline(always)] + pub unsafe fn new_unchecked(dims: u32, data: Vec) -> Self { + Self { dims, data } + } +} + +impl VectorOwned for BVectOwned { + type Borrowed<'a> = BVectBorrowed<'a>; + + #[inline(always)] + fn as_borrowed(&self) -> BVectBorrowed<'_> { + BVectBorrowed { + dims: self.dims, + data: &self.data, + } + } + + #[inline(always)] + fn zero(dims: u32) -> Self { + Self::new(dims, vec![0; dims.div_ceil(BVECTOR_WIDTH) as usize]) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct BVectBorrowed<'a> { + dims: u32, + data: &'a [u64], +} + +impl<'a> BVectBorrowed<'a> { + #[inline(always)] + pub fn new(dims: u32, data: &'a [u64]) -> Self { + Self::new_checked(dims, data).expect("invalid data") + } + + #[inline(always)] + pub fn new_checked(dims: u32, data: &'a [u64]) -> Option { + if !(1..=65535).contains(&dims) { + return None; + } + if data.len() != dims.div_ceil(BVECTOR_WIDTH) as usize { + return None; + } + if dims % BVECTOR_WIDTH != 0 && data[data.len() - 1] >> (dims % BVECTOR_WIDTH) != 0 { + return None; + } + unsafe { Some(Self::new_unchecked(dims, data)) } + } + + /// # Safety + /// + /// * `dims` must be in `1..=65535`. + /// * `data` must be of the correct length. + /// * The padding bits must be zero. + #[inline(always)] + pub unsafe fn new_unchecked(dims: u32, data: &'a [u64]) -> Self { + Self { dims, data } + } + + #[inline(always)] + pub fn data(&self) -> &'a [u64] { + self.data + } + + #[inline(always)] + pub fn get(&self, index: u32) -> bool { + assert!(index < self.dims); + self.data[(index / BVECTOR_WIDTH) as usize] & (1 << (index % BVECTOR_WIDTH)) != 0 + } + + #[inline(always)] + pub fn iter(self) -> impl Iterator + 'a { + let mut index = 0_u32; + std::iter::from_fn(move || { + if index < self.dims { + let result = self.data[(index / BVECTOR_WIDTH) as usize] + & (1 << (index % BVECTOR_WIDTH)) + != 0; + index += 1; + Some(result) + } else { + None + } + }) + } +} + +impl VectorBorrowed for BVectBorrowed<'_> { + type Owned = BVectOwned; + + #[inline(always)] + fn dims(&self) -> u32 { + self.dims + } + + fn own(&self) -> BVectOwned { + BVectOwned { + dims: self.dims, + data: self.data.to_vec(), + } + } + + #[inline(always)] + fn norm(&self) -> f32 { + (simd::bit::sum_of_x(self.data) as f32).sqrt() + } + + #[inline(always)] + fn operator_dot(self, rhs: Self) -> Distance { + Distance::from(-(simd::bit::sum_of_and(self.data, rhs.data) as f32)) + } + + #[inline(always)] + fn operator_l2(self, _: Self) -> Distance { + unimplemented!() + } + + #[inline(always)] + fn operator_cos(self, _: Self) -> Distance { + unimplemented!() + } + + #[inline(always)] + fn operator_hamming(self, rhs: Self) -> Distance { + Distance::from(simd::bit::sum_of_xor(self.data, rhs.data) as f32) + } + + #[inline(always)] + fn operator_jaccard(self, rhs: Self) -> Distance { + let (and, or) = simd::bit::sum_of_and_or(self.data, rhs.data); + Distance::from(1.0 - (and as f32 / or as f32)) + } + + #[inline(always)] + fn function_normalize(&self) -> BVectOwned { + unimplemented!() + } + + fn operator_add(&self, _: Self) -> Self::Owned { + unimplemented!() + } + + fn operator_sub(&self, _: Self) -> Self::Owned { + unimplemented!() + } + + fn operator_mul(&self, _: Self) -> Self::Owned { + unimplemented!() + } + + fn operator_and(&self, rhs: Self) -> Self::Owned { + assert_eq!(self.dims, rhs.dims); + let data = simd::bit::vector_and(self.data, self.data); + BVectOwned::new(self.dims, data) + } + + fn operator_or(&self, rhs: Self) -> Self::Owned { + assert_eq!(self.dims, rhs.dims); + let data = simd::bit::vector_or(self.data, rhs.data); + BVectOwned::new(self.dims, data) + } + + fn operator_xor(&self, rhs: Self) -> Self::Owned { + assert_eq!(self.dims, rhs.dims); + let data = simd::bit::vector_xor(self.data, rhs.data); + BVectOwned::new(self.dims, data) + } + + #[inline(always)] + fn subvector(&self, bounds: impl RangeBounds) -> Option { + let start = match bounds.start_bound().cloned() { + Bound::Included(x) => x, + Bound::Excluded(u32::MAX) => return None, + Bound::Excluded(x) => x + 1, + Bound::Unbounded => 0, + }; + let end = match bounds.end_bound().cloned() { + Bound::Included(u32::MAX) => return None, + Bound::Included(x) => x + 1, + Bound::Excluded(x) => x, + Bound::Unbounded => self.dims, + }; + if start >= end || end > self.dims { + return None; + } + let dims = end - start; + let mut data = vec![0_u64; dims.div_ceil(BVECTOR_WIDTH) as _]; + { + let mut i = 0; + let mut j = start; + while j < end { + if self.data[(j / BVECTOR_WIDTH) as usize] & (1 << (j % BVECTOR_WIDTH)) != 0 { + data[(i / BVECTOR_WIDTH) as usize] |= 1 << (i % BVECTOR_WIDTH); + } + i += 1; + j += 1; + } + } + Self::Owned::new_checked(dims, data) + } +} + +impl PartialEq for BVectBorrowed<'_> { + fn eq(&self, other: &Self) -> bool { + if self.dims != other.dims { + return false; + } + for (&l, &r) in self.data.iter().zip(other.data.iter()) { + let l = l.reverse_bits(); + let r = r.reverse_bits(); + if l != r { + return false; + } + } + true + } +} + +impl PartialOrd for BVectBorrowed<'_> { + fn partial_cmp(&self, other: &Self) -> Option { + use std::cmp::Ordering; + if self.dims != other.dims { + return None; + } + for (&l, &r) in self.data.iter().zip(other.data.iter()) { + let l = l.reverse_bits(); + let r = r.reverse_bits(); + match l.cmp(&r) { + Ordering::Equal => (), + x => return Some(x), + } + } + Some(Ordering::Equal) + } +} diff --git a/crates/vector/src/lib.rs b/crates/vector/src/lib.rs new file mode 100644 index 0000000..e82e4a6 --- /dev/null +++ b/crates/vector/src/lib.rs @@ -0,0 +1,48 @@ +pub mod bvect; +pub mod scalar8; +pub mod svect; +pub mod vect; + +pub trait VectorOwned: Clone + serde::Serialize + for<'a> serde::Deserialize<'a> + 'static { + type Borrowed<'a>: VectorBorrowed; + + fn as_borrowed(&self) -> Self::Borrowed<'_>; + + fn zero(dims: u32) -> Self; +} + +pub trait VectorBorrowed: Copy { + type Owned: VectorOwned; + + fn own(&self) -> Self::Owned; + + fn dims(&self) -> u32; + + fn norm(&self) -> f32; + + fn operator_dot(self, rhs: Self) -> distance::Distance; + + fn operator_l2(self, rhs: Self) -> distance::Distance; + + fn operator_cos(self, rhs: Self) -> distance::Distance; + + fn operator_hamming(self, rhs: Self) -> distance::Distance; + + fn operator_jaccard(self, rhs: Self) -> distance::Distance; + + fn function_normalize(&self) -> Self::Owned; + + fn operator_add(&self, rhs: Self) -> Self::Owned; + + fn operator_sub(&self, rhs: Self) -> Self::Owned; + + fn operator_mul(&self, rhs: Self) -> Self::Owned; + + fn operator_and(&self, rhs: Self) -> Self::Owned; + + fn operator_or(&self, rhs: Self) -> Self::Owned; + + fn operator_xor(&self, rhs: Self) -> Self::Owned; + + fn subvector(&self, bounds: impl std::ops::RangeBounds) -> Option; +} diff --git a/src/types/scalar8.rs b/crates/vector/src/scalar8.rs similarity index 93% rename from src/types/scalar8.rs rename to crates/vector/src/scalar8.rs index 55c0074..ff9095a 100644 --- a/src/types/scalar8.rs +++ b/crates/vector/src/scalar8.rs @@ -1,5 +1,5 @@ -use base::distance::Distance; -use base::vector::{VectorBorrowed, VectorOwned}; +use crate::{VectorBorrowed, VectorOwned}; +use distance::Distance; use serde::{Deserialize, Serialize}; use std::ops::RangeBounds; @@ -13,7 +13,6 @@ pub struct Scalar8Owned { } impl Scalar8Owned { - #[allow(dead_code)] #[inline(always)] pub fn new(sum_of_x2: f32, k: f32, b: f32, sum_of_code: f32, code: Vec) -> Self { Self::new_checked(sum_of_x2, k, b, sum_of_code, code).expect("invalid data") @@ -182,7 +181,7 @@ impl VectorBorrowed for Scalar8Borrowed<'_> { #[inline(always)] fn operator_dot(self, rhs: Self) -> Distance { assert_eq!(self.code.len(), rhs.code.len()); - let xy = self.k * rhs.k * base::simd::u8::reduce_sum_of_xy(self.code, rhs.code) as f32 + let xy = self.k * rhs.k * simd::u8::reduce_sum_of_xy(self.code, rhs.code) as f32 + self.b * rhs.b * self.code.len() as f32 + self.k * rhs.b * self.sum_of_code + self.b * rhs.k * rhs.sum_of_code; @@ -192,7 +191,7 @@ impl VectorBorrowed for Scalar8Borrowed<'_> { #[inline(always)] fn operator_l2(self, rhs: Self) -> Distance { assert_eq!(self.code.len(), rhs.code.len()); - let xy = self.k * rhs.k * base::simd::u8::reduce_sum_of_xy(self.code, rhs.code) as f32 + let xy = self.k * rhs.k * simd::u8::reduce_sum_of_xy(self.code, rhs.code) as f32 + self.b * rhs.b * self.code.len() as f32 + self.k * rhs.b * self.sum_of_code + self.b * rhs.k * rhs.sum_of_code; @@ -204,7 +203,7 @@ impl VectorBorrowed for Scalar8Borrowed<'_> { #[inline(always)] fn operator_cos(self, rhs: Self) -> Distance { assert_eq!(self.code.len(), rhs.code.len()); - let xy = self.k * rhs.k * base::simd::u8::reduce_sum_of_xy(self.code, rhs.code) as f32 + let xy = self.k * rhs.k * simd::u8::reduce_sum_of_xy(self.code, rhs.code) as f32 + self.b * rhs.b * self.code.len() as f32 + self.k * rhs.b * self.sum_of_code + self.b * rhs.k * rhs.sum_of_code; @@ -279,7 +278,7 @@ impl VectorBorrowed for Scalar8Borrowed<'_> { }, self.k, self.b, - base::simd::u8::reduce_sum_of_x_as_u32(code) as f32, + simd::u8::reduce_sum_of_x(code) as f32, code.to_owned(), ) } diff --git a/crates/vector/src/svect.rs b/crates/vector/src/svect.rs new file mode 100644 index 0000000..08f678d --- /dev/null +++ b/crates/vector/src/svect.rs @@ -0,0 +1,445 @@ +use crate::{VectorBorrowed, VectorOwned}; +use distance::Distance; +use serde::{Deserialize, Serialize}; +use simd::Floating; +use std::ops::{Bound, RangeBounds}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SVectOwned { + dims: u32, + indexes: Vec, + values: Vec, +} + +impl SVectOwned { + #[inline(always)] + pub fn new(dims: u32, indexes: Vec, values: Vec) -> Self { + Self::new_checked(dims, indexes, values).expect("invalid data") + } + + #[inline(always)] + pub fn new_checked(dims: u32, indexes: Vec, values: Vec) -> Option { + if !(1..=1_048_575).contains(&dims) { + return None; + } + if indexes.len() != values.len() { + return None; + } + let len = indexes.len(); + for i in 1..len { + if !(indexes[i - 1] < indexes[i]) { + return None; + } + } + if len != 0 && !(indexes[len - 1] < dims) { + return None; + } + if S::reduce_or_of_is_zero_x(&values) { + return None; + } + unsafe { Some(Self::new_unchecked(dims, indexes, values)) } + } + + /// # Safety + /// + /// * `dims` must be in `1..=1_048_575`. + /// * `indexes.len()` must be equal to `values.len()`. + /// * `indexes` must be a strictly increasing sequence and the last in the sequence must be less than `dims`. + /// * A floating number in `values` must not be positive zero or negative zero. + #[inline(always)] + pub unsafe fn new_unchecked(dims: u32, indexes: Vec, values: Vec) -> Self { + Self { + dims, + indexes, + values, + } + } + + #[inline(always)] + pub fn indexes(&self) -> &[u32] { + &self.indexes + } + + #[inline(always)] + pub fn values(&self) -> &[S] { + &self.values + } +} + +impl VectorOwned for SVectOwned { + type Borrowed<'a> = SVectBorrowed<'a, S>; + + #[inline(always)] + fn as_borrowed(&self) -> SVectBorrowed<'_, S> { + SVectBorrowed { + dims: self.dims, + indexes: &self.indexes, + values: &self.values, + } + } + + #[inline(always)] + fn zero(dims: u32) -> Self { + Self::new(dims, vec![], vec![]) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct SVectBorrowed<'a, S> { + dims: u32, + indexes: &'a [u32], + values: &'a [S], +} + +impl<'a, S: Floating> SVectBorrowed<'a, S> { + #[inline(always)] + pub fn new(dims: u32, indexes: &'a [u32], values: &'a [S]) -> Self { + Self::new_checked(dims, indexes, values).expect("invalid data") + } + + #[inline(always)] + pub fn new_checked(dims: u32, indexes: &'a [u32], values: &'a [S]) -> Option { + if !(1..=1_048_575).contains(&dims) { + return None; + } + if indexes.len() != values.len() { + return None; + } + let len = indexes.len(); + for i in 1..len { + if !(indexes[i - 1] < indexes[i]) { + return None; + } + } + if len != 0 && !(indexes[len - 1] < dims) { + return None; + } + for i in 0..len { + if values[i] == S::zero() { + return None; + } + } + unsafe { Some(Self::new_unchecked(dims, indexes, values)) } + } + + /// # Safety + /// + /// * `dims` must be in `1..=1_048_575`. + /// * `indexes.len()` must be equal to `values.len()`. + /// * `indexes` must be a strictly increasing sequence and the last in the sequence must be less than `dims`. + /// * A floating number in `values` must not be positive zero or negative zero. + #[inline(always)] + pub unsafe fn new_unchecked(dims: u32, indexes: &'a [u32], values: &'a [S]) -> Self { + Self { + dims, + indexes, + values, + } + } + + #[inline(always)] + pub fn indexes(&self) -> &'a [u32] { + self.indexes + } + + #[inline(always)] + pub fn values(&self) -> &'a [S] { + self.values + } + + #[inline(always)] + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> u32 { + self.indexes.len() as u32 + } +} + +impl VectorBorrowed for SVectBorrowed<'_, S> { + type Owned = SVectOwned; + + #[inline(always)] + fn dims(&self) -> u32 { + self.dims + } + + #[inline(always)] + fn own(&self) -> SVectOwned { + SVectOwned { + dims: self.dims, + indexes: self.indexes.to_vec(), + values: self.values.to_vec(), + } + } + + #[inline(always)] + fn norm(&self) -> f32 { + S::reduce_sum_of_x2(self.values).sqrt() + } + + #[inline(always)] + fn operator_dot(self, rhs: Self) -> Distance { + let xy = S::reduce_sum_of_xy_sparse(self.indexes, self.values, rhs.indexes, rhs.values); + Distance::from(-xy) + } + + #[inline(always)] + fn operator_l2(self, rhs: Self) -> Distance { + let d2 = S::reduce_sum_of_d2_sparse(self.indexes, self.values, rhs.indexes, rhs.values); + Distance::from(d2) + } + + #[inline(always)] + fn operator_cos(self, rhs: Self) -> Distance { + let xy = S::reduce_sum_of_xy_sparse(self.indexes, self.values, rhs.indexes, rhs.values); + let x2 = S::reduce_sum_of_x2(self.values); + let y2 = S::reduce_sum_of_x2(rhs.values); + Distance::from(1.0 - xy / (x2 * y2).sqrt()) + } + + #[inline(always)] + fn operator_hamming(self, _: Self) -> Distance { + unimplemented!() + } + + #[inline(always)] + fn operator_jaccard(self, _: Self) -> Distance { + unimplemented!() + } + + #[inline(always)] + fn function_normalize(&self) -> SVectOwned { + let l = S::reduce_sum_of_x2(self.values).sqrt(); + let mut indexes = self.indexes.to_vec(); + let mut values = self.values.to_vec(); + let n = indexes.len(); + S::vector_mul_scalar_inplace(&mut values, 1.0 / l); + let mut j = 0_usize; + for i in 0..n { + if values[i] != S::zero() { + indexes[j] = indexes[i]; + values[j] = values[i]; + j += 1; + } + } + indexes.truncate(j); + values.truncate(j); + SVectOwned::new(self.dims, indexes, values) + } + + fn operator_add(&self, rhs: Self) -> Self::Owned { + assert_eq!(self.dims, rhs.dims); + let size1 = self.len(); + let size2 = rhs.len(); + let mut pos1 = 0; + let mut pos2 = 0; + let mut pos = 0; + let mut indexes = vec![0; (size1 + size2) as _]; + let mut values = vec![S::zero(); (size1 + size2) as _]; + while pos1 < size1 && pos2 < size2 { + let lhs_index = self.indexes[pos1 as usize]; + let rhs_index = rhs.indexes[pos2 as usize]; + let lhs_value = self.values[pos1 as usize]; + let rhs_value = rhs.values[pos2 as usize]; + indexes[pos] = lhs_index.min(rhs_index); + values[pos] = S::scalar_add( + lhs_value.mask(lhs_index <= rhs_index), + rhs_value.mask(lhs_index >= rhs_index), + ); + pos1 += (lhs_index <= rhs_index) as u32; + pos2 += (lhs_index >= rhs_index) as u32; + pos += (values[pos] != S::zero()) as usize; + } + for i in pos1..size1 { + indexes[pos] = self.indexes[i as usize]; + values[pos] = self.values[i as usize]; + pos += 1; + } + for i in pos2..size2 { + indexes[pos] = rhs.indexes[i as usize]; + values[pos] = rhs.values[i as usize]; + pos += 1; + } + indexes.truncate(pos); + values.truncate(pos); + SVectOwned::new(self.dims, indexes, values) + } + + fn operator_sub(&self, rhs: Self) -> Self::Owned { + assert_eq!(self.dims, rhs.dims); + let size1 = self.len(); + let size2 = rhs.len(); + let mut pos1 = 0; + let mut pos2 = 0; + let mut pos = 0; + let mut indexes = vec![0; (size1 + size2) as _]; + let mut values = vec![S::zero(); (size1 + size2) as _]; + while pos1 < size1 && pos2 < size2 { + let lhs_index = self.indexes[pos1 as usize]; + let rhs_index = rhs.indexes[pos2 as usize]; + let lhs_value = self.values[pos1 as usize]; + let rhs_value = rhs.values[pos2 as usize]; + indexes[pos] = lhs_index.min(rhs_index); + values[pos] = S::scalar_sub( + lhs_value.mask(lhs_index <= rhs_index), + rhs_value.mask(lhs_index >= rhs_index), + ); + pos1 += (lhs_index <= rhs_index) as u32; + pos2 += (lhs_index >= rhs_index) as u32; + pos += (values[pos] != S::zero()) as usize; + } + for i in pos1..size1 { + indexes[pos] = self.indexes[i as usize]; + values[pos] = self.values[i as usize]; + pos += 1; + } + for i in pos2..size2 { + indexes[pos] = rhs.indexes[i as usize]; + values[pos] = S::scalar_neg(rhs.values[i as usize]); + pos += 1; + } + indexes.truncate(pos); + values.truncate(pos); + SVectOwned::new(self.dims, indexes, values) + } + + fn operator_mul(&self, rhs: Self) -> Self::Owned { + assert_eq!(self.dims, rhs.dims); + let size1 = self.len(); + let size2 = rhs.len(); + let mut pos1 = 0; + let mut pos2 = 0; + let mut pos = 0; + let mut indexes = vec![0; std::cmp::min(size1, size2) as _]; + let mut values = vec![S::zero(); std::cmp::min(size1, size2) as _]; + while pos1 < size1 && pos2 < size2 { + let lhs_index = self.indexes[pos1 as usize]; + let rhs_index = rhs.indexes[pos2 as usize]; + match lhs_index.cmp(&rhs_index) { + std::cmp::Ordering::Less => { + pos1 += 1; + } + std::cmp::Ordering::Equal => { + // only both indexes are not zero, values are multiplied + let lhs_value = self.values[pos1 as usize]; + let rhs_value = rhs.values[pos2 as usize]; + indexes[pos] = lhs_index; + values[pos] = S::scalar_mul(lhs_value, rhs_value); + pos1 += 1; + pos2 += 1; + // only increment pos if the value is not zero + pos += (values[pos] != S::zero()) as usize; + } + std::cmp::Ordering::Greater => { + pos2 += 1; + } + } + } + indexes.truncate(pos); + values.truncate(pos); + SVectOwned::new(self.dims, indexes, values) + } + + fn operator_and(&self, _: Self) -> Self::Owned { + unimplemented!() + } + + fn operator_or(&self, _: Self) -> Self::Owned { + unimplemented!() + } + + fn operator_xor(&self, _: Self) -> Self::Owned { + unimplemented!() + } + + #[inline(always)] + fn subvector(&self, bounds: impl RangeBounds) -> Option { + let start = match bounds.start_bound().cloned() { + Bound::Included(x) => x, + Bound::Excluded(u32::MAX) => return None, + Bound::Excluded(x) => x + 1, + Bound::Unbounded => 0, + }; + let end = match bounds.end_bound().cloned() { + Bound::Included(u32::MAX) => return None, + Bound::Included(x) => x + 1, + Bound::Excluded(x) => x, + Bound::Unbounded => self.dims, + }; + if start >= end || end > self.dims { + return None; + } + let dims = end - start; + let s = self.indexes.partition_point(|&x| x < start); + let e = self.indexes.partition_point(|&x| x < end); + let indexes = self.indexes[s..e] + .iter() + .map(|x| x - start) + .collect::>(); + let values = self.values[s..e].to_vec(); + Self::Owned::new_checked(dims, indexes, values) + } +} + +impl PartialEq for SVectBorrowed<'_, S> { + fn eq(&self, other: &Self) -> bool { + if self.dims != other.dims { + return false; + } + if self.indexes.len() != other.indexes.len() { + return false; + } + for (&l, &r) in self.indexes.iter().zip(other.indexes.iter()) { + if l != r { + return false; + } + } + for (&l, &r) in self.values.iter().zip(other.values.iter()) { + if l != r { + return false; + } + } + true + } +} + +impl PartialOrd for SVectBorrowed<'_, S> { + fn partial_cmp(&self, other: &Self) -> Option { + use std::cmp::Ordering; + if self.dims != other.dims { + return None; + } + let mut lhs = self + .indexes + .iter() + .copied() + .zip(self.values.iter().copied()); + let mut rhs = other + .indexes + .iter() + .copied() + .zip(other.values.iter().copied()); + loop { + return match (lhs.next(), rhs.next()) { + (Some(lh), Some(rh)) => match lh.0.cmp(&rh.0) { + Ordering::Equal => match lh.1.partial_cmp(&rh.1)? { + Ordering::Equal => continue, + x => Some(x), + }, + Ordering::Less => Some(if lh.1 < S::zero() { + Ordering::Less + } else { + Ordering::Greater + }), + Ordering::Greater => Some(if S::zero() < rh.1 { + Ordering::Less + } else { + Ordering::Greater + }), + }, + (Some((_, x)), None) => Some(PartialOrd::partial_cmp(&x, &S::zero())?), + (None, Some((_, y))) => Some(PartialOrd::partial_cmp(&S::zero(), &y)?), + (None, None) => Some(Ordering::Equal), + }; + } + } +} diff --git a/crates/vector/src/vect.rs b/crates/vector/src/vect.rs new file mode 100644 index 0000000..34b186f --- /dev/null +++ b/crates/vector/src/vect.rs @@ -0,0 +1,216 @@ +use super::{VectorBorrowed, VectorOwned}; +use distance::Distance; +use serde::{Deserialize, Serialize}; +use simd::Floating; +use std::cmp::Ordering; +use std::ops::RangeBounds; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[repr(transparent)] +pub struct VectOwned(Vec); + +impl VectOwned { + #[inline(always)] + pub fn new(slice: Vec) -> Self { + Self::new_checked(slice).expect("invalid data") + } + + #[inline(always)] + pub fn new_checked(slice: Vec) -> Option { + if !(1..=65535).contains(&slice.len()) { + return None; + } + Some(unsafe { Self::new_unchecked(slice) }) + } + + /// # Safety + /// + /// * `slice.len()` must not be zero. + #[inline(always)] + pub unsafe fn new_unchecked(slice: Vec) -> Self { + Self(slice) + } + + #[inline(always)] + pub fn slice(&self) -> &[S] { + self.0.as_slice() + } + + #[inline(always)] + pub fn slice_mut(&mut self) -> &mut [S] { + self.0.as_mut_slice() + } + + #[inline(always)] + pub fn into_vec(self) -> Vec { + self.0 + } +} + +impl VectorOwned for VectOwned { + type Borrowed<'a> = VectBorrowed<'a, S>; + + #[inline(always)] + fn as_borrowed(&self) -> VectBorrowed<'_, S> { + VectBorrowed(self.0.as_slice()) + } + + #[inline(always)] + fn zero(dims: u32) -> Self { + Self::new(vec![S::zero(); dims as usize]) + } +} + +#[derive(Debug, Clone, Copy)] +#[repr(transparent)] +pub struct VectBorrowed<'a, S>(&'a [S]); + +impl<'a, S: Floating> VectBorrowed<'a, S> { + #[inline(always)] + pub fn new(slice: &'a [S]) -> Self { + Self::new_checked(slice).expect("invalid data") + } + + #[inline(always)] + pub fn new_checked(slice: &'a [S]) -> Option { + if !(1..=65535).contains(&slice.len()) { + return None; + } + Some(unsafe { Self::new_unchecked(slice) }) + } + + /// # Safety + /// + /// * `slice.len()` must not be zero. + #[inline(always)] + pub unsafe fn new_unchecked(slice: &'a [S]) -> Self { + Self(slice) + } + + #[inline(always)] + pub fn slice(&self) -> &'a [S] { + self.0 + } +} + +impl VectorBorrowed for VectBorrowed<'_, S> { + type Owned = VectOwned; + + #[inline(always)] + fn dims(&self) -> u32 { + self.0.len() as u32 + } + + #[inline(always)] + fn own(&self) -> VectOwned { + VectOwned(self.0.to_vec()) + } + + #[inline(always)] + fn norm(&self) -> f32 { + S::reduce_sum_of_x2(self.0).sqrt() + } + + #[inline(always)] + fn operator_dot(self, rhs: Self) -> Distance { + Distance::from(-S::reduce_sum_of_xy(self.slice(), rhs.slice())) + } + + #[inline(always)] + fn operator_l2(self, rhs: Self) -> Distance { + Distance::from(S::reduce_sum_of_d2(self.slice(), rhs.slice())) + } + + #[inline(always)] + fn operator_cos(self, rhs: Self) -> Distance { + let xy = S::reduce_sum_of_xy(self.slice(), rhs.slice()); + let x2 = S::reduce_sum_of_x2(self.0); + let y2 = S::reduce_sum_of_x2(rhs.0); + Distance::from(1.0 - xy / (x2 * y2).sqrt()) + } + + #[inline(always)] + fn operator_hamming(self, _: Self) -> Distance { + unimplemented!() + } + + #[inline(always)] + fn operator_jaccard(self, _: Self) -> Distance { + unimplemented!() + } + + #[inline(always)] + fn function_normalize(&self) -> VectOwned { + let mut data = self.0.to_vec(); + let l = S::reduce_sum_of_x2(&data).sqrt(); + S::vector_mul_scalar_inplace(&mut data, 1.0 / l); + VectOwned(data) + } + + fn operator_add(&self, rhs: Self) -> Self::Owned { + VectOwned::new(S::vector_add(self.slice(), rhs.slice())) + } + + fn operator_sub(&self, rhs: Self) -> Self::Owned { + VectOwned::new(S::vector_sub(self.slice(), rhs.slice())) + } + + fn operator_mul(&self, rhs: Self) -> Self::Owned { + VectOwned::new(S::vector_mul(self.slice(), rhs.slice())) + } + + fn operator_and(&self, _: Self) -> Self::Owned { + unimplemented!() + } + + fn operator_or(&self, _: Self) -> Self::Owned { + unimplemented!() + } + + fn operator_xor(&self, _: Self) -> Self::Owned { + unimplemented!() + } + + #[inline(always)] + fn subvector(&self, bounds: impl RangeBounds) -> Option { + let start_bound = bounds.start_bound().map(|x| *x as usize); + let end_bound = bounds.end_bound().map(|x| *x as usize); + let slice = self.0.get((start_bound, end_bound))?; + if slice.is_empty() { + return None; + } + Self::Owned::new_checked(slice.to_vec()) + } +} + +impl PartialEq for VectBorrowed<'_, S> { + fn eq(&self, other: &Self) -> bool { + if self.0.len() != other.0.len() { + return false; + } + let n = self.0.len(); + for i in 0..n { + if self.0[i] != other.0[i] { + return false; + } + } + true + } +} + +impl PartialOrd for VectBorrowed<'_, S> { + fn partial_cmp(&self, other: &Self) -> Option { + if self.0.len() != other.0.len() { + return None; + } + let n = self.0.len(); + for i in 0..n { + match PartialOrd::partial_cmp(&self.0[i], &other.0[i])? { + Ordering::Less => return Some(Ordering::Less), + Ordering::Equal => continue, + Ordering::Greater => return Some(Ordering::Greater), + } + } + Some(Ordering::Equal) + } +} diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..3501136 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1 @@ +style_edition = "2024" diff --git a/scripts/README.md b/scripts/README.md index 938ae90..9ed7172 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -10,8 +10,8 @@ export PGRX_VERSION=$(awk -F'version = "=|"' '/^pgrx\s*=.*version/ {print $2}' C export RUST_TOOLCHAIN=$(awk -F'"' '/^\s*channel\s*=/ {print $2}' rust-toolchain.toml) export PGRX_IMAGE=ghcr.io/tensorchord/vectorchord-pgrx:$PGRX_VERSION-$RUST_TOOLCHAIN -docker run --rm -v .:/workspace $PGRX_IMAGE cargo build --lib --features pg16 --profile opt -docker run --rm -v .:/workspace $PGRX_IMAGE ./tools/schema.sh --features pg16 --profile opt +docker run --rm -v .:/workspace $PGRX_IMAGE cargo build --lib --features pg16 --release +docker run --rm -v .:/workspace $PGRX_IMAGE ./tools/schema.sh --features pg16 --release ``` - (option 2) With Local Development Environment @@ -20,8 +20,8 @@ docker run --rm -v .:/workspace $PGRX_IMAGE ./tools/schema.sh --features pg16 -- sudo apt install -y build-essential libreadline-dev zlib1g-dev flex bison libxml2-dev libxslt-dev libssl-dev libxml2-utils xsltproc ccache pkg-config clang cargo install --locked cargo-pgrx cargo pgrx init -cargo build --package vchord --lib --features pg16 --profile opt -./tools/schema.sh --features pg16 --profile opt +cargo build --package vchord --lib --features pg16 --release +./tools/schema.sh --features pg16 --release ``` - build the debian package @@ -31,7 +31,6 @@ export SEMVER="0.0.0" export VERSION="16" export ARCH="x86_64" export PLATFORM="amd64" -export PROFILE="opt" ./tools/package.sh ``` diff --git a/src/bin/pgrx_embed.rs b/src/bin/pgrx_embed.rs index 20a006e..5f5c4d8 100644 --- a/src/bin/pgrx_embed.rs +++ b/src/bin/pgrx_embed.rs @@ -1,2 +1 @@ -#![feature(strict_provenance_lints)] ::pgrx::pgrx_embed!(); diff --git a/src/datatype/binary_scalar8.rs b/src/datatype/binary_scalar8.rs index 4dfe844..efb8758 100644 --- a/src/datatype/binary_scalar8.rs +++ b/src/datatype/binary_scalar8.rs @@ -1,8 +1,8 @@ use super::memory_scalar8::{Scalar8Input, Scalar8Output}; -use crate::types::scalar8::Scalar8Borrowed; -use base::vector::VectorBorrowed; use pgrx::datum::Internal; use pgrx::pg_sys::Oid; +use vector::VectorBorrowed; +use vector::scalar8::Scalar8Borrowed; #[pgrx::pg_extern(immutable, strict, parallel_safe)] fn _vchord_scalar8_send(vector: Scalar8Input<'_>) -> Vec { diff --git a/src/datatype/functions_scalar8.rs b/src/datatype/functions_scalar8.rs index fc1c221..adcdb30 100644 --- a/src/datatype/functions_scalar8.rs +++ b/src/datatype/functions_scalar8.rs @@ -1,17 +1,17 @@ use crate::datatype::memory_pgvector_halfvec::PgvectorHalfvecInput; use crate::datatype::memory_pgvector_vector::PgvectorVectorInput; use crate::datatype::memory_scalar8::Scalar8Output; -use crate::types::scalar8::Scalar8Borrowed; -use base::simd::ScalarLike; use half::f16; +use simd::Floating; +use vector::scalar8::Scalar8Borrowed; #[pgrx::pg_extern(sql = "")] fn _vchord_vector_quantize_to_scalar8(vector: PgvectorVectorInput) -> Scalar8Output { let vector = vector.as_borrowed(); let sum_of_x2 = f32::reduce_sum_of_x2(vector.slice()); let (k, b, code) = - base::simd::quantize::quantize(f32::vector_to_f32_borrowed(vector.slice()).as_ref(), 255.0); - let sum_of_code = base::simd::u8::reduce_sum_of_x_as_u32(&code) as f32; + simd::quantize::quantize(f32::vector_to_f32_borrowed(vector.slice()).as_ref(), 255.0); + let sum_of_code = simd::u8::reduce_sum_of_x(&code) as f32; Scalar8Output::new(Scalar8Borrowed::new(sum_of_x2, k, b, sum_of_code, &code)) } @@ -20,7 +20,7 @@ fn _vchord_halfvec_quantize_to_scalar8(vector: PgvectorHalfvecInput) -> Scalar8O let vector = vector.as_borrowed(); let sum_of_x2 = f16::reduce_sum_of_x2(vector.slice()); let (k, b, code) = - base::simd::quantize::quantize(f16::vector_to_f32_borrowed(vector.slice()).as_ref(), 255.0); - let sum_of_code = base::simd::u8::reduce_sum_of_x_as_u32(&code) as f32; + simd::quantize::quantize(f16::vector_to_f32_borrowed(vector.slice()).as_ref(), 255.0); + let sum_of_code = simd::u8::reduce_sum_of_x(&code) as f32; Scalar8Output::new(Scalar8Borrowed::new(sum_of_x2, k, b, sum_of_code, &code)) } diff --git a/src/datatype/memory_pgvector_halfvec.rs b/src/datatype/memory_pgvector_halfvec.rs index 7265a1b..1c065d9 100644 --- a/src/datatype/memory_pgvector_halfvec.rs +++ b/src/datatype/memory_pgvector_halfvec.rs @@ -1,4 +1,3 @@ -use base::vector::*; use half::f16; use pgrx::datum::FromDatum; use pgrx::datum::IntoDatum; @@ -11,6 +10,8 @@ use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping; use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable; use std::ops::Deref; use std::ptr::NonNull; +use vector::VectorBorrowed; +use vector::vect::VectBorrowed; #[repr(C, align(8))] pub struct PgvectorHalfvecHeader { diff --git a/src/datatype/memory_pgvector_vector.rs b/src/datatype/memory_pgvector_vector.rs index d81492c..e3ab9f9 100644 --- a/src/datatype/memory_pgvector_vector.rs +++ b/src/datatype/memory_pgvector_vector.rs @@ -1,4 +1,3 @@ -use base::vector::*; use pgrx::datum::FromDatum; use pgrx::datum::IntoDatum; use pgrx::pg_sys::Datum; @@ -10,6 +9,8 @@ use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping; use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable; use std::ops::Deref; use std::ptr::NonNull; +use vector::VectorBorrowed; +use vector::vect::VectBorrowed; #[repr(C, align(8))] pub struct PgvectorVectorHeader { diff --git a/src/datatype/memory_scalar8.rs b/src/datatype/memory_scalar8.rs index 3a7dab4..1641c63 100644 --- a/src/datatype/memory_scalar8.rs +++ b/src/datatype/memory_scalar8.rs @@ -1,5 +1,3 @@ -use crate::types::scalar8::Scalar8Borrowed; -use base::vector::*; use pgrx::datum::FromDatum; use pgrx::datum::IntoDatum; use pgrx::pg_sys::Datum; @@ -11,6 +9,8 @@ use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping; use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable; use std::ops::Deref; use std::ptr::NonNull; +use vector::VectorBorrowed; +use vector::scalar8::Scalar8Borrowed; #[repr(C, align(8))] pub struct Scalar8Header { diff --git a/src/datatype/operators_pgvector_halfvec.rs b/src/datatype/operators_pgvector_halfvec.rs index 2e707eb..fb0492a 100644 --- a/src/datatype/operators_pgvector_halfvec.rs +++ b/src/datatype/operators_pgvector_halfvec.rs @@ -1,6 +1,7 @@ -use crate::datatype::memory_pgvector_halfvec::*; -use base::vector::{VectBorrowed, VectorBorrowed}; +use crate::datatype::memory_pgvector_halfvec::{PgvectorHalfvecInput, PgvectorHalfvecOutput}; use std::num::NonZero; +use vector::VectorBorrowed; +use vector::vect::VectBorrowed; #[pgrx::pg_extern(immutable, strict, parallel_safe)] fn _vchord_halfvec_sphere_l2_in( diff --git a/src/datatype/operators_pgvector_vector.rs b/src/datatype/operators_pgvector_vector.rs index 2308ab9..e6b4be1 100644 --- a/src/datatype/operators_pgvector_vector.rs +++ b/src/datatype/operators_pgvector_vector.rs @@ -1,6 +1,7 @@ -use crate::datatype::memory_pgvector_vector::*; -use base::vector::{VectBorrowed, VectorBorrowed}; +use crate::datatype::memory_pgvector_vector::{PgvectorVectorInput, PgvectorVectorOutput}; use std::num::NonZero; +use vector::VectorBorrowed; +use vector::vect::VectBorrowed; #[pgrx::pg_extern(immutable, strict, parallel_safe)] fn _vchord_vector_sphere_l2_in( diff --git a/src/datatype/operators_scalar8.rs b/src/datatype/operators_scalar8.rs index db6a372..8b75e59 100644 --- a/src/datatype/operators_scalar8.rs +++ b/src/datatype/operators_scalar8.rs @@ -1,7 +1,7 @@ use crate::datatype::memory_scalar8::{Scalar8Input, Scalar8Output}; -use crate::types::scalar8::Scalar8Borrowed; -use base::vector::*; use std::num::NonZero; +use vector::VectorBorrowed; +use vector::scalar8::Scalar8Borrowed; #[pgrx::pg_extern(immutable, strict, parallel_safe)] fn _vchord_scalar8_operator_ip(lhs: Scalar8Input<'_>, rhs: Scalar8Input<'_>) -> f32 { diff --git a/src/datatype/text_scalar8.rs b/src/datatype/text_scalar8.rs index 5de82da..8669a01 100644 --- a/src/datatype/text_scalar8.rs +++ b/src/datatype/text_scalar8.rs @@ -1,8 +1,8 @@ use super::memory_scalar8::Scalar8Output; use crate::datatype::memory_scalar8::Scalar8Input; -use crate::types::scalar8::Scalar8Borrowed; use pgrx::pg_sys::Oid; use std::ffi::{CStr, CString}; +use vector::scalar8::Scalar8Borrowed; #[pgrx::pg_extern(immutable, strict, parallel_safe)] fn _vchord_scalar8_in(input: &CStr, oid: Oid, typmod: i32) -> Scalar8Output { diff --git a/src/datatype/typmod.rs b/src/datatype/typmod.rs index fe90a6d..08ac375 100644 --- a/src/datatype/typmod.rs +++ b/src/datatype/typmod.rs @@ -19,7 +19,6 @@ impl Typmod { None } } - #[allow(dead_code)] pub fn into_option_string(self) -> Option { use Typmod::*; match self { @@ -27,7 +26,6 @@ impl Typmod { Dims(x) => Some(x.get().to_string()), } } - #[allow(dead_code)] pub fn into_i32(self) -> i32 { use Typmod::*; match self { diff --git a/src/lib.rs b/src/lib.rs index 26d80ed..018388d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,17 +1,11 @@ #![allow(clippy::collapsible_else_if)] -#![allow(clippy::identity_op)] -#![allow(clippy::needless_range_loop)] +#![allow(clippy::infallible_destructuring_match)] #![allow(clippy::too_many_arguments)] #![allow(clippy::type_complexity)] -#![allow(clippy::int_plus_one)] -#![allow(clippy::unused_unit)] -#![allow(clippy::infallible_destructuring_match)] -#![feature(strict_provenance_lints)] mod datatype; mod postgres; mod projection; -mod types; mod upgrade; mod utils; mod vchordrq; diff --git a/src/postgres.rs b/src/postgres.rs index c24d669..6ff91e2 100644 --- a/src/postgres.rs +++ b/src/postgres.rs @@ -1,4 +1,5 @@ -use std::mem::{offset_of, MaybeUninit}; +use algorithm::{Opaque, Page, PageGuard, RelationRead, RelationWrite}; +use std::mem::{MaybeUninit, offset_of}; use std::ops::{Deref, DerefMut}; use std::ptr::NonNull; @@ -7,7 +8,7 @@ const _: () = assert!( ); const fn size_of_contents() -> usize { - use pgrx::pg_sys::{PageHeaderData, BLCKSZ}; + use pgrx::pg_sys::{BLCKSZ, PageHeaderData}; let size_of_page = BLCKSZ as usize; let size_of_header = offset_of!(PageHeaderData, pd_linp); let size_of_opaque = size_of::(); @@ -15,17 +16,17 @@ const fn size_of_contents() -> usize { } #[repr(C, align(8))] -pub struct Page { +pub struct PostgresPage { header: pgrx::pg_sys::PageHeaderData, content: [u8; size_of_contents()], opaque: Opaque, } -const _: () = assert!(align_of::() == pgrx::pg_sys::MAXIMUM_ALIGNOF as usize); -const _: () = assert!(size_of::() == pgrx::pg_sys::BLCKSZ as usize); +const _: () = assert!(align_of::() == pgrx::pg_sys::MAXIMUM_ALIGNOF as usize); +const _: () = assert!(size_of::() == pgrx::pg_sys::BLCKSZ as usize); -impl Page { - pub fn init_mut(this: &mut MaybeUninit) -> &mut Self { +impl PostgresPage { + fn init_mut(this: &mut MaybeUninit) -> &mut Self { unsafe { pgrx::pg_sys::PageInit( this.as_mut_ptr() as pgrx::pg_sys::Page, @@ -42,26 +43,29 @@ impl Page { this } #[allow(dead_code)] - pub unsafe fn assume_init_mut(this: &mut MaybeUninit) -> &mut Self { + unsafe fn assume_init_mut(this: &mut MaybeUninit) -> &mut Self { let this = unsafe { MaybeUninit::assume_init_mut(this) }; assert_eq!(offset_of!(Self, opaque), this.header.pd_special as usize); this } #[allow(dead_code)] - pub fn clone_into_boxed(&self) -> Box { + fn clone_into_boxed(&self) -> Box { let mut result = Box::new_uninit(); unsafe { std::ptr::copy(self as *const Self, result.as_mut_ptr(), 1); result.assume_init() } } - pub fn get_opaque(&self) -> &Opaque { +} + +impl Page for PostgresPage { + fn get_opaque(&self) -> &Opaque { &self.opaque } - pub fn get_opaque_mut(&mut self) -> &mut Opaque { + fn get_opaque_mut(&mut self) -> &mut Opaque { &mut self.opaque } - pub fn len(&self) -> u16 { + fn len(&self) -> u16 { use pgrx::pg_sys::{ItemIdData, PageHeaderData}; assert!(self.header.pd_lower as usize <= size_of::()); assert!(self.header.pd_upper as usize <= size_of::()); @@ -70,7 +74,7 @@ impl Page { assert!(lower <= upper); ((lower - offset_of!(PageHeaderData, pd_linp)) / size_of::()) as u16 } - pub fn get(&self, i: u16) -> Option<&[u8]> { + fn get(&self, i: u16) -> Option<&[u8]> { use pgrx::pg_sys::{ItemIdData, PageHeaderData}; if i == 0 { return None; @@ -99,8 +103,7 @@ impl Page { Some(std::slice::from_raw_parts(ptr, lp_len as _)) } } - #[allow(unused)] - pub fn get_mut(&mut self, i: u16) -> Option<&mut [u8]> { + fn get_mut(&mut self, i: u16) -> Option<&mut [u8]> { use pgrx::pg_sys::{ItemIdData, PageHeaderData}; if i == 0 { return None; @@ -122,7 +125,7 @@ impl Page { Some(std::slice::from_raw_parts_mut(ptr, lp_len as _)) } } - pub fn alloc(&mut self, data: &[u8]) -> Option { + fn alloc(&mut self, data: &[u8]) -> Option { unsafe { let i = pgrx::pg_sys::PageAddItemExtended( (self as *const Self).cast_mut().cast(), @@ -131,19 +134,15 @@ impl Page { 0, 0, ); - if i == 0 { - None - } else { - Some(i) - } + if i == 0 { None } else { Some(i) } } } - pub fn free(&mut self, i: u16) { + fn free(&mut self, i: u16) { unsafe { pgrx::pg_sys::PageIndexTupleDeleteNoCompact((self as *mut Self).cast(), i); } } - pub fn reconstruct(&mut self, removes: &[u16]) { + fn reconstruct(&mut self, removes: &[u16]) { let mut removes = removes.to_vec(); removes.sort(); removes.dedup(); @@ -159,41 +158,34 @@ impl Page { } } } - pub fn freespace(&self) -> u16 { + fn freespace(&self) -> u16 { unsafe { pgrx::pg_sys::PageGetFreeSpace((self as *const Self).cast_mut().cast()) as u16 } } } -#[repr(C, align(8))] -pub struct Opaque { - pub next: u32, - pub skip: u32, -} - const _: () = assert!(align_of::() == pgrx::pg_sys::MAXIMUM_ALIGNOF as usize); -pub struct BufferReadGuard { +pub struct PostgresBufferReadGuard { buf: i32, - page: NonNull, + page: NonNull, id: u32, } -impl BufferReadGuard { - #[allow(dead_code)] - pub fn id(&self) -> u32 { +impl PageGuard for PostgresBufferReadGuard { + fn id(&self) -> u32 { self.id } } -impl Deref for BufferReadGuard { - type Target = Page; +impl Deref for PostgresBufferReadGuard { + type Target = PostgresPage; - fn deref(&self) -> &Page { + fn deref(&self) -> &PostgresPage { unsafe { self.page.as_ref() } } } -impl Drop for BufferReadGuard { +impl Drop for PostgresBufferReadGuard { fn drop(&mut self) { unsafe { pgrx::pg_sys::UnlockReleaseBuffer(self.buf); @@ -201,36 +193,36 @@ impl Drop for BufferReadGuard { } } -pub struct BufferWriteGuard { +pub struct PostgresBufferWriteGuard { raw: pgrx::pg_sys::Relation, buf: i32, - page: NonNull, + page: NonNull, state: *mut pgrx::pg_sys::GenericXLogState, id: u32, tracking_freespace: bool, } -impl BufferWriteGuard { - pub fn id(&self) -> u32 { +impl PageGuard for PostgresBufferWriteGuard { + fn id(&self) -> u32 { self.id } } -impl Deref for BufferWriteGuard { - type Target = Page; +impl Deref for PostgresBufferWriteGuard { + type Target = PostgresPage; - fn deref(&self) -> &Page { + fn deref(&self) -> &PostgresPage { unsafe { self.page.as_ref() } } } -impl DerefMut for BufferWriteGuard { - fn deref_mut(&mut self) -> &mut Page { +impl DerefMut for PostgresBufferWriteGuard { + fn deref_mut(&mut self) -> &mut PostgresPage { unsafe { self.page.as_mut() } } } -impl Drop for BufferWriteGuard { +impl Drop for PostgresBufferWriteGuard { fn drop(&mut self) { unsafe { if std::thread::panicking() { @@ -248,19 +240,36 @@ impl Drop for BufferWriteGuard { } #[derive(Debug, Clone)] -pub struct Relation { +pub struct PostgresRelation { raw: pgrx::pg_sys::Relation, } -impl Relation { +impl PostgresRelation { pub unsafe fn new(raw: pgrx::pg_sys::Relation) -> Self { Self { raw } } - pub fn read(&self, id: u32) -> BufferReadGuard { + + #[allow(dead_code)] + pub fn len(&self) -> u32 { + unsafe { + pgrx::pg_sys::RelationGetNumberOfBlocksInFork( + self.raw, + pgrx::pg_sys::ForkNumber::MAIN_FORKNUM, + ) + } + } +} + +impl RelationRead for PostgresRelation { + type Page = PostgresPage; + + type ReadGuard<'a> = PostgresBufferReadGuard; + + fn read(&self, id: u32) -> Self::ReadGuard<'_> { assert!(id != u32::MAX, "no such page"); unsafe { use pgrx::pg_sys::{ - BufferGetPage, LockBuffer, ReadBufferExtended, ReadBufferMode, BUFFER_LOCK_SHARE, + BUFFER_LOCK_SHARE, BufferGetPage, LockBuffer, ReadBufferExtended, ReadBufferMode, }; let buf = ReadBufferExtended( self.raw, @@ -271,15 +280,21 @@ impl Relation { ); LockBuffer(buf, BUFFER_LOCK_SHARE as _); let page = NonNull::new(BufferGetPage(buf).cast()).expect("failed to get page"); - BufferReadGuard { buf, page, id } + PostgresBufferReadGuard { buf, page, id } } } - pub fn write(&self, id: u32, tracking_freespace: bool) -> BufferWriteGuard { +} + +impl RelationWrite for PostgresRelation { + type WriteGuard<'a> = PostgresBufferWriteGuard; + + fn write(&self, id: u32, tracking_freespace: bool) -> PostgresBufferWriteGuard { assert!(id != u32::MAX, "no such page"); unsafe { use pgrx::pg_sys::{ - ForkNumber, GenericXLogRegisterBuffer, GenericXLogStart, LockBuffer, - ReadBufferExtended, ReadBufferMode, BUFFER_LOCK_EXCLUSIVE, GENERIC_XLOG_FULL_IMAGE, + BUFFER_LOCK_EXCLUSIVE, ForkNumber, GENERIC_XLOG_FULL_IMAGE, + GenericXLogRegisterBuffer, GenericXLogStart, LockBuffer, ReadBufferExtended, + ReadBufferMode, }; let buf = ReadBufferExtended( self.raw, @@ -292,10 +307,10 @@ impl Relation { let state = GenericXLogStart(self.raw); let page = NonNull::new( GenericXLogRegisterBuffer(state, buf, GENERIC_XLOG_FULL_IMAGE as _) - .cast::>(), + .cast::>(), ) .expect("failed to get page"); - BufferWriteGuard { + PostgresBufferWriteGuard { raw: self.raw, buf, page: page.cast(), @@ -305,12 +320,12 @@ impl Relation { } } } - pub fn extend(&self, tracking_freespace: bool) -> BufferWriteGuard { + fn extend(&self, tracking_freespace: bool) -> PostgresBufferWriteGuard { unsafe { use pgrx::pg_sys::{ - ExclusiveLock, ForkNumber, GenericXLogRegisterBuffer, GenericXLogStart, LockBuffer, - LockRelationForExtension, ReadBufferExtended, ReadBufferMode, - UnlockRelationForExtension, BUFFER_LOCK_EXCLUSIVE, GENERIC_XLOG_FULL_IMAGE, + BUFFER_LOCK_EXCLUSIVE, ExclusiveLock, ForkNumber, GENERIC_XLOG_FULL_IMAGE, + GenericXLogRegisterBuffer, GenericXLogStart, LockBuffer, LockRelationForExtension, + ReadBufferExtended, ReadBufferMode, UnlockRelationForExtension, }; LockRelationForExtension(self.raw, ExclusiveLock as _); let buf = ReadBufferExtended( @@ -325,11 +340,11 @@ impl Relation { let state = GenericXLogStart(self.raw); let mut page = NonNull::new( GenericXLogRegisterBuffer(state, buf, GENERIC_XLOG_FULL_IMAGE as _) - .cast::>(), + .cast::>(), ) .expect("failed to get page"); - Page::init_mut(page.as_mut()); - BufferWriteGuard { + PostgresPage::init_mut(page.as_mut()); + PostgresBufferWriteGuard { raw: self.raw, buf, page: page.cast(), @@ -339,7 +354,7 @@ impl Relation { } } } - pub fn search(&self, freespace: usize) -> Option { + fn search(&self, freespace: usize) -> Option { unsafe { loop { let id = pgrx::pg_sys::GetPageWithFreeSpace(self.raw, freespace); @@ -357,13 +372,4 @@ impl Relation { } } } - #[allow(dead_code)] - pub fn len(&self) -> u32 { - unsafe { - pgrx::pg_sys::RelationGetNumberOfBlocksInFork( - self.raw, - pgrx::pg_sys::ForkNumber::MAIN_FORKNUM, - ) - } - } } diff --git a/src/projection.rs b/src/projection.rs index 4273180..fbcaeff 100644 --- a/src/projection.rs +++ b/src/projection.rs @@ -1,75 +1,21 @@ -use nalgebra::DMatrix; +use random_orthogonal_matrix::random_orthogonal_matrix; use std::sync::OnceLock; -fn random_matrix(n: usize) -> DMatrix { - use rand::{Rng, SeedableRng}; - use rand_chacha::ChaCha12Rng; - use rand_distr::StandardNormal; - let mut rng = ChaCha12Rng::from_seed([7; 32]); - DMatrix::from_fn(n, n, |_, _| rng.sample(StandardNormal)) +fn matrix(n: usize) -> Option<&'static Vec>> { + static MATRIXS: [OnceLock>>; 1 + 60000] = [const { OnceLock::new() }; 1 + 60000]; + MATRIXS + .get(n) + .map(|x| x.get_or_init(|| random_orthogonal_matrix(n))) } -#[ignore] -#[test] -fn check_all_matrixs_are_full_rank() { - let parallelism = std::thread::available_parallelism().unwrap().get(); - std::thread::scope(|scope| { - let mut threads = vec![]; - for remainder in 0..parallelism { - threads.push(scope.spawn(move || { - for n in (0..=60000).filter(|x| x % parallelism == remainder) { - let matrix = random_matrix(n); - assert!(matrix.is_invertible()); - } - })); - } - for thread in threads { - thread.join().unwrap(); - } - }); -} - -#[test] -fn check_matrices() { - assert_eq!( - orthogonal_matrix(2), - vec![vec![-0.5424608, -0.8400813], vec![0.8400813, -0.54246056]] - ); - assert_eq!( - orthogonal_matrix(3), - vec![ - vec![-0.5309615, -0.69094884, -0.49058124], - vec![0.8222731, -0.56002235, -0.10120347], - vec![0.20481002, 0.45712686, -0.86549866] - ] - ); -} - -fn orthogonal_matrix(n: usize) -> Vec> { - use nalgebra::QR; - let matrix = random_matrix(n); - // QR decomposition is unique if the matrix is full rank - let qr = QR::new(matrix); - let q = qr.q(); - let mut projection = Vec::new(); - for row in q.row_iter() { - projection.push(row.iter().copied().collect::>()); - } - projection -} - -static MATRIXS: [OnceLock>>; 1 + 60000] = [const { OnceLock::new() }; 1 + 60000]; - pub fn prewarm(n: usize) { - if n <= 60000 { - MATRIXS[n].get_or_init(|| orthogonal_matrix(n)); - } + let _ = matrix(n); } pub fn project(vector: &[f32]) -> Vec { - use base::simd::ScalarLike; + use simd::Floating; let n = vector.len(); - let matrix = MATRIXS[n].get_or_init(|| orthogonal_matrix(n)); + let matrix = matrix(n).expect("dimension too large"); (0..n) .map(|i| f32::reduce_sum_of_xy(vector, &matrix[i])) .collect() diff --git a/src/types/mod.rs b/src/types/mod.rs deleted file mode 100644 index af08ee7..0000000 --- a/src/types/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod scalar8; diff --git a/src/utils/k_means.rs b/src/utils/k_means.rs index 8aeac72..7b44a24 100644 --- a/src/utils/k_means.rs +++ b/src/utils/k_means.rs @@ -1,16 +1,14 @@ -#![allow(clippy::ptr_arg)] - use super::parallelism::{ParallelIterator, Parallelism}; -use base::simd::*; use half::f16; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; +use simd::Floating; pub fn k_means( parallelism: &P, c: usize, dims: usize, - samples: &Vec>, + samples: &[Vec], is_spherical: bool, iterations: usize, ) -> Vec> { @@ -18,9 +16,9 @@ pub fn k_means( assert!(dims > 0); let n = samples.len(); if n <= c { - quick_centers(c, dims, samples.clone(), is_spherical) + quick_centers(c, dims, samples.to_vec(), is_spherical) } else { - let compute = |parallelism: &P, centroids: &Vec>| { + let compute = |parallelism: &P, centroids: &[Vec]| { if n >= 1000 && c >= 1000 { rabitq_index(parallelism, dims, n, c, samples, centroids) } else { @@ -62,9 +60,7 @@ fn quick_centers( let mut rng = rand::thread_rng(); let mut centroids = samples; for _ in n..c { - let r = (0..dims) - .map(|_| f32::from_f32(rng.gen_range(-1.0f32..1.0f32))) - .collect(); + let r = (0..dims).map(|_| rng.gen_range(-1.0f32..1.0f32)).collect(); centroids.push(r); } if is_spherical { @@ -82,152 +78,37 @@ fn rabitq_index( dims: usize, n: usize, c: usize, - samples: &Vec>, - centroids: &Vec>, + samples: &[Vec], + centroids: &[Vec], ) -> Vec { - fn code_alpha(vector: &[f32]) -> (f32, f32, f32, f32) { - let dims = vector.len(); - let sum_of_abs_x = f32::reduce_sum_of_abs_x(vector); - let sum_of_x2 = f32::reduce_sum_of_x2(vector); - let dis_u = sum_of_x2.sqrt(); - let x0 = sum_of_abs_x / (sum_of_x2 * (dims as f32)).sqrt(); - let x_x0 = dis_u / x0; - let fac_norm = (dims as f32).sqrt(); - let max_x1 = 1.0f32 / (dims as f32 - 1.0).sqrt(); - let factor_err = 2.0f32 * max_x1 * (x_x0 * x_x0 - dis_u * dis_u).sqrt(); - let factor_ip = -2.0f32 / fac_norm * x_x0; - let cnt_pos = vector - .iter() - .map(|x| x.scalar_is_sign_positive() as i32) - .sum::(); - let cnt_neg = vector - .iter() - .map(|x| x.scalar_is_sign_negative() as i32) - .sum::(); - let factor_ppc = factor_ip * (cnt_pos - cnt_neg) as f32; - (sum_of_x2, factor_ppc, factor_ip, factor_err) - } - fn code_beta(vector: &[f32]) -> Vec { - let dims = vector.len(); - let mut code = Vec::new(); - for i in 0..dims { - code.push(vector[i].scalar_is_sign_positive() as u8); - } - code - } let mut a0 = Vec::new(); let mut a1 = Vec::new(); let mut a2 = Vec::new(); let mut a3 = Vec::new(); let mut a4 = Vec::new(); for vectors in centroids.chunks(32) { - use base::simd::fast_scan::b4::pack; - let code_alphas = std::array::from_fn::<_, 32, _>(|i| { + use simd::fast_scan::pack; + let x = std::array::from_fn::<_, 32, _>(|i| { if let Some(vector) = vectors.get(i) { - code_alpha(vector) + rabitq::block::code(dims as _, vector) } else { - (0.0, 0.0, 0.0, 0.0) - } - }); - let code_betas = std::array::from_fn::<_, 32, _>(|i| { - let mut result = vec![0_u8; dims.div_ceil(4)]; - if let Some(vector) = vectors.get(i) { - let mut c = code_beta(vector); - c.resize(dims.next_multiple_of(4), 0); - for i in 0..dims.div_ceil(4) { - for j in 0..4 { - result[i] |= c[i * 4 + j] << j; - } - } + rabitq::block::dummy_code(dims as _) } - result }); - a0.push(code_alphas.map(|x| x.0)); - a1.push(code_alphas.map(|x| x.1)); - a2.push(code_alphas.map(|x| x.2)); - a3.push(code_alphas.map(|x| x.3)); - a4.push(pack(dims.div_ceil(4) as _, code_betas).collect::>()); + a0.push(x.each_ref().map(|x| x.dis_u_2)); + a1.push(x.each_ref().map(|x| x.factor_ppc)); + a2.push(x.each_ref().map(|x| x.factor_ip)); + a3.push(x.each_ref().map(|x| x.factor_err)); + a4.push(pack(dims.div_ceil(4) as _, x.map(|x| x.signs)).collect::>()); } parallelism - .into_par_iter(0..n) + .rayon_into_par_iter(0..n) .map(|i| { - fn generate(mut qvector: Vec) -> Vec { - let dims = qvector.len() as u32; - let t = dims.div_ceil(4); - qvector.resize(qvector.len().next_multiple_of(4), 0); - let mut lut = vec![0u8; t as usize * 16]; - for i in 0..t as usize { - unsafe { - // this hint is used to skip bound checks - std::hint::assert_unchecked(4 * i + 3 < qvector.len()); - std::hint::assert_unchecked(16 * i + 15 < lut.len()); - } - let t0 = qvector[4 * i + 0]; - let t1 = qvector[4 * i + 1]; - let t2 = qvector[4 * i + 2]; - let t3 = qvector[4 * i + 3]; - lut[16 * i + 0b0000] = 0; - lut[16 * i + 0b0001] = t0; - lut[16 * i + 0b0010] = t1; - lut[16 * i + 0b0011] = t1 + t0; - lut[16 * i + 0b0100] = t2; - lut[16 * i + 0b0101] = t2 + t0; - lut[16 * i + 0b0110] = t2 + t1; - lut[16 * i + 0b0111] = t2 + t1 + t0; - lut[16 * i + 0b1000] = t3; - lut[16 * i + 0b1001] = t3 + t0; - lut[16 * i + 0b1010] = t3 + t1; - lut[16 * i + 0b1011] = t3 + t1 + t0; - lut[16 * i + 0b1100] = t3 + t2; - lut[16 * i + 0b1101] = t3 + t2 + t0; - lut[16 * i + 0b1110] = t3 + t2 + t1; - lut[16 * i + 0b1111] = t3 + t2 + t1 + t0; - } - lut - } - fn fscan_process_lowerbound( - dims: u32, - lut: &(f32, f32, f32, f32, Vec), - (dis_u_2, factor_ppc, factor_ip, factor_err, t): ( - &[f32; 32], - &[f32; 32], - &[f32; 32], - &[f32; 32], - &[u8], - ), - epsilon: f32, - ) -> [Distance; 32] { - use base::simd::fast_scan::b4::fast_scan_b4; - let &(dis_v_2, b, k, qvector_sum, ref s) = lut; - let r = fast_scan_b4(dims.div_ceil(4), t, s); - std::array::from_fn(|i| { - let rough = dis_u_2[i] - + dis_v_2 - + b * factor_ppc[i] - + ((2.0 * r[i] as f32) - qvector_sum) * factor_ip[i] * k; - let err = factor_err[i] * dis_v_2.sqrt(); - Distance::from_f32(rough - epsilon * err) - }) - } - use base::distance::Distance; - use base::simd::quantize; - - let lut = { - let vector = &samples[i]; - let dis_v_2 = f32::reduce_sum_of_x2(vector); - let (k, b, qvector) = - quantize::quantize(f32::vector_to_f32_borrowed(vector).as_ref(), 15.0); - let qvector_sum = if vector.len() <= 4369 { - u8::reduce_sum_of_x_as_u16(&qvector) as f32 - } else { - u8::reduce_sum_of_x_as_u32(&qvector) as f32 - }; - (dis_v_2, b, k, qvector_sum, generate(qvector)) - }; - + use distance::Distance; + let lut = rabitq::block::fscan_preprocess(&samples[i]); let mut result = (Distance::INFINITY, 0); for block in 0..c.div_ceil(32) { - let lowerbound = fscan_process_lowerbound( + let lowerbound = rabitq::block::fscan_process_lowerbound_l2( dims as _, &lut, (&a0[block], &a1[block], &a2[block], &a3[block], &a4[block]), @@ -253,11 +134,11 @@ fn flat_index( _dims: usize, n: usize, c: usize, - samples: &Vec>, - centroids: &Vec>, + samples: &[Vec], + centroids: &[Vec], ) -> Vec { parallelism - .into_par_iter(0..n) + .rayon_into_par_iter(0..n) .map(|i| { let mut result = (f32::INFINITY, 0); for j in 0..c { @@ -279,18 +160,18 @@ struct LloydKMeans<'a, P, F> { centroids: Vec>, assign: Vec, rng: StdRng, - samples: &'a Vec>, + samples: &'a [Vec], compute: F, } const DELTA: f32 = f16::EPSILON.to_f32_const(); -impl<'a, P: Parallelism, F: Fn(&P, &Vec>) -> Vec> LloydKMeans<'a, P, F> { +impl<'a, P: Parallelism, F: Fn(&P, &[Vec]) -> Vec> LloydKMeans<'a, P, F> { fn new( parallelism: &'a P, c: usize, dims: usize, - samples: &'a Vec>, + samples: &'a [Vec], is_spherical: bool, compute: F, ) -> Self { @@ -304,7 +185,7 @@ impl<'a, P: Parallelism, F: Fn(&P, &Vec>) -> Vec> LloydKMeans<'a } let assign = parallelism - .into_par_iter(0..n) + .rayon_into_par_iter(0..n) .map(|i| { let mut result = (f32::INFINITY, 0); for j in 0..c { @@ -339,7 +220,7 @@ impl<'a, P: Parallelism, F: Fn(&P, &Vec>) -> Vec> LloydKMeans<'a let (sum, mut count) = self .parallelism - .into_par_iter(0..n) + .rayon_into_par_iter(0..n) .fold( || (vec![vec![f32::zero(); dims]; c], vec![0.0f32; c]), |(mut sum, mut count), i| { @@ -361,7 +242,7 @@ impl<'a, P: Parallelism, F: Fn(&P, &Vec>) -> Vec> LloydKMeans<'a let mut centroids = self .parallelism - .into_par_iter(0..c) + .rayon_into_par_iter(0..c) .map(|i| f32::vector_mul_scalar(&sum[i], 1.0 / count[i])) .collect::>(); @@ -379,15 +260,15 @@ impl<'a, P: Parallelism, F: Fn(&P, &Vec>) -> Vec> LloydKMeans<'a o = (o + 1) % c; } centroids[i] = centroids[o].clone(); - f32::kmeans_helper(&mut centroids[i], 1.0 + DELTA, 1.0 - DELTA); - f32::kmeans_helper(&mut centroids[o], 1.0 - DELTA, 1.0 + DELTA); + vector_mul_scalars_inplace(&mut centroids[i], [1.0 + DELTA, 1.0 - DELTA]); + vector_mul_scalars_inplace(&mut centroids[o], [1.0 - DELTA, 1.0 + DELTA]); count[i] = count[o] / 2.0; count[o] -= count[i]; } if self.is_spherical { self.parallelism - .into_par_iter(&mut centroids) + .rayon_into_par_iter(&mut centroids) .for_each(|centroid| { let l = f32::reduce_sum_of_x2(centroid).sqrt(); f32::vector_mul_scalar_inplace(centroid, 1.0 / l); @@ -408,3 +289,14 @@ impl<'a, P: Parallelism, F: Fn(&P, &Vec>) -> Vec> LloydKMeans<'a self.centroids } } + +fn vector_mul_scalars_inplace(this: &mut [f32], scalars: [f32; 2]) { + let n: usize = this.len(); + for i in 0..n { + if i % 2 == 0 { + this[i] *= scalars[0]; + } else { + this[i] *= scalars[1]; + } + } +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 2d9a3b7..1b07dc6 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,3 +1,2 @@ -pub mod infinite_byte_chunks; pub mod k_means; pub mod parallelism; diff --git a/src/utils/parallelism.rs b/src/utils/parallelism.rs index bd11191..b960b56 100644 --- a/src/utils/parallelism.rs +++ b/src/utils/parallelism.rs @@ -7,8 +7,7 @@ pub use rayon::iter::ParallelIterator; pub trait Parallelism: Send + Sync { fn check(&self); - #[allow(clippy::wrong_self_convention)] - fn into_par_iter(&self, x: I) -> I::Iter; + fn rayon_into_par_iter(&self, x: I) -> I::Iter; } struct ParallelismCheckPanic(Box); @@ -57,7 +56,7 @@ impl Parallelism for RayonParallelism { } } - fn into_par_iter(&self, x: I) -> I::Iter { + fn rayon_into_par_iter(&self, x: I) -> I::Iter { x.into_par_iter() } } diff --git a/src/vchordrq/algorithm/build.rs b/src/vchordrq/algorithm/build.rs index 6554bfc..4213e59 100644 --- a/src/vchordrq/algorithm/build.rs +++ b/src/vchordrq/algorithm/build.rs @@ -1,26 +1,25 @@ -use super::RelationWrite; use crate::vchordrq::algorithm::rabitq; use crate::vchordrq::algorithm::tuples::*; -use crate::vchordrq::algorithm::PageGuard; use crate::vchordrq::index::am_options::Opfamily; +use crate::vchordrq::types::DistanceKind; use crate::vchordrq::types::VchordrqBuildOptions; use crate::vchordrq::types::VchordrqExternalBuildOptions; use crate::vchordrq::types::VchordrqIndexingOptions; use crate::vchordrq::types::VchordrqInternalBuildOptions; use crate::vchordrq::types::VectorOptions; -use base::distance::DistanceKind; -use base::search::Pointer; -use base::simd::ScalarLike; -use base::vector::VectorBorrowed; +use algorithm::{Page, PageGuard, RelationWrite}; use rand::Rng; use rkyv::ser::serializers::AllocSerializer; +use simd::Floating; use std::marker::PhantomData; +use std::num::NonZeroU64; use std::sync::Arc; +use vector::VectorBorrowed; pub trait HeapRelation { fn traverse(&self, progress: bool, callback: F) where - F: FnMut((Pointer, V)); + F: FnMut((NonZeroU64, V)); fn opfamily(&self) -> Opfamily; } @@ -207,8 +206,8 @@ impl Structure { let mut vectors = BTreeMap::new(); pgrx::spi::Spi::connect(|client| { use crate::datatype::memory_pgvector_vector::PgvectorVectorOutput; - use base::vector::VectorBorrowed; use pgrx::pg_sys::panic::ErrorReportable; + use vector::VectorBorrowed; let schema_query = "SELECT n.nspname::TEXT FROM pg_catalog.pg_extension e LEFT JOIN pg_catalog.pg_namespace n ON n.oid = e.extnamespace diff --git a/src/vchordrq/algorithm/insert.rs b/src/vchordrq/algorithm/insert.rs index 323fa24..f625dca 100644 --- a/src/vchordrq/algorithm/insert.rs +++ b/src/vchordrq/algorithm/insert.rs @@ -1,19 +1,18 @@ -use super::RelationWrite; -use crate::vchordrq::algorithm::rabitq::fscan_process_lowerbound; +use crate::vchordrq::algorithm::rabitq; use crate::vchordrq::algorithm::tuples::*; use crate::vchordrq::algorithm::vectors; -use crate::vchordrq::algorithm::PageGuard; -use base::always_equal::AlwaysEqual; -use base::distance::Distance; -use base::distance::DistanceKind; -use base::search::Pointer; -use base::vector::VectorBorrowed; +use crate::vchordrq::types::DistanceKind; +use algorithm::{Page, PageGuard, RelationWrite}; +use always_equal::AlwaysEqual; +use distance::Distance; use std::cmp::Reverse; use std::collections::BinaryHeap; +use std::num::NonZeroU64; +use vector::VectorBorrowed; pub fn insert( relation: impl RelationWrite + Clone, - payload: Pointer, + payload: NonZeroU64, vector: V, distance_kind: DistanceKind, in_building: bool, @@ -41,7 +40,7 @@ pub fn insert( for i in (0..slices.len()).rev() { let tuple = rkyv::to_bytes::<_, 8192>(&VectorTuple:: { slice: slices[i].to_vec(), - payload: Some(payload.as_u64()), + payload: Some(payload), chain, }) .unwrap(); @@ -56,7 +55,7 @@ pub fn insert( } chain.ok().unwrap() }; - let h0_payload = payload.as_u64(); + let h0_payload = payload; let mut list = { let Some((_, original)) = vectors::vector_dist::( relation.clone(), @@ -90,7 +89,7 @@ pub fn insert( .map(rkyv::check_archived_root::) .expect("data corruption") .expect("data corruption"); - let lowerbounds = fscan_process_lowerbound( + let lowerbounds = rabitq::process_lowerbound( distance_kind, dims, lut, diff --git a/src/vchordrq/algorithm/mod.rs b/src/vchordrq/algorithm/mod.rs index 41744d7..88239a8 100644 --- a/src/vchordrq/algorithm/mod.rs +++ b/src/vchordrq/algorithm/mod.rs @@ -6,62 +6,3 @@ pub mod scan; pub mod tuples; pub mod vacuum; pub mod vectors; - -use crate::postgres::Page; -use std::ops::{Deref, DerefMut}; - -pub trait PageGuard { - fn id(&self) -> u32; -} - -pub trait RelationRead { - type ReadGuard<'a>: PageGuard + Deref - where - Self: 'a; - fn read(&self, id: u32) -> Self::ReadGuard<'_>; -} - -pub trait RelationWrite: RelationRead { - type WriteGuard<'a>: PageGuard + DerefMut - where - Self: 'a; - fn write(&self, id: u32, tracking_freespace: bool) -> Self::WriteGuard<'_>; - fn extend(&self, tracking_freespace: bool) -> Self::WriteGuard<'_>; - fn search(&self, freespace: usize) -> Option>; -} - -impl PageGuard for crate::postgres::BufferReadGuard { - fn id(&self) -> u32 { - self.id() - } -} - -impl PageGuard for crate::postgres::BufferWriteGuard { - fn id(&self) -> u32 { - self.id() - } -} - -impl RelationRead for crate::postgres::Relation { - type ReadGuard<'a> = crate::postgres::BufferReadGuard; - - fn read(&self, id: u32) -> Self::ReadGuard<'_> { - self.read(id) - } -} - -impl RelationWrite for crate::postgres::Relation { - type WriteGuard<'a> = crate::postgres::BufferWriteGuard; - - fn write(&self, id: u32, tracking_freespace: bool) -> Self::WriteGuard<'_> { - self.write(id, tracking_freespace) - } - - fn extend(&self, tracking_freespace: bool) -> Self::WriteGuard<'_> { - self.extend(tracking_freespace) - } - - fn search(&self, freespace: usize) -> Option> { - self.search(freespace) - } -} diff --git a/src/vchordrq/algorithm/prewarm.rs b/src/vchordrq/algorithm/prewarm.rs index 6d01f7c..6a7dc25 100644 --- a/src/vchordrq/algorithm/prewarm.rs +++ b/src/vchordrq/algorithm/prewarm.rs @@ -1,6 +1,6 @@ -use super::RelationRead; use crate::vchordrq::algorithm::tuples::*; use crate::vchordrq::algorithm::vectors; +use algorithm::{Page, RelationRead}; use std::fmt::Write; pub fn prewarm(relation: impl RelationRead + Clone, height: i32) -> String { diff --git a/src/vchordrq/algorithm/rabitq.rs b/src/vchordrq/algorithm/rabitq.rs index b7b3858..4f406e1 100644 --- a/src/vchordrq/algorithm/rabitq.rs +++ b/src/vchordrq/algorithm/rabitq.rs @@ -1,128 +1,21 @@ -use base::distance::{Distance, DistanceKind}; -use base::simd::ScalarLike; +use crate::vchordrq::types::DistanceKind; +use distance::Distance; -#[derive(Debug, Clone)] -pub struct Code { - pub dis_u_2: f32, - pub factor_ppc: f32, - pub factor_ip: f32, - pub factor_err: f32, - pub signs: Vec, -} - -impl Code { - pub fn t(&self) -> Vec { - use crate::utils::infinite_byte_chunks::InfiniteByteChunks; - let mut result = Vec::new(); - for x in InfiniteByteChunks::<_, 64>::new(self.signs.iter().copied()) - .take(self.signs.len().div_ceil(64)) - { - let mut r = 0_u64; - for i in 0..64 { - r |= (x[i] as u64) << i; - } - result.push(r); - } - result - } -} - -pub fn code(dims: u32, vector: &[f32]) -> Code { - let sum_of_abs_x = f32::reduce_sum_of_abs_x(vector); - let sum_of_x_2 = f32::reduce_sum_of_x2(vector); - let dis_u = sum_of_x_2.sqrt(); - let x0 = sum_of_abs_x / (sum_of_x_2 * (dims as f32)).sqrt(); - let x_x0 = dis_u / x0; - let fac_norm = (dims as f32).sqrt(); - let max_x1 = 1.0f32 / (dims as f32 - 1.0).sqrt(); - let factor_err = 2.0f32 * max_x1 * (x_x0 * x_x0 - dis_u * dis_u).sqrt(); - let factor_ip = -2.0f32 / fac_norm * x_x0; - let cnt_pos = vector - .iter() - .map(|x| x.is_sign_positive() as i32) - .sum::(); - let cnt_neg = vector - .iter() - .map(|x| x.is_sign_negative() as i32) - .sum::(); - let factor_ppc = factor_ip * (cnt_pos - cnt_neg) as f32; - let mut signs = Vec::new(); - for i in 0..dims { - signs.push(vector[i as usize].is_sign_positive() as u8); - } - Code { - dis_u_2: sum_of_x_2, - factor_ppc, - factor_ip, - factor_err, - signs, - } -} +pub use rabitq::binary::Code; +pub use rabitq::binary::Lut; +pub use rabitq::binary::code; +pub use rabitq::binary::preprocess; +pub use rabitq::binary::{process_lowerbound_dot, process_lowerbound_l2}; -pub type Lut = (f32, f32, f32, f32, (Vec, Vec, Vec, Vec)); - -pub fn fscan_preprocess(vector: &[f32]) -> Lut { - use base::simd::quantize; - let dis_v_2 = f32::reduce_sum_of_x2(vector); - let (k, b, qvector) = quantize::quantize(vector, 15.0); - let qvector_sum = if vector.len() <= 4369 { - base::simd::u8::reduce_sum_of_x_as_u16(&qvector) as f32 - } else { - base::simd::u8::reduce_sum_of_x_as_u32(&qvector) as f32 - }; - (dis_v_2, b, k, qvector_sum, binarize(&qvector)) -} - -pub fn fscan_process_lowerbound( +pub fn process_lowerbound( distance_kind: DistanceKind, - _dims: u32, + dims: u32, lut: &Lut, - (dis_u_2, factor_ppc, factor_ip, factor_err, t): (f32, f32, f32, f32, &[u64]), + code: (f32, f32, f32, f32, &[u64]), epsilon: f32, ) -> Distance { match distance_kind { - DistanceKind::L2 => { - let &(dis_v_2, b, k, qvector_sum, ref s) = lut; - let value = asymmetric_binary_dot_product(t, s) as u16; - let rough = dis_u_2 - + dis_v_2 - + b * factor_ppc - + ((2.0 * value as f32) - qvector_sum) * factor_ip * k; - let err = factor_err * dis_v_2.sqrt(); - Distance::from_f32(rough - epsilon * err) - } - DistanceKind::Dot => { - let &(dis_v_2, b, k, qvector_sum, ref s) = lut; - let value = asymmetric_binary_dot_product(t, s) as u16; - let rough = - 0.5 * b * factor_ppc + 0.5 * ((2.0 * value as f32) - qvector_sum) * factor_ip * k; - let err = 0.5 * factor_err * dis_v_2.sqrt(); - Distance::from_f32(rough - epsilon * err) - } - DistanceKind::Hamming => unimplemented!(), - DistanceKind::Jaccard => unimplemented!(), - } -} - -fn binarize(vector: &[u8]) -> (Vec, Vec, Vec, Vec) { - let n = vector.len(); - let mut t0 = vec![0u64; n.div_ceil(64)]; - let mut t1 = vec![0u64; n.div_ceil(64)]; - let mut t2 = vec![0u64; n.div_ceil(64)]; - let mut t3 = vec![0u64; n.div_ceil(64)]; - for i in 0..n { - t0[i / 64] |= (((vector[i] >> 0) & 1) as u64) << (i % 64); - t1[i / 64] |= (((vector[i] >> 1) & 1) as u64) << (i % 64); - t2[i / 64] |= (((vector[i] >> 2) & 1) as u64) << (i % 64); - t3[i / 64] |= (((vector[i] >> 3) & 1) as u64) << (i % 64); + DistanceKind::L2 => process_lowerbound_l2(dims, lut, code, epsilon), + DistanceKind::Dot => process_lowerbound_dot(dims, lut, code, epsilon), } - (t0, t1, t2, t3) -} - -fn asymmetric_binary_dot_product(x: &[u64], y: &(Vec, Vec, Vec, Vec)) -> u32 { - let t0 = base::simd::bit::sum_of_and(x, &y.0); - let t1 = base::simd::bit::sum_of_and(x, &y.1); - let t2 = base::simd::bit::sum_of_and(x, &y.2); - let t3 = base::simd::bit::sum_of_and(x, &y.3); - (t0 << 0) + (t1 << 1) + (t2 << 2) + (t3 << 3) } diff --git a/src/vchordrq/algorithm/scan.rs b/src/vchordrq/algorithm/scan.rs index 42357c0..e6915f0 100644 --- a/src/vchordrq/algorithm/scan.rs +++ b/src/vchordrq/algorithm/scan.rs @@ -1,14 +1,14 @@ -use super::RelationRead; -use crate::vchordrq::algorithm::rabitq::fscan_process_lowerbound; +use crate::vchordrq::algorithm::rabitq; use crate::vchordrq::algorithm::tuples::*; use crate::vchordrq::algorithm::vectors; -use base::always_equal::AlwaysEqual; -use base::distance::Distance; -use base::distance::DistanceKind; -use base::search::Pointer; -use base::vector::VectorBorrowed; +use crate::vchordrq::types::DistanceKind; +use algorithm::{Page, RelationRead}; +use always_equal::AlwaysEqual; +use distance::Distance; use std::cmp::Reverse; use std::collections::BinaryHeap; +use std::num::NonZeroU64; +use vector::VectorBorrowed; pub fn scan( relation: impl RelationRead + Clone, @@ -16,7 +16,7 @@ pub fn scan( distance_kind: DistanceKind, probes: Vec, epsilon: f32, -) -> impl Iterator { +) -> impl Iterator { let vector = vector.as_borrowed(); let meta_guard = relation.read(0); let meta_tuple = meta_guard @@ -71,7 +71,7 @@ pub fn scan( .map(rkyv::check_archived_root::) .expect("data corruption") .expect("data corruption"); - let lowerbounds = fscan_process_lowerbound( + let lowerbounds = rabitq::process_lowerbound( distance_kind, dims, lut, @@ -143,7 +143,7 @@ pub fn scan( .map(rkyv::check_archived_root::) .expect("data corruption") .expect("data corruption"); - let lowerbounds = fscan_process_lowerbound( + let lowerbounds = rabitq::process_lowerbound( distance_kind, dims, lut, @@ -183,7 +183,7 @@ pub fn scan( cache.push((Reverse(dis_u), AlwaysEqual(pay_u))); } let (Reverse(dis_u), AlwaysEqual(pay_u)) = cache.pop()?; - Some((dis_u, Pointer::new(pay_u))) + Some((dis_u, pay_u)) }) } } diff --git a/src/vchordrq/algorithm/tuples.rs b/src/vchordrq/algorithm/tuples.rs index 40a795c..23fe664 100644 --- a/src/vchordrq/algorithm/tuples.rs +++ b/src/vchordrq/algorithm/tuples.rs @@ -1,10 +1,13 @@ +use std::num::NonZeroU64; + use super::rabitq::{self, Code, Lut}; +use crate::vchordrq::types::DistanceKind; use crate::vchordrq::types::OwnedVector; -use base::distance::DistanceKind; -use base::simd::ScalarLike; -use base::vector::{VectOwned, VectorOwned}; use half::f16; use rkyv::{Archive, ArchiveUnsized, CheckBytes, Deserialize, Serialize}; +use simd::Floating; +use vector::VectorOwned; +use vector::vect::VectOwned; pub trait Vector: VectorOwned { type Metadata: Copy @@ -70,20 +73,15 @@ impl Vector for VectOwned { type Element = f32; - fn metadata_from_archived(_: &::Archived) -> Self::Metadata { - () - } + fn metadata_from_archived(_: &::Archived) -> Self::Metadata {} fn vector_split(vector: Self::Borrowed<'_>) -> ((), Vec<&[f32]>) { let vector = vector.slice(); - ( - (), - match vector.len() { - 0..=960 => vec![vector], - 961..=1280 => vec![&vector[..640], &vector[640..]], - 1281.. => vector.chunks(1920).collect(), - }, - ) + ((), match vector.len() { + 0..=960 => vec![vector], + 961..=1280 => vec![&vector[..640], &vector[640..]], + 1281.. => vector.chunks(1920).collect(), + }) } fn vector_merge((): Self::Metadata, slice: &[Self::Element]) -> Self { @@ -109,8 +107,6 @@ impl Vector for VectOwned { match accumulator.0 { DistanceKind::L2 => accumulator.1 += f32::reduce_sum_of_d2(left, right), DistanceKind::Dot => accumulator.1 += -f32::reduce_sum_of_xy(left, right), - DistanceKind::Hamming => unreachable!(), - DistanceKind::Jaccard => unreachable!(), } } fn distance_end( @@ -126,11 +122,11 @@ impl Vector for VectOwned { } fn residual(vector: Self::Borrowed<'_>, center: Self::Borrowed<'_>) -> Self { - Self::new(ScalarLike::vector_sub(vector.slice(), center.slice())) + Self::new(Floating::vector_sub(vector.slice(), center.slice())) } fn rabitq_fscan_preprocess(vector: Self::Borrowed<'_>) -> Lut { - rabitq::fscan_preprocess(vector.slice()) + rabitq::preprocess(vector.slice()) } fn rabitq_code(dims: u32, vector: Self::Borrowed<'_>) -> Code { @@ -151,20 +147,15 @@ impl Vector for VectOwned { type Element = f16; - fn metadata_from_archived(_: &::Archived) -> Self::Metadata { - () - } + fn metadata_from_archived(_: &::Archived) -> Self::Metadata {} fn vector_split(vector: Self::Borrowed<'_>) -> ((), Vec<&[f16]>) { let vector = vector.slice(); - ( - (), - match vector.len() { - 0..=1920 => vec![vector], - 1921..=2560 => vec![&vector[..1280], &vector[1280..]], - 2561.. => vector.chunks(3840).collect(), - }, - ) + ((), match vector.len() { + 0..=1920 => vec![vector], + 1921..=2560 => vec![&vector[..1280], &vector[1280..]], + 2561.. => vector.chunks(3840).collect(), + }) } fn vector_merge((): Self::Metadata, slice: &[Self::Element]) -> Self { @@ -190,8 +181,6 @@ impl Vector for VectOwned { match accumulator.0 { DistanceKind::L2 => accumulator.1 += f16::reduce_sum_of_d2(left, right), DistanceKind::Dot => accumulator.1 += -f16::reduce_sum_of_xy(left, right), - DistanceKind::Hamming => unreachable!(), - DistanceKind::Jaccard => unreachable!(), } } fn distance_end( @@ -209,11 +198,11 @@ impl Vector for VectOwned { } fn residual(vector: Self::Borrowed<'_>, center: Self::Borrowed<'_>) -> Self { - Self::new(ScalarLike::vector_sub(vector.slice(), center.slice())) + Self::new(Floating::vector_sub(vector.slice(), center.slice())) } fn rabitq_fscan_preprocess(vector: Self::Borrowed<'_>) -> Lut { - rabitq::fscan_preprocess(&f16::vector_to_f32(vector.slice())) + rabitq::preprocess(&f16::vector_to_f32(vector.slice())) } fn rabitq_code(dims: u32, vector: Self::Borrowed<'_>) -> Code { @@ -246,7 +235,7 @@ pub struct MetaTuple { #[archive(check_bytes)] pub struct VectorTuple { pub slice: Vec, - pub payload: Option, + pub payload: Option, pub chain: Result<(u32, u16), V::Metadata>, } @@ -271,7 +260,7 @@ pub struct Height0Tuple { // raw vector pub mean: (u32, u16), // for height 0 tuple, it's pointers to heap relation - pub payload: u64, + pub payload: NonZeroU64, // RaBitQ algorithm pub dis_u_2: f32, pub factor_ppc: f32, diff --git a/src/vchordrq/algorithm/vacuum.rs b/src/vchordrq/algorithm/vacuum.rs index c77bd21..ee97ca6 100644 --- a/src/vchordrq/algorithm/vacuum.rs +++ b/src/vchordrq/algorithm/vacuum.rs @@ -1,11 +1,11 @@ -use super::RelationWrite; use crate::vchordrq::algorithm::tuples::*; -use base::search::Pointer; +use algorithm::{Page, RelationWrite}; +use std::num::NonZeroU64; pub fn vacuum( relation: impl RelationWrite, delay: impl Fn(), - callback: impl Fn(Pointer) -> bool, + callback: impl Fn(NonZeroU64) -> bool, ) { // step 1: vacuum height_0_tuple { @@ -50,7 +50,7 @@ pub fn vacuum( .map(rkyv::check_archived_root::) .expect("data corruption") .expect("data corruption"); - if callback(Pointer::new(h0_tuple.payload)) { + if callback(h0_tuple.payload) { reconstruct_removes.push(i); } } @@ -81,7 +81,7 @@ pub fn vacuum( let vector_tuple = unsafe { rkyv::archived_root::>(vector_tuple) }; if let Some(payload) = vector_tuple.payload.as_ref().copied() { - if callback(Pointer::new(payload)) { + if callback(payload) { break 'flag true; } } @@ -98,7 +98,7 @@ pub fn vacuum( let vector_tuple = unsafe { rkyv::archived_root::>(vector_tuple) }; if let Some(payload) = vector_tuple.payload.as_ref().copied() { - if callback(Pointer::new(payload)) { + if callback(payload) { write.free(i); } } diff --git a/src/vchordrq/algorithm/vectors.rs b/src/vchordrq/algorithm/vectors.rs index 06075d3..72e448a 100644 --- a/src/vchordrq/algorithm/vectors.rs +++ b/src/vchordrq/algorithm/vectors.rs @@ -1,14 +1,15 @@ use super::tuples::Vector; -use super::RelationRead; use crate::vchordrq::algorithm::tuples::VectorTuple; -use base::distance::Distance; -use base::distance::DistanceKind; +use crate::vchordrq::types::DistanceKind; +use algorithm::{Page, RelationRead}; +use distance::Distance; +use std::num::NonZeroU64; pub fn vector_dist( relation: impl RelationRead, vector: V::Borrowed<'_>, mean: (u32, u16), - payload: Option, + payload: Option, for_distance: Option, for_original: bool, ) -> Option<(Option, Option)> { diff --git a/src/vchordrq/gucs/executing.rs b/src/vchordrq/gucs/executing.rs index 6b20f74..af6cce7 100644 --- a/src/vchordrq/gucs/executing.rs +++ b/src/vchordrq/gucs/executing.rs @@ -68,9 +68,5 @@ pub fn epsilon() -> f32 { pub fn max_scan_tuples() -> Option { let x = MAX_SCAN_TUPLES.get(); - if x < 0 { - None - } else { - Some(x as u32) - } + if x < 0 { None } else { Some(x as u32) } } diff --git a/src/vchordrq/index/am.rs b/src/vchordrq/index/am.rs index 1dc7cdc..3d19a49 100644 --- a/src/vchordrq/index/am.rs +++ b/src/vchordrq/index/am.rs @@ -1,4 +1,4 @@ -use crate::postgres::Relation; +use crate::postgres::PostgresRelation; use crate::vchordrq::algorithm; use crate::vchordrq::algorithm::build::{HeapRelation, Reporter}; use crate::vchordrq::algorithm::tuples::Vector; @@ -7,11 +7,11 @@ use crate::vchordrq::index::am_scan::Scanner; use crate::vchordrq::index::utils::{ctid_to_pointer, pointer_to_ctid}; use crate::vchordrq::index::{am_options, am_scan}; use crate::vchordrq::types::VectorKind; -use base::search::Pointer; -use base::vector::VectOwned; use half::f16; use pgrx::datum::Internal; use pgrx::pg_sys::Datum; +use std::num::NonZeroU64; +use vector::vect::VectOwned; static mut RELOPT_KIND_VCHORDRQ: pgrx::pg_sys::relopt_kind::Type = 0; @@ -170,7 +170,7 @@ pub unsafe extern "C" fn ambuild( impl HeapRelation for Heap { fn traverse(&self, progress: bool, callback: F) where - F: FnMut((Pointer, V)), + F: FnMut((NonZeroU64, V)), { pub struct State<'a, F> { pub this: &'a Heap, @@ -185,7 +185,7 @@ pub unsafe extern "C" fn ambuild( _tuple_is_alive: bool, state: *mut core::ffi::c_void, ) where - F: FnMut((Pointer, V)), + F: FnMut((NonZeroU64, V)), { let state = unsafe { &mut *state.cast::>() }; let opfamily = state.this.opfamily; @@ -242,7 +242,7 @@ pub unsafe extern "C" fn ambuild( opfamily, }; let mut reporter = PgReporter {}; - let index_relation = unsafe { Relation::new(index) }; + let index_relation = unsafe { PostgresRelation::new(index) }; match opfamily.vector_kind() { VectorKind::Vecf32 => algorithm::build::build::, Heap, _>( vector_options, @@ -289,7 +289,7 @@ pub unsafe extern "C" fn ambuild( } else { let mut indtuples = 0; reporter.tuples_done(indtuples); - let relation = unsafe { Relation::new(index) }; + let relation = unsafe { PostgresRelation::new(index) }; match opfamily.vector_kind() { VectorKind::Vecf32 => { HeapRelation::>::traverse( @@ -575,7 +575,7 @@ unsafe fn parallel_build( impl HeapRelation for Heap { fn traverse(&self, progress: bool, callback: F) where - F: FnMut((Pointer, V)), + F: FnMut((NonZeroU64, V)), { pub struct State<'a, F> { pub this: &'a Heap, @@ -590,7 +590,7 @@ unsafe fn parallel_build( _tuple_is_alive: bool, state: *mut core::ffi::c_void, ) where - F: FnMut((Pointer, V)), + F: FnMut((NonZeroU64, V)), { let state = unsafe { &mut *state.cast::>() }; let opfamily = state.this.opfamily; @@ -627,7 +627,7 @@ unsafe fn parallel_build( } } - let index_relation = unsafe { Relation::new(index) }; + let index_relation = unsafe { PostgresRelation::new(index) }; let scan = unsafe { pgrx::pg_sys::table_beginscan_parallel(heap, tablescandesc) }; let opfamily = unsafe { am_options::opfamily(index) }; @@ -716,14 +716,14 @@ pub unsafe extern "C" fn aminsert( let pointer = ctid_to_pointer(unsafe { heap_tid.read() }); match opfamily.vector_kind() { VectorKind::Vecf32 => algorithm::insert::insert::>( - unsafe { Relation::new(index) }, + unsafe { PostgresRelation::new(index) }, pointer, VectOwned::::from_owned(vector), opfamily.distance_kind(), false, ), VectorKind::Vecf16 => algorithm::insert::insert::>( - unsafe { Relation::new(index) }, + unsafe { PostgresRelation::new(index) }, pointer, VectOwned::::from_owned(vector), opfamily.distance_kind(), @@ -752,14 +752,14 @@ pub unsafe extern "C" fn aminsert( let pointer = ctid_to_pointer(unsafe { heap_tid.read() }); match opfamily.vector_kind() { VectorKind::Vecf32 => algorithm::insert::insert::>( - unsafe { Relation::new(index) }, + unsafe { PostgresRelation::new(index) }, pointer, VectOwned::::from_owned(vector), opfamily.distance_kind(), false, ), VectorKind::Vecf16 => algorithm::insert::insert::>( - unsafe { Relation::new(index) }, + unsafe { PostgresRelation::new(index) }, pointer, VectOwned::::from_owned(vector), opfamily.distance_kind(), @@ -854,7 +854,7 @@ pub unsafe extern "C" fn amgettuple( pgrx::error!("scanning with a non-MVCC-compliant snapshot is not supported"); } let scanner = unsafe { (*scan).opaque.cast::().as_mut().unwrap_unchecked() }; - let relation = unsafe { Relation::new((*scan).indexRelation) }; + let relation = unsafe { PostgresRelation::new((*scan).indexRelation) }; if let Some((pointer, recheck)) = am_scan::scan_next(scanner, relation) { let ctid = pointer_to_ctid(pointer); unsafe { @@ -892,17 +892,17 @@ pub unsafe extern "C" fn ambulkdelete( } let opfamily = unsafe { am_options::opfamily((*info).index) }; let callback = callback.unwrap(); - let callback = |p: Pointer| unsafe { callback(&mut pointer_to_ctid(p), callback_state) }; + let callback = |p: NonZeroU64| unsafe { callback(&mut pointer_to_ctid(p), callback_state) }; match opfamily.vector_kind() { VectorKind::Vecf32 => algorithm::vacuum::vacuum::>( - unsafe { Relation::new((*info).index) }, + unsafe { PostgresRelation::new((*info).index) }, || unsafe { pgrx::pg_sys::vacuum_delay_point(); }, callback, ), VectorKind::Vecf16 => algorithm::vacuum::vacuum::>( - unsafe { Relation::new((*info).index) }, + unsafe { PostgresRelation::new((*info).index) }, || unsafe { pgrx::pg_sys::vacuum_delay_point(); }, diff --git a/src/vchordrq/index/am_options.rs b/src/vchordrq/index/am_options.rs index 25fcd0c..5c730ed 100644 --- a/src/vchordrq/index/am_options.rs +++ b/src/vchordrq/index/am_options.rs @@ -3,16 +3,16 @@ use crate::datatype::memory_pgvector_halfvec::PgvectorHalfvecOutput; use crate::datatype::memory_pgvector_vector::PgvectorVectorInput; use crate::datatype::memory_pgvector_vector::PgvectorVectorOutput; use crate::datatype::typmod::Typmod; -use crate::vchordrq::types::VchordrqIndexingOptions; -use crate::vchordrq::types::VectorOptions; -use crate::vchordrq::types::{BorrowedVector, OwnedVector, VectorKind}; -use base::distance::*; -use base::vector::VectorBorrowed; +use crate::vchordrq::types::{BorrowedVector, OwnedVector}; +use crate::vchordrq::types::{DistanceKind, VectorKind}; +use crate::vchordrq::types::{VchordrqIndexingOptions, VectorOptions}; +use distance::Distance; use pgrx::datum::FromDatum; use pgrx::heap_tuple::PgHeapTuple; use serde::Deserialize; use std::ffi::CStr; use std::num::NonZero; +use vector::VectorBorrowed; #[derive(Copy, Clone, Debug, Default)] #[repr(C)] diff --git a/src/vchordrq/index/am_scan.rs b/src/vchordrq/index/am_scan.rs index 1b78ff0..6e2d30d 100644 --- a/src/vchordrq/index/am_scan.rs +++ b/src/vchordrq/index/am_scan.rs @@ -1,5 +1,5 @@ use super::am_options::Opfamily; -use crate::postgres::Relation; +use crate::postgres::PostgresRelation; use crate::vchordrq::algorithm::scan::scan; use crate::vchordrq::algorithm::tuples::Vector; use crate::vchordrq::gucs::executing::epsilon; @@ -7,10 +7,10 @@ use crate::vchordrq::gucs::executing::max_scan_tuples; use crate::vchordrq::gucs::executing::probes; use crate::vchordrq::types::OwnedVector; use crate::vchordrq::types::VectorKind; -use base::distance::Distance; -use base::search::*; -use base::vector::VectOwned; +use distance::Distance; use half::f16; +use std::num::NonZeroU64; +use vector::vect::VectOwned; pub enum Scanner { Initial { @@ -19,7 +19,7 @@ pub enum Scanner { recheck: bool, }, Vbase { - vbase: Box>, + vbase: Box>, threshold: Option, recheck: bool, opfamily: Opfamily, @@ -66,7 +66,7 @@ pub fn scan_make( } } -pub fn scan_next(scanner: &mut Scanner, relation: Relation) -> Option<(Pointer, bool)> { +pub fn scan_next(scanner: &mut Scanner, relation: PostgresRelation) -> Option<(NonZeroU64, bool)> { if let Scanner::Initial { vector, threshold, diff --git a/src/vchordrq/index/functions.rs b/src/vchordrq/index/functions.rs index 05f348f..32e6d03 100644 --- a/src/vchordrq/index/functions.rs +++ b/src/vchordrq/index/functions.rs @@ -1,11 +1,11 @@ use super::am_options; -use crate::postgres::Relation; +use crate::postgres::PostgresRelation; use crate::vchordrq::algorithm::prewarm::prewarm; use crate::vchordrq::types::VectorKind; -use base::vector::VectOwned; use half::f16; use pgrx::pg_sys::Oid; use pgrx_catalog::{PgAm, PgClass}; +use vector::vect::VectOwned; #[pgrx::pg_extern(sql = "")] fn _vchordrq_prewarm(indexrelid: Oid, height: i32) -> String { @@ -21,7 +21,7 @@ fn _vchordrq_prewarm(indexrelid: Oid, height: i32) -> String { pgrx::error!("{:?} is not a vchordrq index", pg_class.relname()); } let index = unsafe { pgrx::pg_sys::index_open(indexrelid, pgrx::pg_sys::ShareLock as _) }; - let relation = unsafe { Relation::new(index) }; + let relation = unsafe { PostgresRelation::new(index) }; let opfamily = unsafe { am_options::opfamily(index) }; let message = match opfamily.vector_kind() { VectorKind::Vecf32 => prewarm::>(relation, height), diff --git a/src/vchordrq/index/utils.rs b/src/vchordrq/index/utils.rs index a5d85a3..726a597 100644 --- a/src/vchordrq/index/utils.rs +++ b/src/vchordrq/index/utils.rs @@ -1,7 +1,7 @@ -use base::search::*; +use std::num::NonZeroU64; -pub fn pointer_to_ctid(pointer: Pointer) -> pgrx::pg_sys::ItemPointerData { - let value = pointer.as_u64(); +pub fn pointer_to_ctid(pointer: NonZeroU64) -> pgrx::pg_sys::ItemPointerData { + let value = pointer.get(); pgrx::pg_sys::ItemPointerData { ip_blkid: pgrx::pg_sys::BlockIdData { bi_hi: ((value >> 32) & 0xffff) as u16, @@ -11,10 +11,10 @@ pub fn pointer_to_ctid(pointer: Pointer) -> pgrx::pg_sys::ItemPointerData { } } -pub fn ctid_to_pointer(ctid: pgrx::pg_sys::ItemPointerData) -> Pointer { +pub fn ctid_to_pointer(ctid: pgrx::pg_sys::ItemPointerData) -> NonZeroU64 { let mut value = 0; value |= (ctid.ip_blkid.bi_hi as u64) << 32; value |= (ctid.ip_blkid.bi_lo as u64) << 16; value |= ctid.ip_posid as u64; - Pointer::new(value) + NonZeroU64::new(value).expect("invalid pointer") } diff --git a/src/vchordrq/types.rs b/src/vchordrq/types.rs index 0e1bdc0..4ef2171 100644 --- a/src/vchordrq/types.rs +++ b/src/vchordrq/types.rs @@ -1,8 +1,7 @@ -use base::distance::DistanceKind; -use base::vector::{VectBorrowed, VectOwned}; use half::f16; use serde::{Deserialize, Serialize}; use validator::{Validate, ValidationError, ValidationErrors}; +use vector::vect::{VectBorrowed, VectOwned}; #[derive(Debug, Clone, Serialize, Deserialize, Validate)] #[serde(deny_unknown_fields)] @@ -111,6 +110,13 @@ pub enum BorrowedVector<'a> { Vecf16(VectBorrowed<'a, f16>), } +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum DistanceKind { + L2, + Dot, +} + #[repr(u8)] #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub enum VectorKind { diff --git a/src/vchordrqfscan/algorithm/build.rs b/src/vchordrqfscan/algorithm/build.rs index 9a0daa6..a245d77 100644 --- a/src/vchordrqfscan/algorithm/build.rs +++ b/src/vchordrqfscan/algorithm/build.rs @@ -1,25 +1,24 @@ -use crate::postgres::BufferWriteGuard; -use crate::postgres::Relation; use crate::vchordrqfscan::algorithm::rabitq; use crate::vchordrqfscan::algorithm::tuples::*; use crate::vchordrqfscan::index::am_options::Opfamily; +use crate::vchordrqfscan::types::DistanceKind; use crate::vchordrqfscan::types::VchordrqfscanBuildOptions; use crate::vchordrqfscan::types::VchordrqfscanExternalBuildOptions; use crate::vchordrqfscan::types::VchordrqfscanIndexingOptions; use crate::vchordrqfscan::types::VchordrqfscanInternalBuildOptions; use crate::vchordrqfscan::types::VectorOptions; -use base::distance::DistanceKind; -use base::search::Pointer; -use base::simd::ScalarLike; +use algorithm::{Page, PageGuard, RelationWrite}; use rand::Rng; use rkyv::ser::serializers::AllocSerializer; +use simd::Floating; use std::marker::PhantomData; +use std::num::NonZeroU64; use std::sync::Arc; pub trait HeapRelation { fn traverse(&self, progress: bool, callback: F) where - F: FnMut((Pointer, Vec)); + F: FnMut((NonZeroU64, Vec)); fn opfamily(&self) -> Opfamily; } @@ -31,7 +30,7 @@ pub fn build( vector_options: VectorOptions, vchordrqfscan_options: VchordrqfscanIndexingOptions, heap_relation: T, - relation: Relation, + relation: impl RelationWrite, mut reporter: R, ) { let dims = vector_options.dims; @@ -73,7 +72,7 @@ pub fn build( }; let mut meta = Tape::create(&relation, false); assert_eq!(meta.first(), 0); - let mut forwards = Tape::::create(&relation, false); + let mut forwards = Tape::::create(&relation, false); assert_eq!(forwards.first(), 1); let mut vectors = Tape::create(&relation, true); assert_eq!(vectors.first(), 2); @@ -94,10 +93,10 @@ pub fn build( let mut level = Vec::new(); for j in 0..structures[i].len() { if i == 0 { - let tape = Tape::::create(&relation, false); + let tape = Tape::::create(&relation, false); level.push(tape.first()); } else { - let mut tape = Tape::::create(&relation, false); + let mut tape = Tape::::create(&relation, false); let mut cache = Vec::new(); let h2_mean = &structures[i].means[j]; let h2_children = &structures[i].children[j]; @@ -250,8 +249,8 @@ impl Structure { let mut vectors = BTreeMap::new(); pgrx::spi::Spi::connect(|client| { use crate::datatype::memory_pgvector_vector::PgvectorVectorOutput; - use base::vector::VectorBorrowed; use pgrx::pg_sys::panic::ErrorReportable; + use vector::VectorBorrowed; let table = client.select(&query, None, None).unwrap_or_report(); for row in table { let id: Option = row.get_by_name("id").unwrap(); @@ -368,16 +367,16 @@ impl Structure { } } -struct Tape<'a, T> { - relation: &'a Relation, - head: BufferWriteGuard, +struct Tape<'a, 'b, T, R: 'b + RelationWrite> { + relation: &'a R, + head: R::WriteGuard<'b>, first: u32, tracking_freespace: bool, _phantom: PhantomData T>, } -impl<'a, T> Tape<'a, T> { - fn create(relation: &'a Relation, tracking_freespace: bool) -> Self { +impl<'a: 'b, 'b, T, R: 'b + RelationWrite> Tape<'a, 'b, T, R> { + fn create(relation: &'a R, tracking_freespace: bool) -> Self { let head = relation.extend(tracking_freespace); let first = head.id(); Self { @@ -393,7 +392,7 @@ impl<'a, T> Tape<'a, T> { } } -impl Tape<'_, T> +impl<'a: 'b, 'b, T, R: 'b + RelationWrite> Tape<'a, 'b, T, R> where T: rkyv::Serialize>, { diff --git a/src/vchordrqfscan/algorithm/insert.rs b/src/vchordrqfscan/algorithm/insert.rs index 1a0ab77..ee44760 100644 --- a/src/vchordrqfscan/algorithm/insert.rs +++ b/src/vchordrqfscan/algorithm/insert.rs @@ -1,17 +1,22 @@ -use crate::postgres::Relation; use crate::vchordrqfscan::algorithm::rabitq; -use crate::vchordrqfscan::algorithm::rabitq::distance; use crate::vchordrqfscan::algorithm::rabitq::fscan_process_lowerbound; use crate::vchordrqfscan::algorithm::tuples::*; -use base::always_equal::AlwaysEqual; -use base::distance::Distance; -use base::distance::DistanceKind; -use base::search::Pointer; -use base::simd::ScalarLike; +use crate::vchordrqfscan::types::DistanceKind; +use crate::vchordrqfscan::types::distance; +use algorithm::{Page, PageGuard, RelationWrite}; +use always_equal::AlwaysEqual; +use distance::Distance; +use simd::Floating; use std::cmp::Reverse; use std::collections::BinaryHeap; +use std::num::NonZeroU64; -pub fn insert(relation: Relation, payload: Pointer, vector: Vec, distance_kind: DistanceKind) { +pub fn insert( + relation: impl RelationWrite, + payload: NonZeroU64, + vector: Vec, + distance_kind: DistanceKind, +) { let meta_guard = relation.read(0); let meta_tuple = meta_guard .get(1) @@ -30,7 +35,7 @@ pub fn insert(relation: Relation, payload: Pointer, vector: Vec, distance_k let h0_vector = 'h0_vector: { let tuple = rkyv::to_bytes::<_, 8192>(&VectorTuple { vector: vector.clone(), - payload: Some(payload.as_u64()), + payload: Some(payload), }) .unwrap(); if let Some(mut write) = relation.search(tuple.len()) { @@ -78,7 +83,7 @@ pub fn insert(relation: Relation, payload: Pointer, vector: Vec, distance_k changed = true; } }; - let h0_payload = payload.as_u64(); + let h0_payload = payload; let mut list = ( meta_tuple.first, if is_residual { @@ -173,7 +178,7 @@ pub fn insert(relation: Relation, payload: Pointer, vector: Vec, distance_k let dummy = rkyv::to_bytes::<_, 8192>(&Height0Tuple { mask: [false; 32], mean: [(0, 0); 32], - payload: [0; 32], + payload: [NonZeroU64::MIN; 32], dis_u_2: [0.0f32; 32], factor_ppc: [0.0f32; 32], factor_ip: [0.0f32; 32], diff --git a/src/vchordrqfscan/algorithm/prewarm.rs b/src/vchordrqfscan/algorithm/prewarm.rs index ec8976a..c8500d4 100644 --- a/src/vchordrqfscan/algorithm/prewarm.rs +++ b/src/vchordrqfscan/algorithm/prewarm.rs @@ -1,8 +1,8 @@ -use crate::postgres::Relation; use crate::vchordrqfscan::algorithm::tuples::*; +use algorithm::{Page, RelationRead}; use std::fmt::Write; -pub fn prewarm(relation: Relation, height: i32) -> String { +pub fn prewarm(relation: impl RelationRead, height: i32) -> String { let mut message = String::new(); let meta_guard = relation.read(0); let meta_tuple = meta_guard diff --git a/src/vchordrqfscan/algorithm/rabitq.rs b/src/vchordrqfscan/algorithm/rabitq.rs index 65c4996..707d81c 100644 --- a/src/vchordrqfscan/algorithm/rabitq.rs +++ b/src/vchordrqfscan/algorithm/rabitq.rs @@ -1,171 +1,22 @@ -use crate::utils::infinite_byte_chunks::InfiniteByteChunks; -use base::distance::{Distance, DistanceKind}; -use base::simd::ScalarLike; +use crate::vchordrqfscan::types::DistanceKind; +use distance::Distance; -#[derive(Debug, Clone)] -pub struct Code { - pub dis_u_2: f32, - pub factor_ppc: f32, - pub factor_ip: f32, - pub factor_err: f32, - pub signs: Vec, -} - -pub fn code(dims: u32, vector: &[f32]) -> Code { - let sum_of_abs_x = f32::reduce_sum_of_abs_x(vector); - let sum_of_x_2 = f32::reduce_sum_of_x2(vector); - let dis_u = sum_of_x_2.sqrt(); - let x0 = sum_of_abs_x / (sum_of_x_2 * (dims as f32)).sqrt(); - let x_x0 = dis_u / x0; - let fac_norm = (dims as f32).sqrt(); - let max_x1 = 1.0f32 / (dims as f32 - 1.0).sqrt(); - let factor_err = 2.0f32 * max_x1 * (x_x0 * x_x0 - dis_u * dis_u).sqrt(); - let factor_ip = -2.0f32 / fac_norm * x_x0; - let cnt_pos = vector - .iter() - .map(|x| x.is_sign_positive() as i32) - .sum::(); - let cnt_neg = vector - .iter() - .map(|x| x.is_sign_negative() as i32) - .sum::(); - let factor_ppc = factor_ip * (cnt_pos - cnt_neg) as f32; - let mut signs = Vec::new(); - for i in 0..dims { - signs.push(vector[i as usize].is_sign_positive() as u8); - } - Code { - dis_u_2: sum_of_x_2, - factor_ppc, - factor_ip, - factor_err, - signs, - } -} - -pub fn dummy_code(dims: u32) -> Code { - Code { - dis_u_2: 0.0, - factor_ppc: 0.0, - factor_ip: 0.0, - factor_err: 0.0, - signs: vec![0; dims as _], - } -} - -pub struct PackedCodes { - pub dis_u_2: [f32; 32], - pub factor_ppc: [f32; 32], - pub factor_ip: [f32; 32], - pub factor_err: [f32; 32], - pub t: Vec, -} - -pub fn pack_codes(dims: u32, codes: [Code; 32]) -> PackedCodes { - PackedCodes { - dis_u_2: std::array::from_fn(|i| codes[i].dis_u_2), - factor_ppc: std::array::from_fn(|i| codes[i].factor_ppc), - factor_ip: std::array::from_fn(|i| codes[i].factor_ip), - factor_err: std::array::from_fn(|i| codes[i].factor_err), - t: { - let signs = codes.map(|code| { - InfiniteByteChunks::new(code.signs.into_iter()) - .map(|[b0, b1, b2, b3]| b0 | b1 << 1 | b2 << 2 | b3 << 3) - .take(dims.div_ceil(4) as usize) - .collect::>() - }); - base::simd::fast_scan::b4::pack(dims.div_ceil(4), signs).collect() - }, - } -} - -pub fn fscan_preprocess(vector: &[f32]) -> (f32, f32, f32, f32, Vec) { - use base::simd::quantize; - let dis_v_2 = f32::reduce_sum_of_x2(vector); - let (k, b, qvector) = quantize::quantize(vector, 15.0); - let qvector_sum = if vector.len() <= 4369 { - base::simd::u8::reduce_sum_of_x_as_u16(&qvector) as f32 - } else { - base::simd::u8::reduce_sum_of_x_as_u32(&qvector) as f32 - }; - (dis_v_2, b, k, qvector_sum, compress(qvector)) -} +pub use rabitq::block::Code; +pub use rabitq::block::code; +pub use rabitq::block::dummy_code; +pub use rabitq::block::fscan_preprocess; +pub use rabitq::block::pack_codes; +pub use rabitq::block::{fscan_process_lowerbound_dot, fscan_process_lowerbound_l2}; pub fn fscan_process_lowerbound( distance_kind: DistanceKind, dims: u32, lut: &(f32, f32, f32, f32, Vec), - (dis_u_2, factor_ppc, factor_ip, factor_err, t): ( - &[f32; 32], - &[f32; 32], - &[f32; 32], - &[f32; 32], - &[u8], - ), + code: (&[f32; 32], &[f32; 32], &[f32; 32], &[f32; 32], &[u8]), epsilon: f32, ) -> [Distance; 32] { - let &(dis_v_2, b, k, qvector_sum, ref s) = lut; - let r = base::simd::fast_scan::b4::fast_scan_b4(dims.div_ceil(4), t, s); match distance_kind { - DistanceKind::L2 => std::array::from_fn(|i| { - let rough = dis_u_2[i] - + dis_v_2 - + b * factor_ppc[i] - + ((2.0 * r[i] as f32) - qvector_sum) * factor_ip[i] * k; - let err = factor_err[i] * dis_v_2.sqrt(); - Distance::from_f32(rough - epsilon * err) - }), - DistanceKind::Dot => std::array::from_fn(|i| { - let rough = 0.5 * b * factor_ppc[i] - + 0.5 * ((2.0 * r[i] as f32) - qvector_sum) * factor_ip[i] * k; - let err = 0.5 * factor_err[i] * dis_v_2.sqrt(); - Distance::from_f32(rough - epsilon * err) - }), - DistanceKind::Hamming => unreachable!(), - DistanceKind::Jaccard => unreachable!(), - } -} - -fn compress(mut qvector: Vec) -> Vec { - let dims = qvector.len() as u32; - let width = dims.div_ceil(4); - qvector.resize(qvector.len().next_multiple_of(4), 0); - let mut t = vec![0u8; width as usize * 16]; - for i in 0..width as usize { - unsafe { - // this hint is used to skip bound checks - std::hint::assert_unchecked(4 * i + 3 < qvector.len()); - std::hint::assert_unchecked(16 * i + 15 < t.len()); - } - let t0 = qvector[4 * i + 0]; - let t1 = qvector[4 * i + 1]; - let t2 = qvector[4 * i + 2]; - let t3 = qvector[4 * i + 3]; - t[16 * i + 0b0000] = 0; - t[16 * i + 0b0001] = t0; - t[16 * i + 0b0010] = t1; - t[16 * i + 0b0011] = t1 + t0; - t[16 * i + 0b0100] = t2; - t[16 * i + 0b0101] = t2 + t0; - t[16 * i + 0b0110] = t2 + t1; - t[16 * i + 0b0111] = t2 + t1 + t0; - t[16 * i + 0b1000] = t3; - t[16 * i + 0b1001] = t3 + t0; - t[16 * i + 0b1010] = t3 + t1; - t[16 * i + 0b1011] = t3 + t1 + t0; - t[16 * i + 0b1100] = t3 + t2; - t[16 * i + 0b1101] = t3 + t2 + t0; - t[16 * i + 0b1110] = t3 + t2 + t1; - t[16 * i + 0b1111] = t3 + t2 + t1 + t0; - } - t -} - -pub fn distance(d: DistanceKind, lhs: &[f32], rhs: &[f32]) -> Distance { - match d { - DistanceKind::L2 => Distance::from_f32(f32::reduce_sum_of_d2(lhs, rhs)), - DistanceKind::Dot => Distance::from_f32(-f32::reduce_sum_of_xy(lhs, rhs)), - DistanceKind::Hamming => unimplemented!(), - DistanceKind::Jaccard => unimplemented!(), + DistanceKind::L2 => fscan_process_lowerbound_l2(dims, lut, code, epsilon), + DistanceKind::Dot => fscan_process_lowerbound_dot(dims, lut, code, epsilon), } } diff --git a/src/vchordrqfscan/algorithm/scan.rs b/src/vchordrqfscan/algorithm/scan.rs index a691b6c..5546af9 100644 --- a/src/vchordrqfscan/algorithm/scan.rs +++ b/src/vchordrqfscan/algorithm/scan.rs @@ -1,23 +1,23 @@ -use crate::postgres::Relation; use crate::vchordrqfscan::algorithm::rabitq; -use crate::vchordrqfscan::algorithm::rabitq::distance; use crate::vchordrqfscan::algorithm::rabitq::fscan_process_lowerbound; use crate::vchordrqfscan::algorithm::tuples::*; -use base::always_equal::AlwaysEqual; -use base::distance::Distance; -use base::distance::DistanceKind; -use base::search::Pointer; -use base::simd::ScalarLike; +use crate::vchordrqfscan::types::DistanceKind; +use crate::vchordrqfscan::types::distance; +use algorithm::{Page, RelationWrite}; +use always_equal::AlwaysEqual; +use distance::Distance; +use simd::Floating; use std::cmp::Reverse; use std::collections::BinaryHeap; +use std::num::NonZeroU64; pub fn scan( - relation: Relation, + relation: impl RelationWrite + Clone, vector: Vec, distance_kind: DistanceKind, probes: Vec, epsilon: f32, -) -> impl Iterator { +) -> impl Iterator { let meta_guard = relation.read(0); let meta_tuple = meta_guard .get(1) @@ -123,6 +123,7 @@ pub fn scan( for i in (1..meta_tuple.height_of_root).rev() { lists = make_lists(lists, probes[i as usize - 1]); } + drop(meta_guard); { let mut results = Vec::new(); for list in lists { @@ -186,7 +187,7 @@ pub fn scan( cache.push((Reverse(dis_u), AlwaysEqual(pay_u))); } let (Reverse(dis_u), AlwaysEqual(pay_u)) = cache.pop()?; - Some((dis_u, Pointer::new(pay_u))) + Some((dis_u, pay_u)) }) } } diff --git a/src/vchordrqfscan/algorithm/tuples.rs b/src/vchordrqfscan/algorithm/tuples.rs index 3b43dac..ff94713 100644 --- a/src/vchordrqfscan/algorithm/tuples.rs +++ b/src/vchordrqfscan/algorithm/tuples.rs @@ -1,3 +1,5 @@ +use std::num::NonZeroU64; + use crate::vchordrqfscan::algorithm::rabitq; use rkyv::{Archive, Deserialize, Serialize}; @@ -20,7 +22,7 @@ pub struct MetaTuple { pub struct VectorTuple { pub vector: Vec, // this field is saved only for vacuum - pub payload: Option, + pub payload: Option, } #[derive(Clone, PartialEq, Archive, Serialize, Deserialize)] @@ -46,7 +48,7 @@ pub struct Height0Tuple { // raw vector pub mean: [(u32, u16); 32], // for height 0 tuple, it's pointers to heap relation - pub payload: [u64; 32], + pub payload: [NonZeroU64; 32], // RaBitQ algorithm pub dis_u_2: [f32; 32], pub factor_ppc: [f32; 32], @@ -60,7 +62,7 @@ pub fn put( dims: u32, code: &rabitq::Code, vector: (u32, u16), - payload: u64, + payload: NonZeroU64, ) -> bool { // todo: use mutable api let mut x = rkyv::from_bytes::(bytes).expect("data corruption"); diff --git a/src/vchordrqfscan/algorithm/vacuum.rs b/src/vchordrqfscan/algorithm/vacuum.rs index 8ede4f6..7773ed2 100644 --- a/src/vchordrqfscan/algorithm/vacuum.rs +++ b/src/vchordrqfscan/algorithm/vacuum.rs @@ -1,9 +1,13 @@ -use crate::postgres::Relation; use crate::vchordrqfscan::algorithm::tuples::VectorTuple; use crate::vchordrqfscan::algorithm::tuples::*; -use base::search::Pointer; +use algorithm::{Page, RelationWrite}; +use std::num::NonZeroU64; -pub fn vacuum(relation: Relation, delay: impl Fn(), callback: impl Fn(Pointer) -> bool) { +pub fn vacuum( + relation: impl RelationWrite, + delay: impl Fn(), + callback: impl Fn(NonZeroU64) -> bool, +) { // step 1: vacuum height_0_tuple { let meta_guard = relation.read(0); @@ -52,7 +56,7 @@ pub fn vacuum(relation: Relation, delay: impl Fn(), callback: impl Fn(Pointer) - .expect("data corruption"); let flag = 'flag: { for j in 0..32 { - if h0_tuple.mask[j] && callback(Pointer::new(h0_tuple.payload[j])) { + if h0_tuple.mask[j] && callback(h0_tuple.payload[j]) { break 'flag true; } } @@ -66,7 +70,7 @@ pub fn vacuum(relation: Relation, delay: impl Fn(), callback: impl Fn(Pointer) - .expect("data corruption") .expect("data corruption"); for j in 0..32 { - if temp.mask[j] && callback(Pointer::new(temp.payload[j])) { + if temp.mask[j] && callback(temp.payload[j]) { temp.mask[j] = false; } } @@ -104,7 +108,7 @@ pub fn vacuum(relation: Relation, delay: impl Fn(), callback: impl Fn(Pointer) - let vector_tuple = rkyv::check_archived_root::(vector_tuple) .expect("data corruption"); if let Some(payload) = vector_tuple.payload.as_ref().copied() { - if callback(Pointer::new(payload)) { + if callback(payload) { break 'flag true; } } @@ -121,7 +125,7 @@ pub fn vacuum(relation: Relation, delay: impl Fn(), callback: impl Fn(Pointer) - let vector_tuple = rkyv::check_archived_root::(vector_tuple) .expect("data corruption"); if let Some(payload) = vector_tuple.payload.as_ref().copied() { - if callback(Pointer::new(payload)) { + if callback(payload) { write.free(i); } } diff --git a/src/vchordrqfscan/gucs/executing.rs b/src/vchordrqfscan/gucs/executing.rs index 6ec186e..ba204b9 100644 --- a/src/vchordrqfscan/gucs/executing.rs +++ b/src/vchordrqfscan/gucs/executing.rs @@ -68,9 +68,5 @@ pub fn epsilon() -> f32 { pub fn max_scan_tuples() -> Option { let x = MAX_SCAN_TUPLES.get(); - if x < 0 { - None - } else { - Some(x as u32) - } + if x < 0 { None } else { Some(x as u32) } } diff --git a/src/vchordrqfscan/index/am.rs b/src/vchordrqfscan/index/am.rs index 2b834fb..7011e26 100644 --- a/src/vchordrqfscan/index/am.rs +++ b/src/vchordrqfscan/index/am.rs @@ -1,13 +1,13 @@ -use crate::postgres::Relation; +use crate::postgres::PostgresRelation; use crate::vchordrqfscan::algorithm; use crate::vchordrqfscan::algorithm::build::{HeapRelation, Reporter}; use crate::vchordrqfscan::index::am_options::{Opfamily, Reloption}; use crate::vchordrqfscan::index::am_scan::Scanner; use crate::vchordrqfscan::index::utils::{ctid_to_pointer, pointer_to_ctid}; use crate::vchordrqfscan::index::{am_options, am_scan}; -use base::search::Pointer; use pgrx::datum::Internal; use pgrx::pg_sys::Datum; +use std::num::NonZeroU64; static mut RELOPT_KIND_VCHORDRQFSCAN: pgrx::pg_sys::relopt_kind::Type = 0; @@ -166,7 +166,7 @@ pub unsafe extern "C" fn ambuild( impl HeapRelation for Heap { fn traverse(&self, progress: bool, callback: F) where - F: FnMut((Pointer, Vec)), + F: FnMut((NonZeroU64, Vec)), { pub struct State<'a, F> { pub this: &'a Heap, @@ -181,7 +181,7 @@ pub unsafe extern "C" fn ambuild( _tuple_is_alive: bool, state: *mut core::ffi::c_void, ) where - F: FnMut((Pointer, Vec)), + F: FnMut((NonZeroU64, Vec)), { use crate::vchordrqfscan::types::OwnedVector; let state = unsafe { &mut *state.cast::>() }; @@ -242,7 +242,7 @@ pub unsafe extern "C" fn ambuild( opfamily, }; let mut reporter = PgReporter {}; - let index_relation = unsafe { Relation::new(index) }; + let index_relation = unsafe { PostgresRelation::new(index) }; algorithm::build::build( vector_options, vchordrqfscan_options, @@ -552,7 +552,7 @@ unsafe fn parallel_build( impl HeapRelation for Heap { fn traverse(&self, progress: bool, callback: F) where - F: FnMut((Pointer, Vec)), + F: FnMut((NonZeroU64, Vec)), { pub struct State<'a, F> { pub this: &'a Heap, @@ -567,7 +567,7 @@ unsafe fn parallel_build( _tuple_is_alive: bool, state: *mut core::ffi::c_void, ) where - F: FnMut((Pointer, Vec)), + F: FnMut((NonZeroU64, Vec)), { use crate::vchordrqfscan::types::OwnedVector; let state = unsafe { &mut *state.cast::>() }; @@ -608,7 +608,7 @@ unsafe fn parallel_build( } } - let index_relation = unsafe { Relation::new(index) }; + let index_relation = unsafe { PostgresRelation::new(index) }; let scan = unsafe { pgrx::pg_sys::table_beginscan_parallel(heap, tablescandesc) }; let opfamily = unsafe { am_options::opfamily(index) }; let heap_relation = Heap { @@ -672,7 +672,7 @@ pub unsafe extern "C" fn aminsert( }; let pointer = ctid_to_pointer(unsafe { heap_tid.read() }); algorithm::insert::insert( - unsafe { Relation::new(index) }, + unsafe { PostgresRelation::new(index) }, pointer, vector.into_vec(), opfamily.distance_kind(), @@ -702,7 +702,7 @@ pub unsafe extern "C" fn aminsert( }; let pointer = ctid_to_pointer(unsafe { heap_tid.read() }); algorithm::insert::insert( - unsafe { Relation::new(index) }, + unsafe { PostgresRelation::new(index) }, pointer, vector.into_vec(), opfamily.distance_kind(), @@ -795,7 +795,7 @@ pub unsafe extern "C" fn amgettuple( pgrx::error!("scanning with a non-MVCC-compliant snapshot is not supported"); } let scanner = unsafe { (*scan).opaque.cast::().as_mut().unwrap_unchecked() }; - let relation = unsafe { Relation::new((*scan).indexRelation) }; + let relation = unsafe { PostgresRelation::new((*scan).indexRelation) }; if let Some((pointer, recheck)) = am_scan::scan_next(scanner, relation) { let ctid = pointer_to_ctid(pointer); unsafe { @@ -832,9 +832,9 @@ pub unsafe extern "C" fn ambulkdelete( }; } let callback = callback.unwrap(); - let callback = |p: Pointer| unsafe { callback(&mut pointer_to_ctid(p), callback_state) }; + let callback = |p: NonZeroU64| unsafe { callback(&mut pointer_to_ctid(p), callback_state) }; algorithm::vacuum::vacuum( - unsafe { Relation::new((*info).index) }, + unsafe { PostgresRelation::new((*info).index) }, || unsafe { pgrx::pg_sys::vacuum_delay_point(); }, diff --git a/src/vchordrqfscan/index/am_options.rs b/src/vchordrqfscan/index/am_options.rs index b49b7a2..be4154e 100644 --- a/src/vchordrqfscan/index/am_options.rs +++ b/src/vchordrqfscan/index/am_options.rs @@ -1,14 +1,16 @@ use crate::datatype::memory_pgvector_vector::PgvectorVectorInput; use crate::datatype::memory_pgvector_vector::PgvectorVectorOutput; use crate::datatype::typmod::Typmod; -use crate::vchordrqfscan::types::*; -use base::distance::*; -use base::vector::VectorBorrowed; +use crate::vchordrqfscan::types::{BorrowedVector, OwnedVector}; +use crate::vchordrqfscan::types::{DistanceKind, VectorKind}; +use crate::vchordrqfscan::types::{VchordrqfscanIndexingOptions, VectorOptions}; +use distance::Distance; use pgrx::datum::FromDatum; use pgrx::heap_tuple::PgHeapTuple; use serde::Deserialize; use std::ffi::CStr; use std::num::NonZero; +use vector::VectorBorrowed; #[derive(Copy, Clone, Debug, Default)] #[repr(C)] diff --git a/src/vchordrqfscan/index/am_scan.rs b/src/vchordrqfscan/index/am_scan.rs index b07edb7..da049ab 100644 --- a/src/vchordrqfscan/index/am_scan.rs +++ b/src/vchordrqfscan/index/am_scan.rs @@ -1,12 +1,12 @@ use super::am_options::Opfamily; -use crate::postgres::Relation; +use crate::postgres::PostgresRelation; use crate::vchordrqfscan::algorithm::scan::scan; use crate::vchordrqfscan::gucs::executing::epsilon; use crate::vchordrqfscan::gucs::executing::max_scan_tuples; use crate::vchordrqfscan::gucs::executing::probes; use crate::vchordrqfscan::types::OwnedVector; -use base::distance::Distance; -use base::search::*; +use distance::Distance; +use std::num::NonZeroU64; pub enum Scanner { Initial { @@ -15,7 +15,7 @@ pub enum Scanner { recheck: bool, }, Vbase { - vbase: Box>, + vbase: Box>, threshold: Option, recheck: bool, opfamily: Opfamily, @@ -62,7 +62,7 @@ pub fn scan_make( } } -pub fn scan_next(scanner: &mut Scanner, relation: Relation) -> Option<(Pointer, bool)> { +pub fn scan_next(scanner: &mut Scanner, relation: PostgresRelation) -> Option<(NonZeroU64, bool)> { if let Scanner::Initial { vector, threshold, diff --git a/src/vchordrqfscan/index/functions.rs b/src/vchordrqfscan/index/functions.rs index 98f0f25..27bd9ac 100644 --- a/src/vchordrqfscan/index/functions.rs +++ b/src/vchordrqfscan/index/functions.rs @@ -1,4 +1,4 @@ -use crate::postgres::Relation; +use crate::postgres::PostgresRelation; use crate::vchordrqfscan::algorithm::prewarm::prewarm; use pgrx::pg_sys::Oid; use pgrx_catalog::{PgAm, PgClass}; @@ -17,7 +17,7 @@ fn _vchordrqfscan_prewarm(indexrelid: Oid, height: i32) -> String { pgrx::error!("{:?} is not a vchordrqfscan index", pg_class.relname()); } let index = unsafe { pgrx::pg_sys::index_open(indexrelid, pgrx::pg_sys::ShareLock as _) }; - let relation = unsafe { Relation::new(index) }; + let relation = unsafe { PostgresRelation::new(index) }; let message = prewarm(relation, height); unsafe { pgrx::pg_sys::index_close(index, pgrx::pg_sys::ShareLock as _); diff --git a/src/vchordrqfscan/index/utils.rs b/src/vchordrqfscan/index/utils.rs index a5d85a3..726a597 100644 --- a/src/vchordrqfscan/index/utils.rs +++ b/src/vchordrqfscan/index/utils.rs @@ -1,7 +1,7 @@ -use base::search::*; +use std::num::NonZeroU64; -pub fn pointer_to_ctid(pointer: Pointer) -> pgrx::pg_sys::ItemPointerData { - let value = pointer.as_u64(); +pub fn pointer_to_ctid(pointer: NonZeroU64) -> pgrx::pg_sys::ItemPointerData { + let value = pointer.get(); pgrx::pg_sys::ItemPointerData { ip_blkid: pgrx::pg_sys::BlockIdData { bi_hi: ((value >> 32) & 0xffff) as u16, @@ -11,10 +11,10 @@ pub fn pointer_to_ctid(pointer: Pointer) -> pgrx::pg_sys::ItemPointerData { } } -pub fn ctid_to_pointer(ctid: pgrx::pg_sys::ItemPointerData) -> Pointer { +pub fn ctid_to_pointer(ctid: pgrx::pg_sys::ItemPointerData) -> NonZeroU64 { let mut value = 0; value |= (ctid.ip_blkid.bi_hi as u64) << 32; value |= (ctid.ip_blkid.bi_lo as u64) << 16; value |= ctid.ip_posid as u64; - Pointer::new(value) + NonZeroU64::new(value).expect("invalid pointer") } diff --git a/src/vchordrqfscan/types.rs b/src/vchordrqfscan/types.rs index 1180e64..91ca769 100644 --- a/src/vchordrqfscan/types.rs +++ b/src/vchordrqfscan/types.rs @@ -1,7 +1,7 @@ -use base::distance::DistanceKind; -use base::vector::{VectBorrowed, VectOwned}; +use distance::Distance; use serde::{Deserialize, Serialize}; use validator::{Validate, ValidationError, ValidationErrors}; +use vector::vect::{VectBorrowed, VectOwned}; #[derive(Debug, Clone, Serialize, Deserialize, Validate)] #[serde(deny_unknown_fields)] @@ -108,6 +108,13 @@ pub enum BorrowedVector<'a> { Vecf32(VectBorrowed<'a, f32>), } +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum DistanceKind { + L2, + Dot, +} + #[repr(u8)] #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub enum VectorKind { @@ -136,3 +143,11 @@ impl VectorOptions { } } } + +pub fn distance(d: DistanceKind, lhs: &[f32], rhs: &[f32]) -> Distance { + use simd::Floating; + match d { + DistanceKind::L2 => Distance::from_f32(f32::reduce_sum_of_d2(lhs, rhs)), + DistanceKind::Dot => Distance::from_f32(-f32::reduce_sum_of_xy(lhs, rhs)), + } +} diff --git a/tests/logic/reindex.slt b/tests/logic/reindex.slt index 7d76343..9a58049 100644 --- a/tests/logic/reindex.slt +++ b/tests/logic/reindex.slt @@ -2,7 +2,7 @@ statement ok CREATE TABLE t (val vector(3)); statement ok -INSERT INTO t (val) SELECT ARRAY[random(), random(), random()]::real[] FROM generate_series(1, 1000); +INSERT INTO t (val) SELECT ARRAY[random(), random(), random()]::real[] FROM generate_series(1, 10000); statement ok CREATE INDEX ON t USING vchordrq (val vector_l2_ops); diff --git a/tools/package.sh b/tools/package.sh index 3f6172e..c6955ae 100755 --- a/tools/package.sh +++ b/tools/package.sh @@ -5,7 +5,6 @@ printf "SEMVER = ${SEMVER}\n" printf "VERSION = ${VERSION}\n" printf "ARCH = ${ARCH}\n" printf "PLATFORM = ${PLATFORM}\n" -printf "PROFILE = ${PROFILE}\n" rm -rf ./build/dir_zip rm -rf ./build/vchord-pg${VERSION}_${ARCH}-unknown-linux-gnu_${SEMVER}.zip @@ -13,9 +12,9 @@ rm -rf ./build/dir_deb rm -rf ./build/vchord-pg${VERSION}_${SEMVER}_${PLATFORM}.deb mkdir -p ./build/dir_zip -cp ./target/${PROFILE}/schema.sql ./build/dir_zip/vchord--$SEMVER.sql +cp ./target/release/schema.sql ./build/dir_zip/vchord--$SEMVER.sql sed -e "s/@CARGO_VERSION@/$SEMVER/g" < ./vchord.control > ./build/dir_zip/vchord.control -cp ./target/${PROFILE}/libvchord.so ./build/dir_zip/vchord.so +cp ./target/release/libvchord.so ./build/dir_zip/vchord.so zip ./build/vchord-pg${VERSION}_${ARCH}-unknown-linux-gnu_${SEMVER}.zip -j ./build/dir_zip/* mkdir -p ./build/dir_deb diff --git a/tools/schema.sh b/tools/schema.sh index 24a26d0..c43f7f8 100755 --- a/tools/schema.sh +++ b/tools/schema.sh @@ -4,20 +4,12 @@ if [[ " $@ " =~ --target' '([^ ]+) ]]; then TARGET="${BASH_REMATCH[1]}" if [[ " $@ " =~ " --release " ]]; then DIR="./target/$TARGET/release" - elif [[ " $@ " =~ " --profile opt " ]]; then - DIR="./target/$TARGET/opt" - elif [[ " $@ " =~ " --profile release " ]]; then - DIR="./target/$TARGET/release" else DIR="./target/$TARGET/debug" fi else if [[ " $@ " =~ " --release " ]]; then DIR="./target/release" - elif [[ " $@ " =~ " --profile opt " ]]; then - DIR="./target/opt" - elif [[ " $@ " =~ " --profile release " ]]; then - DIR="./target/release" else DIR="./target/debug" fi