diff --git a/.clang-format b/.clang-format index 643005d0..d99fd0eb 100644 --- a/.clang-format +++ b/.clang-format @@ -1,10 +1,69 @@ Language: Cpp -BasedOnStyle: LLVM +BasedOnStyle: LLVM IndentWidth: 4 TabWidth: 4 NamespaceIndentation: None ColumnLimit: 120 ReflowComments: true UseTab: Never -PointerAlignment: Left + +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlines: true +AlignOperands: true +AlignTrailingComments: true + +AllowAllArgumentsOnNextLine: false +AllowAllConstructorInitializersOnNextLine: true +AllowAllParametersOfDeclarationOnNextLine: true +AllowShortBlocksOnASingleLine: Always +AllowShortIfStatementsOnASingleLine: Always AllowShortCaseLabelsOnASingleLine: true +AllowShortFunctionsOnASingleLine: true +AllowShortLambdasOnASingleLine: true +AllowShortLoopsOnASingleLine: true +AlwaysBreakTemplateDeclarations: Yes +AlwaysBreakAfterReturnType: None +PenaltyReturnTypeOnItsOwnLine: 200 + +BreakBeforeBraces: Custom +BraceWrapping: + AfterCaseLabel: false + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterExternBlock: false + AfterFunction: false + AfterStruct: false + AfterNamespace: false + AfterUnion: false + BeforeCatch: true + BeforeElse: true + SplitEmptyFunction: false + SplitEmptyRecord: false + SplitEmptyNamespace: false + IndentBraces: false + +SortIncludes: true +SortUsingDeclarations: true + +SpaceAfterCStyleCast: false +SpaceAfterLogicalNot: false +SpaceAfterTemplateKeyword: true +SpaceBeforeAssignmentOperators: true +SpaceBeforeCpp11BracedList: true +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeParens: ControlStatements +SpaceBeforeRangeBasedForLoopColon: true +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInCStyleCastParentheses: false +SpacesInContainerLiterals: false +SpacesInParentheses: false +SpacesInSquareBrackets: false + +BinPackArguments: true +BinPackParameters: true +PenaltyBreakBeforeFirstCallParameter: 1 \ No newline at end of file diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 00000000..57b3f619 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1 @@ +a4022a988287e527757ecc9bc16a4f2e7dc4770e diff --git a/CMakeLists.txt b/CMakeLists.txt index 21e4e377..9d7b4c8f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,7 +4,7 @@ project( simsimd VERSION 5.8.0 LANGUAGES C CXX - DESCRIPTION "Fastest SIMD-Accelerated Vector Similarity Functions for x86 and Arm" + DESCRIPTION "Portable mixed-precision BLAS-like vector math library for x86 and ARM" HOMEPAGE_URL "https://github.com/ashvardanian/simsimd" ) @@ -16,7 +16,7 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED YES) set(CMAKE_CXX_EXTENSIONS NO) -# Determine if StringZilla is built as a subproject (using `add_subdirectory`) or if it is the main project +# Determine if SimSIMD is built as a subproject (using `add_subdirectory`) or if it is the main project set(SIMSIMD_IS_MAIN_PROJECT OFF) if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR) @@ -79,7 +79,7 @@ if (SIMSIMD_BUILD_BENCHMARKS) ) FetchContent_MakeAvailable(benchmark) - # Remove the google benchmark built in debug warning + # Remove the Google Benchmark's "built in debug warning" if (CMAKE_BUILD_TYPE STREQUAL "Release") target_compile_definitions(benchmark PRIVATE NDEBUG) endif () @@ -88,12 +88,19 @@ if (SIMSIMD_BUILD_BENCHMARKS) add_executable(simsimd_bench scripts/bench.cxx) target_link_libraries(simsimd_bench simsimd Threads::Threads benchmark) - find_package(BLAS) - - if (BLAS_FOUND AND SIMSIMD_BUILD_BENCHMARKS_WITH_CBLAS) - target_compile_definitions(simsimd_bench PRIVATE SIMSIMD_BUILD_BENCHMARKS_WITH_CBLAS=1) - target_link_libraries(simsimd_bench ${BLAS_LIBRARIES}) + if (SIMSIMD_BUILD_BENCHMARKS_WITH_CBLAS) + find_package(BLAS REQUIRED) + if (BLAS_FOUND) + message(STATUS "BLAS found: ${BLAS_LIBRARIES}") + include_directories(${BLAS_INCLUDE_DIRS}) + target_include_directories(simsimd_bench PRIVATE ${BLAS_INCLUDE_DIRS}) + target_link_libraries(simsimd_bench ${BLAS_LIBRARIES}) + target_compile_definitions(simsimd_bench PRIVATE SIMSIMD_BUILD_BENCHMARKS_WITH_CBLAS=1) + else () + message(FATAL_ERROR "BLAS not found") + endif () endif () + endif () if (SIMSIMD_BUILD_TESTS) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 14348c7b..0728a344 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -29,15 +29,37 @@ sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 100 sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-12 100 ``` -On MacOS it's recommended to use Homebrew and install Clang, as opposed to "Apple Clang". -Replacing the default compiler is not recommended, as it may break the system, but you can pass it as an environment variable: +To compile with the default Apple Clang on MacOS, use: ```sh -brew install llvm -cmake -D CMAKE_BUILD_TYPE=Release -D SIMSIMD_BUILD_TESTS=1 \ - -D CMAKE_C_COMPILER="$(brew --prefix llvm)/bin/clang" \ - -D CMAKE_CXX_COMPILER="$(brew --prefix llvm)/bin/clang++" \ - -B build_release +brew install openblas +cmake -D CMAKE_BUILD_TYPE=Release \ + -D SIMSIMD_BUILD_TESTS=1 \ + -D SIMSIMD_BUILD_BENCHMARKS=1 \ + -D SIMSIMD_BUILD_BENCHMARKS_WITH_CBLAS=1 \ + -D CMAKE_PREFIX_PATH="$(brew --prefix openblas)" \ + -D CMAKE_CXX_STANDARD_INCLUDE_DIRECTORIES="$(brew --prefix openblas)/include" \ + -B build_release +cmake --build build_release --config Release +``` + +On MacOS it's recommended to use Homebrew and install Clang, as opposed to "Apple Clang". +Replacing the default compiler across the entire system is not recommended on MacOS, as it may break the system, but you can pass it as an environment variable: + +```sh +brew install llvm openblas +cmake -D CMAKE_BUILD_TYPE=Release \ + -D SIMSIMD_BUILD_TESTS=1 \ + -D SIMSIMD_BUILD_BENCHMARKS=1 \ + -D SIMSIMD_BUILD_BENCHMARKS_WITH_CBLAS=1 \ + -D CMAKE_CXX_STANDARD_INCLUDE_DIRECTORIES="$(brew --prefix openblas)/include" \ + -D CMAKE_C_LINK_FLAGS="-L$(xcrun --sdk macosx --show-sdk-path)/usr/lib" \ + -D CMAKE_EXE_LINKER_FLAGS="-L$(xcrun --sdk macosx --show-sdk-path)/usr/lib" \ + -D CMAKE_C_COMPILER="$(brew --prefix llvm)/bin/clang" \ + -D CMAKE_CXX_COMPILER="$(brew --prefix llvm)/bin/clang++" \ + -D CMAKE_OSX_SYSROOT="$(xcrun --sdk macosx --show-sdk-path)" \ + -D CMAKE_OSX_DEPLOYMENT_TARGET=$(sw_vers -productVersion) \ + -B build_release cmake --build build_release --config Release ``` diff --git a/Cargo.toml b/Cargo.toml index 9ac3eea3..04a08adc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simsimd" -description = "Fastest SIMD-Accelerated Vector Similarity Functions for x86 and Arm" +description = "Portable mixed-precision BLAS-like vector math library for x86 and ARM" version = "5.8.0" edition = "2021" license = "Apache-2.0" diff --git a/README.md b/README.md index 8d1fd0c5..63744309 100644 --- a/README.md +++ b/README.md @@ -729,7 +729,8 @@ To explicitly disable half-precision support, define the following macro before > But if you are running on different generations of devices, it makes sense to pre-compile the library for all supported generations at once, and dispatch at runtime. > This flag does just that and is used to produce the `simsimd.so` shared library, as well as the Python and other bindings. -`SIMSIMD_TARGET_ARM` (`SIMSIMD_TARGET_NEON`, `SIMSIMD_TARGET_SVE`, `SIMSIMD_TARGET_SVE2`, `SIMSIMD_TARGET_NEON_F16`, `SIMSIMD_TARGET_SVE_F16`, `SIMSIMD_TARGET_NEON_BF16`, `SIMSIMD_TARGET_SVE_BF16`), `SIMSIMD_TARGET_X86` (`SIMSIMD_TARGET_HASWELL`, `SIMSIMD_TARGET_SKYLAKE`, `SIMSIMD_TARGET_ICE`, `SIMSIMD_TARGET_GENOA`, `SIMSIMD_TARGET_SAPPHIRE`, `SIMSIMD_TARGET_TURIN`, `SIMSIMD_TARGET_SIERRA`): +For Arm: `SIMSIMD_TARGET_NEON`, `SIMSIMD_TARGET_SVE`, `SIMSIMD_TARGET_SVE2`, `SIMSIMD_TARGET_NEON_F16`, `SIMSIMD_TARGET_SVE_F16`, `SIMSIMD_TARGET_NEON_BF16`, `SIMSIMD_TARGET_SVE_BF16`. +For x86: (`SIMSIMD_TARGET_HASWELL`, `SIMSIMD_TARGET_SKYLAKE`, `SIMSIMD_TARGET_ICE`, `SIMSIMD_TARGET_GENOA`, `SIMSIMD_TARGET_SAPPHIRE`, `SIMSIMD_TARGET_TURIN`, `SIMSIMD_TARGET_SIERRA`. > By default, SimSIMD automatically infers the target architecture and pre-compiles as many kernels as possible. > In some cases, you may want to explicitly disable some of the kernels. @@ -753,6 +754,7 @@ In general there are a few principles that SimSIMD follows: - Avoid returning from public interfaces, use out-arguments instead. - Don't over-optimize for old CPUs and single- and double-precision floating-point numbers. - Prioritize mixed-precision and integer operations, and new ISA extensions. +- Prefer saturated arithmetic and avoid overflows. Possibly, in the future: diff --git a/c/lib.c b/c/lib.c index 81f3e8a2..951f50ed 100644 --- a/c/lib.c +++ b/c/lib.c @@ -55,56 +55,56 @@ extern "C" { // If no metric is found, it returns NaN. We can obtain NaN by dividing 0.0 by 0.0, but that annoys // the MSVC compiler. Instead we can directly write-in the signaling NaN (0x7FF0000000000001) // or the qNaN (0x7FF8000000000000). -#define SIMSIMD_DECLARATION_DENSE(name, extension, type) \ - SIMSIMD_DYNAMIC void simsimd_##name##_##extension(simsimd_##type##_t const* a, simsimd_##type##_t const* b, \ - simsimd_size_t n, simsimd_distance_t* results) { \ - static simsimd_metric_punned_t metric = 0; \ - if (metric == 0) { \ - simsimd_capability_t used_capability; \ - simsimd_find_metric_punned(simsimd_metric_##name##_k, simsimd_datatype_##extension##_k, \ - simsimd_capabilities(), simsimd_cap_any_k, &metric, &used_capability); \ - if (!metric) { \ - *(simsimd_u64_t*)results = 0x7FF0000000000001ull; \ - return; \ - } \ - } \ - metric(a, b, n, results); \ +#define SIMSIMD_DECLARATION_DENSE(name, extension, type) \ + SIMSIMD_DYNAMIC void simsimd_##name##_##extension(simsimd_##type##_t const *a, simsimd_##type##_t const *b, \ + simsimd_size_t n, simsimd_distance_t *results) { \ + static simsimd_metric_punned_t metric = 0; \ + if (metric == 0) { \ + simsimd_capability_t used_capability; \ + simsimd_find_metric_punned(simsimd_metric_##name##_k, simsimd_datatype_##extension##_k, \ + simsimd_capabilities(), simsimd_cap_any_k, &metric, &used_capability); \ + if (!metric) { \ + *(simsimd_u64_t *)results = 0x7FF0000000000001ull; \ + return; \ + } \ + } \ + metric(a, b, n, results); \ } -#define SIMSIMD_DECLARATION_SPARSE(name, extension, type) \ - SIMSIMD_DYNAMIC void simsimd_##name##_##extension(simsimd_##type##_t const* a, simsimd_##type##_t const* b, \ - simsimd_size_t a_length, simsimd_size_t b_length, \ - simsimd_distance_t* result) { \ - static simsimd_metric_sparse_punned_t metric = 0; \ - if (metric == 0) { \ - simsimd_capability_t used_capability; \ - simsimd_find_metric_punned(simsimd_metric_##name##_k, simsimd_datatype_##extension##_k, \ - simsimd_capabilities(), simsimd_cap_any_k, (simsimd_metric_punned_t*)(&metric), \ - &used_capability); \ - if (!metric) { \ - *(simsimd_u64_t*)result = 0x7FF0000000000001ull; \ - return; \ - } \ - } \ - metric(a, b, a_length, b_length, result); \ +#define SIMSIMD_DECLARATION_SPARSE(name, extension, type) \ + SIMSIMD_DYNAMIC void simsimd_##name##_##extension(simsimd_##type##_t const *a, simsimd_##type##_t const *b, \ + simsimd_size_t a_length, simsimd_size_t b_length, \ + simsimd_distance_t *result) { \ + static simsimd_metric_sparse_punned_t metric = 0; \ + if (metric == 0) { \ + simsimd_capability_t used_capability; \ + simsimd_find_metric_punned(simsimd_metric_##name##_k, simsimd_datatype_##extension##_k, \ + simsimd_capabilities(), simsimd_cap_any_k, \ + (simsimd_metric_punned_t *)(&metric), &used_capability); \ + if (!metric) { \ + *(simsimd_u64_t *)result = 0x7FF0000000000001ull; \ + return; \ + } \ + } \ + metric(a, b, a_length, b_length, result); \ } -#define SIMSIMD_DECLARATION_CURVED(name, extension, type) \ - SIMSIMD_DYNAMIC void simsimd_##name##_##extension(simsimd_##type##_t const* a, simsimd_##type##_t const* b, \ - simsimd_##type##_t const* c, simsimd_size_t n, \ - simsimd_distance_t* result) { \ - static simsimd_metric_curved_punned_t metric = 0; \ - if (metric == 0) { \ - simsimd_capability_t used_capability; \ - simsimd_find_metric_punned(simsimd_metric_##name##_k, simsimd_datatype_##extension##_k, \ - simsimd_capabilities(), simsimd_cap_any_k, (simsimd_metric_punned_t*)(&metric), \ - &used_capability); \ - if (!metric) { \ - *(simsimd_u64_t*)result = 0x7FF0000000000001ull; \ - return; \ - } \ - } \ - metric(a, b, c, n, result); \ +#define SIMSIMD_DECLARATION_CURVED(name, extension, type) \ + SIMSIMD_DYNAMIC void simsimd_##name##_##extension(simsimd_##type##_t const *a, simsimd_##type##_t const *b, \ + simsimd_##type##_t const *c, simsimd_size_t n, \ + simsimd_distance_t *result) { \ + static simsimd_metric_curved_punned_t metric = 0; \ + if (metric == 0) { \ + simsimd_capability_t used_capability; \ + simsimd_find_metric_punned(simsimd_metric_##name##_k, simsimd_datatype_##extension##_k, \ + simsimd_capabilities(), simsimd_cap_any_k, \ + (simsimd_metric_punned_t *)(&metric), &used_capability); \ + if (!metric) { \ + *(simsimd_u64_t *)result = 0x7FF0000000000001ull; \ + return; \ + } \ + } \ + metric(a, b, c, n, result); \ } // Dot products @@ -191,8 +191,7 @@ SIMSIMD_DYNAMIC int simsimd_uses_dynamic_dispatch(void) { return 1; } SIMSIMD_DYNAMIC simsimd_capability_t simsimd_capabilities(void) { //! The latency of the CPUID instruction can be over 100 cycles, so we cache the result. static simsimd_capability_t static_capabilities = simsimd_cap_any_k; - if (static_capabilities != simsimd_cap_any_k) - return static_capabilities; + if (static_capabilities != simsimd_cap_any_k) return static_capabilities; static_capabilities = _simsimd_capabilities_implementation(); @@ -200,73 +199,74 @@ SIMSIMD_DYNAMIC simsimd_capability_t simsimd_capabilities(void) { // so the first time we are probing for capabilities, we should also probe all of our metrics // with dummy inputs: simsimd_distance_t dummy_results_buffer[2]; - simsimd_distance_t* dummy_results = &dummy_results_buffer[0]; - void* dummy = 0; + simsimd_distance_t *dummy_results = &dummy_results_buffer[0]; + void *dummy = 0; // Dense: - simsimd_dot_i8((simsimd_i8_t*)dummy, (simsimd_i8_t*)dummy, 0, dummy_results); - simsimd_dot_u8((simsimd_u8_t*)dummy, (simsimd_u8_t*)dummy, 0, dummy_results); - simsimd_dot_f16((simsimd_f16_t*)dummy, (simsimd_f16_t*)dummy, 0, dummy_results); - simsimd_dot_bf16((simsimd_bf16_t*)dummy, (simsimd_bf16_t*)dummy, 0, dummy_results); - simsimd_dot_f32((simsimd_f32_t*)dummy, (simsimd_f32_t*)dummy, 0, dummy_results); - simsimd_dot_f64((simsimd_f64_t*)dummy, (simsimd_f64_t*)dummy, 0, dummy_results); + simsimd_dot_i8((simsimd_i8_t *)dummy, (simsimd_i8_t *)dummy, 0, dummy_results); + simsimd_dot_u8((simsimd_u8_t *)dummy, (simsimd_u8_t *)dummy, 0, dummy_results); + simsimd_dot_f16((simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, 0, dummy_results); + simsimd_dot_bf16((simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, 0, dummy_results); + simsimd_dot_f32((simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, 0, dummy_results); + simsimd_dot_f64((simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, 0, dummy_results); - simsimd_dot_f16c((simsimd_f16_t*)dummy, (simsimd_f16_t*)dummy, 0, dummy_results); - simsimd_dot_bf16c((simsimd_bf16_t*)dummy, (simsimd_bf16_t*)dummy, 0, dummy_results); - simsimd_dot_f32c((simsimd_f32_t*)dummy, (simsimd_f32_t*)dummy, 0, dummy_results); - simsimd_dot_f64c((simsimd_f64_t*)dummy, (simsimd_f64_t*)dummy, 0, dummy_results); - simsimd_vdot_f16c((simsimd_f16_t*)dummy, (simsimd_f16_t*)dummy, 0, dummy_results); - simsimd_vdot_bf16c((simsimd_bf16_t*)dummy, (simsimd_bf16_t*)dummy, 0, dummy_results); - simsimd_vdot_f32c((simsimd_f32_t*)dummy, (simsimd_f32_t*)dummy, 0, dummy_results); - simsimd_vdot_f64c((simsimd_f64_t*)dummy, (simsimd_f64_t*)dummy, 0, dummy_results); + simsimd_dot_f16c((simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, 0, dummy_results); + simsimd_dot_bf16c((simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, 0, dummy_results); + simsimd_dot_f32c((simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, 0, dummy_results); + simsimd_dot_f64c((simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, 0, dummy_results); + simsimd_vdot_f16c((simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, 0, dummy_results); + simsimd_vdot_bf16c((simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, 0, dummy_results); + simsimd_vdot_f32c((simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, 0, dummy_results); + simsimd_vdot_f64c((simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, 0, dummy_results); - simsimd_cos_i8((simsimd_i8_t*)dummy, (simsimd_i8_t*)dummy, 0, dummy_results); - simsimd_cos_u8((simsimd_u8_t*)dummy, (simsimd_u8_t*)dummy, 0, dummy_results); - simsimd_cos_f16((simsimd_f16_t*)dummy, (simsimd_f16_t*)dummy, 0, dummy_results); - simsimd_cos_bf16((simsimd_bf16_t*)dummy, (simsimd_bf16_t*)dummy, 0, dummy_results); - simsimd_cos_f32((simsimd_f32_t*)dummy, (simsimd_f32_t*)dummy, 0, dummy_results); - simsimd_cos_f64((simsimd_f64_t*)dummy, (simsimd_f64_t*)dummy, 0, dummy_results); + simsimd_cos_i8((simsimd_i8_t *)dummy, (simsimd_i8_t *)dummy, 0, dummy_results); + simsimd_cos_u8((simsimd_u8_t *)dummy, (simsimd_u8_t *)dummy, 0, dummy_results); + simsimd_cos_f16((simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, 0, dummy_results); + simsimd_cos_bf16((simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, 0, dummy_results); + simsimd_cos_f32((simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, 0, dummy_results); + simsimd_cos_f64((simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, 0, dummy_results); - simsimd_l2sq_i8((simsimd_i8_t*)dummy, (simsimd_i8_t*)dummy, 0, dummy_results); - simsimd_l2sq_u8((simsimd_u8_t*)dummy, (simsimd_u8_t*)dummy, 0, dummy_results); - simsimd_l2sq_f16((simsimd_f16_t*)dummy, (simsimd_f16_t*)dummy, 0, dummy_results); - simsimd_l2sq_bf16((simsimd_bf16_t*)dummy, (simsimd_bf16_t*)dummy, 0, dummy_results); - simsimd_l2sq_f32((simsimd_f32_t*)dummy, (simsimd_f32_t*)dummy, 0, dummy_results); - simsimd_l2sq_f64((simsimd_f64_t*)dummy, (simsimd_f64_t*)dummy, 0, dummy_results); + simsimd_l2sq_i8((simsimd_i8_t *)dummy, (simsimd_i8_t *)dummy, 0, dummy_results); + simsimd_l2sq_u8((simsimd_u8_t *)dummy, (simsimd_u8_t *)dummy, 0, dummy_results); + simsimd_l2sq_f16((simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, 0, dummy_results); + simsimd_l2sq_bf16((simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, 0, dummy_results); + simsimd_l2sq_f32((simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, 0, dummy_results); + simsimd_l2sq_f64((simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, 0, dummy_results); - simsimd_l2_i8((simsimd_i8_t*)dummy, (simsimd_i8_t*)dummy, 0, dummy_results); - simsimd_l2_i8((simsimd_i8_t*)dummy, (simsimd_i8_t*)dummy, 0, dummy_results); - simsimd_l2_u8((simsimd_u8_t*)dummy, (simsimd_u8_t*)dummy, 0, dummy_results); - simsimd_l2_f16((simsimd_f16_t*)dummy, (simsimd_f16_t*)dummy, 0, dummy_results); - simsimd_l2_bf16((simsimd_bf16_t*)dummy, (simsimd_bf16_t*)dummy, 0, dummy_results); - simsimd_l2_f32((simsimd_f32_t*)dummy, (simsimd_f32_t*)dummy, 0, dummy_results); - simsimd_l2_f64((simsimd_f64_t*)dummy, (simsimd_f64_t*)dummy, 0, dummy_results); + simsimd_l2_i8((simsimd_i8_t *)dummy, (simsimd_i8_t *)dummy, 0, dummy_results); + simsimd_l2_i8((simsimd_i8_t *)dummy, (simsimd_i8_t *)dummy, 0, dummy_results); + simsimd_l2_u8((simsimd_u8_t *)dummy, (simsimd_u8_t *)dummy, 0, dummy_results); + simsimd_l2_f16((simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, 0, dummy_results); + simsimd_l2_bf16((simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, 0, dummy_results); + simsimd_l2_f32((simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, 0, dummy_results); + simsimd_l2_f64((simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, 0, dummy_results); - simsimd_hamming_b8((simsimd_b8_t*)dummy, (simsimd_b8_t*)dummy, 0, dummy_results); - simsimd_jaccard_b8((simsimd_b8_t*)dummy, (simsimd_b8_t*)dummy, 0, dummy_results); + simsimd_hamming_b8((simsimd_b8_t *)dummy, (simsimd_b8_t *)dummy, 0, dummy_results); + simsimd_jaccard_b8((simsimd_b8_t *)dummy, (simsimd_b8_t *)dummy, 0, dummy_results); - simsimd_kl_f16((simsimd_f16_t*)dummy, (simsimd_f16_t*)dummy, 0, dummy_results); - simsimd_kl_bf16((simsimd_bf16_t*)dummy, (simsimd_bf16_t*)dummy, 0, dummy_results); - simsimd_kl_f32((simsimd_f32_t*)dummy, (simsimd_f32_t*)dummy, 0, dummy_results); - simsimd_kl_f64((simsimd_f64_t*)dummy, (simsimd_f64_t*)dummy, 0, dummy_results); - simsimd_js_f16((simsimd_f16_t*)dummy, (simsimd_f16_t*)dummy, 0, dummy_results); - simsimd_js_bf16((simsimd_bf16_t*)dummy, (simsimd_bf16_t*)dummy, 0, dummy_results); - simsimd_js_f32((simsimd_f32_t*)dummy, (simsimd_f32_t*)dummy, 0, dummy_results); - simsimd_js_f64((simsimd_f64_t*)dummy, (simsimd_f64_t*)dummy, 0, dummy_results); + simsimd_kl_f16((simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, 0, dummy_results); + simsimd_kl_bf16((simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, 0, dummy_results); + simsimd_kl_f32((simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, 0, dummy_results); + simsimd_kl_f64((simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, 0, dummy_results); + simsimd_js_f16((simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, 0, dummy_results); + simsimd_js_bf16((simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, 0, dummy_results); + simsimd_js_f32((simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, 0, dummy_results); + simsimd_js_f64((simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, 0, dummy_results); // Sparse - simsimd_intersect_u16((simsimd_u16_t*)dummy, (simsimd_u16_t*)dummy, 0, 0, dummy_results); - simsimd_intersect_u32((simsimd_u32_t*)dummy, (simsimd_u32_t*)dummy, 0, 0, dummy_results); + simsimd_intersect_u16((simsimd_u16_t *)dummy, (simsimd_u16_t *)dummy, 0, 0, dummy_results); + simsimd_intersect_u32((simsimd_u32_t *)dummy, (simsimd_u32_t *)dummy, 0, 0, dummy_results); // Curved: - simsimd_bilinear_f64((simsimd_f64_t*)dummy, (simsimd_f64_t*)dummy, (simsimd_f64_t*)dummy, 0, dummy_results); - simsimd_mahalanobis_f64((simsimd_f64_t*)dummy, (simsimd_f64_t*)dummy, (simsimd_f64_t*)dummy, 0, dummy_results); - simsimd_bilinear_f32((simsimd_f32_t*)dummy, (simsimd_f32_t*)dummy, (simsimd_f32_t*)dummy, 0, dummy_results); - simsimd_mahalanobis_f32((simsimd_f32_t*)dummy, (simsimd_f32_t*)dummy, (simsimd_f32_t*)dummy, 0, dummy_results); - simsimd_bilinear_f16((simsimd_f16_t*)dummy, (simsimd_f16_t*)dummy, (simsimd_f16_t*)dummy, 0, dummy_results); - simsimd_mahalanobis_f16((simsimd_f16_t*)dummy, (simsimd_f16_t*)dummy, (simsimd_f16_t*)dummy, 0, dummy_results); - simsimd_bilinear_bf16((simsimd_bf16_t*)dummy, (simsimd_bf16_t*)dummy, (simsimd_bf16_t*)dummy, 0, dummy_results); - simsimd_mahalanobis_bf16((simsimd_bf16_t*)dummy, (simsimd_bf16_t*)dummy, (simsimd_bf16_t*)dummy, 0, dummy_results); + simsimd_bilinear_f64((simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, 0, dummy_results); + simsimd_mahalanobis_f64((simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, 0, dummy_results); + simsimd_bilinear_f32((simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, 0, dummy_results); + simsimd_mahalanobis_f32((simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, 0, dummy_results); + simsimd_bilinear_f16((simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, 0, dummy_results); + simsimd_mahalanobis_f16((simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, 0, dummy_results); + simsimd_bilinear_bf16((simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, 0, dummy_results); + simsimd_mahalanobis_bf16((simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, 0, + dummy_results); return static_capabilities; } @@ -276,8 +276,8 @@ SIMSIMD_DYNAMIC void simsimd_find_metric_punned( // simsimd_datatype_t datatype, // simsimd_capability_t supported, // simsimd_capability_t allowed, // - simsimd_metric_punned_t* metric_output, // - simsimd_capability_t* capability_output) { + simsimd_metric_punned_t *metric_output, // + simsimd_capability_t *capability_output) { _simsimd_find_metric_punned_implementation(kind, datatype, supported, allowed, metric_output, capability_output); } diff --git a/include/simsimd/binary.h b/include/simsimd/binary.h index 7165e568..be452882 100644 --- a/include/simsimd/binary.h +++ b/include/simsimd/binary.h @@ -60,23 +60,22 @@ SIMSIMD_PUBLIC unsigned char simsimd_popcount_b8(simsimd_b8_t x) { return lookup_table[x]; } -SIMSIMD_PUBLIC void simsimd_hamming_b8_serial(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_hamming_b8_serial(simsimd_b8_t const *a, simsimd_b8_t const *b, simsimd_size_t n_words, + simsimd_distance_t *result) { simsimd_i32_t differences = 0; - for (simsimd_size_t i = 0; i != n_words; ++i) - differences += simsimd_popcount_b8(a[i] ^ b[i]); + for (simsimd_size_t i = 0; i != n_words; ++i) differences += simsimd_popcount_b8(a[i] ^ b[i]); *result = differences; } -SIMSIMD_PUBLIC void simsimd_jaccard_b8_serial(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_jaccard_b8_serial(simsimd_b8_t const *a, simsimd_b8_t const *b, simsimd_size_t n_words, + simsimd_distance_t *result) { simsimd_i32_t intersection = 0, union_ = 0; for (simsimd_size_t i = 0; i != n_words; ++i) intersection += simsimd_popcount_b8(a[i] & b[i]), union_ += simsimd_popcount_b8(a[i] | b[i]); *result = (union_ != 0) ? 1 - (simsimd_f64_t)intersection / (simsimd_f64_t)union_ : 1; } -#if SIMSIMD_TARGET_ARM +#if _SIMSIMD_TARGET_ARM #if SIMSIMD_TARGET_NEON #pragma GCC push_options #pragma GCC target("arch=armv8.2-a+simd") @@ -97,8 +96,8 @@ SIMSIMD_INTERNAL simsimd_u32_t _simsimd_reduce_u8x16_neon(uint8x16_t vec) { return final_sum; } -SIMSIMD_PUBLIC void simsimd_hamming_b8_neon(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_hamming_b8_neon(simsimd_b8_t const *a, simsimd_b8_t const *b, simsimd_size_t n_words, + simsimd_distance_t *result) { simsimd_i32_t differences = 0; simsimd_size_t i = 0; // In each 8-bit word we may have up to 8 differences. @@ -116,13 +115,12 @@ SIMSIMD_PUBLIC void simsimd_hamming_b8_neon(simsimd_b8_t const* a, simsimd_b8_t differences += _simsimd_reduce_u8x16_neon(differences_cycle_vec); } // Handle the tail - for (; i != n_words; ++i) - differences += simsimd_popcount_b8(a[i] ^ b[i]); + for (; i != n_words; ++i) differences += simsimd_popcount_b8(a[i] ^ b[i]); *result = differences; } -SIMSIMD_PUBLIC void simsimd_jaccard_b8_neon(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_jaccard_b8_neon(simsimd_b8_t const *a, simsimd_b8_t const *b, simsimd_size_t n_words, + simsimd_distance_t *result) { simsimd_i32_t intersection = 0, union_ = 0; simsimd_size_t i = 0; // In each 8-bit word we may have up to 8 intersections/unions. @@ -158,8 +156,8 @@ SIMSIMD_PUBLIC void simsimd_jaccard_b8_neon(simsimd_b8_t const* a, simsimd_b8_t #pragma GCC target("arch=armv8.2-a+sve") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve"))), apply_to = function) -SIMSIMD_PUBLIC void simsimd_hamming_b8_sve(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_hamming_b8_sve(simsimd_b8_t const *a, simsimd_b8_t const *b, simsimd_size_t n_words, + simsimd_distance_t *result) { // On very small register sizes, NEON is at least as fast as SVE. simsimd_size_t const words_per_register = svcntb(); @@ -191,8 +189,8 @@ SIMSIMD_PUBLIC void simsimd_hamming_b8_sve(simsimd_b8_t const* a, simsimd_b8_t c *result = differences; } -SIMSIMD_PUBLIC void simsimd_jaccard_b8_sve(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_jaccard_b8_sve(simsimd_b8_t const *a, simsimd_b8_t const *b, simsimd_size_t n_words, + simsimd_distance_t *result) { // On very small register sizes, NEON is at least as fast as SVE. simsimd_size_t const words_per_register = svcntb(); @@ -232,17 +230,17 @@ SIMSIMD_PUBLIC void simsimd_jaccard_b8_sve(simsimd_b8_t const* a, simsimd_b8_t c #pragma clang attribute pop #pragma GCC pop_options #endif // SIMSIMD_TARGET_SVE -#endif // SIMSIMD_TARGET_ARM +#endif // _SIMSIMD_TARGET_ARM -#if SIMSIMD_TARGET_X86 +#if _SIMSIMD_TARGET_X86 #if SIMSIMD_TARGET_ICE #pragma GCC push_options #pragma GCC target("avx2", "avx512f", "avx512vl", "bmi2", "avx512bw", "avx512vpopcntdq") -#pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,bmi2,avx512bw,avx512vpopcntdq"))), \ +#pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,bmi2,avx512bw,avx512vpopcntdq"))), \ apply_to = function) -SIMSIMD_PUBLIC void simsimd_hamming_b8_ice(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_hamming_b8_ice(simsimd_b8_t const *a, simsimd_b8_t const *b, simsimd_size_t n_words, + simsimd_distance_t *result) { simsimd_size_t xor_count; // It's harder to squeeze out performance from tiny representations, so we unroll the loops for binary metrics. @@ -252,7 +250,8 @@ SIMSIMD_PUBLIC void simsimd_hamming_b8_ice(simsimd_b8_t const* a, simsimd_b8_t c __m512i b_vec = _mm512_maskz_loadu_epi8(mask, b); __m512i xor_count_vec = _mm512_popcnt_epi64(_mm512_xor_si512(a_vec, b_vec)); xor_count = _mm512_reduce_add_epi64(xor_count_vec); - } else if (n_words <= 128) { // Up to 1024 bits. + } + else if (n_words <= 128) { // Up to 1024 bits. __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_words - 64); __m512i a1_vec = _mm512_loadu_epi8(a); __m512i b1_vec = _mm512_loadu_epi8(b); @@ -261,7 +260,8 @@ SIMSIMD_PUBLIC void simsimd_hamming_b8_ice(simsimd_b8_t const* a, simsimd_b8_t c __m512i xor1_count_vec = _mm512_popcnt_epi64(_mm512_xor_si512(a1_vec, b1_vec)); __m512i xor2_count_vec = _mm512_popcnt_epi64(_mm512_xor_si512(a2_vec, b2_vec)); xor_count = _mm512_reduce_add_epi64(_mm512_add_epi64(xor2_count_vec, xor1_count_vec)); - } else if (n_words <= 196) { // Up to 1568 bits. + } + else if (n_words <= 196) { // Up to 1568 bits. __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_words - 128); __m512i a1_vec = _mm512_loadu_epi8(a); __m512i b1_vec = _mm512_loadu_epi8(b); @@ -274,7 +274,8 @@ SIMSIMD_PUBLIC void simsimd_hamming_b8_ice(simsimd_b8_t const* a, simsimd_b8_t c __m512i xor3_count_vec = _mm512_popcnt_epi64(_mm512_xor_si512(a3_vec, b3_vec)); xor_count = _mm512_reduce_add_epi64(_mm512_add_epi64(xor3_count_vec, _mm512_add_epi64(xor2_count_vec, xor1_count_vec))); - } else if (n_words <= 256) { // Up to 2048 bits. + } + else if (n_words <= 256) { // Up to 2048 bits. __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_words - 192); __m512i a1_vec = _mm512_loadu_epi8(a); __m512i b1_vec = _mm512_loadu_epi8(b); @@ -290,7 +291,8 @@ SIMSIMD_PUBLIC void simsimd_hamming_b8_ice(simsimd_b8_t const* a, simsimd_b8_t c __m512i xor4_count_vec = _mm512_popcnt_epi64(_mm512_xor_si512(a4_vec, b4_vec)); xor_count = _mm512_reduce_add_epi64(_mm512_add_epi64(_mm512_add_epi64(xor4_count_vec, xor3_count_vec), _mm512_add_epi64(xor2_count_vec, xor1_count_vec))); - } else { + } + else { __m512i xor_count_vec = _mm512_setzero_si512(); __m512i a_vec, b_vec; @@ -300,23 +302,23 @@ SIMSIMD_PUBLIC void simsimd_hamming_b8_ice(simsimd_b8_t const* a, simsimd_b8_t c a_vec = _mm512_maskz_loadu_epi8(mask, a); b_vec = _mm512_maskz_loadu_epi8(mask, b); n_words = 0; - } else { + } + else { a_vec = _mm512_loadu_epi8(a); b_vec = _mm512_loadu_epi8(b); a += 64, b += 64, n_words -= 64; } __m512i xor_vec = _mm512_xor_si512(a_vec, b_vec); xor_count_vec = _mm512_add_epi64(xor_count_vec, _mm512_popcnt_epi64(xor_vec)); - if (n_words) - goto simsimd_hamming_b8_ice_cycle; + if (n_words) goto simsimd_hamming_b8_ice_cycle; xor_count = _mm512_reduce_add_epi64(xor_count_vec); } *result = xor_count; } -SIMSIMD_PUBLIC void simsimd_jaccard_b8_ice(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_jaccard_b8_ice(simsimd_b8_t const *a, simsimd_b8_t const *b, simsimd_size_t n_words, + simsimd_distance_t *result) { simsimd_size_t intersection = 0, union_ = 0; //? On such vectors we can clearly see that the CPU struggles to perform this many parallel @@ -341,7 +343,8 @@ SIMSIMD_PUBLIC void simsimd_jaccard_b8_ice(simsimd_b8_t const* a, simsimd_b8_t c __m512i or_count_vec = _mm512_popcnt_epi64(_mm512_or_si512(a_vec, b_vec)); intersection = _mm512_reduce_add_epi64(and_count_vec); union_ = _mm512_reduce_add_epi64(or_count_vec); - } else if (n_words <= 128) { // Up to 1024 bits. + } + else if (n_words <= 128) { // Up to 1024 bits. __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_words - 64); __m512i a1_vec = _mm512_loadu_epi8(a); __m512i b1_vec = _mm512_loadu_epi8(b); @@ -353,7 +356,8 @@ SIMSIMD_PUBLIC void simsimd_jaccard_b8_ice(simsimd_b8_t const* a, simsimd_b8_t c __m512i or2_count_vec = _mm512_popcnt_epi64(_mm512_or_si512(a2_vec, b2_vec)); intersection = _mm512_reduce_add_epi64(_mm512_add_epi64(and2_count_vec, and1_count_vec)); union_ = _mm512_reduce_add_epi64(_mm512_add_epi64(or2_count_vec, or1_count_vec)); - } else if (n_words <= 196) { // Up to 1568 bits. + } + else if (n_words <= 196) { // Up to 1568 bits. __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_words - 128); __m512i a1_vec = _mm512_loadu_epi8(a); __m512i b1_vec = _mm512_loadu_epi8(b); @@ -371,7 +375,8 @@ SIMSIMD_PUBLIC void simsimd_jaccard_b8_ice(simsimd_b8_t const* a, simsimd_b8_t c _mm512_add_epi64(and3_count_vec, _mm512_add_epi64(and2_count_vec, and1_count_vec))); union_ = _mm512_reduce_add_epi64( // _mm512_add_epi64(or3_count_vec, _mm512_add_epi64(or2_count_vec, or1_count_vec))); - } else if (n_words <= 256) { // Up to 2048 bits. + } + else if (n_words <= 256) { // Up to 2048 bits. __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_words - 192); __m512i a1_vec = _mm512_loadu_epi8(a); __m512i b1_vec = _mm512_loadu_epi8(b); @@ -393,7 +398,8 @@ SIMSIMD_PUBLIC void simsimd_jaccard_b8_ice(simsimd_b8_t const* a, simsimd_b8_t c _mm512_add_epi64(and2_count_vec, and1_count_vec))); union_ = _mm512_reduce_add_epi64(_mm512_add_epi64(_mm512_add_epi64(or4_count_vec, or3_count_vec), _mm512_add_epi64(or2_count_vec, or1_count_vec))); - } else { + } + else { __m512i and_count_vec = _mm512_setzero_si512(), or_count_vec = _mm512_setzero_si512(); __m512i a_vec, b_vec; @@ -403,7 +409,8 @@ SIMSIMD_PUBLIC void simsimd_jaccard_b8_ice(simsimd_b8_t const* a, simsimd_b8_t c a_vec = _mm512_maskz_loadu_epi8(mask, a); b_vec = _mm512_maskz_loadu_epi8(mask, b); n_words = 0; - } else { + } + else { a_vec = _mm512_loadu_epi8(a); b_vec = _mm512_loadu_epi8(b); a += 64, b += 64, n_words -= 64; @@ -412,8 +419,7 @@ SIMSIMD_PUBLIC void simsimd_jaccard_b8_ice(simsimd_b8_t const* a, simsimd_b8_t c __m512i or_vec = _mm512_or_si512(a_vec, b_vec); and_count_vec = _mm512_add_epi64(and_count_vec, _mm512_popcnt_epi64(and_vec)); or_count_vec = _mm512_add_epi64(or_count_vec, _mm512_popcnt_epi64(or_vec)); - if (n_words) - goto simsimd_jaccard_b8_ice_cycle; + if (n_words) goto simsimd_jaccard_b8_ice_cycle; intersection = _mm512_reduce_add_epi64(and_count_vec); union_ = _mm512_reduce_add_epi64(or_count_vec); @@ -430,33 +436,31 @@ SIMSIMD_PUBLIC void simsimd_jaccard_b8_ice(simsimd_b8_t const* a, simsimd_b8_t c #pragma GCC target("popcnt") #pragma clang attribute push(__attribute__((target("popcnt"))), apply_to = function) -SIMSIMD_PUBLIC void simsimd_hamming_b8_haswell(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_hamming_b8_haswell(simsimd_b8_t const *a, simsimd_b8_t const *b, simsimd_size_t n_words, + simsimd_distance_t *result) { // x86 supports unaligned loads and works just fine with the scalar version for small vectors. simsimd_size_t differences = 0; for (; n_words >= 8; n_words -= 8, a += 8, b += 8) - differences += _mm_popcnt_u64(*(simsimd_u64_t const*)a ^ *(simsimd_u64_t const*)b); - for (; n_words; --n_words, ++a, ++b) - differences += _mm_popcnt_u32(*a ^ *b); + differences += _mm_popcnt_u64(*(simsimd_u64_t const *)a ^ *(simsimd_u64_t const *)b); + for (; n_words; --n_words, ++a, ++b) differences += _mm_popcnt_u32(*a ^ *b); *result = differences; } -SIMSIMD_PUBLIC void simsimd_jaccard_b8_haswell(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_jaccard_b8_haswell(simsimd_b8_t const *a, simsimd_b8_t const *b, simsimd_size_t n_words, + simsimd_distance_t *result) { // x86 supports unaligned loads and works just fine with the scalar version for small vectors. simsimd_size_t intersection = 0, union_ = 0; for (; n_words >= 8; n_words -= 8, a += 8, b += 8) - intersection += _mm_popcnt_u64(*(simsimd_u64_t const*)a & *(simsimd_u64_t const*)b), - union_ += _mm_popcnt_u64(*(simsimd_u64_t const*)a | *(simsimd_u64_t const*)b); - for (; n_words; --n_words, ++a, ++b) - intersection += _mm_popcnt_u32(*a & *b), union_ += _mm_popcnt_u32(*a | *b); + intersection += _mm_popcnt_u64(*(simsimd_u64_t const *)a & *(simsimd_u64_t const *)b), + union_ += _mm_popcnt_u64(*(simsimd_u64_t const *)a | *(simsimd_u64_t const *)b); + for (; n_words; --n_words, ++a, ++b) intersection += _mm_popcnt_u32(*a & *b), union_ += _mm_popcnt_u32(*a | *b); *result = (union_ != 0) ? 1 - (simsimd_f64_t)intersection / (simsimd_f64_t)union_ : 1; } #pragma clang attribute pop #pragma GCC pop_options #endif // SIMSIMD_TARGET_HASWELL -#endif // SIMSIMD_TARGET_X86 +#endif // _SIMSIMD_TARGET_X86 #ifdef __cplusplus } diff --git a/include/simsimd/curved.h b/include/simsimd/curved.h index 5d9846e9..59a99fe6 100644 --- a/include/simsimd/curved.h +++ b/include/simsimd/curved.h @@ -92,40 +92,40 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f16_sapphire(simsimd_f16_t const* a, simsim SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_f16_t const* c, simsimd_size_t n, simsimd_distance_t* result); // clang-format on -#define SIMSIMD_MAKE_BILINEAR(name, input_type, accumulator_type, load_and_convert) \ - SIMSIMD_PUBLIC void simsimd_bilinear_##input_type##_##name( \ - simsimd_##input_type##_t const* a, simsimd_##input_type##_t const* b, simsimd_##input_type##_t const* c, \ - simsimd_size_t n, simsimd_distance_t* result) { \ - simsimd_##accumulator_type##_t sum = 0; \ - for (simsimd_size_t i = 0; i != n; ++i) { \ - simsimd_##accumulator_type##_t partial = 0; \ - simsimd_##accumulator_type##_t a_i = load_and_convert(a + i); \ - for (simsimd_size_t j = 0; j != n; ++j) { \ - simsimd_##accumulator_type##_t b_j = load_and_convert(b + j); \ - simsimd_##accumulator_type##_t c_ij = load_and_convert(c + i * n + j); \ - partial += c_ij * b_j; \ - } \ - sum += a_i * partial; \ - } \ - *result = (simsimd_distance_t)sum; \ +#define SIMSIMD_MAKE_BILINEAR(name, input_type, accumulator_type, load_and_convert) \ + SIMSIMD_PUBLIC void simsimd_bilinear_##input_type##_##name( \ + simsimd_##input_type##_t const *a, simsimd_##input_type##_t const *b, simsimd_##input_type##_t const *c, \ + simsimd_size_t n, simsimd_distance_t *result) { \ + simsimd_##accumulator_type##_t sum = 0; \ + for (simsimd_size_t i = 0; i != n; ++i) { \ + simsimd_##accumulator_type##_t partial = 0; \ + simsimd_##accumulator_type##_t a_i = load_and_convert(a + i); \ + for (simsimd_size_t j = 0; j != n; ++j) { \ + simsimd_##accumulator_type##_t b_j = load_and_convert(b + j); \ + simsimd_##accumulator_type##_t c_ij = load_and_convert(c + i * n + j); \ + partial += c_ij * b_j; \ + } \ + sum += a_i * partial; \ + } \ + *result = (simsimd_distance_t)sum; \ } -#define SIMSIMD_MAKE_MAHALANOBIS(name, input_type, accumulator_type, load_and_convert) \ - SIMSIMD_PUBLIC void simsimd_mahalanobis_##input_type##_##name( \ - simsimd_##input_type##_t const* a, simsimd_##input_type##_t const* b, simsimd_##input_type##_t const* c, \ - simsimd_size_t n, simsimd_distance_t* result) { \ - simsimd_##accumulator_type##_t sum = 0; \ - for (simsimd_size_t i = 0; i != n; ++i) { \ - simsimd_##accumulator_type##_t partial = 0; \ - simsimd_##accumulator_type##_t diff_i = load_and_convert(a + i) - load_and_convert(b + i); \ - for (simsimd_size_t j = 0; j != n; ++j) { \ - simsimd_##accumulator_type##_t diff_j = load_and_convert(a + j) - load_and_convert(b + j); \ - simsimd_##accumulator_type##_t c_ij = load_and_convert(c + i * n + j); \ - partial += c_ij * diff_j; \ - } \ - sum += diff_i * partial; \ - } \ - *result = (simsimd_distance_t)SIMSIMD_SQRT(sum); \ +#define SIMSIMD_MAKE_MAHALANOBIS(name, input_type, accumulator_type, load_and_convert) \ + SIMSIMD_PUBLIC void simsimd_mahalanobis_##input_type##_##name( \ + simsimd_##input_type##_t const *a, simsimd_##input_type##_t const *b, simsimd_##input_type##_t const *c, \ + simsimd_size_t n, simsimd_distance_t *result) { \ + simsimd_##accumulator_type##_t sum = 0; \ + for (simsimd_size_t i = 0; i != n; ++i) { \ + simsimd_##accumulator_type##_t partial = 0; \ + simsimd_##accumulator_type##_t diff_i = load_and_convert(a + i) - load_and_convert(b + i); \ + for (simsimd_size_t j = 0; j != n; ++j) { \ + simsimd_##accumulator_type##_t diff_j = load_and_convert(a + j) - load_and_convert(b + j); \ + simsimd_##accumulator_type##_t c_ij = load_and_convert(c + i * n + j); \ + partial += c_ij * diff_j; \ + } \ + sum += diff_i * partial; \ + } \ + *result = (simsimd_distance_t)SIMSIMD_SQRT(sum); \ } SIMSIMD_MAKE_BILINEAR(serial, f64, f64, SIMSIMD_DEREFERENCE) // simsimd_bilinear_f64_serial @@ -149,14 +149,14 @@ SIMSIMD_MAKE_MAHALANOBIS(accurate, f16, f64, SIMSIMD_F16_TO_F32) // simsimd_maha SIMSIMD_MAKE_BILINEAR(accurate, bf16, f64, SIMSIMD_BF16_TO_F32) // simsimd_bilinear_bf16_accurate SIMSIMD_MAKE_MAHALANOBIS(accurate, bf16, f64, SIMSIMD_BF16_TO_F32) // simsimd_mahalanobis_bf16_accurate -#if SIMSIMD_TARGET_ARM +#if _SIMSIMD_TARGET_ARM #if SIMSIMD_TARGET_NEON #pragma GCC push_options #pragma GCC target("arch=armv8.2-a+simd") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd"))), apply_to = function) -SIMSIMD_PUBLIC void simsimd_bilinear_f32_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_f32_t const* c, - simsimd_size_t n, simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_bilinear_f32_neon(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_f32_t const *c, + simsimd_size_t n, simsimd_distance_t *result) { float32x4_t sum_vec = vdupq_n_f32(0); for (simsimd_size_t i = 0; i != n; ++i) { float32x4_t a_vec = vdupq_n_f32(a[i]); @@ -177,8 +177,7 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f32_neon(simsimd_f32_t const* a, simsimd_f3 for (simsimd_size_t i = 0; i != n; ++i) { simsimd_f32_t a_i = a[i]; simsimd_f32_t partial_sum = 0; - for (simsimd_size_t j = tail_start; j != n; ++j) - partial_sum += b[j] * c[i * n + j]; + for (simsimd_size_t j = tail_start; j != n; ++j) partial_sum += b[j] * c[i * n + j]; sum += a[i] * partial_sum; } } @@ -186,8 +185,8 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f32_neon(simsimd_f32_t const* a, simsimd_f3 *result = sum; } -SIMSIMD_PUBLIC void simsimd_mahalanobis_f32_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_f32_t const* c, - simsimd_size_t n, simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_mahalanobis_f32_neon(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_f32_t const *c, + simsimd_size_t n, simsimd_distance_t *result) { float32x4_t sum_vec = vdupq_n_f32(0); for (simsimd_size_t i = 0; i != n; ++i) { float32x4_t diff_i_vec = vdupq_n_f32(a[i] - b[i]); @@ -229,16 +228,16 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f32_neon(simsimd_f32_t const* a, simsimd #pragma GCC target("arch=armv8.2-a+simd+fp16") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function) -SIMSIMD_PUBLIC void simsimd_bilinear_f16_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_f16_t const* c, - simsimd_size_t n, simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_bilinear_f16_neon(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_f16_t const *c, + simsimd_size_t n, simsimd_distance_t *result) { float32x4_t sum_vec = vdupq_n_f32(0); for (simsimd_size_t i = 0; i != n; ++i) { // MSVC doesn't recognize `vdup_n_f16` as a valid intrinsic - float32x4_t a_vec = vcvt_f32_f16(vreinterpret_f16_s16(vdup_n_s16(*(short const*)(a + i)))); + float32x4_t a_vec = vcvt_f32_f16(vreinterpret_f16_s16(vdup_n_s16(*(short const *)(a + i)))); float32x4_t partial_sum_vec = vdupq_n_f32(0); for (simsimd_size_t j = 0; j + 4 <= n; j += 4) { - float32x4_t b_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const*)(b + j))); - float32x4_t c_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const*)(c + i * n + j))); + float32x4_t b_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const *)(b + j))); + float32x4_t c_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const *)(c + i * n + j))); partial_sum_vec = vmlaq_f32(partial_sum_vec, b_vec, c_vec); } sum_vec = vmlaq_f32(sum_vec, a_vec, partial_sum_vec); @@ -261,20 +260,20 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f16_neon(simsimd_f16_t const* a, simsimd_f1 *result = sum; } -SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_f16_t const* c, - simsimd_size_t n, simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_neon(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_f16_t const *c, + simsimd_size_t n, simsimd_distance_t *result) { float32x4_t sum_vec = vdupq_n_f32(0); for (simsimd_size_t i = 0; i != n; ++i) { // MSVC doesn't recognize `vdup_n_f16` as a valid intrinsic - float32x4_t a_i_vec = vcvt_f32_f16(vreinterpret_f16_s16(vdup_n_s16(*(short const*)(a + i)))); - float32x4_t b_i_vec = vcvt_f32_f16(vreinterpret_f16_s16(vdup_n_s16(*(short const*)(b + i)))); + float32x4_t a_i_vec = vcvt_f32_f16(vreinterpret_f16_s16(vdup_n_s16(*(short const *)(a + i)))); + float32x4_t b_i_vec = vcvt_f32_f16(vreinterpret_f16_s16(vdup_n_s16(*(short const *)(b + i)))); float32x4_t diff_i_vec = vsubq_f32(a_i_vec, b_i_vec); float32x4_t partial_sum_vec = vdupq_n_f32(0); for (simsimd_size_t j = 0; j + 4 <= n; j += 4) { - float32x4_t a_j_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const*)(a + j))); - float32x4_t b_j_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const*)(b + j))); + float32x4_t a_j_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const *)(a + j))); + float32x4_t b_j_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const *)(b + j))); float32x4_t diff_j_vec = vsubq_f32(a_j_vec, b_j_vec); - float32x4_t c_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const*)(c + i * n + j))); + float32x4_t c_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const *)(c + i * n + j))); partial_sum_vec = vmlaq_f32(partial_sum_vec, diff_j_vec, c_vec); } sum_vec = vmlaq_f32(sum_vec, diff_i_vec, partial_sum_vec); @@ -310,15 +309,15 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_neon(simsimd_f16_t const* a, simsimd #pragma GCC target("arch=armv8.6-a+simd+bf16") #pragma clang attribute push(__attribute__((target("arch=armv8.6-a+simd+bf16"))), apply_to = function) -SIMSIMD_PUBLIC void simsimd_bilinear_bf16_neon(simsimd_bf16_t const* a, simsimd_bf16_t const* b, - simsimd_bf16_t const* c, simsimd_size_t n, simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_bilinear_bf16_neon(simsimd_bf16_t const *a, simsimd_bf16_t const *b, + simsimd_bf16_t const *c, simsimd_size_t n, simsimd_distance_t *result) { float32x4_t sum_vec = vdupq_n_f32(0); for (simsimd_size_t i = 0; i != n; ++i) { float32x4_t a_vec = vdupq_n_f32(simsimd_bf16_to_f32(a + i)); float32x4_t partial_sum_vec = vdupq_n_f32(0); for (simsimd_size_t j = 0; j + 8 <= n; j += 8) { - bfloat16x8_t b_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const*)(b + j)); - bfloat16x8_t c_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const*)(c + i * n + j)); + bfloat16x8_t b_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const *)(b + j)); + bfloat16x8_t c_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const *)(c + i * n + j)); partial_sum_vec = vbfdotq_f32(partial_sum_vec, b_vec, c_vec); } sum_vec = vmlaq_f32(sum_vec, a_vec, partial_sum_vec); @@ -341,9 +340,9 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16_neon(simsimd_bf16_t const* a, simsimd_ *result = sum; } -SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_neon(simsimd_bf16_t const* a, simsimd_bf16_t const* b, - simsimd_bf16_t const* c, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_neon(simsimd_bf16_t const *a, simsimd_bf16_t const *b, + simsimd_bf16_t const *c, simsimd_size_t n, + simsimd_distance_t *result) { float32x4_t sum_vec = vdupq_n_f32(0); for (simsimd_size_t i = 0; i != n; ++i) { simsimd_f32_t a_i = simsimd_bf16_to_f32(a + i); @@ -351,8 +350,8 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_neon(simsimd_bf16_t const* a, simsi float32x4_t diff_i_vec = vdupq_n_f32(a_i - b_i); float32x4_t partial_sum_vec = vdupq_n_f32(0); for (simsimd_size_t j = 0; j + 8 <= n; j += 8) { - bfloat16x8_t a_j_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const*)(a + j)); - bfloat16x8_t b_j_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const*)(b + j)); + bfloat16x8_t a_j_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const *)(a + j)); + bfloat16x8_t b_j_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const *)(b + j)); // Arm NEON does not have a native subtraction instruction for `bf16`, // so we need to convert to `f32` first, subtract, and only then get back to `bf16` @@ -365,7 +364,7 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_neon(simsimd_bf16_t const* a, simsi float32x4_t diff_j_vec_low = vsubq_f32(a_j_vec_low, b_j_vec_low); bfloat16x8_t diff_j_vec = vcombine_bf16(vcvt_bf16_f32(diff_j_vec_low), vcvt_bf16_f32(diff_j_vec_high)); - bfloat16x8_t c_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const*)(c + i * n + j)); + bfloat16x8_t c_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const *)(c + i * n + j)); partial_sum_vec = vbfdotq_f32(partial_sum_vec, diff_j_vec, c_vec); } sum_vec = vmlaq_f32(sum_vec, diff_i_vec, partial_sum_vec); @@ -405,23 +404,23 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_neon(simsimd_bf16_t const* a, simsi #pragma GCC pop_options #endif // SIMSIMD_TARGET_NEON_BF16 -#endif // SIMSIMD_TARGET_ARM +#endif // _SIMSIMD_TARGET_ARM -#if SIMSIMD_TARGET_X86 +#if _SIMSIMD_TARGET_X86 #if SIMSIMD_TARGET_HASWELL #pragma GCC push_options #pragma GCC target("avx2", "f16c", "fma") #pragma clang attribute push(__attribute__((target("avx2,f16c,fma"))), apply_to = function) -SIMSIMD_PUBLIC void simsimd_bilinear_f16_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_f16_t const* c, - simsimd_size_t n, simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_bilinear_f16_haswell(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_f16_t const *c, + simsimd_size_t n, simsimd_distance_t *result) { __m256 sum_vec = _mm256_setzero_ps(); for (simsimd_size_t i = 0; i != n; ++i) { - __m256 a_vec = _mm256_cvtph_ps(_mm_set1_epi16(*(short const*)(a + i))); + __m256 a_vec = _mm256_cvtph_ps(_mm_set1_epi16(*(short const *)(a + i))); __m256 partial_sum_vec = _mm256_setzero_ps(); for (simsimd_size_t j = 0; j + 8 <= n; j += 8) { - __m256 b_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const*)(b + j))); - __m256 c_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const*)(c + i * n + j))); + __m256 b_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)(b + j))); + __m256 c_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)(c + i * n + j))); partial_sum_vec = _mm256_fmadd_ps(b_vec, c_vec, partial_sum_vec); } sum_vec = _mm256_fmadd_ps(a_vec, partial_sum_vec, sum_vec); @@ -433,7 +432,7 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f16_haswell(simsimd_f16_t const* a, simsimd simsimd_size_t tail_start = n - tail_length; if (tail_length) { for (simsimd_size_t i = 0; i != n; ++i) { - simsimd_f32_t a_i = _mm256_cvtss_f32(_mm256_cvtph_ps(_mm_set1_epi16(*(short const*)(a + i)))); + simsimd_f32_t a_i = _mm256_cvtss_f32(_mm256_cvtph_ps(_mm_set1_epi16(*(short const *)(a + i)))); __m256 b_vec = _simsimd_partial_load_f16x8_haswell(b + tail_start, tail_length); __m256 c_vec = _simsimd_partial_load_f16x8_haswell(c + i * n + tail_start, tail_length); simsimd_f32_t partial_sum = _simsimd_reduce_f32x8_haswell(_mm256_mul_ps(b_vec, c_vec)); @@ -444,20 +443,20 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f16_haswell(simsimd_f16_t const* a, simsimd *result = sum; } -SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, - simsimd_f16_t const* c, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_haswell(simsimd_f16_t const *a, simsimd_f16_t const *b, + simsimd_f16_t const *c, simsimd_size_t n, + simsimd_distance_t *result) { __m256 sum_vec = _mm256_setzero_ps(); for (simsimd_size_t i = 0; i != n; ++i) { - __m256 diff_i_vec = _mm256_sub_ps( // - _mm256_cvtph_ps(_mm_set1_epi16(*(short const*)(a + i))), // - _mm256_cvtph_ps(_mm_set1_epi16(*(short const*)(b + i)))); + __m256 diff_i_vec = _mm256_sub_ps( // + _mm256_cvtph_ps(_mm_set1_epi16(*(short const *)(a + i))), // + _mm256_cvtph_ps(_mm_set1_epi16(*(short const *)(b + i)))); __m256 partial_sum_vec = _mm256_setzero_ps(); for (simsimd_size_t j = 0; j + 8 <= n; j += 8) { __m256 diff_j_vec = _mm256_sub_ps( // - _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const*)(a + j))), - _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const*)(b + j)))); - __m256 c_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const*)(c + i * n + j))); + _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)(a + j))), + _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)(b + j)))); + __m256 c_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)(c + i * n + j))); partial_sum_vec = _mm256_fmadd_ps(diff_j_vec, c_vec, partial_sum_vec); } sum_vec = _mm256_fmadd_ps(diff_i_vec, partial_sum_vec, sum_vec); @@ -469,9 +468,9 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_haswell(simsimd_f16_t const* a, sims simsimd_size_t tail_start = n - tail_length; if (tail_length) { for (simsimd_size_t i = 0; i != n; ++i) { - simsimd_f32_t diff_i = _mm256_cvtss_f32(_mm256_sub_ps( // - _mm256_cvtph_ps(_mm_set1_epi16(*(short const*)(a + i))), // - _mm256_cvtph_ps(_mm_set1_epi16(*(short const*)(b + i))))); + simsimd_f32_t diff_i = _mm256_cvtss_f32(_mm256_sub_ps( // + _mm256_cvtph_ps(_mm_set1_epi16(*(short const *)(a + i))), // + _mm256_cvtph_ps(_mm_set1_epi16(*(short const *)(b + i))))); __m256 diff_j_vec = _mm256_sub_ps( // _simsimd_partial_load_f16x8_haswell(a + tail_start, tail_length), _simsimd_partial_load_f16x8_haswell(b + tail_start, tail_length)); @@ -484,17 +483,17 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_haswell(simsimd_f16_t const* a, sims *result = _simsimd_sqrt_f32_haswell(sum); } -SIMSIMD_PUBLIC void simsimd_bilinear_bf16_haswell(simsimd_bf16_t const* a, simsimd_bf16_t const* b, - simsimd_bf16_t const* c, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_bilinear_bf16_haswell(simsimd_bf16_t const *a, simsimd_bf16_t const *b, + simsimd_bf16_t const *c, simsimd_size_t n, + simsimd_distance_t *result) { __m256 sum_vec = _mm256_setzero_ps(); for (simsimd_size_t i = 0; i != n; ++i) { // The `simsimd_bf16_to_f32` is cheaper than `_simsimd_bf16x8_to_f32x8_haswell` __m256 a_vec = _mm256_set1_ps(simsimd_bf16_to_f32(a + i)); __m256 partial_sum_vec = _mm256_setzero_ps(); for (simsimd_size_t j = 0; j + 8 <= n; j += 8) { - __m256 b_vec = _simsimd_bf16x8_to_f32x8_haswell(_mm_lddqu_si128((__m128i const*)(b + j))); - __m256 c_vec = _simsimd_bf16x8_to_f32x8_haswell(_mm_lddqu_si128((__m128i const*)(c + i * n + j))); + __m256 b_vec = _simsimd_bf16x8_to_f32x8_haswell(_mm_lddqu_si128((__m128i const *)(b + j))); + __m256 c_vec = _simsimd_bf16x8_to_f32x8_haswell(_mm_lddqu_si128((__m128i const *)(c + i * n + j))); partial_sum_vec = _mm256_fmadd_ps(b_vec, c_vec, partial_sum_vec); } sum_vec = _mm256_fmadd_ps(a_vec, partial_sum_vec, sum_vec); @@ -519,9 +518,9 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16_haswell(simsimd_bf16_t const* a, simsi *result = sum; } -SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_haswell(simsimd_bf16_t const* a, simsimd_bf16_t const* b, - simsimd_bf16_t const* c, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_haswell(simsimd_bf16_t const *a, simsimd_bf16_t const *b, + simsimd_bf16_t const *c, simsimd_size_t n, + simsimd_distance_t *result) { __m256 sum_vec = _mm256_setzero_ps(); for (simsimd_size_t i = 0; i != n; ++i) { __m256 diff_i_vec = _mm256_sub_ps( // @@ -529,10 +528,10 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_haswell(simsimd_bf16_t const* a, si _mm256_set1_ps(simsimd_bf16_to_f32(b + i))); __m256 partial_sum_vec = _mm256_setzero_ps(); for (simsimd_size_t j = 0; j + 8 <= n; j += 8) { - __m256 diff_j_vec = _mm256_sub_ps( // - _simsimd_bf16x8_to_f32x8_haswell(_mm_lddqu_si128((__m128i const*)(a + j))), // - _simsimd_bf16x8_to_f32x8_haswell(_mm_lddqu_si128((__m128i const*)(b + j)))); - __m256 c_vec = _simsimd_bf16x8_to_f32x8_haswell(_mm_lddqu_si128((__m128i const*)(c + i * n + j))); + __m256 diff_j_vec = _mm256_sub_ps( // + _simsimd_bf16x8_to_f32x8_haswell(_mm_lddqu_si128((__m128i const *)(a + j))), // + _simsimd_bf16x8_to_f32x8_haswell(_mm_lddqu_si128((__m128i const *)(b + j)))); + __m256 c_vec = _simsimd_bf16x8_to_f32x8_haswell(_mm_lddqu_si128((__m128i const *)(c + i * n + j))); partial_sum_vec = _mm256_fmadd_ps(diff_j_vec, c_vec, partial_sum_vec); } sum_vec = _mm256_fmadd_ps(diff_i_vec, partial_sum_vec, sum_vec); @@ -567,8 +566,8 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_haswell(simsimd_bf16_t const* a, si #pragma GCC target("avx2", "avx512f", "avx512vl", "bmi2") #pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,bmi2"))), apply_to = function) -SIMSIMD_PUBLIC void simsimd_bilinear_f32_skylake(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_f32_t const* c, - simsimd_size_t n, simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_bilinear_f32_skylake(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_f32_t const *c, + simsimd_size_t n, simsimd_distance_t *result) { simsimd_size_t tail_length = n % 16; simsimd_size_t tail_start = n - tail_length; __m512 sum_vec = _mm512_setzero_ps(); @@ -584,23 +583,23 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f32_skylake(simsimd_f32_t const* a, simsimd if (j + 16 <= n) { b_vec = _mm512_loadu_ps(b + j); c_vec = _mm512_loadu_ps(c + i * n + j); - } else { + } + else { b_vec = _mm512_maskz_loadu_ps(tail_mask, b + tail_start); c_vec = _mm512_maskz_loadu_ps(tail_mask, c + i * n + tail_start); } partial_sum_vec = _mm512_fmadd_ps(b_vec, c_vec, partial_sum_vec); j += 16; - if (j < n) - goto simsimd_bilinear_f32_skylake_cycle; + if (j < n) goto simsimd_bilinear_f32_skylake_cycle; sum_vec = _mm512_fmadd_ps(a_vec, partial_sum_vec, sum_vec); } *result = _mm512_reduce_add_ps(sum_vec); } -SIMSIMD_PUBLIC void simsimd_mahalanobis_f32_skylake(simsimd_f32_t const* a, simsimd_f32_t const* b, - simsimd_f32_t const* c, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_mahalanobis_f32_skylake(simsimd_f32_t const *a, simsimd_f32_t const *b, + simsimd_f32_t const *c, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_size_t tail_length = n % 16; simsimd_size_t tail_start = n - tail_length; __m512 sum_vec = _mm512_setzero_ps(); @@ -618,7 +617,8 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f32_skylake(simsimd_f32_t const* a, sims a_j_vec = _mm512_loadu_ps(a + j); b_j_vec = _mm512_loadu_ps(b + j); c_vec = _mm512_loadu_ps(c + i * n + j); - } else { + } + else { a_j_vec = _mm512_maskz_loadu_ps(tail_mask, a + tail_start); b_j_vec = _mm512_maskz_loadu_ps(tail_mask, b + tail_start); c_vec = _mm512_maskz_loadu_ps(tail_mask, c + i * n + tail_start); @@ -626,8 +626,7 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f32_skylake(simsimd_f32_t const* a, sims diff_j_vec = _mm512_sub_ps(a_j_vec, b_j_vec); partial_sum_vec = _mm512_fmadd_ps(diff_j_vec, c_vec, partial_sum_vec); j += 16; - if (j < n) - goto simsimd_bilinear_f32_skylake_cycle; + if (j < n) goto simsimd_bilinear_f32_skylake_cycle; sum_vec = _mm512_fmadd_ps(diff_i_vec, partial_sum_vec, sum_vec); } @@ -641,11 +640,11 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f32_skylake(simsimd_f32_t const* a, sims #if SIMSIMD_TARGET_GENOA #pragma GCC push_options #pragma GCC target("avx2", "avx512f", "avx512vl", "bmi2", "avx512bw", "avx512bf16") -#pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,bmi2,avx512bw,avx512bf16"))), \ +#pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,bmi2,avx512bw,avx512bf16"))), \ apply_to = function) -SIMSIMD_PUBLIC void simsimd_bilinear_bf16_genoa(simsimd_bf16_t const* a, simsimd_bf16_t const* b, - simsimd_bf16_t const* c, simsimd_size_t n, simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_bilinear_bf16_genoa(simsimd_bf16_t const *a, simsimd_bf16_t const *b, + simsimd_bf16_t const *c, simsimd_size_t n, simsimd_distance_t *result) { simsimd_size_t tail_length = n % 32; simsimd_size_t tail_start = n - tail_length; __m512 sum_vec = _mm512_setzero_ps(); @@ -661,23 +660,23 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16_genoa(simsimd_bf16_t const* a, simsimd if (j + 32 <= n) { b_vec = _mm512_loadu_epi16(b + j); c_vec = _mm512_loadu_epi16(c + i * n + j); - } else { + } + else { b_vec = _mm512_maskz_loadu_epi16(tail_mask, b + tail_start); c_vec = _mm512_maskz_loadu_epi16(tail_mask, c + i * n + tail_start); } partial_sum_vec = _mm512_dpbf16_ps(partial_sum_vec, (__m512bh)(b_vec), (__m512bh)(c_vec)); j += 32; - if (j < n) - goto simsimd_bilinear_bf16_genoa_cycle; + if (j < n) goto simsimd_bilinear_bf16_genoa_cycle; sum_vec = _mm512_fmadd_ps(a_vec, partial_sum_vec, sum_vec); } *result = _mm512_reduce_add_ps(sum_vec); } -SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_genoa(simsimd_bf16_t const* a, simsimd_bf16_t const* b, - simsimd_bf16_t const* c, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_genoa(simsimd_bf16_t const *a, simsimd_bf16_t const *b, + simsimd_bf16_t const *c, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_size_t tail_length = n % 32; simsimd_size_t tail_start = n - tail_length; __m512 sum_vec = _mm512_setzero_ps(); @@ -695,7 +694,8 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_genoa(simsimd_bf16_t const* a, sims a_j_vec = _mm512_loadu_epi16(a + j); b_j_vec = _mm512_loadu_epi16(b + j); c_vec = _mm512_loadu_epi16(c + i * n + j); - } else { + } + else { a_j_vec = _mm512_maskz_loadu_epi16(tail_mask, a + tail_start); b_j_vec = _mm512_maskz_loadu_epi16(tail_mask, b + tail_start); c_vec = _mm512_maskz_loadu_epi16(tail_mask, c + i * n + tail_start); @@ -703,8 +703,7 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_genoa(simsimd_bf16_t const* a, sims diff_j_vec = _simsimd_substract_bf16x32_genoa(a_j_vec, b_j_vec); partial_sum_vec = _mm512_dpbf16_ps(partial_sum_vec, (__m512bh)(diff_j_vec), (__m512bh)(c_vec)); j += 32; - if (j < n) - goto simsimd_mahalanobis_bf16_genoa_cycle; + if (j < n) goto simsimd_mahalanobis_bf16_genoa_cycle; sum_vec = _mm512_fmadd_ps(diff_i_vec, partial_sum_vec, sum_vec); } @@ -718,19 +717,19 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_genoa(simsimd_bf16_t const* a, sims #if SIMSIMD_TARGET_SAPPHIRE #pragma GCC push_options #pragma GCC target("avx2", "avx512f", "avx512vl", "bmi2", "avx512bw", "avx512fp16") -#pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,bmi2,avx512bw,avx512fp16"))), \ +#pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,bmi2,avx512bw,avx512fp16"))), \ apply_to = function) -SIMSIMD_PUBLIC void simsimd_bilinear_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, - simsimd_f16_t const* c, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_bilinear_f16_sapphire(simsimd_f16_t const *a, simsimd_f16_t const *b, + simsimd_f16_t const *c, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_size_t tail_length = n % 32; simsimd_size_t tail_start = n - tail_length; __m512h sum_vec = _mm512_setzero_ph(); __mmask32 tail_mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, tail_length); for (simsimd_size_t i = 0; i != n; ++i) { - __m512h a_vec = _mm512_castsi512_ph(_mm512_set1_epi16(*(short const*)(a + i))); + __m512h a_vec = _mm512_castsi512_ph(_mm512_set1_epi16(*(short const *)(a + i))); __m512h partial_sum_vec = _mm512_setzero_ph(); __m512i b_vec, c_vec; simsimd_size_t j = 0; @@ -739,31 +738,31 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f16_sapphire(simsimd_f16_t const* a, simsim if (j + 32 <= n) { b_vec = _mm512_loadu_epi16(b + j); c_vec = _mm512_loadu_epi16(c + i * n + j); - } else { + } + else { b_vec = _mm512_maskz_loadu_epi16(tail_mask, b + tail_start); c_vec = _mm512_maskz_loadu_epi16(tail_mask, c + i * n + tail_start); } partial_sum_vec = _mm512_fmadd_ph(_mm512_castsi512_ph(b_vec), _mm512_castsi512_ph(c_vec), partial_sum_vec); j += 32; - if (j < n) - goto simsimd_bilinear_f16_sapphire_cycle; + if (j < n) goto simsimd_bilinear_f16_sapphire_cycle; sum_vec = _mm512_fmadd_ph(a_vec, partial_sum_vec, sum_vec); } *result = _mm512_reduce_add_ph(sum_vec); } -SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, - simsimd_f16_t const* c, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_sapphire(simsimd_f16_t const *a, simsimd_f16_t const *b, + simsimd_f16_t const *c, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_size_t tail_length = n % 32; simsimd_size_t tail_start = n - tail_length; __m512h sum_vec = _mm512_setzero_ph(); __mmask32 tail_mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, tail_length); for (simsimd_size_t i = 0; i != n; ++i) { - __m512h a_i_vec = _mm512_castsi512_ph(_mm512_set1_epi16(*(short const*)(a + i))); - __m512h b_i_vec = _mm512_castsi512_ph(_mm512_set1_epi16(*(short const*)(b + i))); + __m512h a_i_vec = _mm512_castsi512_ph(_mm512_set1_epi16(*(short const *)(a + i))); + __m512h b_i_vec = _mm512_castsi512_ph(_mm512_set1_epi16(*(short const *)(b + i))); __m512h diff_i_vec = _mm512_sub_ph(a_i_vec, b_i_vec); __m512h partial_sum_vec = _mm512_setzero_ph(); __m512h diff_j_vec; @@ -776,7 +775,8 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_sapphire(simsimd_f16_t const* a, sim a_j_vec = _mm512_loadu_epi16(a + j); b_j_vec = _mm512_loadu_epi16(b + j); c_vec = _mm512_loadu_epi16(c + i * n + j); - } else { + } + else { a_j_vec = _mm512_maskz_loadu_epi16(tail_mask, a + tail_start); b_j_vec = _mm512_maskz_loadu_epi16(tail_mask, b + tail_start); c_vec = _mm512_maskz_loadu_epi16(tail_mask, c + i * n + tail_start); @@ -784,8 +784,7 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_sapphire(simsimd_f16_t const* a, sim diff_j_vec = _mm512_sub_ph(_mm512_castsi512_ph(a_j_vec), _mm512_castsi512_ph(b_j_vec)); partial_sum_vec = _mm512_fmadd_ph(diff_j_vec, _mm512_castsi512_ph(c_vec), partial_sum_vec); j += 32; - if (j < n) - goto simsimd_mahalanobis_f16_sapphire_cycle; + if (j < n) goto simsimd_mahalanobis_f16_sapphire_cycle; sum_vec = _mm512_fmadd_ph(diff_i_vec, partial_sum_vec, sum_vec); } @@ -795,7 +794,7 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_sapphire(simsimd_f16_t const* a, sim #pragma clang attribute pop #pragma GCC pop_options #endif // SIMSIMD_TARGET_SAPPHIRE -#endif // SIMSIMD_TARGET_X86 +#endif // _SIMSIMD_TARGET_X86 #ifdef __cplusplus } diff --git a/include/simsimd/dot.h b/include/simsimd/dot.h index 112a659f..556940b1 100644 --- a/include/simsimd/dot.h +++ b/include/simsimd/dot.h @@ -153,51 +153,51 @@ SIMSIMD_PUBLIC void simsimd_vdot_f16c_sapphire(simsimd_f16_t const* a, simsimd_f SIMSIMD_PUBLIC void simsimd_dot_i8_sierra(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t* result); // clang-format on -#define SIMSIMD_MAKE_DOT(name, input_type, accumulator_type, load_and_convert) \ - SIMSIMD_PUBLIC void simsimd_dot_##input_type##_##name(simsimd_##input_type##_t const* a, \ - simsimd_##input_type##_t const* b, simsimd_size_t n, \ - simsimd_distance_t* result) { \ - simsimd_##accumulator_type##_t ab = 0; \ - for (simsimd_size_t i = 0; i != n; ++i) { \ - simsimd_##accumulator_type##_t ai = load_and_convert(a + i); \ - simsimd_##accumulator_type##_t bi = load_and_convert(b + i); \ - ab += ai * bi; \ - } \ - *result = ab; \ +#define SIMSIMD_MAKE_DOT(name, input_type, accumulator_type, load_and_convert) \ + SIMSIMD_PUBLIC void simsimd_dot_##input_type##_##name(simsimd_##input_type##_t const *a, \ + simsimd_##input_type##_t const *b, simsimd_size_t n, \ + simsimd_distance_t *result) { \ + simsimd_##accumulator_type##_t ab = 0; \ + for (simsimd_size_t i = 0; i != n; ++i) { \ + simsimd_##accumulator_type##_t ai = load_and_convert(a + i); \ + simsimd_##accumulator_type##_t bi = load_and_convert(b + i); \ + ab += ai * bi; \ + } \ + *result = ab; \ } -#define SIMSIMD_MAKE_COMPLEX_DOT(name, input_type, accumulator_type, load_and_convert) \ - SIMSIMD_PUBLIC void simsimd_dot_##input_type##c_##name(simsimd_##input_type##_t const* a, \ - simsimd_##input_type##_t const* b, simsimd_size_t n, \ - simsimd_distance_t* results) { \ - simsimd_##accumulator_type##_t ab_real = 0, ab_imag = 0; \ - for (simsimd_size_t i = 0; i + 2 <= n; i += 2) { \ - simsimd_##accumulator_type##_t ar = load_and_convert(a + i); \ - simsimd_##accumulator_type##_t br = load_and_convert(b + i); \ - simsimd_##accumulator_type##_t ai = load_and_convert(a + i + 1); \ - simsimd_##accumulator_type##_t bi = load_and_convert(b + i + 1); \ - ab_real += ar * br - ai * bi; \ - ab_imag += ar * bi + ai * br; \ - } \ - results[0] = ab_real; \ - results[1] = ab_imag; \ +#define SIMSIMD_MAKE_COMPLEX_DOT(name, input_type, accumulator_type, load_and_convert) \ + SIMSIMD_PUBLIC void simsimd_dot_##input_type##c_##name(simsimd_##input_type##_t const *a, \ + simsimd_##input_type##_t const *b, simsimd_size_t n, \ + simsimd_distance_t *results) { \ + simsimd_##accumulator_type##_t ab_real = 0, ab_imag = 0; \ + for (simsimd_size_t i = 0; i + 2 <= n; i += 2) { \ + simsimd_##accumulator_type##_t ar = load_and_convert(a + i); \ + simsimd_##accumulator_type##_t br = load_and_convert(b + i); \ + simsimd_##accumulator_type##_t ai = load_and_convert(a + i + 1); \ + simsimd_##accumulator_type##_t bi = load_and_convert(b + i + 1); \ + ab_real += ar * br - ai * bi; \ + ab_imag += ar * bi + ai * br; \ + } \ + results[0] = ab_real; \ + results[1] = ab_imag; \ } -#define SIMSIMD_MAKE_COMPLEX_VDOT(name, input_type, accumulator_type, load_and_convert) \ - SIMSIMD_PUBLIC void simsimd_vdot_##input_type##c_##name(simsimd_##input_type##_t const* a, \ - simsimd_##input_type##_t const* b, simsimd_size_t n, \ - simsimd_distance_t* results) { \ - simsimd_##accumulator_type##_t ab_real = 0, ab_imag = 0; \ - for (simsimd_size_t i = 0; i + 2 <= n; i += 2) { \ - simsimd_##accumulator_type##_t ar = load_and_convert(a + i); \ - simsimd_##accumulator_type##_t br = load_and_convert(b + i); \ - simsimd_##accumulator_type##_t ai = load_and_convert(a + i + 1); \ - simsimd_##accumulator_type##_t bi = load_and_convert(b + i + 1); \ - ab_real += ar * br + ai * bi; \ - ab_imag += ar * bi - ai * br; \ - } \ - results[0] = ab_real; \ - results[1] = ab_imag; \ +#define SIMSIMD_MAKE_COMPLEX_VDOT(name, input_type, accumulator_type, load_and_convert) \ + SIMSIMD_PUBLIC void simsimd_vdot_##input_type##c_##name(simsimd_##input_type##_t const *a, \ + simsimd_##input_type##_t const *b, simsimd_size_t n, \ + simsimd_distance_t *results) { \ + simsimd_##accumulator_type##_t ab_real = 0, ab_imag = 0; \ + for (simsimd_size_t i = 0; i + 2 <= n; i += 2) { \ + simsimd_##accumulator_type##_t ar = load_and_convert(a + i); \ + simsimd_##accumulator_type##_t br = load_and_convert(b + i); \ + simsimd_##accumulator_type##_t ai = load_and_convert(a + i + 1); \ + simsimd_##accumulator_type##_t bi = load_and_convert(b + i + 1); \ + ab_real += ar * br + ai * bi; \ + ab_imag += ar * bi - ai * br; \ + } \ + results[0] = ab_real; \ + results[1] = ab_imag; \ } SIMSIMD_MAKE_DOT(serial, f64, f64, SIMSIMD_DEREFERENCE) // simsimd_dot_f64_serial @@ -231,27 +231,25 @@ SIMSIMD_MAKE_DOT(accurate, bf16, f64, SIMSIMD_BF16_TO_F32) // simsimd_d SIMSIMD_MAKE_COMPLEX_DOT(accurate, bf16, f64, SIMSIMD_BF16_TO_F32) // simsimd_dot_bf16c_accurate SIMSIMD_MAKE_COMPLEX_VDOT(accurate, bf16, f64, SIMSIMD_BF16_TO_F32) // simsimd_vdot_bf16c_accurate -#if SIMSIMD_TARGET_ARM +#if _SIMSIMD_TARGET_ARM #if SIMSIMD_TARGET_NEON #pragma GCC push_options #pragma GCC target("arch=armv8.2-a+simd") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd"))), apply_to = function) -SIMSIMD_INTERNAL float32x4_t _simsimd_partial_load_f32x4_neon(simsimd_f32_t const* a, simsimd_size_t n) { +SIMSIMD_INTERNAL float32x4_t _simsimd_partial_load_f32x4_neon(simsimd_f32_t const *a, simsimd_size_t n) { union { float32x4_t vec; simsimd_f32_t scalars[4]; } result; simsimd_size_t i = 0; - for (; i < n; ++i) - result.scalars[i] = a[i]; - for (; i < 4; ++i) - result.scalars[i] = 0; + for (; i < n; ++i) result.scalars[i] = a[i]; + for (; i < 4; ++i) result.scalars[i] = 0; return result.vec; } -SIMSIMD_PUBLIC void simsimd_dot_f32_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_dot_f32_neon(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { float32x4_t ab_vec = vdupq_n_f32(0); simsimd_size_t i = 0; for (; i + 4 <= n; i += 4) { @@ -260,13 +258,12 @@ SIMSIMD_PUBLIC void simsimd_dot_f32_neon(simsimd_f32_t const* a, simsimd_f32_t c ab_vec = vfmaq_f32(ab_vec, a_vec, b_vec); } simsimd_f32_t ab = vaddvq_f32(ab_vec); - for (; i < n; ++i) - ab += a[i] * b[i]; + for (; i < n; ++i) ab += a[i] * b[i]; *result = ab; } -SIMSIMD_PUBLIC void simsimd_dot_f32c_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, // - simsimd_distance_t* results) { +SIMSIMD_PUBLIC void simsimd_dot_f32c_neon(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, // + simsimd_distance_t *results) { float32x4_t ab_real_vec = vdupq_n_f32(0); float32x4_t ab_imag_vec = vdupq_n_f32(0); simsimd_size_t i = 0; @@ -300,8 +297,8 @@ SIMSIMD_PUBLIC void simsimd_dot_f32c_neon(simsimd_f32_t const* a, simsimd_f32_t results[1] = ab_imag; } -SIMSIMD_PUBLIC void simsimd_vdot_f32c_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, // - simsimd_distance_t* results) { +SIMSIMD_PUBLIC void simsimd_vdot_f32c_neon(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, // + simsimd_distance_t *results) { float32x4_t ab_real_vec = vdupq_n_f32(0); float32x4_t ab_imag_vec = vdupq_n_f32(0); simsimd_size_t i = 0; @@ -337,13 +334,15 @@ SIMSIMD_PUBLIC void simsimd_vdot_f32c_neon(simsimd_f32_t const* a, simsimd_f32_t #pragma clang attribute pop #pragma GCC pop_options +#endif // SIMSIMD_TARGET_NEON +#if SIMSIMD_TARGET_NEON_I8 #pragma GCC push_options #pragma GCC target("arch=armv8.2-a+dotprod") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+dotprod"))), apply_to = function) -SIMSIMD_PUBLIC void simsimd_dot_i8_neon(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_dot_i8_neon(simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { int32x4_t ab_vec = vdupq_n_s32(0); simsimd_size_t i = 0; @@ -372,8 +371,8 @@ SIMSIMD_PUBLIC void simsimd_dot_i8_neon(simsimd_i8_t const* a, simsimd_i8_t cons *result = ab; } -SIMSIMD_PUBLIC void simsimd_dot_u8_neon(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_dot_u8_neon(simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { uint32x4_t ab_vec = vdupq_n_u32(0); simsimd_size_t i = 0; @@ -395,14 +394,14 @@ SIMSIMD_PUBLIC void simsimd_dot_u8_neon(simsimd_u8_t const* a, simsimd_u8_t cons #pragma clang attribute pop #pragma GCC pop_options -#endif +#endif // SIMSIMD_TARGET_NEON_I8 #if SIMSIMD_TARGET_NEON_F16 #pragma GCC push_options #pragma GCC target("arch=armv8.2-a+simd+fp16") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function) -SIMSIMD_INTERNAL float16x4_t _simsimd_partial_load_f16x4_neon(simsimd_f16_t const* a, simsimd_size_t n) { +SIMSIMD_INTERNAL float16x4_t _simsimd_partial_load_f16x4_neon(simsimd_f16_t const *a, simsimd_size_t n) { // In case the software emulation for `f16` scalars is enabled, the `simsimd_f16_to_f32` // function will run. It is extremely slow, so even for the tail, let's combine serial // loads and stores with vectorized math. @@ -411,15 +410,13 @@ SIMSIMD_INTERNAL float16x4_t _simsimd_partial_load_f16x4_neon(simsimd_f16_t cons simsimd_f16_t scalars[4]; } result; simsimd_size_t i = 0; - for (; i < n; ++i) - result.scalars[i] = a[i]; - for (; i < 4; ++i) - result.scalars[i] = 0; + for (; i < n; ++i) result.scalars[i] = a[i]; + for (; i < 4; ++i) result.scalars[i] = 0; return result.vec; } -SIMSIMD_PUBLIC void simsimd_dot_f16_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_dot_f16_neon(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { float32x4_t ab_vec = vdupq_n_f32(0); float32x4_t a_vec, b_vec; simsimd_size_t i = 0; @@ -429,19 +426,19 @@ SIMSIMD_PUBLIC void simsimd_dot_f16_neon(simsimd_f16_t const* a, simsimd_f16_t c a_vec = vcvt_f32_f16(_simsimd_partial_load_f16x4_neon(a, n)); b_vec = vcvt_f32_f16(_simsimd_partial_load_f16x4_neon(b, n)); n = 0; - } else { - a_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const*)a)); - b_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const*)b)); + } + else { + a_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const *)a)); + b_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const *)b)); a += 4, b += 4, n -= 4; } ab_vec = vfmaq_f32(ab_vec, a_vec, b_vec); - if (n) - goto simsimd_dot_f16_neon_cycle; + if (n) goto simsimd_dot_f16_neon_cycle; *result = vaddvq_f32(ab_vec); } -SIMSIMD_PUBLIC void simsimd_dot_f16c_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, // - simsimd_distance_t* results) { +SIMSIMD_PUBLIC void simsimd_dot_f16c_neon(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, // + simsimd_distance_t *results) { // A nicer approach is to use `f16` arithmetic for the dot product, but that requires // FMLA extensions available on Arm v8.3 and later. That we can also process 16 entries @@ -453,8 +450,8 @@ SIMSIMD_PUBLIC void simsimd_dot_f16c_neon(simsimd_f16_t const* a, simsimd_f16_t // Unpack the input arrays into real and imaginary parts. // MSVC sadly doesn't recognize the `vld2_f16`, so we load the data as signed // integers of the same size and reinterpret with `vreinterpret_f16_s16` afterwards. - int16x4x2_t a_vec = vld2_s16((short*)a); - int16x4x2_t b_vec = vld2_s16((short*)b); + int16x4x2_t a_vec = vld2_s16((short *)a); + int16x4x2_t b_vec = vld2_s16((short *)b); float32x4_t a_real_vec = vcvt_f32_f16(vreinterpret_f16_s16(a_vec.val[0])); float32x4_t a_imag_vec = vcvt_f32_f16(vreinterpret_f16_s16(a_vec.val[1])); float32x4_t b_real_vec = vcvt_f32_f16(vreinterpret_f16_s16(b_vec.val[0])); @@ -475,8 +472,8 @@ SIMSIMD_PUBLIC void simsimd_dot_f16c_neon(simsimd_f16_t const* a, simsimd_f16_t results[1] += vaddvq_f32(ab_imag_vec); } -SIMSIMD_PUBLIC void simsimd_vdot_f16c_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, // - simsimd_distance_t* results) { +SIMSIMD_PUBLIC void simsimd_vdot_f16c_neon(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, // + simsimd_distance_t *results) { // A nicer approach is to use `f16` arithmetic for the dot product, but that requires // FMLA extensions available on Arm v8.3 and later. That we can also process 16 entries @@ -488,8 +485,8 @@ SIMSIMD_PUBLIC void simsimd_vdot_f16c_neon(simsimd_f16_t const* a, simsimd_f16_t // Unpack the input arrays into real and imaginary parts. // MSVC sadly doesn't recognize the `vld2_f16`, so we load the data as signed // integers of the same size and reinterpret with `vreinterpret_f16_s16` afterwards. - int16x4x2_t a_vec = vld2_s16((short*)a); - int16x4x2_t b_vec = vld2_s16((short*)b); + int16x4x2_t a_vec = vld2_s16((short *)a); + int16x4x2_t b_vec = vld2_s16((short *)b); float32x4_t a_real_vec = vcvt_f32_f16(vreinterpret_f16_s16(a_vec.val[0])); float32x4_t a_imag_vec = vcvt_f32_f16(vreinterpret_f16_s16(a_vec.val[1])); float32x4_t b_real_vec = vcvt_f32_f16(vreinterpret_f16_s16(b_vec.val[0])); @@ -519,21 +516,19 @@ SIMSIMD_PUBLIC void simsimd_vdot_f16c_neon(simsimd_f16_t const* a, simsimd_f16_t #pragma GCC target("arch=armv8.6-a+simd+bf16") #pragma clang attribute push(__attribute__((target("arch=armv8.6-a+simd+bf16"))), apply_to = function) -SIMSIMD_INTERNAL bfloat16x8_t _simsimd_partial_load_bf16x8_neon(simsimd_bf16_t const* a, simsimd_size_t n) { +SIMSIMD_INTERNAL bfloat16x8_t _simsimd_partial_load_bf16x8_neon(simsimd_bf16_t const *a, simsimd_size_t n) { union { bfloat16x8_t vec; simsimd_bf16_t scalars[8]; } result; simsimd_size_t i = 0; - for (; i < n; ++i) - result.scalars[i] = a[i]; - for (; i < 8; ++i) - result.scalars[i] = 0; + for (; i < n; ++i) result.scalars[i] = a[i]; + for (; i < 8; ++i) result.scalars[i] = 0; return result.vec; } -SIMSIMD_PUBLIC void simsimd_dot_bf16_neon(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_dot_bf16_neon(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { float32x4_t ab_vec = vdupq_n_f32(0); bfloat16x8_t a_vec, b_vec; @@ -542,20 +537,20 @@ SIMSIMD_PUBLIC void simsimd_dot_bf16_neon(simsimd_bf16_t const* a, simsimd_bf16_ a_vec = _simsimd_partial_load_bf16x8_neon(a, n); b_vec = _simsimd_partial_load_bf16x8_neon(b, n); n = 0; - } else { - a_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const*)a); - b_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const*)b); + } + else { + a_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const *)a); + b_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const *)b); a += 8, b += 8, n -= 8; } ab_vec = vbfdotq_f32(ab_vec, a_vec, b_vec); - if (n) - goto simsimd_dot_bf16_neon_cycle; + if (n) goto simsimd_dot_bf16_neon_cycle; *result = vaddvq_f32(ab_vec); } -SIMSIMD_PUBLIC void simsimd_dot_bf16c_neon(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* results) { +SIMSIMD_PUBLIC void simsimd_dot_bf16c_neon(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *results) { // A nicer approach is to use `bf16` arithmetic for the dot product, but that requires // FMLA extensions available on Arm v8.3 and later. That we can also process 16 entries @@ -567,8 +562,8 @@ SIMSIMD_PUBLIC void simsimd_dot_bf16c_neon(simsimd_bf16_t const* a, simsimd_bf16 // Unpack the input arrays into real and imaginary parts. // MSVC sadly doesn't recognize the `vld2_bf16`, so we load the data as signed // integers of the same size and reinterpret with `vreinterpret_bf16_s16` afterwards. - int16x4x2_t a_vec = vld2_s16((short*)a); - int16x4x2_t b_vec = vld2_s16((short*)b); + int16x4x2_t a_vec = vld2_s16((short *)a); + int16x4x2_t b_vec = vld2_s16((short *)b); float32x4_t a_real_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(a_vec.val[0])); float32x4_t a_imag_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(a_vec.val[1])); float32x4_t b_real_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(b_vec.val[0])); @@ -589,8 +584,8 @@ SIMSIMD_PUBLIC void simsimd_dot_bf16c_neon(simsimd_bf16_t const* a, simsimd_bf16 results[1] += vaddvq_f32(ab_imag_vec); } -SIMSIMD_PUBLIC void simsimd_vdot_bf16c_neon(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, // - simsimd_distance_t* results) { +SIMSIMD_PUBLIC void simsimd_vdot_bf16c_neon(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, // + simsimd_distance_t *results) { // A nicer approach is to use `bf16` arithmetic for the dot product, but that requires // FMLA extensions available on Arm v8.3 and later. That we can also process 16 entries @@ -602,8 +597,8 @@ SIMSIMD_PUBLIC void simsimd_vdot_bf16c_neon(simsimd_bf16_t const* a, simsimd_bf1 // Unpack the input arrays into real and imaginary parts. // MSVC sadly doesn't recognize the `vld2_bf16`, so we load the data as signed // integers of the same size and reinterpret with `vreinterpret_bf16_s16` afterwards. - int16x4x2_t a_vec = vld2_s16((short*)a); - int16x4x2_t b_vec = vld2_s16((short*)b); + int16x4x2_t a_vec = vld2_s16((short *)a); + int16x4x2_t b_vec = vld2_s16((short *)b); float32x4_t a_real_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(a_vec.val[0])); float32x4_t a_imag_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(a_vec.val[1])); float32x4_t b_real_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(b_vec.val[0])); @@ -634,8 +629,8 @@ SIMSIMD_PUBLIC void simsimd_vdot_bf16c_neon(simsimd_bf16_t const* a, simsimd_bf1 #pragma GCC target("arch=armv8.2-a+sve") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve"))), apply_to = function) -SIMSIMD_PUBLIC void simsimd_dot_f32_sve(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_dot_f32_sve(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_size_t i = 0; svfloat32_t ab_vec = svdup_f32(0.f); do { @@ -648,8 +643,8 @@ SIMSIMD_PUBLIC void simsimd_dot_f32_sve(simsimd_f32_t const* a, simsimd_f32_t co *result = svaddv_f32(svptrue_b32(), ab_vec); } -SIMSIMD_PUBLIC void simsimd_dot_f32c_sve(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* results) { +SIMSIMD_PUBLIC void simsimd_dot_f32c_sve(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *results) { simsimd_size_t i = 0; svfloat32_t ab_real_vec = svdup_f32(0.f); svfloat32_t ab_imag_vec = svdup_f32(0.f); @@ -671,8 +666,8 @@ SIMSIMD_PUBLIC void simsimd_dot_f32c_sve(simsimd_f32_t const* a, simsimd_f32_t c results[1] = svaddv_f32(svptrue_b32(), ab_imag_vec); } -SIMSIMD_PUBLIC void simsimd_vdot_f32c_sve(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* results) { +SIMSIMD_PUBLIC void simsimd_vdot_f32c_sve(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *results) { simsimd_size_t i = 0; svfloat32_t ab_real_vec = svdup_f32(0.f); svfloat32_t ab_imag_vec = svdup_f32(0.f); @@ -694,8 +689,8 @@ SIMSIMD_PUBLIC void simsimd_vdot_f32c_sve(simsimd_f32_t const* a, simsimd_f32_t results[1] = svaddv_f32(svptrue_b32(), ab_imag_vec); } -SIMSIMD_PUBLIC void simsimd_dot_f64_sve(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_dot_f64_sve(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_size_t i = 0; svfloat64_t ab_vec = svdup_f64(0.); do { @@ -708,8 +703,8 @@ SIMSIMD_PUBLIC void simsimd_dot_f64_sve(simsimd_f64_t const* a, simsimd_f64_t co *result = svaddv_f64(svptrue_b32(), ab_vec); } -SIMSIMD_PUBLIC void simsimd_dot_f64c_sve(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* results) { +SIMSIMD_PUBLIC void simsimd_dot_f64c_sve(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *results) { simsimd_size_t i = 0; svfloat64_t ab_real_vec = svdup_f64(0.); svfloat64_t ab_imag_vec = svdup_f64(0.); @@ -731,8 +726,8 @@ SIMSIMD_PUBLIC void simsimd_dot_f64c_sve(simsimd_f64_t const* a, simsimd_f64_t c results[1] = svaddv_f64(svptrue_b64(), ab_imag_vec); } -SIMSIMD_PUBLIC void simsimd_vdot_f64c_sve(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* results) { +SIMSIMD_PUBLIC void simsimd_vdot_f64c_sve(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *results) { simsimd_size_t i = 0; svfloat64_t ab_real_vec = svdup_f64(0.); svfloat64_t ab_imag_vec = svdup_f64(0.); @@ -761,12 +756,12 @@ SIMSIMD_PUBLIC void simsimd_vdot_f64c_sve(simsimd_f64_t const* a, simsimd_f64_t #pragma GCC target("arch=armv8.2-a+sve+fp16") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve+fp16"))), apply_to = function) -SIMSIMD_PUBLIC void simsimd_dot_f16_sve(simsimd_f16_t const* a_enum, simsimd_f16_t const* b_enum, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_dot_f16_sve(simsimd_f16_t const *a_enum, simsimd_f16_t const *b_enum, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_size_t i = 0; svfloat16_t ab_vec = svdup_f16(0); - simsimd_f16_for_arm_simd_t const* a = (simsimd_f16_for_arm_simd_t const*)(a_enum); - simsimd_f16_for_arm_simd_t const* b = (simsimd_f16_for_arm_simd_t const*)(b_enum); + simsimd_f16_for_arm_simd_t const *a = (simsimd_f16_for_arm_simd_t const *)(a_enum); + simsimd_f16_for_arm_simd_t const *b = (simsimd_f16_for_arm_simd_t const *)(b_enum); do { svbool_t pg_vec = svwhilelt_b16((unsigned int)i, (unsigned int)n); svfloat16_t a_vec = svld1_f16(pg_vec, a + i); @@ -778,15 +773,15 @@ SIMSIMD_PUBLIC void simsimd_dot_f16_sve(simsimd_f16_t const* a_enum, simsimd_f16 *result = ab; } -SIMSIMD_PUBLIC void simsimd_dot_f16c_sve(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* results) { +SIMSIMD_PUBLIC void simsimd_dot_f16c_sve(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *results) { simsimd_size_t i = 0; svfloat16_t ab_real_vec = svdup_f16(0); svfloat16_t ab_imag_vec = svdup_f16(0); do { svbool_t pg_vec = svwhilelt_b16((unsigned int)i / 2, (unsigned int)n / 2); - svfloat16x2_t a_vec = svld2_f16(pg_vec, (simsimd_f16_for_arm_simd_t const*)a + i); - svfloat16x2_t b_vec = svld2_f16(pg_vec, (simsimd_f16_for_arm_simd_t const*)b + i); + svfloat16x2_t a_vec = svld2_f16(pg_vec, (simsimd_f16_for_arm_simd_t const *)a + i); + svfloat16x2_t b_vec = svld2_f16(pg_vec, (simsimd_f16_for_arm_simd_t const *)b + i); svfloat16_t a_real_vec = svget2_f16(a_vec, 0); svfloat16_t a_imag_vec = svget2_f16(a_vec, 1); svfloat16_t b_real_vec = svget2_f16(b_vec, 0); @@ -801,15 +796,15 @@ SIMSIMD_PUBLIC void simsimd_dot_f16c_sve(simsimd_f16_t const* a, simsimd_f16_t c results[1] = svaddv_f16(svptrue_b16(), ab_imag_vec); } -SIMSIMD_PUBLIC void simsimd_vdot_f16c_sve(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* results) { +SIMSIMD_PUBLIC void simsimd_vdot_f16c_sve(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *results) { simsimd_size_t i = 0; svfloat16_t ab_real_vec = svdup_f16(0); svfloat16_t ab_imag_vec = svdup_f16(0); do { svbool_t pg_vec = svwhilelt_b16((unsigned int)i / 2, (unsigned int)n / 2); - svfloat16x2_t a_vec = svld2_f16(pg_vec, (simsimd_f16_for_arm_simd_t const*)a + i); - svfloat16x2_t b_vec = svld2_f16(pg_vec, (simsimd_f16_for_arm_simd_t const*)b + i); + svfloat16x2_t a_vec = svld2_f16(pg_vec, (simsimd_f16_for_arm_simd_t const *)a + i); + svfloat16x2_t b_vec = svld2_f16(pg_vec, (simsimd_f16_for_arm_simd_t const *)b + i); svfloat16_t a_real_vec = svget2_f16(a_vec, 0); svfloat16_t a_imag_vec = svget2_f16(a_vec, 1); svfloat16_t b_real_vec = svget2_f16(b_vec, 0); @@ -827,14 +822,28 @@ SIMSIMD_PUBLIC void simsimd_vdot_f16c_sve(simsimd_f16_t const* a, simsimd_f16_t #pragma clang attribute pop #pragma GCC pop_options #endif // SIMSIMD_TARGET_SVE -#endif // SIMSIMD_TARGET_ARM +#endif // _SIMSIMD_TARGET_ARM -#if SIMSIMD_TARGET_X86 +#if _SIMSIMD_TARGET_X86 #if SIMSIMD_TARGET_HASWELL #pragma GCC push_options #pragma GCC target("avx2", "f16c", "fma") #pragma clang attribute push(__attribute__((target("avx2,f16c,fma"))), apply_to = function) +SIMSIMD_INTERNAL simsimd_f64_t _simsimd_reduce_f64x4_haswell(__m256d vec) { + // Reduce the double-precision vector to a scalar + // Horizontal add the first and second double-precision values, and third and fourth + __m128d vec_low = _mm256_castpd256_pd128(vec); + __m128d vec_high = _mm256_extractf128_pd(vec, 1); + __m128d vec128 = _mm_add_pd(vec_low, vec_high); + + // Horizontal add again to accumulate all four values into one + vec128 = _mm_hadd_pd(vec128, vec128); + + // Convert the final sum to a scalar double-precision value and return + return _mm_cvtsd_f64(vec128); +} + SIMSIMD_INTERNAL simsimd_f64_t _simsimd_reduce_f32x8_haswell(__m256 vec) { // Convert the lower and higher 128-bit lanes of the input vector to double precision __m128 low_f32 = _mm256_castps256_ps128(vec); @@ -846,18 +855,7 @@ SIMSIMD_INTERNAL simsimd_f64_t _simsimd_reduce_f32x8_haswell(__m256 vec) { // Perform the addition in double-precision __m256d sum = _mm256_add_pd(low_f64, high_f64); - - // Reduce the double-precision vector to a scalar - // Horizontal add the first and second double-precision values, and third and fourth - __m128d sum_low = _mm256_castpd256_pd128(sum); - __m128d sum_high = _mm256_extractf128_pd(sum, 1); - __m128d sum128 = _mm_add_pd(sum_low, sum_high); - - // Horizontal add again to accumulate all four values into one - sum128 = _mm_hadd_pd(sum128, sum128); - - // Convert the final sum to a scalar double-precision value and return - return _mm_cvtsd_f64(sum128); + return _simsimd_reduce_f64x4_haswell(sum); } SIMSIMD_INTERNAL simsimd_i32_t _simsimd_reduce_i32x8_haswell(__m256i vec) { @@ -869,8 +867,8 @@ SIMSIMD_INTERNAL simsimd_i32_t _simsimd_reduce_i32x8_haswell(__m256i vec) { return _mm_cvtsi128_si32(sum); } -SIMSIMD_PUBLIC void simsimd_dot_f32_haswell(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* results) { +SIMSIMD_PUBLIC void simsimd_dot_f32_haswell(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *results) { __m256 ab_vec = _mm256_setzero_ps(); simsimd_size_t i = 0; @@ -880,13 +878,12 @@ SIMSIMD_PUBLIC void simsimd_dot_f32_haswell(simsimd_f32_t const* a, simsimd_f32_ ab_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_vec); } simsimd_f64_t ab = _simsimd_reduce_f32x8_haswell(ab_vec); - for (; i < n; ++i) - ab += a[i] * b[i]; + for (; i < n; ++i) ab += a[i] * b[i]; *results = ab; } -SIMSIMD_PUBLIC void simsimd_dot_f32c_haswell(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* results) { +SIMSIMD_PUBLIC void simsimd_dot_f32c_haswell(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *results) { // The naive approach would be to use FMA and FMS instructions on different parts of the vectors. // Prior to that we would need to shuffle the input vectors to separate real and imaginary parts. @@ -955,8 +952,8 @@ SIMSIMD_PUBLIC void simsimd_dot_f32c_haswell(simsimd_f32_t const* a, simsimd_f32 results[1] = ab_imag; } -SIMSIMD_PUBLIC void simsimd_vdot_f32c_haswell(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* results) { +SIMSIMD_PUBLIC void simsimd_vdot_f32c_haswell(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *results) { __m256 ab_real_vec = _mm256_setzero_ps(); __m256 ab_imag_vec = _mm256_setzero_ps(); @@ -997,7 +994,7 @@ SIMSIMD_PUBLIC void simsimd_vdot_f32c_haswell(simsimd_f32_t const* a, simsimd_f3 results[1] = ab_imag; } -SIMSIMD_INTERNAL __m256 _simsimd_partial_load_f16x8_haswell(simsimd_f16_t const* a, simsimd_size_t n) { +SIMSIMD_INTERNAL __m256 _simsimd_partial_load_f16x8_haswell(simsimd_f16_t const *a, simsimd_size_t n) { // In case the software emulation for `f16` scalars is enabled, the `simsimd_f16_to_f32` // function will run. It is extremely slow, so even for the tail, let's combine serial // loads and stores with vectorized math. @@ -1006,15 +1003,13 @@ SIMSIMD_INTERNAL __m256 _simsimd_partial_load_f16x8_haswell(simsimd_f16_t const* simsimd_f16_t scalars[8]; } result; simsimd_size_t i = 0; - for (; i < n; ++i) - result.scalars[i] = a[i]; - for (; i < 8; ++i) - result.scalars[i] = 0; + for (; i < n; ++i) result.scalars[i] = a[i]; + for (; i < 8; ++i) result.scalars[i] = 0; return _mm256_cvtph_ps(result.vec); } -SIMSIMD_PUBLIC void simsimd_dot_f16_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_dot_f16_haswell(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m256 a_vec, b_vec; __m256 ab_vec = _mm256_setzero_ps(); @@ -1023,9 +1018,10 @@ SIMSIMD_PUBLIC void simsimd_dot_f16_haswell(simsimd_f16_t const* a, simsimd_f16_ a_vec = _simsimd_partial_load_f16x8_haswell(a, n); b_vec = _simsimd_partial_load_f16x8_haswell(b, n); n = 0; - } else { - a_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const*)a)); - b_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const*)b)); + } + else { + a_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)a)); + b_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)b)); n -= 8, a += 8, b += 8; } // We can silence the NaNs using blends: @@ -1035,14 +1031,13 @@ SIMSIMD_PUBLIC void simsimd_dot_f16_haswell(simsimd_f16_t const* a, simsimd_f16_ // ab_vec = _mm256_blendv_ps(_mm256_fmadd_ps(a_vec, b_vec, ab_vec), ab_vec, _mm256_or_ps(a_is_nan, b_is_nan)); // ab_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_vec); - if (n) - goto simsimd_dot_f16_haswell_cycle; + if (n) goto simsimd_dot_f16_haswell_cycle; *result = _simsimd_reduce_f32x8_haswell(ab_vec); } -SIMSIMD_PUBLIC void simsimd_dot_f16c_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* results) { +SIMSIMD_PUBLIC void simsimd_dot_f16c_haswell(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *results) { // Ideally the implementation would load 256 bits worth of vector data at a time, // shuffle those within a register, split in halfs, and only then upcast. // That way, we are stepping through 32x 16-bit vector components at a time, or 16 dimensions. @@ -1065,8 +1060,8 @@ SIMSIMD_PUBLIC void simsimd_dot_f16c_haswell(simsimd_f16_t const* a, simsimd_f16 ); while (n >= 8) { - __m256 a_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const*)a)); - __m256 b_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const*)b)); + __m256 a_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)a)); + __m256 b_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)b)); __m256 b_swapped_vec = _mm256_castsi256_ps(_mm256_shuffle_epi8(_mm256_castps_si256(b_vec), swap_adjacent_vec)); ab_real_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_real_vec); ab_imag_vec = _mm256_fmadd_ps(a_vec, b_swapped_vec, ab_imag_vec); @@ -1083,8 +1078,8 @@ SIMSIMD_PUBLIC void simsimd_dot_f16c_haswell(simsimd_f16_t const* a, simsimd_f16 results[1] += _simsimd_reduce_f32x8_haswell(ab_imag_vec); } -SIMSIMD_PUBLIC void simsimd_vdot_f16c_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* results) { +SIMSIMD_PUBLIC void simsimd_vdot_f16c_haswell(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *results) { __m256 ab_real_vec = _mm256_setzero_ps(); __m256 ab_imag_vec = _mm256_setzero_ps(); @@ -1101,8 +1096,8 @@ SIMSIMD_PUBLIC void simsimd_vdot_f16c_haswell(simsimd_f16_t const* a, simsimd_f1 ); while (n >= 8) { - __m256 a_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const*)a)); - __m256 b_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const*)b)); + __m256 a_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)a)); + __m256 b_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)b)); ab_real_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_real_vec); b_vec = _mm256_castsi256_ps(_mm256_shuffle_epi8(_mm256_castps_si256(b_vec), swap_adjacent_vec)); ab_imag_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_imag_vec); @@ -1119,8 +1114,8 @@ SIMSIMD_PUBLIC void simsimd_vdot_f16c_haswell(simsimd_f16_t const* a, simsimd_f1 results[1] += _simsimd_reduce_f32x8_haswell(ab_imag_vec); } -SIMSIMD_PUBLIC void simsimd_dot_i8_haswell(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_dot_i8_haswell(simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m256i ab_i32_low_vec = _mm256_setzero_si256(); __m256i ab_i32_high_vec = _mm256_setzero_si256(); @@ -1139,8 +1134,8 @@ SIMSIMD_PUBLIC void simsimd_dot_i8_haswell(simsimd_i8_t const* a, simsimd_i8_t c // This can easily lead to noticeable numerical errors in the final result. simsimd_size_t i = 0; for (; i + 32 <= n; i += 32) { - __m256i a_i8_vec = _mm256_lddqu_si256((__m256i const*)(a + i)); - __m256i b_i8_vec = _mm256_lddqu_si256((__m256i const*)(b + i)); + __m256i a_i8_vec = _mm256_lddqu_si256((__m256i const *)(a + i)); + __m256i b_i8_vec = _mm256_lddqu_si256((__m256i const *)(b + i)); // Upcast `int8` to `int16` __m256i a_i16_low_vec = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(a_i8_vec, 0)); @@ -1157,13 +1152,12 @@ SIMSIMD_PUBLIC void simsimd_dot_i8_haswell(simsimd_i8_t const* a, simsimd_i8_t c int ab = _simsimd_reduce_i32x8_haswell(_mm256_add_epi32(ab_i32_low_vec, ab_i32_high_vec)); // Take care of the tail: - for (; i < n; ++i) - ab += (int)(a[i]) * b[i]; + for (; i < n; ++i) ab += (int)(a[i]) * b[i]; *result = ab; } -SIMSIMD_PUBLIC void simsimd_dot_u8_haswell(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_dot_u8_haswell(simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m256i ab_i32_low_vec = _mm256_setzero_si256(); __m256i ab_i32_high_vec = _mm256_setzero_si256(); @@ -1173,8 +1167,8 @@ SIMSIMD_PUBLIC void simsimd_dot_u8_haswell(simsimd_u8_t const* a, simsimd_u8_t c // but it has a weird instruction for mixed signed-unsigned 8-bit dot-product. simsimd_size_t i = 0; for (; i + 32 <= n; i += 32) { - __m256i a_u8_vec = _mm256_lddqu_si256((__m256i const*)(a + i)); - __m256i b_u8_vec = _mm256_lddqu_si256((__m256i const*)(b + i)); + __m256i a_u8_vec = _mm256_lddqu_si256((__m256i const *)(a + i)); + __m256i b_u8_vec = _mm256_lddqu_si256((__m256i const *)(b + i)); // Upcast `uint8` to `int16`. Unlike the signed version, we can use the unpacking // instructions instead of extracts, as they are much faster and more efficient. @@ -1192,8 +1186,7 @@ SIMSIMD_PUBLIC void simsimd_dot_u8_haswell(simsimd_u8_t const* a, simsimd_u8_t c int ab = _simsimd_reduce_i32x8_haswell(_mm256_add_epi32(ab_i32_low_vec, ab_i32_high_vec)); // Take care of the tail: - for (; i < n; ++i) - ab += (int)(a[i]) * b[i]; + for (; i < n; ++i) ab += (int)(a[i]) * b[i]; *result = ab; } @@ -1202,7 +1195,23 @@ SIMSIMD_INTERNAL __m256 _simsimd_bf16x8_to_f32x8_haswell(__m128i a) { return _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(a), 16)); } -SIMSIMD_INTERNAL __m128i _simsimd_partial_load_bf16x8_haswell(simsimd_bf16_t const* a, simsimd_size_t n) { +SIMSIMD_INTERNAL __m128i _simsimd_f32x8_to_bf16x8_haswell(__m256 a) { + // Pack the 32-bit integers into 16-bit integers. + // This is less trivial than unpacking: https://stackoverflow.com/a/77781241/2766161 + // The best approach is to shuffle within lanes first: https://stackoverflow.com/a/49723746/2766161 + // Our shuffling mask will drop the low 2-bytes from every 4-byte word. + __m256i trunc_elements = _mm256_shuffle_epi8( // + _mm256_castps_si256(a), // + _mm256_set_epi8( // + -1, -1, -1, -1, -1, -1, -1, -1, 15, 14, 11, 10, 7, 6, 3, 2, // + -1, -1, -1, -1, -1, -1, -1, -1, 15, 14, 11, 10, 7, 6, 3, 2 // + )); + __m256i ordered = _mm256_permute4x64_epi64(trunc_elements, 0x58); + __m128i result = _mm256_castsi256_si128(ordered); + return result; +} + +SIMSIMD_INTERNAL __m128i _simsimd_partial_load_bf16x8_haswell(simsimd_bf16_t const *a, simsimd_size_t n) { // In case the software emulation for `bf16` scalars is enabled, the `simsimd_bf16_to_f32` // function will run. It is extremely slow, so even for the tail, let's combine serial // loads and stores with vectorized math. @@ -1211,15 +1220,13 @@ SIMSIMD_INTERNAL __m128i _simsimd_partial_load_bf16x8_haswell(simsimd_bf16_t con simsimd_bf16_t scalars[8]; } result; simsimd_size_t i = 0; - for (; i < n; ++i) - result.scalars[i] = a[i]; - for (; i < 8; ++i) - result.scalars[i] = 0; + for (; i < n; ++i) result.scalars[i] = a[i]; + for (; i < 8; ++i) result.scalars[i] = 0; return result.vec; } -SIMSIMD_PUBLIC void simsimd_dot_bf16_haswell(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_dot_bf16_haswell(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m128i a_vec, b_vec; __m256 ab_vec = _mm256_setzero_ps(); @@ -1228,14 +1235,14 @@ SIMSIMD_PUBLIC void simsimd_dot_bf16_haswell(simsimd_bf16_t const* a, simsimd_bf a_vec = _simsimd_partial_load_bf16x8_haswell(a, n); b_vec = _simsimd_partial_load_bf16x8_haswell(b, n); n = 0; - } else { - a_vec = _mm_lddqu_si128((__m128i const*)a); - b_vec = _mm_lddqu_si128((__m128i const*)b); + } + else { + a_vec = _mm_lddqu_si128((__m128i const *)a); + b_vec = _mm_lddqu_si128((__m128i const *)b); a += 8, b += 8, n -= 8; } ab_vec = _mm256_fmadd_ps(_simsimd_bf16x8_to_f32x8_haswell(a_vec), _simsimd_bf16x8_to_f32x8_haswell(b_vec), ab_vec); - if (n) - goto simsimd_dot_bf16_haswell_cycle; + if (n) goto simsimd_dot_bf16_haswell_cycle; *result = _simsimd_reduce_f32x8_haswell(ab_vec); } @@ -1256,8 +1263,19 @@ SIMSIMD_INTERNAL simsimd_f64_t _simsimd_reduce_f32x16_skylake(__m512 a) { return _mm_cvtss_f32(_mm_hadd_ps(r, r)); } -SIMSIMD_PUBLIC void simsimd_dot_f32_skylake(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_INTERNAL __m512 _simsimd_bf16x16_to_f32x16_skylake(__m256i a) { + // Upcasting from `bf16` to `f32` is done by shifting the `bf16` values by 16 bits to the left, like: + return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16)); +} + +SIMSIMD_INTERNAL __m256i _simsimd_f32x16_to_bf16x16_skylake(__m512 a) { + // Add 2^15 and right shift 16 to do round-nearest + __m512i x = _mm512_srli_epi32(_mm512_add_epi32(_mm512_castps_si512(a), _mm512_set1_epi32(1 << 15)), 16); + return _mm512_cvtepi32_epi16(x); +} + +SIMSIMD_PUBLIC void simsimd_dot_f32_skylake(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m512 ab_vec = _mm512_setzero(); __m512 a_vec, b_vec; @@ -1267,20 +1285,20 @@ SIMSIMD_PUBLIC void simsimd_dot_f32_skylake(simsimd_f32_t const* a, simsimd_f32_ a_vec = _mm512_maskz_loadu_ps(mask, a); b_vec = _mm512_maskz_loadu_ps(mask, b); n = 0; - } else { + } + else { a_vec = _mm512_loadu_ps(a); b_vec = _mm512_loadu_ps(b); a += 16, b += 16, n -= 16; } ab_vec = _mm512_fmadd_ps(a_vec, b_vec, ab_vec); - if (n) - goto simsimd_dot_f32_skylake_cycle; + if (n) goto simsimd_dot_f32_skylake_cycle; *result = _simsimd_reduce_f32x16_skylake(ab_vec); } -SIMSIMD_PUBLIC void simsimd_dot_f64_skylake(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_dot_f64_skylake(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m512d ab_vec = _mm512_setzero_pd(); __m512d a_vec, b_vec; @@ -1290,20 +1308,20 @@ SIMSIMD_PUBLIC void simsimd_dot_f64_skylake(simsimd_f64_t const* a, simsimd_f64_ a_vec = _mm512_maskz_loadu_pd(mask, a); b_vec = _mm512_maskz_loadu_pd(mask, b); n = 0; - } else { + } + else { a_vec = _mm512_loadu_pd(a); b_vec = _mm512_loadu_pd(b); a += 8, b += 8, n -= 8; } ab_vec = _mm512_fmadd_pd(a_vec, b_vec, ab_vec); - if (n) - goto simsimd_dot_f64_skylake_cycle; + if (n) goto simsimd_dot_f64_skylake_cycle; *result = _mm512_reduce_add_pd(ab_vec); } -SIMSIMD_PUBLIC void simsimd_dot_f32c_skylake(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* results) { +SIMSIMD_PUBLIC void simsimd_dot_f32c_skylake(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *results) { __m512 ab_real_vec = _mm512_setzero(); __m512 ab_imag_vec = _mm512_setzero(); __m512 a_vec; @@ -1327,7 +1345,8 @@ SIMSIMD_PUBLIC void simsimd_dot_f32c_skylake(simsimd_f32_t const* a, simsimd_f32 a_vec = _mm512_maskz_loadu_ps(mask, a); b_vec = _mm512_maskz_loadu_ps(mask, b); n = 0; - } else { + } + else { a_vec = _mm512_loadu_ps(a); b_vec = _mm512_loadu_ps(b); a += 16, b += 16, n -= 16; @@ -1335,8 +1354,7 @@ SIMSIMD_PUBLIC void simsimd_dot_f32c_skylake(simsimd_f32_t const* a, simsimd_f32 ab_real_vec = _mm512_fmadd_ps(b_vec, a_vec, ab_real_vec); ab_imag_vec = _mm512_fmadd_ps( _mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(b_vec), swap_adjacent_vec)), a_vec, ab_imag_vec); - if (n) - goto simsimd_dot_f32c_skylake_cycle; + if (n) goto simsimd_dot_f32c_skylake_cycle; // Flip the sign bit in every second scalar before accumulation: ab_real_vec = _mm512_castsi512_ps(_mm512_xor_si512(_mm512_castps_si512(ab_real_vec), sign_flip_vec)); @@ -1346,8 +1364,8 @@ SIMSIMD_PUBLIC void simsimd_dot_f32c_skylake(simsimd_f32_t const* a, simsimd_f32 results[1] = _simsimd_reduce_f32x16_skylake(ab_imag_vec); } -SIMSIMD_PUBLIC void simsimd_vdot_f32c_skylake(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* results) { +SIMSIMD_PUBLIC void simsimd_vdot_f32c_skylake(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *results) { __m512 ab_real_vec = _mm512_setzero(); __m512 ab_imag_vec = _mm512_setzero(); __m512 a_vec; @@ -1371,7 +1389,8 @@ SIMSIMD_PUBLIC void simsimd_vdot_f32c_skylake(simsimd_f32_t const* a, simsimd_f3 a_vec = _mm512_maskz_loadu_ps(mask, a); b_vec = _mm512_maskz_loadu_ps(mask, b); n = 0; - } else { + } + else { a_vec = _mm512_loadu_ps(a); b_vec = _mm512_loadu_ps(b); a += 16, b += 16, n -= 16; @@ -1379,8 +1398,7 @@ SIMSIMD_PUBLIC void simsimd_vdot_f32c_skylake(simsimd_f32_t const* a, simsimd_f3 ab_real_vec = _mm512_fmadd_ps(a_vec, b_vec, ab_real_vec); b_vec = _mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(b_vec), swap_adjacent_vec)); ab_imag_vec = _mm512_fmadd_ps(a_vec, b_vec, ab_imag_vec); - if (n) - goto simsimd_vdot_f32c_skylake_cycle; + if (n) goto simsimd_vdot_f32c_skylake_cycle; // Flip the sign bit in every second scalar before accumulation: ab_imag_vec = _mm512_castsi512_ps(_mm512_xor_si512(_mm512_castps_si512(ab_imag_vec), sign_flip_vec)); @@ -1390,8 +1408,8 @@ SIMSIMD_PUBLIC void simsimd_vdot_f32c_skylake(simsimd_f32_t const* a, simsimd_f3 results[1] = _simsimd_reduce_f32x16_skylake(ab_imag_vec); } -SIMSIMD_PUBLIC void simsimd_dot_f64c_skylake(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* results) { +SIMSIMD_PUBLIC void simsimd_dot_f64c_skylake(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *results) { __m512d ab_real_vec = _mm512_setzero_pd(); __m512d ab_imag_vec = _mm512_setzero_pd(); __m512d a_vec; @@ -1418,7 +1436,8 @@ SIMSIMD_PUBLIC void simsimd_dot_f64c_skylake(simsimd_f64_t const* a, simsimd_f64 a_vec = _mm512_maskz_loadu_pd(mask, a); b_vec = _mm512_maskz_loadu_pd(mask, b); n = 0; - } else { + } + else { a_vec = _mm512_loadu_pd(a); b_vec = _mm512_loadu_pd(b); a += 8, b += 8, n -= 8; @@ -1426,8 +1445,7 @@ SIMSIMD_PUBLIC void simsimd_dot_f64c_skylake(simsimd_f64_t const* a, simsimd_f64 ab_real_vec = _mm512_fmadd_pd(b_vec, a_vec, ab_real_vec); ab_imag_vec = _mm512_fmadd_pd( _mm512_castsi512_pd(_mm512_shuffle_epi8(_mm512_castpd_si512(b_vec), swap_adjacent_vec)), a_vec, ab_imag_vec); - if (n) - goto simsimd_dot_f64c_skylake_cycle; + if (n) goto simsimd_dot_f64c_skylake_cycle; // Flip the sign bit in every second scalar before accumulation: ab_real_vec = _mm512_castsi512_pd(_mm512_xor_si512(_mm512_castpd_si512(ab_real_vec), sign_flip_vec)); @@ -1437,8 +1455,8 @@ SIMSIMD_PUBLIC void simsimd_dot_f64c_skylake(simsimd_f64_t const* a, simsimd_f64 results[1] = _mm512_reduce_add_pd(ab_imag_vec); } -SIMSIMD_PUBLIC void simsimd_vdot_f64c_skylake(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* results) { +SIMSIMD_PUBLIC void simsimd_vdot_f64c_skylake(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *results) { __m512d ab_real_vec = _mm512_setzero_pd(); __m512d ab_imag_vec = _mm512_setzero_pd(); __m512d a_vec; @@ -1465,7 +1483,8 @@ SIMSIMD_PUBLIC void simsimd_vdot_f64c_skylake(simsimd_f64_t const* a, simsimd_f6 a_vec = _mm512_maskz_loadu_pd(mask, a); b_vec = _mm512_maskz_loadu_pd(mask, b); n = 0; - } else { + } + else { a_vec = _mm512_loadu_pd(a); b_vec = _mm512_loadu_pd(b); a += 8, b += 8, n -= 8; @@ -1473,8 +1492,7 @@ SIMSIMD_PUBLIC void simsimd_vdot_f64c_skylake(simsimd_f64_t const* a, simsimd_f6 ab_real_vec = _mm512_fmadd_pd(a_vec, b_vec, ab_real_vec); b_vec = _mm512_castsi512_pd(_mm512_shuffle_epi8(_mm512_castpd_si512(b_vec), swap_adjacent_vec)); ab_imag_vec = _mm512_fmadd_pd(a_vec, b_vec, ab_imag_vec); - if (n) - goto simsimd_vdot_f64c_skylake_cycle; + if (n) goto simsimd_vdot_f64c_skylake_cycle; // Flip the sign bit in every second scalar before accumulation: ab_imag_vec = _mm512_castsi512_pd(_mm512_xor_si512(_mm512_castpd_si512(ab_imag_vec), sign_flip_vec)); @@ -1491,11 +1509,11 @@ SIMSIMD_PUBLIC void simsimd_vdot_f64c_skylake(simsimd_f64_t const* a, simsimd_f6 #if SIMSIMD_TARGET_GENOA #pragma GCC push_options #pragma GCC target("avx2", "avx512f", "avx512vl", "bmi2", "avx512bw", "avx512bf16") -#pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,bmi2,avx512bw,avx512bf16"))), \ +#pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,bmi2,avx512bw,avx512bf16"))), \ apply_to = function) -SIMSIMD_PUBLIC void simsimd_dot_bf16_genoa(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_dot_bf16_genoa(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m512 ab_vec = _mm512_setzero_ps(); __m512i a_i16_vec, b_i16_vec; @@ -1505,20 +1523,20 @@ SIMSIMD_PUBLIC void simsimd_dot_bf16_genoa(simsimd_bf16_t const* a, simsimd_bf16 a_i16_vec = _mm512_maskz_loadu_epi16(mask, a); b_i16_vec = _mm512_maskz_loadu_epi16(mask, b); n = 0; - } else { + } + else { a_i16_vec = _mm512_loadu_epi16(a); b_i16_vec = _mm512_loadu_epi16(b); a += 32, b += 32, n -= 32; } ab_vec = _mm512_dpbf16_ps(ab_vec, (__m512bh)(a_i16_vec), (__m512bh)(b_i16_vec)); - if (n) - goto simsimd_dot_bf16_genoa_cycle; + if (n) goto simsimd_dot_bf16_genoa_cycle; *result = _simsimd_reduce_f32x16_skylake(ab_vec); } -SIMSIMD_PUBLIC void simsimd_dot_bf16c_genoa(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* results) { +SIMSIMD_PUBLIC void simsimd_dot_bf16c_genoa(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *results) { __m512 ab_real_vec = _mm512_setzero_ps(); __m512 ab_imag_vec = _mm512_setzero_ps(); __m512i a_vec; @@ -1543,7 +1561,8 @@ SIMSIMD_PUBLIC void simsimd_dot_bf16c_genoa(simsimd_bf16_t const* a, simsimd_bf1 a_vec = _mm512_maskz_loadu_epi16(mask, a); b_vec = _mm512_maskz_loadu_epi16(mask, b); n = 0; - } else { + } + else { a_vec = _mm512_loadu_epi16(a); b_vec = _mm512_loadu_epi16(b); a += 32, b += 32, n -= 32; @@ -1551,16 +1570,15 @@ SIMSIMD_PUBLIC void simsimd_dot_bf16c_genoa(simsimd_bf16_t const* a, simsimd_bf1 ab_real_vec = _mm512_dpbf16_ps(ab_real_vec, (__m512bh)(_mm512_xor_si512(b_vec, sign_flip_vec)), (__m512bh)(a_vec)); ab_imag_vec = _mm512_dpbf16_ps(ab_imag_vec, (__m512bh)(_mm512_shuffle_epi8(b_vec, swap_adjacent_vec)), (__m512bh)(a_vec)); - if (n) - goto simsimd_dot_bf16c_genoa_cycle; + if (n) goto simsimd_dot_bf16c_genoa_cycle; // Reduce horizontal sums: results[0] = _simsimd_reduce_f32x16_skylake(ab_real_vec); results[1] = _simsimd_reduce_f32x16_skylake(ab_imag_vec); } -SIMSIMD_PUBLIC void simsimd_vdot_bf16c_genoa(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* results) { +SIMSIMD_PUBLIC void simsimd_vdot_bf16c_genoa(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *results) { __m512 ab_real_vec = _mm512_setzero_ps(); __m512 ab_imag_vec = _mm512_setzero_ps(); __m512i a_vec; @@ -1585,7 +1603,8 @@ SIMSIMD_PUBLIC void simsimd_vdot_bf16c_genoa(simsimd_bf16_t const* a, simsimd_bf a_vec = _mm512_maskz_loadu_epi16(mask, a); b_vec = _mm512_maskz_loadu_epi16(mask, b); n = 0; - } else { + } + else { a_vec = _mm512_loadu_epi16(a); b_vec = _mm512_loadu_epi16(b); a += 32, b += 32, n -= 32; @@ -1594,8 +1613,7 @@ SIMSIMD_PUBLIC void simsimd_vdot_bf16c_genoa(simsimd_bf16_t const* a, simsimd_bf a_vec = _mm512_xor_si512(a_vec, sign_flip_vec); b_vec = _mm512_shuffle_epi8(b_vec, swap_adjacent_vec); ab_imag_vec = _mm512_dpbf16_ps(ab_imag_vec, (__m512bh)(a_vec), (__m512bh)(b_vec)); - if (n) - goto simsimd_dot_bf16c_genoa_cycle; + if (n) goto simsimd_dot_bf16c_genoa_cycle; // Reduce horizontal sums: results[0] = _simsimd_reduce_f32x16_skylake(ab_real_vec); @@ -1609,11 +1627,11 @@ SIMSIMD_PUBLIC void simsimd_vdot_bf16c_genoa(simsimd_bf16_t const* a, simsimd_bf #if SIMSIMD_TARGET_SAPPHIRE #pragma GCC push_options #pragma GCC target("avx2", "avx512f", "avx512vl", "bmi2", "avx512bw", "avx512fp16") -#pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,bmi2,avx512bw,avx512fp16"))), \ +#pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,bmi2,avx512bw,avx512fp16"))), \ apply_to = function) -SIMSIMD_PUBLIC void simsimd_dot_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_dot_f16_sapphire(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m512h ab_vec = _mm512_setzero_ph(); __m512i a_i16_vec, b_i16_vec; @@ -1623,20 +1641,20 @@ SIMSIMD_PUBLIC void simsimd_dot_f16_sapphire(simsimd_f16_t const* a, simsimd_f16 a_i16_vec = _mm512_maskz_loadu_epi16(mask, a); b_i16_vec = _mm512_maskz_loadu_epi16(mask, b); n = 0; - } else { + } + else { a_i16_vec = _mm512_loadu_epi16(a); b_i16_vec = _mm512_loadu_epi16(b); a += 32, b += 32, n -= 32; } ab_vec = _mm512_fmadd_ph(_mm512_castsi512_ph(a_i16_vec), _mm512_castsi512_ph(b_i16_vec), ab_vec); - if (n) - goto simsimd_dot_f16_sapphire_cycle; + if (n) goto simsimd_dot_f16_sapphire_cycle; *result = _mm512_reduce_add_ph(ab_vec); } -SIMSIMD_PUBLIC void simsimd_dot_f16c_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* results) { +SIMSIMD_PUBLIC void simsimd_dot_f16c_sapphire(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *results) { __m512h ab_real_vec = _mm512_setzero_ph(); __m512h ab_imag_vec = _mm512_setzero_ph(); __m512i a_vec; @@ -1661,7 +1679,8 @@ SIMSIMD_PUBLIC void simsimd_dot_f16c_sapphire(simsimd_f16_t const* a, simsimd_f1 a_vec = _mm512_maskz_loadu_epi16(mask, a); b_vec = _mm512_maskz_loadu_epi16(mask, b); n = 0; - } else { + } + else { a_vec = _mm512_loadu_epi16(a); b_vec = _mm512_loadu_epi16(b); a += 32, b += 32, n -= 32; @@ -1671,8 +1690,7 @@ SIMSIMD_PUBLIC void simsimd_dot_f16c_sapphire(simsimd_f16_t const* a, simsimd_f1 _mm512_castsi512_ph(a_vec), ab_real_vec); ab_imag_vec = _mm512_fmadd_ph(_mm512_castsi512_ph(_mm512_shuffle_epi8(b_vec, swap_adjacent_vec)), _mm512_castsi512_ph(a_vec), ab_imag_vec); - if (n) - goto simsimd_dot_f16c_sapphire_cycle; + if (n) goto simsimd_dot_f16c_sapphire_cycle; // Reduce horizontal sums: // TODO: Optimize this with tree-like reductions @@ -1680,8 +1698,8 @@ SIMSIMD_PUBLIC void simsimd_dot_f16c_sapphire(simsimd_f16_t const* a, simsimd_f1 results[1] = _mm512_reduce_add_ph(ab_imag_vec); } -SIMSIMD_PUBLIC void simsimd_vdot_f16c_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* results) { +SIMSIMD_PUBLIC void simsimd_vdot_f16c_sapphire(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *results) { __m512h ab_real_vec = _mm512_setzero_ph(); __m512h ab_imag_vec = _mm512_setzero_ph(); __m512i a_vec; @@ -1706,7 +1724,8 @@ SIMSIMD_PUBLIC void simsimd_vdot_f16c_sapphire(simsimd_f16_t const* a, simsimd_f a_vec = _mm512_maskz_loadu_epi16(mask, a); b_vec = _mm512_maskz_loadu_epi16(mask, b); n = 0; - } else { + } + else { a_vec = _mm512_loadu_epi16(a); b_vec = _mm512_loadu_epi16(b); a += 32, b += 32, n -= 32; @@ -1715,8 +1734,7 @@ SIMSIMD_PUBLIC void simsimd_vdot_f16c_sapphire(simsimd_f16_t const* a, simsimd_f a_vec = _mm512_xor_si512(a_vec, sign_flip_vec); b_vec = _mm512_shuffle_epi8(b_vec, swap_adjacent_vec); ab_imag_vec = _mm512_fmadd_ph(_mm512_castsi512_ph(a_vec), _mm512_castsi512_ph(b_vec), ab_imag_vec); - if (n) - goto simsimd_dot_f16c_sapphire_cycle; + if (n) goto simsimd_dot_f16c_sapphire_cycle; // Reduce horizontal sums: results[0] = _mm512_reduce_add_ph(ab_real_vec); @@ -1730,11 +1748,11 @@ SIMSIMD_PUBLIC void simsimd_vdot_f16c_sapphire(simsimd_f16_t const* a, simsimd_f #if SIMSIMD_TARGET_ICE #pragma GCC push_options #pragma GCC target("avx2", "avx512f", "avx512vl", "bmi2", "avx512bw", "avx512vnni") -#pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,bmi2,avx512bw,avx512vnni"))), \ +#pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,bmi2,avx512bw,avx512vnni"))), \ apply_to = function) -SIMSIMD_PUBLIC void simsimd_dot_i8_ice(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_dot_i8_ice(simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m512i ab_i32_vec = _mm512_setzero_si512(); __m512i a_i16_vec, b_i16_vec; @@ -1744,9 +1762,10 @@ SIMSIMD_PUBLIC void simsimd_dot_i8_ice(simsimd_i8_t const* a, simsimd_i8_t const a_i16_vec = _mm512_cvtepi8_epi16(_mm256_maskz_loadu_epi8(mask, a)); b_i16_vec = _mm512_cvtepi8_epi16(_mm256_maskz_loadu_epi8(mask, b)); n = 0; - } else { - a_i16_vec = _mm512_cvtepi8_epi16(_mm256_lddqu_si256((__m256i const*)a)); - b_i16_vec = _mm512_cvtepi8_epi16(_mm256_lddqu_si256((__m256i const*)b)); + } + else { + a_i16_vec = _mm512_cvtepi8_epi16(_mm256_lddqu_si256((__m256i const *)a)); + b_i16_vec = _mm512_cvtepi8_epi16(_mm256_lddqu_si256((__m256i const *)b)); a += 32, b += 32, n -= 32; } // Unfortunately we can't use the `_mm512_dpbusd_epi32` intrinsics here either, @@ -1755,14 +1774,13 @@ SIMSIMD_PUBLIC void simsimd_dot_i8_ice(simsimd_i8_t const* a, simsimd_i8_t const // So we have to use the `_mm512_dpwssd_epi32` intrinsics instead, upcasting // to 16-bit beforehand. ab_i32_vec = _mm512_dpwssd_epi32(ab_i32_vec, a_i16_vec, b_i16_vec); - if (n) - goto simsimd_dot_i8_ice_cycle; + if (n) goto simsimd_dot_i8_ice_cycle; *result = _mm512_reduce_add_epi32(ab_i32_vec); } -SIMSIMD_PUBLIC void simsimd_dot_u8_ice(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_dot_u8_ice(simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m512i ab_i32_low_vec = _mm512_setzero_si512(); __m512i ab_i32_high_vec = _mm512_setzero_si512(); __m512i const zeros_vec = _mm512_setzero_si512(); @@ -1775,7 +1793,8 @@ SIMSIMD_PUBLIC void simsimd_dot_u8_ice(simsimd_u8_t const* a, simsimd_u8_t const a_u8_vec = _mm512_maskz_loadu_epi8(mask, a); b_u8_vec = _mm512_maskz_loadu_epi8(mask, b); n = 0; - } else { + } + else { a_u8_vec = _mm512_loadu_si512(a); b_u8_vec = _mm512_loadu_si512(b); a += 64, b += 64, n -= 64; @@ -1794,8 +1813,7 @@ SIMSIMD_PUBLIC void simsimd_dot_u8_ice(simsimd_u8_t const* a, simsimd_u8_t const // to 16-bit beforehand. ab_i32_low_vec = _mm512_dpwssd_epi32(ab_i32_low_vec, a_i16_low_vec, b_i16_low_vec); ab_i32_high_vec = _mm512_dpwssd_epi32(ab_i32_high_vec, a_i16_high_vec, b_i16_high_vec); - if (n) - goto simsimd_dot_u8_ice_cycle; + if (n) goto simsimd_dot_u8_ice_cycle; *result = _mm512_reduce_add_epi32(_mm512_add_epi32(ab_i32_low_vec, ab_i32_high_vec)); } @@ -1809,15 +1827,15 @@ SIMSIMD_PUBLIC void simsimd_dot_u8_ice(simsimd_u8_t const* a, simsimd_u8_t const #pragma GCC target("avx2", "bmi2", "avx2vnni") #pragma clang attribute push(__attribute__((target("avx2,bmi2,avx2vnni"))), apply_to = function) -SIMSIMD_PUBLIC void simsimd_dot_i8_sierra(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_dot_i8_sierra(simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m256i ab_i32_vec = _mm256_setzero_si256(); simsimd_size_t i = 0; for (; i + 32 <= n; i += 32) { - __m256i a_i8_vec = _mm256_lddqu_si256((__m256i const*)(a + i)); - __m256i b_i8_vec = _mm256_lddqu_si256((__m256i const*)(b + i)); + __m256i a_i8_vec = _mm256_lddqu_si256((__m256i const *)(a + i)); + __m256i b_i8_vec = _mm256_lddqu_si256((__m256i const *)(b + i)); ab_i32_vec = _mm256_dpbssds_epi32(ab_i32_vec, a_i8_vec, b_i8_vec); } @@ -1825,8 +1843,7 @@ SIMSIMD_PUBLIC void simsimd_dot_i8_sierra(simsimd_i8_t const* a, simsimd_i8_t co int ab = _simsimd_reduce_i32x8_haswell(ab_i32_vec); // Take care of the tail: - for (; i < n; ++i) - ab += (int)(a[i]) * b[i]; + for (; i < n; ++i) ab += (int)(a[i]) * b[i]; *result = ab; } @@ -1834,7 +1851,7 @@ SIMSIMD_PUBLIC void simsimd_dot_i8_sierra(simsimd_i8_t const* a, simsimd_i8_t co #pragma clang attribute pop #pragma GCC pop_options #endif // SIMSIMD_TARGET_SIERRA -#endif // SIMSIMD_TARGET_X86 +#endif // _SIMSIMD_TARGET_X86 #ifdef __cplusplus } diff --git a/include/simsimd/fma.h b/include/simsimd/fma.h index 858d0496..ab99027c 100644 --- a/include/simsimd/fma.h +++ b/include/simsimd/fma.h @@ -5,8 +5,14 @@ * @date October 16, 2024 * * Contains following element-wise operations: - * - Weighted Sum: Oq[i] = Alpha * X[i] + Beta * Z[i] - * - FMA or Fused-Multiply-Add: O[i] = Alpha * X[i] * Y[i] + Beta * Z[i] + * - WSum or Weighted-Sum: R[i] = Alpha * A[i] + Beta * B[i] + * - FMA or Fused-Multiply-Add: R[i] = Alpha * A[i] * B[i] + Beta * C[i] + * + * This tiny set of operations if enough to implement a wide range of algorithms. + * To scale a vector by a scalar, just call WSum with $Beta$ = 0. + * To sum two vectors, just call WSum with $Alpha$ = $Beta$ = 1. + * To average two vectors, just call WSum with $Alpha$ = $Beta$ = 0.5. + * To multiply vectors element-wise, just call FMA with $Beta$ = 0. * * For datatypes: * - 64-bit IEEE floating point numbers @@ -20,6 +26,11 @@ * - Arm: NEON, SVE * - x86: Haswell, Ice Lake, Skylake, Genoa, Sapphire * + * We use `f16` for `i8` and `u8` arithmetic. This is because Arm received `f16` support earlier than `bf16`. + * For example, Apple M1 has `f16` support and `bf16` was only added in M2. On the other hand, on paper, + * AMD Genoa has `bf16` support, and `f16` is only available on Intel Sapphire Rapids and newer. + * Sadly, the SIMD support for `bf16` is limited to mixed-precision dot-products, which makes it useless here. + * * x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/ * Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/ */ @@ -32,6 +43,1100 @@ extern "C" { #endif +SIMSIMD_PUBLIC void simsimd_wsum_f64_serial( // + simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f64_t *result); +SIMSIMD_PUBLIC void simsimd_wsum_f32_serial( // + simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f32_t *result); +SIMSIMD_PUBLIC void simsimd_wsum_f16_serial( // + simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f16_t *result); +SIMSIMD_PUBLIC void simsimd_wsum_bf16_serial( // + simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_bf16_t *result); +SIMSIMD_PUBLIC void simsimd_wsum_i8_serial( // + simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_i8_t *result); +SIMSIMD_PUBLIC void simsimd_wsum_u8_serial( // + simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_u8_t *result); + +SIMSIMD_PUBLIC void simsimd_fma_f64_serial( // + simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_f64_t const *c, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f64_t *result); +SIMSIMD_PUBLIC void simsimd_fma_f32_serial( // + simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_f32_t const *c, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f32_t *result); +SIMSIMD_PUBLIC void simsimd_fma_f16_serial( // + simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_f16_t const *c, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f16_t *result); +SIMSIMD_PUBLIC void simsimd_fma_bf16_serial( // + simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_bf16_t const *c, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_bf16_t *result); +SIMSIMD_PUBLIC void simsimd_fma_i8_serial( // + simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_i8_t const *c, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_i8_t *result); +SIMSIMD_PUBLIC void simsimd_fma_u8_serial( // + simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_u8_t const *c, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_u8_t *result); + +#define SIMSIMD_MAKE_WSUM(name, input_type, accumulator_type, load_and_convert, convert_and_store) \ + SIMSIMD_PUBLIC void simsimd_wsum_##input_type##_##name( \ + simsimd_##input_type##_t const *a, simsimd_##input_type##_t const *b, simsimd_size_t n, \ + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_##input_type##_t *result) { \ + for (simsimd_size_t i = 0; i != n; ++i) { \ + simsimd_##accumulator_type##_t ai = load_and_convert(a + i); \ + simsimd_##accumulator_type##_t bi = load_and_convert(b + i); \ + simsimd_##accumulator_type##_t ai_scaled = (simsimd_##accumulator_type##_t)(ai * alpha); \ + simsimd_##accumulator_type##_t bi_scaled = (simsimd_##accumulator_type##_t)(bi * beta); \ + simsimd_##accumulator_type##_t sum = ai_scaled + bi_scaled; \ + convert_and_store(sum, result + i); \ + } \ + } + +#define SIMSIMD_MAKE_FMA(name, input_type, accumulator_type, load_and_convert, convert_and_store) \ + SIMSIMD_PUBLIC void simsimd_fma_##input_type##_##name( \ + simsimd_##input_type##_t const *a, simsimd_##input_type##_t const *b, simsimd_##input_type##_t const *c, \ + simsimd_size_t n, simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_##input_type##_t *result) { \ + for (simsimd_size_t i = 0; i != n; ++i) { \ + simsimd_##accumulator_type##_t ai = load_and_convert(a + i); \ + simsimd_##accumulator_type##_t bi = load_and_convert(b + i); \ + simsimd_##accumulator_type##_t ci = load_and_convert(c + i); \ + simsimd_##accumulator_type##_t abi_scaled = (simsimd_##accumulator_type##_t)(ai * bi * alpha); \ + simsimd_##accumulator_type##_t ci_scaled = (simsimd_##accumulator_type##_t)(ci * beta); \ + simsimd_##accumulator_type##_t sum = abi_scaled + ci_scaled; \ + convert_and_store(sum, result + i); \ + } \ + } + +SIMSIMD_MAKE_WSUM(serial, f64, f64, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT) // simsimd_wsum_f64_serial +SIMSIMD_MAKE_WSUM(serial, f32, f32, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT) // simsimd_wsum_f32_serial +SIMSIMD_MAKE_WSUM(serial, f16, f32, SIMSIMD_F16_TO_F32, SIMSIMD_F32_TO_F16) // simsimd_wsum_f16_serial +SIMSIMD_MAKE_WSUM(serial, bf16, f32, SIMSIMD_BF16_TO_F32, SIMSIMD_F32_TO_BF16) // simsimd_wsum_bf16_serial +SIMSIMD_MAKE_WSUM(serial, i8, f32, SIMSIMD_DEREFERENCE, SIMSIMD_F32_TO_I8) // simsimd_wsum_i8_serial +SIMSIMD_MAKE_WSUM(serial, u8, f32, SIMSIMD_DEREFERENCE, SIMSIMD_F32_TO_U8) // simsimd_wsum_u8_serial + +SIMSIMD_MAKE_WSUM(accurate, f32, f64, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT) // simsimd_wsum_f32_accurate +SIMSIMD_MAKE_WSUM(accurate, f16, f64, SIMSIMD_F16_TO_F32, SIMSIMD_F32_TO_F16) // simsimd_wsum_f16_accurate +SIMSIMD_MAKE_WSUM(accurate, bf16, f64, SIMSIMD_BF16_TO_F32, SIMSIMD_F32_TO_BF16) // simsimd_wsum_bf16_accurate +SIMSIMD_MAKE_WSUM(accurate, i8, f64, SIMSIMD_DEREFERENCE, SIMSIMD_F64_TO_I8) // simsimd_wsum_i8_accurate +SIMSIMD_MAKE_WSUM(accurate, u8, f64, SIMSIMD_DEREFERENCE, SIMSIMD_F64_TO_U8) // simsimd_wsum_u8_accurate + +SIMSIMD_MAKE_FMA(serial, f64, f64, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT) // simsimd_fma_f64_serial +SIMSIMD_MAKE_FMA(serial, f32, f32, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT) // simsimd_fma_f32_serial +SIMSIMD_MAKE_FMA(serial, f16, f32, SIMSIMD_F16_TO_F32, SIMSIMD_F32_TO_F16) // simsimd_fma_f16_serial +SIMSIMD_MAKE_FMA(serial, bf16, f32, SIMSIMD_BF16_TO_F32, SIMSIMD_F32_TO_BF16) // simsimd_fma_bf16_serial +SIMSIMD_MAKE_FMA(serial, i8, f32, SIMSIMD_DEREFERENCE, SIMSIMD_F32_TO_I8) // simsimd_fma_i8_serial +SIMSIMD_MAKE_FMA(serial, u8, f32, SIMSIMD_DEREFERENCE, SIMSIMD_F32_TO_U8) // simsimd_fma_u8_serial + +SIMSIMD_MAKE_FMA(accurate, f32, f64, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT) // simsimd_fma_f32_accurate +SIMSIMD_MAKE_FMA(accurate, f16, f64, SIMSIMD_F16_TO_F32, SIMSIMD_F32_TO_F16) // simsimd_fma_f16_accurate +SIMSIMD_MAKE_FMA(accurate, bf16, f64, SIMSIMD_BF16_TO_F32, SIMSIMD_F32_TO_BF16) // simsimd_fma_bf16_accurate +SIMSIMD_MAKE_FMA(accurate, i8, f64, SIMSIMD_DEREFERENCE, SIMSIMD_F64_TO_I8) // simsimd_fma_i8_accurate +SIMSIMD_MAKE_FMA(accurate, u8, f64, SIMSIMD_DEREFERENCE, SIMSIMD_F64_TO_U8) // simsimd_fma_u8_accurate + +SIMSIMD_PUBLIC void simsimd_wsum_f64_haswell( // + simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f64_t *result); +SIMSIMD_PUBLIC void simsimd_wsum_f32_haswell( // + simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f32_t *result); +SIMSIMD_PUBLIC void simsimd_wsum_f16_haswell( // + simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f16_t *result); +SIMSIMD_PUBLIC void simsimd_wsum_bf16_haswell( // + simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_bf16_t *result); +SIMSIMD_PUBLIC void simsimd_wsum_i8_haswell( // + simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_i8_t *result); +SIMSIMD_PUBLIC void simsimd_wsum_u8_haswell( // + simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_u8_t *result); +SIMSIMD_PUBLIC void simsimd_fma_f64_haswell( // + simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_f64_t const *c, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f64_t *result); +SIMSIMD_PUBLIC void simsimd_fma_f32_haswell( // + simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_f32_t const *c, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f32_t *result); +SIMSIMD_PUBLIC void simsimd_fma_f16_haswell( // + simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_f16_t const *c, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f16_t *result); +SIMSIMD_PUBLIC void simsimd_fma_bf16_haswell( // + simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_bf16_t const *c, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_bf16_t *result); +SIMSIMD_PUBLIC void simsimd_fma_i8_haswell( // + simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_i8_t const *c, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_i8_t *result); +SIMSIMD_PUBLIC void simsimd_fma_u8_haswell( // + simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_u8_t const *c, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_u8_t *result); + +/* SIMD-powered backends for various generations of AVX512 CPUs. + * Unlike the distance metrics, the SIMD implementation of FMA and WSum benefits from aligned stores. + * Assuming the size of ZMM register matches the width of the cache line, we skip the unaligned head + * and tail of the output buffer, and only use aligned stores in the main loop. + */ +SIMSIMD_PUBLIC void simsimd_wsum_f64_skylake( // + simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f64_t *result); +SIMSIMD_PUBLIC void simsimd_wsum_f32_skylake( // + simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f32_t *result); +SIMSIMD_PUBLIC void simsimd_wsum_bf16_skylake( // + simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_bf16_t *result); + +SIMSIMD_PUBLIC void simsimd_fma_f64_skylake( // + simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_f64_t const *c, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f64_t *result); +SIMSIMD_PUBLIC void simsimd_fma_f32_skylake( // + simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_f32_t const *c, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f32_t *result); +SIMSIMD_PUBLIC void simsimd_fma_bf16_skylake( // + simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_bf16_t const *c, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_bf16_t *result); + +SIMSIMD_PUBLIC void simsimd_wsum_f16_sapphire( // + simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f16_t *result); +SIMSIMD_PUBLIC void simsimd_wsum_i8_sapphire( // + simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_i8_t *result); +SIMSIMD_PUBLIC void simsimd_wsum_u8_sapphire( // + simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_u8_t *result); + +SIMSIMD_PUBLIC void simsimd_fma_f16_sapphire( // + simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_f16_t const *c, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f16_t *result); +SIMSIMD_PUBLIC void simsimd_fma_i8_sapphire( // + simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_i8_t const *c, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_i8_t *result); +SIMSIMD_PUBLIC void simsimd_fma_u8_sapphire( // + simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_u8_t const *c, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_u8_t *result); + +#if _SIMSIMD_TARGET_X86 +#if SIMSIMD_TARGET_HASWELL +#pragma GCC push_options +#pragma GCC target("avx2", "f16c", "fma") +#pragma clang attribute push(__attribute__((target("avx2,f16c,fma"))), apply_to = function) + +SIMSIMD_PUBLIC void simsimd_wsum_f32_haswell( // + simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f32_t *result) { + simsimd_f32_t alpha_f32 = (simsimd_f32_t)alpha; + simsimd_f32_t beta_f32 = (simsimd_f32_t)beta; + __m256 alpha_vec = _mm256_set1_ps(alpha_f32); + __m256 beta_vec = _mm256_set1_ps(beta_f32); + + // The main loop: + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + __m256 a_vec = _mm256_loadu_ps(a + i); + __m256 b_vec = _mm256_loadu_ps(b + i); + __m256 a_scaled = _mm256_mul_ps(a_vec, alpha_vec); + __m256 b_scaled = _mm256_mul_ps(b_vec, beta_vec); + __m256 sum_vec = _mm256_add_ps(a_scaled, b_scaled); + _mm256_storeu_ps(result + i, sum_vec); + } + + // The tail: + for (; i < n; ++i) result[i] = alpha_f32 * a[i] + beta_f32 * b[i]; +} + +SIMSIMD_PUBLIC void simsimd_wsum_f64_haswell( // + simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f64_t *result) { + __m256d alpha_vec = _mm256_set1_pd(alpha); + __m256d beta_vec = _mm256_set1_pd(beta); + + // The main loop: + simsimd_size_t i = 0; + for (; i + 4 <= n; i += 4) { + __m256d a_vec = _mm256_loadu_pd(a + i); + __m256d b_vec = _mm256_loadu_pd(b + i); + __m256d a_scaled = _mm256_mul_pd(a_vec, alpha_vec); + __m256d b_scaled = _mm256_mul_pd(b_vec, beta_vec); + __m256d sum_vec = _mm256_add_pd(a_scaled, b_scaled); + _mm256_storeu_pd(result + i, sum_vec); + } + + // The tail: + for (; i < n; ++i) result[i] = alpha * a[i] + beta * b[i]; +} + +SIMSIMD_PUBLIC void simsimd_wsum_f16_haswell( // + simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f16_t *result) { + simsimd_f32_t alpha_f32 = (simsimd_f32_t)alpha; + simsimd_f32_t beta_f32 = (simsimd_f32_t)beta; + __m256 alpha_vec = _mm256_set1_ps(alpha_f32); + __m256 beta_vec = _mm256_set1_ps(beta_f32); + + // The main loop: + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + __m128i a_f16 = _mm_loadu_si128((__m128i const *)(a + i)); + __m128i b_f16 = _mm_loadu_si128((__m128i const *)(b + i)); + __m256 a_vec = _mm256_cvtph_ps(a_f16); + __m256 b_vec = _mm256_cvtph_ps(b_f16); + __m256 a_scaled = _mm256_mul_ps(a_vec, alpha_vec); + __m256 b_scaled = _mm256_mul_ps(b_vec, beta_vec); + __m256 sum_vec = _mm256_add_ps(a_scaled, b_scaled); + __m128i sum_f16 = _mm256_cvtps_ph(sum_vec, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + _mm_storeu_si128((__m128i *)(result + i), sum_f16); + } + + // The tail: + for (; i < n; ++i) { + simsimd_f32_t ai = SIMSIMD_F16_TO_F32(a + i); + simsimd_f32_t bi = SIMSIMD_F16_TO_F32(b + i); + simsimd_f32_t sum = alpha_f32 * ai + beta_f32 * bi; + SIMSIMD_F32_TO_F16(sum, result + i); + } +} + +SIMSIMD_PUBLIC void simsimd_wsum_bf16_haswell( // + simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_bf16_t *result) { + simsimd_f32_t alpha_f32 = (simsimd_f32_t)alpha; + simsimd_f32_t beta_f32 = (simsimd_f32_t)beta; + __m256 alpha_vec = _mm256_set1_ps(alpha_f32); + __m256 beta_vec = _mm256_set1_ps(beta_f32); + + // The main loop: + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + __m128i a_bf16 = _mm_loadu_si128((__m128i const *)(a + i)); + __m128i b_bf16 = _mm_loadu_si128((__m128i const *)(b + i)); + __m256 a_vec = _simsimd_bf16x8_to_f32x8_haswell(a_bf16); + __m256 b_vec = _simsimd_bf16x8_to_f32x8_haswell(b_bf16); + __m256 a_scaled = _mm256_mul_ps(a_vec, alpha_vec); + __m256 b_scaled = _mm256_mul_ps(b_vec, beta_vec); + __m256 sum_vec = _mm256_add_ps(a_scaled, b_scaled); + __m128i sum_bf16 = _simsimd_f32x8_to_bf16x8_haswell(sum_vec); + _mm_storeu_si128((__m128i *)(result + i), sum_bf16); + } + + // The tail: + for (; i < n; ++i) { + simsimd_f32_t ai = SIMSIMD_BF16_TO_F32(a + i); + simsimd_f32_t bi = SIMSIMD_BF16_TO_F32(b + i); + simsimd_f32_t sum = alpha_f32 * ai + beta_f32 * bi; + SIMSIMD_F32_TO_BF16(sum, result + i); + } +} + +SIMSIMD_PUBLIC void simsimd_fma_f32_haswell( // + simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_f32_t const *c, // + simsimd_size_t n, simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f32_t *result) { + simsimd_f32_t alpha_f32 = (simsimd_f32_t)alpha; + simsimd_f32_t beta_f32 = (simsimd_f32_t)beta; + __m256 alpha_vec = _mm256_set1_ps(alpha_f32); + __m256 beta_vec = _mm256_set1_ps(beta_f32); + + // The main loop: + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + __m256 a_vec = _mm256_loadu_ps(a + i); + __m256 b_vec = _mm256_loadu_ps(b + i); + __m256 c_vec = _mm256_loadu_ps(c + i); + __m256 ab_vec = _mm256_mul_ps(a_vec, b_vec); + __m256 ab_scaled_vec = _mm256_mul_ps(ab_vec, alpha_vec); + __m256 c_scaled_vec = _mm256_mul_ps(c_vec, beta_vec); + __m256 sum_vec = _mm256_add_ps(ab_scaled_vec, c_scaled_vec); + _mm256_storeu_ps(result + i, sum_vec); + } + + // The tail: + for (; i < n; ++i) result[i] = alpha_f32 * a[i] * b[i] + beta_f32 * c[i]; +} + +SIMSIMD_PUBLIC void simsimd_fma_f64_haswell( // + simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_f64_t const *c, // + simsimd_size_t n, simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f64_t *result) { + __m256d alpha_vec = _mm256_set1_pd(alpha); + __m256d beta_vec = _mm256_set1_pd(beta); + + // The main loop: + simsimd_size_t i = 0; + for (; i + 4 <= n; i += 4) { + __m256d a_vec = _mm256_loadu_pd(a + i); + __m256d b_vec = _mm256_loadu_pd(b + i); + __m256d c_vec = _mm256_loadu_pd(c + i); + __m256d ab_vec = _mm256_mul_pd(a_vec, b_vec); + __m256d ab_scaled_vec = _mm256_mul_pd(ab_vec, alpha_vec); + __m256d c_scaled_vec = _mm256_mul_pd(c_vec, beta_vec); + __m256d sum_vec = _mm256_add_pd(ab_scaled_vec, c_scaled_vec); + _mm256_storeu_pd(result + i, sum_vec); + } + + // The tail: + for (; i < n; ++i) result[i] = alpha * a[i] * b[i] + beta * c[i]; +} + +SIMSIMD_PUBLIC void simsimd_fma_f16_haswell( // + simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_f16_t const *c, // + simsimd_size_t n, simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f16_t *result) { + simsimd_f32_t alpha_f32 = (simsimd_f32_t)alpha; + simsimd_f32_t beta_f32 = (simsimd_f32_t)beta; + __m256 alpha_vec = _mm256_set1_ps(alpha_f32); + __m256 beta_vec = _mm256_set1_ps(beta_f32); + + // The main loop: + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + __m128i a_f16 = _mm_loadu_si128((__m128i const *)(a + i)); + __m128i b_f16 = _mm_loadu_si128((__m128i const *)(b + i)); + __m128i c_f16 = _mm_loadu_si128((__m128i const *)(c + i)); + __m256 a_vec = _mm256_cvtph_ps(a_f16); + __m256 b_vec = _mm256_cvtph_ps(b_f16); + __m256 c_vec = _mm256_cvtph_ps(c_f16); + __m256 ab_vec = _mm256_mul_ps(a_vec, b_vec); + __m256 ab_scaled_vec = _mm256_mul_ps(ab_vec, alpha_vec); + __m256 c_scaled_vec = _mm256_mul_ps(c_vec, beta_vec); + __m256 sum_vec = _mm256_add_ps(ab_scaled_vec, c_scaled_vec); + __m128i sum_f16 = _mm256_cvtps_ph(sum_vec, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + _mm_storeu_si128((__m128i *)(result + i), sum_f16); + } + + // The tail: + for (; i < n; ++i) { + simsimd_f32_t ai = SIMSIMD_F16_TO_F32(a + i); + simsimd_f32_t bi = SIMSIMD_F16_TO_F32(b + i); + simsimd_f32_t ci = SIMSIMD_F16_TO_F32(c + i); + simsimd_f32_t sum = alpha * ai * bi + beta * ci; + SIMSIMD_F32_TO_F16(sum, result + i); + } +} + +SIMSIMD_PUBLIC void simsimd_fma_bf16_haswell( // + simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_bf16_t const *c, // + simsimd_size_t n, simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_bf16_t *result) { + simsimd_f32_t alpha_f32 = (simsimd_f32_t)alpha; + simsimd_f32_t beta_f32 = (simsimd_f32_t)beta; + __m256 alpha_vec = _mm256_set1_ps(alpha_f32); + __m256 beta_vec = _mm256_set1_ps(beta_f32); + + // The main loop: + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + __m128i a_bf16 = _mm_loadu_si128((__m128i const *)(a + i)); + __m128i b_bf16 = _mm_loadu_si128((__m128i const *)(b + i)); + __m128i c_bf16 = _mm_loadu_si128((__m128i const *)(c + i)); + __m256 a_vec = _simsimd_bf16x8_to_f32x8_haswell(a_bf16); + __m256 b_vec = _simsimd_bf16x8_to_f32x8_haswell(b_bf16); + __m256 c_vec = _simsimd_bf16x8_to_f32x8_haswell(c_bf16); + __m256 ab_vec = _mm256_mul_ps(a_vec, b_vec); + __m256 ab_scaled_vec = _mm256_mul_ps(ab_vec, alpha_vec); + __m256 c_scaled_vec = _mm256_mul_ps(c_vec, beta_vec); + __m256 sum_vec = _mm256_add_ps(ab_scaled_vec, c_scaled_vec); + __m128i sum_bf16 = _simsimd_f32x8_to_bf16x8_haswell(sum_vec); + _mm_storeu_si128((__m128i *)(result + i), sum_bf16); + } + + // The tail: + for (; i < n; ++i) { + simsimd_f32_t ai = SIMSIMD_BF16_TO_F32(a + i); + simsimd_f32_t bi = SIMSIMD_BF16_TO_F32(b + i); + simsimd_f32_t ci = SIMSIMD_BF16_TO_F32(c + i); + simsimd_f32_t sum = alpha * ai * bi + beta * ci; + SIMSIMD_F32_TO_BF16(sum, result + i); + } +} + +SIMSIMD_PUBLIC void simsimd_wsum_i8_haswell( // + simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_i8_t *result) { + simsimd_wsum_i8_serial(a, b, n, alpha, beta, result); // TODO +} +SIMSIMD_PUBLIC void simsimd_wsum_u8_haswell( // + simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_u8_t *result) { + simsimd_wsum_u8_serial(a, b, n, alpha, beta, result); // TODO +} +SIMSIMD_PUBLIC void simsimd_fma_i8_haswell( // + simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_i8_t const *c, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_i8_t *result) { + simsimd_fma_i8_serial(a, b, c, n, alpha, beta, result); // TODO +} +SIMSIMD_PUBLIC void simsimd_fma_u8_haswell( // + simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_u8_t const *c, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_u8_t *result) { + simsimd_fma_u8_serial(a, b, c, n, alpha, beta, result); // TODO +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_HASWELL + +#if SIMSIMD_TARGET_SKYLAKE +#pragma GCC push_options +#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "bmi2") +#pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,avx512bw,bmi2"))), apply_to = function) + +SIMSIMD_PUBLIC void simsimd_wsum_f64_skylake( // + simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f64_t *result) { + __m512d alpha_vec = _mm512_set1_pd(alpha); + __m512d beta_vec = _mm512_set1_pd(beta); + __m512d a_vec, b_vec, a_scaled_vec, sum_vec; + __mmask8 mask = 0xFF; + +simsimd_wsum_f64_skylake_cycle: + if (n < 8) { + mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, n); + a_vec = _mm512_maskz_loadu_pd(mask, a); + b_vec = _mm512_maskz_loadu_pd(mask, b); + n = 0; + } + else { + a_vec = _mm512_loadu_pd(a); + b_vec = _mm512_loadu_pd(b); + a += 8, b += 8, n -= 8; + } + a_scaled_vec = _mm512_mul_pd(a_vec, alpha_vec); + sum_vec = _mm512_fmadd_pd(b_vec, beta_vec, a_scaled_vec); + _mm512_mask_storeu_pd(result, mask, sum_vec); + result += 8; + if (n) goto simsimd_wsum_f64_skylake_cycle; +} + +SIMSIMD_PUBLIC void simsimd_wsum_f32_skylake( // + simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f32_t *result) { + __m512 alpha_vec = _mm512_set1_ps(alpha); + __m512 beta_vec = _mm512_set1_ps(beta); + __m512 a_vec, b_vec, a_scaled_vec, sum_vec; + __mmask16 mask = 0xFFFF; + +simsimd_wsum_f32_skylake_cycle: + if (n < 16) { + mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, n); + a_vec = _mm512_maskz_loadu_ps(mask, a); + b_vec = _mm512_maskz_loadu_ps(mask, b); + n = 0; + } + else { + a_vec = _mm512_loadu_ps(a); + b_vec = _mm512_loadu_ps(b); + a += 16, b += 16, n -= 16; + } + a_scaled_vec = _mm512_mul_ps(a_vec, alpha_vec); + sum_vec = _mm512_fmadd_ps(b_vec, beta_vec, a_scaled_vec); + _mm512_mask_storeu_ps(result, mask, sum_vec); + result += 16; + if (n) goto simsimd_wsum_f32_skylake_cycle; +} + +SIMSIMD_PUBLIC void simsimd_wsum_bf16_skylake( // + simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_bf16_t *result) { + __m512 alpha_vec = _mm512_set1_ps(alpha); + __m512 beta_vec = _mm512_set1_ps(beta); + __m256i a_bf16_vec, b_bf16_vec, sum_bf16_vec; + __m512 a_vec, b_vec, a_scaled_vec, sum_vec; + __mmask16 mask = 0xFFFF; + +simsimd_wsum_bf16_skylake_cycle: + if (n < 16) { + mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, n); + a_bf16_vec = _mm256_maskz_loadu_epi16(mask, a); + b_bf16_vec = _mm256_maskz_loadu_epi16(mask, b); + n = 0; + } + else { + a_bf16_vec = _mm256_loadu_epi16(a); + b_bf16_vec = _mm256_loadu_epi16(b); + a += 16, b += 16, n -= 16; + } + a_vec = _simsimd_bf16x16_to_f32x16_skylake(a_bf16_vec); + b_vec = _simsimd_bf16x16_to_f32x16_skylake(b_bf16_vec); + a_scaled_vec = _mm512_mul_ps(a_vec, alpha_vec); + sum_vec = _mm512_fmadd_ps(b_vec, beta_vec, a_scaled_vec); + sum_bf16_vec = _simsimd_f32x16_to_bf16x16_skylake(sum_vec); + _mm256_mask_storeu_epi16(result, mask, sum_bf16_vec); + result += 16; + if (n) goto simsimd_wsum_bf16_skylake_cycle; +} + +SIMSIMD_PUBLIC void simsimd_fma_f64_skylake( // + simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_f64_t const *c, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f64_t *result) { + __m512d alpha_vec = _mm512_set1_pd(alpha); + __m512d beta_vec = _mm512_set1_pd(beta); + __m512d a_vec, b_vec, c_vec, ab_vec, ab_scaled_vec, sum_vec; + __mmask8 mask = 0xFF; + +simsimd_fma_f64_skylake_cycle: + if (n < 8) { + mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, n); + a_vec = _mm512_maskz_loadu_pd(mask, a); + b_vec = _mm512_maskz_loadu_pd(mask, b); + c_vec = _mm512_maskz_loadu_pd(mask, c); + n = 0; + } + else { + a_vec = _mm512_loadu_pd(a); + b_vec = _mm512_loadu_pd(b); + c_vec = _mm512_loadu_pd(c); + a += 8, b += 8, c += 8, n -= 8; + } + ab_vec = _mm512_mul_pd(a_vec, b_vec); + ab_scaled_vec = _mm512_mul_pd(ab_vec, alpha_vec); + sum_vec = _mm512_fmadd_pd(c_vec, beta_vec, ab_scaled_vec); + _mm512_mask_storeu_pd(result, mask, sum_vec); + result += 8; + if (n) goto simsimd_fma_f64_skylake_cycle; +} + +SIMSIMD_PUBLIC void simsimd_fma_f32_skylake( // + simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_f32_t const *c, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f32_t *result) { + __m512 alpha_vec = _mm512_set1_ps(alpha); + __m512 beta_vec = _mm512_set1_ps(beta); + __m512 a_vec, b_vec, c_vec, ab_vec, ab_scaled_vec, sum_vec; + __mmask16 mask = 0xFFFF; + +simsimd_fma_f32_skylake_cycle: + if (n < 16) { + mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, n); + a_vec = _mm512_maskz_loadu_ps(mask, a); + b_vec = _mm512_maskz_loadu_ps(mask, b); + c_vec = _mm512_maskz_loadu_ps(mask, c); + n = 0; + } + else { + a_vec = _mm512_loadu_ps(a); + b_vec = _mm512_loadu_ps(b); + c_vec = _mm512_loadu_ps(c); + a += 16, b += 16, c += 16, n -= 16; + } + ab_vec = _mm512_mul_ps(a_vec, b_vec); + ab_scaled_vec = _mm512_mul_ps(ab_vec, alpha_vec); + sum_vec = _mm512_fmadd_ps(c_vec, beta_vec, ab_scaled_vec); + _mm512_mask_storeu_ps(result, mask, sum_vec); + result += 16; + if (n) goto simsimd_fma_f32_skylake_cycle; +} + +SIMSIMD_PUBLIC void simsimd_fma_bf16_skylake( // + simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_bf16_t const *c, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_bf16_t *result) { + __m512 alpha_vec = _mm512_set1_ps(alpha); + __m512 beta_vec = _mm512_set1_ps(beta); + __m256i a_bf16_vec, b_bf16_vec, c_bf16_vec, sum_bf16_vec; + __m512 a_vec, b_vec, c_vec, ab_vec, ab_scaled_vec, sum_vec; + __mmask16 mask = 0xFFFF; + +simsimd_fma_bf16_skylake_cycle: + if (n < 16) { + mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, n); + a_bf16_vec = _mm256_maskz_loadu_epi16(mask, a); + b_bf16_vec = _mm256_maskz_loadu_epi16(mask, b); + c_bf16_vec = _mm256_maskz_loadu_epi16(mask, c); + n = 0; + } + else { + a_bf16_vec = _mm256_loadu_epi16(a); + b_bf16_vec = _mm256_loadu_epi16(b); + c_bf16_vec = _mm256_loadu_epi16(c); + a += 16, b += 16, c += 16, n -= 16; + } + a_vec = _simsimd_bf16x16_to_f32x16_skylake(a_bf16_vec); + b_vec = _simsimd_bf16x16_to_f32x16_skylake(b_bf16_vec); + c_vec = _simsimd_bf16x16_to_f32x16_skylake(c_bf16_vec); + ab_vec = _mm512_mul_ps(a_vec, b_vec); + ab_scaled_vec = _mm512_mul_ps(ab_vec, alpha_vec); + sum_vec = _mm512_fmadd_ps(c_vec, beta_vec, ab_scaled_vec); + sum_bf16_vec = _simsimd_f32x16_to_bf16x16_skylake(sum_vec); + _mm256_mask_storeu_epi16(result, mask, sum_bf16_vec); + result += 16; + if (n) goto simsimd_fma_bf16_skylake_cycle; +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_SKYLAKE + +#if SIMSIMD_TARGET_SAPPHIRE +#pragma GCC push_options +#pragma GCC target("avx2", "avx512f", "avx512vl", "bmi2", "avx512bw", "avx512fp16") +#pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,bmi2,avx512bw,avx512fp16"))), \ + apply_to = function) + +SIMSIMD_PUBLIC void simsimd_wsum_f16_sapphire( // + simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f16_t *result) { + + __mmask32 mask = 0xFFFFFFFF; + __m512h alpha_vec = _mm512_set1_ph((_Float16)alpha); + __m512h beta_vec = _mm512_set1_ph((_Float16)beta); + __m512h a_f16_vec, b_f16_vec, c_f16_vec; + __m512h a_scaled_f16_vec, sum_f16_vec; + +simsimd_wsum_f16_sapphire_cycle: + if (n < 32) { + mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n); + a_f16_vec = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a)); + b_f16_vec = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b)); + n = 0; + } + else { + a_f16_vec = _mm512_loadu_ph(a); + b_f16_vec = _mm512_loadu_ph(b); + a += 32, b += 32, n -= 32; + } + a_scaled_f16_vec = _mm512_mul_ph(a_f16_vec, alpha_vec); + sum_f16_vec = _mm512_fmadd_ph(b_f16_vec, beta_vec, a_scaled_f16_vec); + _mm512_mask_storeu_epi16(result, mask, _mm512_castph_si512(sum_f16_vec)); + result += 32; + if (n) goto simsimd_wsum_f16_sapphire_cycle; +} + +SIMSIMD_PUBLIC void simsimd_fma_f16_sapphire( // + simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_f16_t const *c, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f16_t *result) { + + __mmask32 mask = 0xFFFFFFFF; + __m512h alpha_vec = _mm512_set1_ph((_Float16)alpha); + __m512h beta_vec = _mm512_set1_ph((_Float16)beta); + __m512h a_f16_vec, b_f16_vec, c_f16_vec; + __m512h ab_f16_vec, ab_scaled_f16_vec, sum_f16_vec; + +simsimd_fma_f16_sapphire_cycle: + if (n < 32) { + mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n); + a_f16_vec = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a)); + b_f16_vec = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b)); + c_f16_vec = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, c)); + n = 0; + } + else { + a_f16_vec = _mm512_loadu_ph(a); + b_f16_vec = _mm512_loadu_ph(b); + c_f16_vec = _mm512_loadu_ph(c); + a += 32, b += 32, c += 32, n -= 32; + } + ab_f16_vec = _mm512_mul_ph(a_f16_vec, b_f16_vec); + ab_scaled_f16_vec = _mm512_mul_ph(ab_f16_vec, alpha_vec); + sum_f16_vec = _mm512_fmadd_ph(c_f16_vec, beta_vec, ab_scaled_f16_vec); + _mm512_mask_storeu_epi16(result, mask, _mm512_castph_si512(sum_f16_vec)); + result += 32; + if (n) goto simsimd_fma_f16_sapphire_cycle; +} + +SIMSIMD_PUBLIC void simsimd_wsum_u8_sapphire( // + simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_u8_t *result) { + + __mmask64 mask = 0xFFFFFFFFFFFFFFFFull; + __m512h alpha_vec = _mm512_set1_ph((_Float16)alpha); + __m512h beta_vec = _mm512_set1_ph((_Float16)beta); + __m512i a_u8_vec, b_u8_vec, c_u8_vec, sum_u8_vec; + __m512h a_f16_low_vec, a_f16_high_vec, b_f16_low_vec, b_f16_high_vec; + __m512h a_scaled_f16_low_vec, a_scaled_f16_high_vec, sum_f16_low_vec, sum_f16_high_vec; + __m512i sum_i16_low_vec, sum_i16_high_vec; + +simsimd_wsum_u8_sapphire_cycle: + if (n < 64) { + mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFFull, n); + a_u8_vec = _mm512_maskz_loadu_epi8(mask, a); + b_u8_vec = _mm512_maskz_loadu_epi8(mask, b); + n = 0; + } + else { + a_u8_vec = _mm512_loadu_epi8(a); + b_u8_vec = _mm512_loadu_epi8(b); + a += 64, b += 64, n -= 64; + } + // Upcast: + a_f16_low_vec = _mm512_cvtepi16_ph(_mm512_unpacklo_epi8(a_u8_vec, _mm512_setzero_si512())); + a_f16_high_vec = _mm512_cvtepi16_ph(_mm512_unpackhi_epi8(a_u8_vec, _mm512_setzero_si512())); + b_f16_low_vec = _mm512_cvtepi16_ph(_mm512_unpacklo_epi8(b_u8_vec, _mm512_setzero_si512())); + b_f16_high_vec = _mm512_cvtepi16_ph(_mm512_unpackhi_epi8(b_u8_vec, _mm512_setzero_si512())); + // Scale: + a_scaled_f16_low_vec = _mm512_mul_ph(a_f16_low_vec, alpha_vec); + a_scaled_f16_high_vec = _mm512_mul_ph(a_f16_high_vec, alpha_vec); + // Add: + sum_f16_low_vec = _mm512_fmadd_ph(b_f16_low_vec, beta_vec, a_scaled_f16_low_vec); + sum_f16_high_vec = _mm512_fmadd_ph(b_f16_high_vec, beta_vec, a_scaled_f16_high_vec); + // Downcast: + sum_i16_low_vec = _mm512_cvtph_epi16(sum_f16_low_vec); + sum_i16_high_vec = _mm512_cvtph_epi16(sum_f16_high_vec); + sum_u8_vec = _mm512_packus_epi16(sum_i16_low_vec, sum_i16_high_vec); + _mm512_mask_storeu_epi8(result, mask, sum_u8_vec); + result += 64; + if (n) goto simsimd_wsum_u8_sapphire_cycle; +} + +SIMSIMD_PUBLIC void simsimd_wsum_i8_sapphire( // + simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_i8_t *result) { + + __mmask64 mask = 0xFFFFFFFFFFFFFFFFull; + __m512h alpha_vec = _mm512_set1_ph((_Float16)alpha); + __m512h beta_vec = _mm512_set1_ph((_Float16)beta); + __m512i a_i8_vec, b_i8_vec, c_i8_vec, sum_i8_vec; + __m512h a_f16_low_vec, a_f16_high_vec, b_f16_low_vec, b_f16_high_vec; + __m512h a_scaled_f16_low_vec, a_scaled_f16_high_vec, sum_f16_low_vec, sum_f16_high_vec; + __m512i sum_i16_low_vec, sum_i16_high_vec; + +simsimd_wsum_i8_sapphire_cycle: + if (n < 64) { + mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFFull, n); + a_i8_vec = _mm512_maskz_loadu_epi8(mask, a); + b_i8_vec = _mm512_maskz_loadu_epi8(mask, b); + n = 0; + } + else { + a_i8_vec = _mm512_loadu_epi8(a); + b_i8_vec = _mm512_loadu_epi8(b); + a += 64, b += 64, n -= 64; + } + // Upcast: + a_f16_low_vec = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(_mm512_castsi512_si256(a_i8_vec))); + a_f16_high_vec = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(a_i8_vec, 1))); + b_f16_low_vec = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(_mm512_castsi512_si256(b_i8_vec))); + b_f16_high_vec = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(b_i8_vec, 1))); + // Scale: + a_scaled_f16_low_vec = _mm512_mul_ph(a_f16_low_vec, alpha_vec); + a_scaled_f16_high_vec = _mm512_mul_ph(a_f16_high_vec, alpha_vec); + // Add: + sum_f16_low_vec = _mm512_fmadd_ph(b_f16_low_vec, beta_vec, a_scaled_f16_low_vec); + sum_f16_high_vec = _mm512_fmadd_ph(b_f16_high_vec, beta_vec, a_scaled_f16_high_vec); + // Downcast: + sum_i16_low_vec = _mm512_cvtph_epi16(sum_f16_low_vec); + sum_i16_high_vec = _mm512_cvtph_epi16(sum_f16_high_vec); + sum_i8_vec = _mm512_inserti64x4(_mm512_castsi256_si512(_mm512_cvtsepi16_epi8(sum_i16_low_vec)), + _mm512_cvtsepi16_epi8(sum_i16_high_vec), 1); + _mm512_mask_storeu_epi8(result, mask, sum_i8_vec); + result += 64; + if (n) goto simsimd_wsum_i8_sapphire_cycle; +} + +SIMSIMD_PUBLIC void simsimd_fma_i8_sapphire( // + simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_i8_t const *c, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_i8_t *result) { + + __mmask64 mask = 0xFFFFFFFFFFFFFFFF; + __m512h alpha_vec = _mm512_set1_ph((_Float16)alpha); + __m512h beta_vec = _mm512_set1_ph((_Float16)beta); + __m512i a_i8_vec, b_i8_vec, c_i8_vec, sum_i8_vec; + __m512h a_f16_low_vec, a_f16_high_vec, b_f16_low_vec, b_f16_high_vec; + __m512h c_f16_low_vec, c_f16_high_vec, ab_f16_low_vec, ab_f16_high_vec; + __m512h ab_scaled_f16_low_vec, ab_scaled_f16_high_vec, sum_f16_low_vec, sum_f16_high_vec; + __m512i sum_i16_low_vec, sum_i16_high_vec; + +simsimd_fma_i8_sapphire_cycle: + if (n < 64) { + mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n); + a_i8_vec = _mm512_maskz_loadu_epi8(mask, a); + b_i8_vec = _mm512_maskz_loadu_epi8(mask, b); + c_i8_vec = _mm512_maskz_loadu_epi8(mask, c); + n = 0; + } + else { + a_i8_vec = _mm512_loadu_epi8(a); + b_i8_vec = _mm512_loadu_epi8(b); + c_i8_vec = _mm512_loadu_epi8(c); + a += 64, b += 64, c += 64, n -= 64; + } + // Upcast: + a_f16_low_vec = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(_mm512_castsi512_si256(a_i8_vec))); + a_f16_high_vec = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(a_i8_vec, 1))); + b_f16_low_vec = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(_mm512_castsi512_si256(b_i8_vec))); + b_f16_high_vec = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(b_i8_vec, 1))); + c_f16_low_vec = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(_mm512_castsi512_si256(c_i8_vec))); + c_f16_high_vec = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(c_i8_vec, 1))); + // Multiply: + ab_f16_low_vec = _mm512_mul_ph(a_f16_low_vec, b_f16_low_vec); + ab_f16_high_vec = _mm512_mul_ph(a_f16_high_vec, b_f16_high_vec); + // Scale: + ab_scaled_f16_low_vec = _mm512_mul_ph(ab_f16_low_vec, alpha_vec); + ab_scaled_f16_high_vec = _mm512_mul_ph(ab_f16_high_vec, alpha_vec); + // Add: + sum_f16_low_vec = _mm512_fmadd_ph(c_f16_low_vec, beta_vec, ab_scaled_f16_low_vec); + sum_f16_high_vec = _mm512_fmadd_ph(c_f16_high_vec, beta_vec, ab_scaled_f16_high_vec); + // Downcast: + sum_i16_low_vec = _mm512_cvtph_epi16(sum_f16_low_vec); + sum_i16_high_vec = _mm512_cvtph_epi16(sum_f16_high_vec); + sum_i8_vec = _mm512_inserti64x4(_mm512_castsi256_si512(_mm512_cvtsepi16_epi8(sum_i16_low_vec)), + _mm512_cvtsepi16_epi8(sum_i16_high_vec), 1); + _mm512_mask_storeu_epi8(result, mask, sum_i8_vec); + result += 64; + if (n) goto simsimd_fma_i8_sapphire_cycle; +} + +SIMSIMD_PUBLIC void simsimd_fma_u8_sapphire( // + simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_u8_t const *c, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_u8_t *result) { + + __mmask64 mask = 0xFFFFFFFFFFFFFFFF; + __m512h alpha_vec = _mm512_set1_ph((_Float16)alpha); + __m512h beta_vec = _mm512_set1_ph((_Float16)beta); + __m512i a_u8_vec, b_u8_vec, c_u8_vec, sum_u8_vec; + __m512h a_f16_low_vec, a_f16_high_vec, b_f16_low_vec, b_f16_high_vec; + __m512h c_f16_low_vec, c_f16_high_vec, ab_f16_low_vec, ab_f16_high_vec; + __m512h ab_scaled_f16_low_vec, ab_scaled_f16_high_vec, sum_f16_low_vec, sum_f16_high_vec; + __m512i sum_i16_low_vec, sum_i16_high_vec; + +simsimd_fma_u8_sapphire_cycle: + if (n < 64) { + mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n); + a_u8_vec = _mm512_maskz_loadu_epi8(mask, a); + b_u8_vec = _mm512_maskz_loadu_epi8(mask, b); + c_u8_vec = _mm512_maskz_loadu_epi8(mask, c); + n = 0; + } + else { + a_u8_vec = _mm512_loadu_epi8(a); + b_u8_vec = _mm512_loadu_epi8(b); + c_u8_vec = _mm512_loadu_epi8(c); + a += 64, b += 64, c += 64, n -= 64; + } + // Upcast: + a_f16_low_vec = _mm512_cvtepi16_ph(_mm512_unpacklo_epi8(a_u8_vec, _mm512_setzero_si512())); + a_f16_high_vec = _mm512_cvtepi16_ph(_mm512_unpackhi_epi8(a_u8_vec, _mm512_setzero_si512())); + b_f16_low_vec = _mm512_cvtepi16_ph(_mm512_unpacklo_epi8(b_u8_vec, _mm512_setzero_si512())); + b_f16_high_vec = _mm512_cvtepi16_ph(_mm512_unpackhi_epi8(b_u8_vec, _mm512_setzero_si512())); + c_f16_low_vec = _mm512_cvtepi16_ph(_mm512_unpacklo_epi8(c_u8_vec, _mm512_setzero_si512())); + c_f16_high_vec = _mm512_cvtepi16_ph(_mm512_unpackhi_epi8(c_u8_vec, _mm512_setzero_si512())); + // Multiply: + ab_f16_low_vec = _mm512_mul_ph(a_f16_low_vec, b_f16_low_vec); + ab_f16_high_vec = _mm512_mul_ph(a_f16_high_vec, b_f16_high_vec); + // Scale: + ab_scaled_f16_low_vec = _mm512_mul_ph(ab_f16_low_vec, alpha_vec); + ab_scaled_f16_high_vec = _mm512_mul_ph(ab_f16_high_vec, alpha_vec); + // Add: + sum_f16_low_vec = _mm512_fmadd_ph(c_f16_low_vec, beta_vec, ab_scaled_f16_low_vec); + sum_f16_high_vec = _mm512_fmadd_ph(c_f16_high_vec, beta_vec, ab_scaled_f16_high_vec); + // Downcast: + sum_i16_low_vec = _mm512_cvtph_epi16(sum_f16_low_vec); + sum_i16_high_vec = _mm512_cvtph_epi16(sum_f16_high_vec); + sum_u8_vec = _mm512_packus_epi16(sum_i16_low_vec, sum_i16_high_vec); + _mm512_mask_storeu_epi8(result, mask, sum_u8_vec); + result += 64; + if (n) goto simsimd_fma_u8_sapphire_cycle; +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_SAPPHIRE +#endif // _SIMSIMD_TARGET_X86 + +#if _SIMSIMD_TARGET_ARM +#if SIMSIMD_TARGET_NEON +#pragma GCC push_options +#pragma GCC target("arch=armv8.2-a+simd") +#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd"))), apply_to = function) + +SIMSIMD_PUBLIC void simsimd_wsum_f32_neon( // + simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f32_t *result) { + simsimd_f32_t alpha_f32 = (simsimd_f32_t)alpha; + simsimd_f32_t beta_f32 = (simsimd_f32_t)beta; + + // The main loop: + simsimd_size_t i = 0; + for (; i + 4 <= n; i += 4) { + float32x4_t a_vec = vld1q_f32(a + i); + float32x4_t b_vec = vld1q_f32(b + i); + float32x4_t a_scaled_vec = vmulq_n_f32(a_vec, alpha_f32); + float32x4_t b_scaled_vec = vmulq_n_f32(b_vec, beta_f32); + float32x4_t sum_vec = vaddq_f32(a_scaled_vec, b_scaled_vec); + vst1q_f32(result + i, sum_vec); + } + + // The tail: + for (; i < n; ++i) result[i] = alpha_f32 * a[i] + beta_f32 * b[i]; +} + +SIMSIMD_PUBLIC void simsimd_fma_f32_neon( // + simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_f32_t const *c, // + simsimd_size_t n, simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f32_t *result) { + simsimd_f32_t alpha_f32 = (simsimd_f32_t)alpha; + simsimd_f32_t beta_f32 = (simsimd_f32_t)beta; + + // The main loop: + simsimd_size_t i = 0; + for (; i + 4 <= n; i += 4) { + float32x4_t a_vec = vld1q_f32(a + i); + float32x4_t b_vec = vld1q_f32(b + i); + float32x4_t c_vec = vld1q_f32(c + i); + float32x4_t ab_vec = vmulq_f32(a_vec, b_vec); + float32x4_t ab_scaled_vec = vmulq_n_f32(ab_vec, alpha_f32); + float32x4_t sum_vec = vfmaq_n_f32(ab_scaled_vec, c_vec, beta_f32); + vst1q_f32(result + i, sum_vec); + } + + // The tail: + for (; i < n; ++i) result[i] = alpha_f32 * a[i] * b[i] + beta_f32 * c[i]; +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_NEON + +#if SIMSIMD_TARGET_NEON_F16 +#pragma GCC push_options +#pragma GCC target("arch=armv8.2-a+simd+fp16") +#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function) + +SIMSIMD_PUBLIC void simsimd_wsum_f16_neon( // + simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f16_t *result) { + float16_t alpha_f16 = (float16_t)alpha; + float16_t beta_f16 = (float16_t)beta; + + // The main loop: + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + float16x8_t a_vec = vld1q_f16(a + i); + float16x8_t b_vec = vld1q_f16(b + i); + float16x8_t a_scaled_vec = vmulq_n_f16(a_vec, alpha_f16); + float16x8_t b_scaled_vec = vmulq_n_f16(b_vec, beta_f16); + float16x8_t sum_vec = vaddq_f16(a_scaled_vec, b_scaled_vec); + vst1q_f16(result + i, sum_vec); + } + + // The tail: + for (; i < n; ++i) + ((float16_t *)result)[i] = alpha_f16 * ((float16_t const *)a)[i] + beta_f16 * ((float16_t const *)b)[i]; +} + +SIMSIMD_PUBLIC void simsimd_fma_f16_neon( // + simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_f16_t const *c, // + simsimd_size_t n, simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f16_t *result) { + float16_t alpha_f16 = (float16_t)alpha; + float16_t beta_f16 = (float16_t)beta; + + // The main loop: + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + float16x8_t a_vec = vld1q_f16(a + i); + float16x8_t b_vec = vld1q_f16(b + i); + float16x8_t c_vec = vld1q_f16(c + i); + float16x8_t ab_vec = vmulq_f16(a_vec, b_vec); + float16x8_t ab_scaled_vec = vmulq_n_f16(ab_vec, alpha_f16); + float16x8_t sum_vec = vfmaq_n_f16(ab_scaled_vec, c_vec, beta_f16); + vst1q_f16(result + i, sum_vec); + } + + // The tail: + for (; i < n; ++i) + ((float16_t *)result)[i] = + alpha_f16 * ((float16_t const *)a)[i] * ((float16_t const *)b)[i] + beta_f16 * ((float16_t const *)c)[i]; +} + +SIMSIMD_PUBLIC void simsimd_wsum_u8_neon( // + simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_u8_t *result) { + float16_t alpha_f16 = (float16_t)alpha; + float16_t beta_f16 = (float16_t)beta; + + // The main loop: + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + uint8x8_t a_u8_vec = vld1_u8(a + i); + uint8x8_t b_u8_vec = vld1_u8(b + i); + float16x8_t a_vec = vcvtq_f16_u16(vmovl_u8(a_u8_vec)); + float16x8_t b_vec = vcvtq_f16_u16(vmovl_u8(b_u8_vec)); + float16x8_t a_scaled_vec = vmulq_n_f16(a_vec, alpha_f16); + float16x8_t b_scaled_vec = vmulq_n_f16(b_vec, beta_f16); + float16x8_t sum_vec = vaddq_f16(a_scaled_vec, b_scaled_vec); + uint8x8_t sum_u8_vec = vmovn_u16(vcvtaq_u16_f16(sum_vec)); + vst1_u8(result + i, sum_u8_vec); + } + + // The tail: + for (; i < n; ++i) { SIMSIMD_F32_TO_U8(alpha_f16 * a[i] + beta_f16 * b[i], result + i); } +} + +SIMSIMD_PUBLIC void simsimd_fma_u8_neon( // + simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_u8_t const *c, // + simsimd_size_t n, simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_u8_t *result) { + float16_t alpha_f16 = (float16_t)alpha; + float16_t beta_f16 = (float16_t)beta; + + // The main loop: + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + uint8x8_t a_u8_vec = vld1_u8(a + i); + uint8x8_t b_u8_vec = vld1_u8(b + i); + uint8x8_t c_u8_vec = vld1_u8(c + i); + float16x8_t a_vec = vcvtq_f16_u16(vmovl_u8(a_u8_vec)); + float16x8_t b_vec = vcvtq_f16_u16(vmovl_u8(b_u8_vec)); + float16x8_t c_vec = vcvtq_f16_u16(vmovl_u8(c_u8_vec)); + float16x8_t ab_vec = vmulq_f16(a_vec, b_vec); + float16x8_t ab_scaled_vec = vmulq_n_f16(ab_vec, alpha_f16); + float16x8_t sum_vec = vfmaq_n_f16(ab_scaled_vec, c_vec, beta_f16); + uint8x8_t sum_u8_vec = vmovn_u16(vcvtaq_u16_f16(sum_vec)); + vst1_u8(result + i, sum_u8_vec); + } + + // The tail: + for (; i < n; ++i) { SIMSIMD_F32_TO_U8(alpha_f16 * a[i] * b[i] + beta_f16 * c[i], result + i); } +} + +SIMSIMD_PUBLIC void simsimd_wsum_i8_neon( // + simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_i8_t *result) { + float16_t alpha_f16 = (float16_t)alpha; + float16_t beta_f16 = (float16_t)beta; + + // The main loop: + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + int8x8_t a_i8_vec = vld1_s8(a + i); + int8x8_t b_i8_vec = vld1_s8(b + i); + float16x8_t a_vec = vcvtq_f16_s16(vmovl_s8(a_i8_vec)); + float16x8_t b_vec = vcvtq_f16_s16(vmovl_s8(b_i8_vec)); + float16x8_t a_scaled_vec = vmulq_n_f16(a_vec, alpha_f16); + float16x8_t b_scaled_vec = vmulq_n_f16(b_vec, beta_f16); + float16x8_t sum_vec = vaddq_f16(a_scaled_vec, b_scaled_vec); + int8x8_t sum_i8_vec = vmovn_s16(vcvtaq_s16_f16(sum_vec)); + vst1_s8(result + i, sum_i8_vec); + } + + // The tail: + for (; i < n; ++i) { SIMSIMD_F32_TO_I8(alpha_f16 * a[i] + beta_f16 * b[i], result + i); } +} + +SIMSIMD_PUBLIC void simsimd_fma_i8_neon( // + simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_i8_t const *c, // + simsimd_size_t n, simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_i8_t *result) { + float16_t alpha_f16 = (float16_t)alpha; + float16_t beta_f16 = (float16_t)beta; + + // The main loop: + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + int8x8_t a_i8_vec = vld1_s8(a + i); + int8x8_t b_i8_vec = vld1_s8(b + i); + int8x8_t c_i8_vec = vld1_s8(c + i); + float16x8_t a_vec = vcvtq_f16_s16(vmovl_s8(a_i8_vec)); + float16x8_t b_vec = vcvtq_f16_s16(vmovl_s8(b_i8_vec)); + float16x8_t c_vec = vcvtq_f16_s16(vmovl_s8(c_i8_vec)); + float16x8_t ab_vec = vmulq_f16(a_vec, b_vec); + float16x8_t ab_scaled_vec = vmulq_n_f16(ab_vec, alpha_f16); + float16x8_t sum_vec = vfmaq_n_f16(ab_scaled_vec, c_vec, beta_f16); + int8x8_t sum_i8_vec = vmovn_s16(vcvtaq_s16_f16(sum_vec)); + vst1_s8(result + i, sum_i8_vec); + } + + // The tail: + for (; i < n; ++i) { SIMSIMD_F32_TO_I8(alpha_f16 * a[i] * b[i] + beta_f16 * c[i], result + i); } +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_NEON_F16 +#endif // _SIMSIMD_TARGET_ARM + #ifdef __cplusplus } #endif diff --git a/include/simsimd/probability.h b/include/simsimd/probability.h index 12f6407b..2865aa32 100644 --- a/include/simsimd/probability.h +++ b/include/simsimd/probability.h @@ -83,32 +83,32 @@ SIMSIMD_PUBLIC void simsimd_kl_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_ SIMSIMD_PUBLIC void simsimd_js_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* divergence); // clang-format on -#define SIMSIMD_MAKE_KL(name, input_type, accumulator_type, load_and_convert, epsilon) \ - SIMSIMD_PUBLIC void simsimd_kl_##input_type##_##name(simsimd_##input_type##_t const* a, \ - simsimd_##input_type##_t const* b, simsimd_size_t n, \ - simsimd_distance_t* result) { \ - simsimd_##accumulator_type##_t d = 0; \ - for (simsimd_size_t i = 0; i != n; ++i) { \ - simsimd_##accumulator_type##_t ai = load_and_convert(a + i); \ - simsimd_##accumulator_type##_t bi = load_and_convert(b + i); \ - d += ai * SIMSIMD_LOG((ai + epsilon) / (bi + epsilon)); \ - } \ - *result = (simsimd_distance_t)d; \ +#define SIMSIMD_MAKE_KL(name, input_type, accumulator_type, load_and_convert, epsilon) \ + SIMSIMD_PUBLIC void simsimd_kl_##input_type##_##name(simsimd_##input_type##_t const *a, \ + simsimd_##input_type##_t const *b, simsimd_size_t n, \ + simsimd_distance_t *result) { \ + simsimd_##accumulator_type##_t d = 0; \ + for (simsimd_size_t i = 0; i != n; ++i) { \ + simsimd_##accumulator_type##_t ai = load_and_convert(a + i); \ + simsimd_##accumulator_type##_t bi = load_and_convert(b + i); \ + d += ai * SIMSIMD_LOG((ai + epsilon) / (bi + epsilon)); \ + } \ + *result = (simsimd_distance_t)d; \ } -#define SIMSIMD_MAKE_JS(name, input_type, accumulator_type, load_and_convert, epsilon) \ - SIMSIMD_PUBLIC void simsimd_js_##input_type##_##name(simsimd_##input_type##_t const* a, \ - simsimd_##input_type##_t const* b, simsimd_size_t n, \ - simsimd_distance_t* result) { \ - simsimd_##accumulator_type##_t d = 0; \ - for (simsimd_size_t i = 0; i != n; ++i) { \ - simsimd_##accumulator_type##_t ai = load_and_convert(a + i); \ - simsimd_##accumulator_type##_t bi = load_and_convert(b + i); \ - simsimd_##accumulator_type##_t mi = (ai + bi) / 2; \ - d += ai * SIMSIMD_LOG((ai + epsilon) / (mi + epsilon)); \ - d += bi * SIMSIMD_LOG((bi + epsilon) / (mi + epsilon)); \ - } \ - *result = (simsimd_distance_t)d / 2; \ +#define SIMSIMD_MAKE_JS(name, input_type, accumulator_type, load_and_convert, epsilon) \ + SIMSIMD_PUBLIC void simsimd_js_##input_type##_##name(simsimd_##input_type##_t const *a, \ + simsimd_##input_type##_t const *b, simsimd_size_t n, \ + simsimd_distance_t *result) { \ + simsimd_##accumulator_type##_t d = 0; \ + for (simsimd_size_t i = 0; i != n; ++i) { \ + simsimd_##accumulator_type##_t ai = load_and_convert(a + i); \ + simsimd_##accumulator_type##_t bi = load_and_convert(b + i); \ + simsimd_##accumulator_type##_t mi = (ai + bi) / 2; \ + d += ai * SIMSIMD_LOG((ai + epsilon) / (mi + epsilon)); \ + d += bi * SIMSIMD_LOG((bi + epsilon) / (mi + epsilon)); \ + } \ + *result = (simsimd_distance_t)d / 2; \ } SIMSIMD_MAKE_KL(serial, f64, f64, SIMSIMD_DEREFERENCE, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_kl_f64_serial @@ -132,7 +132,7 @@ SIMSIMD_MAKE_JS(accurate, f16, f64, SIMSIMD_F16_TO_F32, SIMSIMD_F32_DIVISION_EPS SIMSIMD_MAKE_KL(accurate, bf16, f64, SIMSIMD_BF16_TO_F32, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_kl_bf16_accurate SIMSIMD_MAKE_JS(accurate, bf16, f64, SIMSIMD_BF16_TO_F32, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_js_bf16_accurate -#if SIMSIMD_TARGET_ARM +#if _SIMSIMD_TARGET_ARM #if SIMSIMD_TARGET_NEON #pragma GCC push_options #pragma GCC target("arch=armv8.2-a+simd") @@ -163,8 +163,8 @@ SIMSIMD_PUBLIC float32x4_t _simsimd_log2_f32_neon(float32x4_t x) { return result; } -SIMSIMD_PUBLIC void simsimd_kl_f32_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_kl_f32_neon(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_f32_t epsilon = SIMSIMD_F32_DIVISION_EPSILON; float32x4_t epsilon_vec = vdupq_n_f32(epsilon); float32x4_t sum_vec = vdupq_n_f32(0); @@ -175,7 +175,8 @@ SIMSIMD_PUBLIC void simsimd_kl_f32_neon(simsimd_f32_t const* a, simsimd_f32_t co a_vec = _simsimd_partial_load_f32x4_neon(a, n); b_vec = _simsimd_partial_load_f32x4_neon(b, n); n = 0; - } else { + } + else { a_vec = vld1q_f32(a); b_vec = vld1q_f32(b); n -= 4, a += 4, b += 4; @@ -185,16 +186,15 @@ SIMSIMD_PUBLIC void simsimd_kl_f32_neon(simsimd_f32_t const* a, simsimd_f32_t co float32x4_t log_ratio_vec = _simsimd_log2_f32_neon(ratio_vec); float32x4_t prod_vec = vmulq_f32(a_vec, log_ratio_vec); sum_vec = vaddq_f32(sum_vec, prod_vec); - if (n != 0) - goto simsimd_kl_f32_neon_cycle; + if (n != 0) goto simsimd_kl_f32_neon_cycle; simsimd_f32_t log2_normalizer = 0.693147181f; simsimd_f32_t sum = vaddvq_f32(sum_vec) * log2_normalizer; *result = sum; } -SIMSIMD_PUBLIC void simsimd_js_f32_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_js_f32_neon(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_f32_t epsilon = SIMSIMD_F32_DIVISION_EPSILON; float32x4_t epsilon_vec = vdupq_n_f32(epsilon); float32x4_t sum_vec = vdupq_n_f32(0); @@ -205,7 +205,8 @@ SIMSIMD_PUBLIC void simsimd_js_f32_neon(simsimd_f32_t const* a, simsimd_f32_t co a_vec = _simsimd_partial_load_f32x4_neon(a, n); b_vec = _simsimd_partial_load_f32x4_neon(b, n); n = 0; - } else { + } + else { a_vec = vld1q_f32(a); b_vec = vld1q_f32(b); n -= 4, a += 4, b += 4; @@ -219,8 +220,7 @@ SIMSIMD_PUBLIC void simsimd_js_f32_neon(simsimd_f32_t const* a, simsimd_f32_t co float32x4_t prod_a_vec = vmulq_f32(a_vec, log_ratio_a_vec); float32x4_t prod_b_vec = vmulq_f32(b_vec, log_ratio_b_vec); sum_vec = vaddq_f32(sum_vec, vaddq_f32(prod_a_vec, prod_b_vec)); - if (n != 0) - goto simsimd_js_f32_neon_cycle; + if (n != 0) goto simsimd_js_f32_neon_cycle; simsimd_f32_t log2_normalizer = 0.693147181f; simsimd_f32_t sum = vaddvq_f32(sum_vec) * log2_normalizer; @@ -236,8 +236,8 @@ SIMSIMD_PUBLIC void simsimd_js_f32_neon(simsimd_f32_t const* a, simsimd_f32_t co #pragma GCC target("arch=armv8.2-a+simd+fp16") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function) -SIMSIMD_PUBLIC void simsimd_kl_f16_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_kl_f16_neon(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { float32x4_t sum_vec = vdupq_n_f32(0); simsimd_f32_t epsilon = SIMSIMD_F32_DIVISION_EPSILON; float32x4_t epsilon_vec = vdupq_n_f32(epsilon); @@ -248,9 +248,10 @@ SIMSIMD_PUBLIC void simsimd_kl_f16_neon(simsimd_f16_t const* a, simsimd_f16_t co a_vec = vcvt_f32_f16(_simsimd_partial_load_f16x4_neon(a, n)); b_vec = vcvt_f32_f16(_simsimd_partial_load_f16x4_neon(b, n)); n = 0; - } else { - a_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const*)a)); - b_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const*)b)); + } + else { + a_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const *)a)); + b_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const *)b)); n -= 4, a += 4, b += 4; } @@ -258,16 +259,15 @@ SIMSIMD_PUBLIC void simsimd_kl_f16_neon(simsimd_f16_t const* a, simsimd_f16_t co float32x4_t log_ratio_vec = _simsimd_log2_f32_neon(ratio_vec); float32x4_t prod_vec = vmulq_f32(a_vec, log_ratio_vec); sum_vec = vaddq_f32(sum_vec, prod_vec); - if (n) - goto simsimd_kl_f16_neon_cycle; + if (n) goto simsimd_kl_f16_neon_cycle; simsimd_f32_t log2_normalizer = 0.693147181f; simsimd_f32_t sum = vaddvq_f32(sum_vec) * log2_normalizer; *result = sum; } -SIMSIMD_PUBLIC void simsimd_js_f16_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_js_f16_neon(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { float32x4_t sum_vec = vdupq_n_f32(0); simsimd_f32_t epsilon = SIMSIMD_F32_DIVISION_EPSILON; float32x4_t epsilon_vec = vdupq_n_f32(epsilon); @@ -278,9 +278,10 @@ SIMSIMD_PUBLIC void simsimd_js_f16_neon(simsimd_f16_t const* a, simsimd_f16_t co a_vec = vcvt_f32_f16(_simsimd_partial_load_f16x4_neon(a, n)); b_vec = vcvt_f32_f16(_simsimd_partial_load_f16x4_neon(b, n)); n = 0; - } else { - a_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const*)a)); - b_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const*)b)); + } + else { + a_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const *)a)); + b_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const *)b)); n -= 4, a += 4, b += 4; } @@ -292,8 +293,7 @@ SIMSIMD_PUBLIC void simsimd_js_f16_neon(simsimd_f16_t const* a, simsimd_f16_t co float32x4_t prod_a_vec = vmulq_f32(a_vec, log_ratio_a_vec); float32x4_t prod_b_vec = vmulq_f32(b_vec, log_ratio_b_vec); sum_vec = vaddq_f32(sum_vec, vaddq_f32(prod_a_vec, prod_b_vec)); - if (n) - goto simsimd_js_f16_neon_cycle; + if (n) goto simsimd_js_f16_neon_cycle; simsimd_f32_t log2_normalizer = 0.693147181f; simsimd_f32_t sum = vaddvq_f32(sum_vec) * log2_normalizer; @@ -303,9 +303,9 @@ SIMSIMD_PUBLIC void simsimd_js_f16_neon(simsimd_f16_t const* a, simsimd_f16_t co #pragma clang attribute pop #pragma GCC pop_options #endif // SIMSIMD_TARGET_NEON_F16 -#endif // SIMSIMD_TARGET_ARM +#endif // _SIMSIMD_TARGET_ARM -#if SIMSIMD_TARGET_X86 +#if _SIMSIMD_TARGET_X86 #if SIMSIMD_TARGET_HASWELL #pragma GCC push_options #pragma GCC target("avx2", "f16c", "fma") @@ -338,8 +338,8 @@ SIMSIMD_INTERNAL __m256 _simsimd_log2_f32_haswell(__m256 x) { return result; } -SIMSIMD_PUBLIC void simsimd_kl_f16_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_kl_f16_haswell(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m256 sum_vec = _mm256_setzero_ps(); simsimd_f32_t epsilon = SIMSIMD_F32_DIVISION_EPSILON; __m256 epsilon_vec = _mm256_set1_ps(epsilon); @@ -350,9 +350,10 @@ SIMSIMD_PUBLIC void simsimd_kl_f16_haswell(simsimd_f16_t const* a, simsimd_f16_t a_vec = _simsimd_partial_load_f16x8_haswell(a, n); b_vec = _simsimd_partial_load_f16x8_haswell(b, n); n = 0; - } else { - a_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const*)a)); - b_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const*)b)); + } + else { + a_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)a)); + b_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)b)); n -= 8, a += 8, b += 8; } a_vec = _mm256_add_ps(a_vec, epsilon_vec); @@ -361,8 +362,7 @@ SIMSIMD_PUBLIC void simsimd_kl_f16_haswell(simsimd_f16_t const* a, simsimd_f16_t __m256 log_ratio_vec = _simsimd_log2_f32_haswell(ratio_vec); __m256 prod_vec = _mm256_mul_ps(a_vec, log_ratio_vec); sum_vec = _mm256_add_ps(sum_vec, prod_vec); - if (n) - goto simsimd_kl_f16_haswell_cycle; + if (n) goto simsimd_kl_f16_haswell_cycle; simsimd_f32_t log2_normalizer = 0.693147181f; simsimd_f32_t sum = _simsimd_reduce_f32x8_haswell(sum_vec); @@ -370,8 +370,8 @@ SIMSIMD_PUBLIC void simsimd_kl_f16_haswell(simsimd_f16_t const* a, simsimd_f16_t *result = sum; } -SIMSIMD_PUBLIC void simsimd_js_f16_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_js_f16_haswell(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_f32_t epsilon = SIMSIMD_F32_DIVISION_EPSILON; __m256 epsilon_vec = _mm256_set1_ps(epsilon); __m256 sum_vec = _mm256_setzero_ps(); @@ -382,9 +382,10 @@ SIMSIMD_PUBLIC void simsimd_js_f16_haswell(simsimd_f16_t const* a, simsimd_f16_t a_vec = _simsimd_partial_load_f16x8_haswell(a, n); b_vec = _simsimd_partial_load_f16x8_haswell(b, n); n = 0; - } else { - a_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const*)a)); - b_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const*)b)); + } + else { + a_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)a)); + b_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)b)); n -= 8, a += 8, b += 8; } a_vec = _mm256_add_ps(a_vec, epsilon_vec); @@ -398,8 +399,7 @@ SIMSIMD_PUBLIC void simsimd_js_f16_haswell(simsimd_f16_t const* a, simsimd_f16_t __m256 prod_b_vec = _mm256_mul_ps(b_vec, log_ratio_b_vec); sum_vec = _mm256_add_ps(sum_vec, prod_a_vec); sum_vec = _mm256_add_ps(sum_vec, prod_b_vec); - if (n) - goto simsimd_js_f16_haswell_cycle; + if (n) goto simsimd_js_f16_haswell_cycle; simsimd_f32_t log2_normalizer = 0.693147181f; simsimd_f32_t sum = _simsimd_reduce_f32x8_haswell(sum_vec); @@ -433,8 +433,8 @@ SIMSIMD_INTERNAL __m512 _simsimd_log2_f32_skylake(__m512 x) { return _mm512_add_ps(_mm512_mul_ps(p, _mm512_sub_ps(m, one)), e); } -SIMSIMD_PUBLIC void simsimd_kl_f32_skylake(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_kl_f32_skylake(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m512 sum_vec = _mm512_setzero(); simsimd_f32_t epsilon = SIMSIMD_F32_DIVISION_EPSILON; __m512 epsilon_vec = _mm512_set1_ps(epsilon); @@ -446,7 +446,8 @@ SIMSIMD_PUBLIC void simsimd_kl_f32_skylake(simsimd_f32_t const* a, simsimd_f32_t a_vec = _mm512_add_ps(_mm512_maskz_loadu_ps(mask, a), epsilon_vec); b_vec = _mm512_add_ps(_mm512_maskz_loadu_ps(mask, b), epsilon_vec); n = 0; - } else { + } + else { a_vec = _mm512_add_ps(_mm512_loadu_ps(a), epsilon_vec); b_vec = _mm512_add_ps(_mm512_loadu_ps(b), epsilon_vec); a += 16, b += 16, n -= 16; @@ -455,15 +456,14 @@ SIMSIMD_PUBLIC void simsimd_kl_f32_skylake(simsimd_f32_t const* a, simsimd_f32_t __m512 log_ratio_vec = _simsimd_log2_f32_skylake(ratio_vec); __m512 prod_vec = _mm512_mul_ps(a_vec, log_ratio_vec); sum_vec = _mm512_add_ps(sum_vec, prod_vec); - if (n) - goto simsimd_kl_f32_skylake_cycle; + if (n) goto simsimd_kl_f32_skylake_cycle; simsimd_f32_t log2_normalizer = 0.693147181f; *result = _mm512_reduce_add_ps(sum_vec) * log2_normalizer; } -SIMSIMD_PUBLIC void simsimd_js_f32_skylake(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_js_f32_skylake(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m512 sum_a_vec = _mm512_setzero(); __m512 sum_b_vec = _mm512_setzero(); simsimd_f32_t epsilon = SIMSIMD_F32_DIVISION_EPSILON; @@ -476,7 +476,8 @@ SIMSIMD_PUBLIC void simsimd_js_f32_skylake(simsimd_f32_t const* a, simsimd_f32_t a_vec = _mm512_maskz_loadu_ps(mask, a); b_vec = _mm512_maskz_loadu_ps(mask, b); n = 0; - } else { + } + else { a_vec = _mm512_loadu_ps(a); b_vec = _mm512_loadu_ps(b); a += 16, b += 16, n -= 16; @@ -492,8 +493,7 @@ SIMSIMD_PUBLIC void simsimd_js_f32_skylake(simsimd_f32_t const* a, simsimd_f32_t __m512 log_ratio_b_vec = _simsimd_log2_f32_skylake(ratio_b_vec); sum_a_vec = _mm512_maskz_fmadd_ps(nonzero_mask, a_vec, log_ratio_a_vec, sum_a_vec); sum_b_vec = _mm512_maskz_fmadd_ps(nonzero_mask, b_vec, log_ratio_b_vec, sum_b_vec); - if (n) - goto simsimd_js_f32_skylake_cycle; + if (n) goto simsimd_js_f32_skylake_cycle; simsimd_f32_t log2_normalizer = 0.693147181f; *result = _mm512_reduce_add_ps(_mm512_add_ps(sum_a_vec, sum_b_vec)) * log2_normalizer / 2; @@ -525,8 +525,8 @@ SIMSIMD_INTERNAL __m512h _simsimd_log2_f16_sapphire(__m512h x) { return _mm512_add_ph(_mm512_mul_ph(p, _mm512_sub_ph(m, one)), e); } -SIMSIMD_PUBLIC void simsimd_kl_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_kl_f16_sapphire(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m512h sum_vec = _mm512_setzero_ph(); __m512h epsilon_vec = _mm512_set1_ph((simsimd_f16_t)SIMSIMD_F16_DIVISION_EPSILON); __m512h a_vec, b_vec; @@ -537,7 +537,8 @@ SIMSIMD_PUBLIC void simsimd_kl_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_ a_vec = _mm512_maskz_add_ph(mask, _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a)), epsilon_vec); b_vec = _mm512_maskz_add_ph(mask, _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b)), epsilon_vec); n = 0; - } else { + } + else { a_vec = _mm512_add_ph(_mm512_castsi512_ph(_mm512_loadu_epi16(a)), epsilon_vec); b_vec = _mm512_add_ph(_mm512_castsi512_ph(_mm512_loadu_epi16(b)), epsilon_vec); a += 32, b += 32, n -= 32; @@ -546,15 +547,14 @@ SIMSIMD_PUBLIC void simsimd_kl_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_ __m512h log_ratio_vec = _simsimd_log2_f16_sapphire(ratio_vec); __m512h prod_vec = _mm512_mul_ph(a_vec, log_ratio_vec); sum_vec = _mm512_add_ph(sum_vec, prod_vec); - if (n) - goto simsimd_kl_f16_sapphire_cycle; + if (n) goto simsimd_kl_f16_sapphire_cycle; simsimd_f32_t log2_normalizer = 0.693147181f; *result = _mm512_reduce_add_ph(sum_vec) * log2_normalizer; } -SIMSIMD_PUBLIC void simsimd_js_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_js_f16_sapphire(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m512h sum_a_vec = _mm512_setzero_ph(); __m512h sum_b_vec = _mm512_setzero_ph(); __m512h epsilon_vec = _mm512_set1_ph((simsimd_f16_t)SIMSIMD_F16_DIVISION_EPSILON); @@ -566,7 +566,8 @@ SIMSIMD_PUBLIC void simsimd_js_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_ a_vec = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a)); b_vec = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b)); n = 0; - } else { + } + else { a_vec = _mm512_castsi512_ph(_mm512_loadu_epi16(a)); b_vec = _mm512_castsi512_ph(_mm512_loadu_epi16(b)); a += 32, b += 32, n -= 32; @@ -582,8 +583,7 @@ SIMSIMD_PUBLIC void simsimd_js_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_ __m512h log_ratio_b_vec = _simsimd_log2_f16_sapphire(ratio_b_vec); sum_a_vec = _mm512_maskz_fmadd_ph(nonzero_mask, a_vec, log_ratio_a_vec, sum_a_vec); sum_b_vec = _mm512_maskz_fmadd_ph(nonzero_mask, b_vec, log_ratio_b_vec, sum_b_vec); - if (n) - goto simsimd_js_f16_sapphire_cycle; + if (n) goto simsimd_js_f16_sapphire_cycle; simsimd_f32_t log2_normalizer = 0.693147181f; *result = _mm512_reduce_add_ph(_mm512_add_ph(sum_a_vec, sum_b_vec)) * log2_normalizer / 2; @@ -592,7 +592,7 @@ SIMSIMD_PUBLIC void simsimd_js_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_ #pragma clang attribute pop #pragma GCC pop_options #endif // SIMSIMD_TARGET_SAPPHIRE -#endif // SIMSIMD_TARGET_X86 +#endif // _SIMSIMD_TARGET_X86 #ifdef __cplusplus } diff --git a/include/simsimd/simsimd.h b/include/simsimd/simsimd.h index 0cfab1a8..e7bee6fa 100644 --- a/include/simsimd/simsimd.h +++ b/include/simsimd/simsimd.h @@ -105,13 +105,14 @@ #include "binary.h" // Hamming, Jaccard #include "curved.h" // Mahalanobis, Bilinear Forms #include "dot.h" // Inner (dot) product, and its conjugate +#include "fma.h" // Weighted Sum, Fused Multiply-Add #include "geospatial.h" // Haversine and Vincenty #include "probability.h" // Kullback-Leibler, Jensen–Shannon #include "sparse.h" // Intersect #include "spatial.h" // L2, Cosine // On Apple Silicon, `mrs` is not allowed in user-space, so we need to use the `sysctl` API. -#if defined(SIMSIMD_DEFINED_APPLE) +#if defined(_SIMSIMD_DEFINED_APPLE) #include #endif @@ -164,6 +165,10 @@ typedef enum { simsimd_metric_js_k = 's', ///< Jensen-Shannon divergence simsimd_metric_jensen_shannon_k = 's', ///< Jensen-Shannon divergence alias + // BLAS-like operations: + simsimd_metric_fma_k = 'f', ///< Fused Multiply-Add + simsimd_metric_wsum_k = 'w', ///< Weighted Sum + } simsimd_metric_kind_t; /** @@ -239,7 +244,7 @@ typedef enum { * @param[out] d Output value as a double-precision float. * In complex dot-products @b two scalars are exported for the real and imaginary parts. */ -typedef void (*simsimd_metric_dense_punned_t)(void const* a, void const* b, simsimd_size_t n, simsimd_distance_t* d); +typedef void (*simsimd_metric_dense_punned_t)(void const *a, void const *b, simsimd_size_t n, simsimd_distance_t *d); /** * @brief Type-punned function pointer for sparse vector representations and similarity measures. @@ -250,9 +255,9 @@ typedef void (*simsimd_metric_dense_punned_t)(void const* a, void const* b, sims * @param[in] b_length Number of scalar words in the second input array. * @param[out] d Output value as a double-precision float, generally without decimals. */ -typedef void (*simsimd_metric_sparse_punned_t)(void const* a, void const* b, // +typedef void (*simsimd_metric_sparse_punned_t)(void const *a, void const *b, // simsimd_size_t a_length, simsimd_size_t b_length, // - simsimd_distance_t* d); + simsimd_distance_t *d); /** * @brief Type-punned function pointer for curved vector spaces and similarity measures. @@ -263,8 +268,39 @@ typedef void (*simsimd_metric_sparse_punned_t)(void const* a, void const* b, * @param[in] n Number of scalar words in the input arrays. * @param[out] d Output value as a double-precision float. */ -typedef void (*simsimd_metric_curved_punned_t)(void const* a, void const* b, void const* c, // - simsimd_size_t n, simsimd_distance_t* d); +typedef void (*simsimd_metric_curved_punned_t)(void const *a, void const *b, void const *c, // + simsimd_size_t n, simsimd_distance_t *d); + +/** + * @brief Type-punned function pointer for FMA operations on dense vector representations. + * Implements the `y = alpha * a * b + beta * c` operation. + * + * @param[in] a Pointer to the first data array. + * @param[in] b Pointer to the second data array. + * @param[in] c Pointer to the third data array. + * @param[in] n Number of scalar words in the input arrays. + * @param[in] alpha Scaling factor for the first two arrays. + * @param[in] beta Scaling factor for the third array. + * @param[out] y Output value in the same precision as the input arrays. + */ +typedef void (*simsimd_kernel_fma_punned_t)(void const *a, void const *b, void const *c, // + simsimd_size_t n, simsimd_distance_t alpha, simsimd_distance_t beta, + void *y); + +/** + * @brief Type-punned function pointer for Weighted Sum operations on dense vector representations. + * Implements the `y = alpha * a + beta * b` operation. + * + * @param[in] a Pointer to the first data array. + * @param[in] b Pointer to the second data array. + * @param[in] n Number of scalar words in the input arrays. + * @param[in] alpha Scaling factor for the first array. + * @param[in] beta Scaling factor for the second array. + * @param[out] y Output value in the same precision as the input arrays. + */ +typedef void (*simsimd_kernel_wsum_punned_t)(void const *a, void const *b, // + simsimd_size_t n, simsimd_distance_t alpha, simsimd_distance_t beta, + void *y); /** * @brief Type-punned function pointer for a SimSIMD public interface. @@ -280,8 +316,8 @@ SIMSIMD_DYNAMIC void simsimd_find_metric_punned( // simsimd_datatype_t datatype, // simsimd_capability_t supported, // simsimd_capability_t allowed, // - simsimd_metric_punned_t* metric_output, // - simsimd_capability_t* capability_output); + simsimd_metric_punned_t *metric_output, // + simsimd_capability_t *capability_output); #else SIMSIMD_PUBLIC simsimd_capability_t simsimd_capabilities(void); SIMSIMD_PUBLIC void simsimd_find_metric_punned( // @@ -289,11 +325,11 @@ SIMSIMD_PUBLIC void simsimd_find_metric_punned( // simsimd_datatype_t datatype, // simsimd_capability_t supported, // simsimd_capability_t allowed, // - simsimd_metric_punned_t* metric_output, // - simsimd_capability_t* capability_output); + simsimd_metric_punned_t *metric_output, // + simsimd_capability_t *capability_output); #endif -#if SIMSIMD_TARGET_X86 +#if _SIMSIMD_TARGET_X86 /** * @brief Function to determine the SIMD capabilities of the current 64-bit x86 machine at @b runtime. @@ -378,9 +414,9 @@ SIMSIMD_PUBLIC simsimd_capability_t _simsimd_capabilities_x86(void) { (simsimd_cap_serial_k)); } -#endif // SIMSIMD_TARGET_X86 +#endif // _SIMSIMD_TARGET_X86 -#if SIMSIMD_TARGET_ARM +#if _SIMSIMD_TARGET_ARM /* Compiling the next section one may get: selected processor does not support system register name 'id_aa64zfr0_el1'. * Suppressing assembler errors is very complicated, so when dealing with older ARM CPUs it's simpler to compile this @@ -395,18 +431,14 @@ SIMSIMD_PUBLIC simsimd_capability_t _simsimd_capabilities_x86(void) { * @return A bitmask of the SIMD capabilities represented as a `simsimd_capability_t` enum value. */ SIMSIMD_PUBLIC simsimd_capability_t _simsimd_capabilities_arm(void) { -#if defined(SIMSIMD_DEFINED_APPLE) +#if defined(_SIMSIMD_DEFINED_APPLE) // On Apple Silicon, `mrs` is not allowed in user-space, so we need to use the `sysctl` API. uint32_t supports_neon = 0, supports_fp16 = 0, supports_bf16 = 0, supports_i8mm = 0; size_t size = sizeof(supports_neon); - if (sysctlbyname("hw.optional.neon", &supports_neon, &size, NULL, 0) != 0) - supports_neon = 0; - if (sysctlbyname("hw.optional.arm.FEAT_FP16", &supports_fp16, &size, NULL, 0) != 0) - supports_fp16 = 0; - if (sysctlbyname("hw.optional.arm.FEAT_BF16", &supports_bf16, &size, NULL, 0) != 0) - supports_bf16 = 0; - if (sysctlbyname("hw.optional.arm.FEAT_I8MM", &supports_i8mm, &size, NULL, 0) != 0) - supports_i8mm = 0; + if (sysctlbyname("hw.optional.neon", &supports_neon, &size, NULL, 0) != 0) supports_neon = 0; + if (sysctlbyname("hw.optional.arm.FEAT_FP16", &supports_fp16, &size, NULL, 0) != 0) supports_fp16 = 0; + if (sysctlbyname("hw.optional.arm.FEAT_BF16", &supports_bf16, &size, NULL, 0) != 0) supports_bf16 = 0; + if (sysctlbyname("hw.optional.arm.FEAT_I8MM", &supports_i8mm, &size, NULL, 0) != 0) supports_i8mm = 0; return (simsimd_capability_t)( // (simsimd_cap_neon_k * (supports_neon)) | // @@ -415,7 +447,7 @@ SIMSIMD_PUBLIC simsimd_capability_t _simsimd_capabilities_arm(void) { (simsimd_cap_neon_i8_k * (supports_neon && supports_i8mm)) | // (simsimd_cap_serial_k)); -#elif defined(SIMSIMD_DEFINED_LINUX) +#elif defined(_SIMSIMD_DEFINED_LINUX) // Read CPUID registers directly unsigned long id_aa64isar0_el1 = 0, id_aa64isar1_el1 = 0, id_aa64pfr0_el1 = 0, id_aa64zfr0_el1 = 0; @@ -445,8 +477,7 @@ SIMSIMD_PUBLIC simsimd_capability_t _simsimd_capabilities_arm(void) { // Now let's unpack the status flags from ID_AA64ZFR0_EL1 // https://developer.arm.com/documentation/ddi0601/2024-03/AArch64-Registers/ID-AA64ZFR0-EL1--SVE-Feature-ID-Register-0?lang=en - if (supports_sve) - __asm__ __volatile__("mrs %0, ID_AA64ZFR0_EL1" : "=r"(id_aa64zfr0_el1)); + if (supports_sve) __asm__ __volatile__("mrs %0, ID_AA64ZFR0_EL1" : "=r"(id_aa64zfr0_el1)); // I8MM, bits [47:44] of ID_AA64ZFR0_EL1 unsigned supports_sve_i8mm = ((id_aa64zfr0_el1 >> 44) & 0xF) >= 1; // BF16, bits [23:20] of ID_AA64ZFR0_EL1 @@ -472,7 +503,7 @@ SIMSIMD_PUBLIC simsimd_capability_t _simsimd_capabilities_arm(void) { (simsimd_cap_sve2_k * (supports_sve2)) | // (simsimd_cap_sve2p1_k * (supports_sve2p1)) | // (simsimd_cap_serial_k)); -#else // SIMSIMD_DEFINED_LINUX +#else // if !_SIMSIMD_DEFINED_LINUX return simsimd_cap_serial_k; #endif } @@ -487,12 +518,12 @@ SIMSIMD_PUBLIC simsimd_capability_t _simsimd_capabilities_arm(void) { * @return A bitmask of the SIMD capabilities represented as a `simsimd_capability_t` enum value. */ SIMSIMD_PUBLIC simsimd_capability_t _simsimd_capabilities_implementation(void) { -#if SIMSIMD_TARGET_X86 +#if _SIMSIMD_TARGET_X86 return _simsimd_capabilities_x86(); -#endif // SIMSIMD_TARGET_X86 -#if SIMSIMD_TARGET_ARM +#endif // _SIMSIMD_TARGET_X86 +#if _SIMSIMD_TARGET_ARM return _simsimd_capabilities_arm(); -#endif // SIMSIMD_TARGET_ARM +#endif // _SIMSIMD_TARGET_ARM return simsimd_cap_serial_k; } @@ -506,693 +537,667 @@ SIMSIMD_PUBLIC simsimd_capability_t _simsimd_capabilities_implementation(void) { #pragma clang diagnostic ignored "-Wvolatile" #endif -/** - * @brief Determines the best suited metric implementation based on the given datatype, - * supported and allowed by hardware capabilities. - * - * @param kind The kind of metric to be evaluated. - * @param datatype The data type for which the metric needs to be evaluated. - * @param supported The hardware capabilities supported by the CPU. - * @param allowed The hardware capabilities allowed for use. - * @param metric_output Output variable for the selected similarity function. - * @param capability_output Output variable for the utilized hardware capabilities. - */ -SIMSIMD_INTERNAL void _simsimd_find_metric_punned_implementation( // - simsimd_metric_kind_t kind, // - simsimd_datatype_t datatype, // - simsimd_capability_t supported, // - simsimd_capability_t allowed, // - simsimd_metric_punned_t* metric_output, // - simsimd_capability_t* capability_output) { - - // Modern compilers abso-freaking-lutely love optimizing-out my logic! - // Just marking the variables as `volatile` is not enough, so we have - // to add inline assembly to further discourage them! -#if defined(_MSC_VER) - _ReadWriteBarrier(); -#else - __asm__ __volatile__("" ::: "memory"); -#endif - - volatile simsimd_metric_punned_t* m = metric_output; - volatile simsimd_capability_t* c = capability_output; - volatile simsimd_capability_t viable = (simsimd_capability_t)(supported & allowed); - *m = (simsimd_metric_punned_t)0; - *c = (simsimd_capability_t)0; - +SIMSIMD_INTERNAL void _simsimd_find_metric_punned_f64(simsimd_capability_t v, simsimd_metric_kind_t k, + simsimd_metric_punned_t *m, simsimd_capability_t *c) { typedef simsimd_metric_punned_t m_t; - switch (datatype) { - - case simsimd_datatype_unknown_k: break; - - // These data-types are not supported yet - case simsimd_datatype_i4x2_k: break; - case simsimd_datatype_i16_k: break; - case simsimd_datatype_i32_k: break; - case simsimd_datatype_i64_k: break; - case simsimd_datatype_u64_k: break; - - // Double-precision floating-point vectors - case simsimd_datatype_f64_k: { #if SIMSIMD_TARGET_SVE - if (viable & simsimd_cap_sve_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f64_sve, *c = simsimd_cap_sve_k; return; - case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f64_sve, *c = simsimd_cap_sve_k; return; - case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f64_sve, *c = simsimd_cap_sve_k; return; - case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_f64_sve, *c = simsimd_cap_sve_k; return; - default: break; - } + if (v & simsimd_cap_sve_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f64_sve, *c = simsimd_cap_sve_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f64_sve, *c = simsimd_cap_sve_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f64_sve, *c = simsimd_cap_sve_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_f64_sve, *c = simsimd_cap_sve_k; return; + default: break; + } #endif #if SIMSIMD_TARGET_NEON - if (viable & simsimd_cap_neon_k) - switch (kind) { - case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f64_neon, *c = simsimd_cap_neon_k; return; - case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f64_neon, *c = simsimd_cap_neon_k; return; - case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_f64_neon, *c = simsimd_cap_neon_k; return; - default: break; - } + if (v & simsimd_cap_neon_k) switch (k) { + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f64_neon, *c = simsimd_cap_neon_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f64_neon, *c = simsimd_cap_neon_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_f64_neon, *c = simsimd_cap_neon_k; return; + default: break; + } #endif #if SIMSIMD_TARGET_SKYLAKE - if (viable & simsimd_cap_skylake_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f64_skylake, *c = simsimd_cap_skylake_k; return; - case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f64_skylake, *c = simsimd_cap_skylake_k; return; - case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f64_skylake, *c = simsimd_cap_skylake_k; return; - case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_f64_skylake, *c = simsimd_cap_skylake_k; return; - default: break; - } -#endif - if (viable & simsimd_cap_serial_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f64_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f64_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f64_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_f64_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_js_k: *m = (m_t)&simsimd_js_f64_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_f64_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_bilinear_k: *m = (m_t)&simsimd_bilinear_f64_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_mahalanobis_k: - *m = (m_t)&simsimd_mahalanobis_f64_serial, *c = simsimd_cap_serial_k; - return; - default: break; - } - - break; - } - - // Single-precision floating-point vectors - case simsimd_datatype_f32_k: { + if (v & simsimd_cap_skylake_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f64_skylake, *c = simsimd_cap_skylake_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f64_skylake, *c = simsimd_cap_skylake_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f64_skylake, *c = simsimd_cap_skylake_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_f64_skylake, *c = simsimd_cap_skylake_k; return; + case simsimd_metric_fma_k: *m = (m_t)&simsimd_fma_f64_skylake, *c = simsimd_cap_skylake_k; return; + case simsimd_metric_wsum_k: *m = (m_t)&simsimd_wsum_f64_skylake, *c = simsimd_cap_skylake_k; return; + default: break; + } +#endif +#if SIMSIMD_TARGET_HASWELL + if (v & simsimd_cap_haswell_k) switch (k) { + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f64_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f64_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_f64_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_fma_k: *m = (m_t)&simsimd_fma_f64_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_wsum_k: *m = (m_t)&simsimd_wsum_f64_haswell, *c = simsimd_cap_haswell_k; return; + default: break; + } +#endif + if (v & simsimd_cap_serial_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f64_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f64_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f64_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_f64_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_js_k: *m = (m_t)&simsimd_js_f64_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_f64_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_bilinear_k: *m = (m_t)&simsimd_bilinear_f64_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_mahalanobis_k: *m = (m_t)&simsimd_mahalanobis_f64_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_fma_k: *m = (m_t)&simsimd_fma_f64_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_wsum_k: *m = (m_t)&simsimd_wsum_f64_serial, *c = simsimd_cap_serial_k; return; + default: break; + } +} +SIMSIMD_INTERNAL void _simsimd_find_metric_punned_f32(simsimd_capability_t v, simsimd_metric_kind_t k, + simsimd_metric_punned_t *m, simsimd_capability_t *c) { + typedef simsimd_metric_punned_t m_t; #if SIMSIMD_TARGET_SVE - if (viable & simsimd_cap_sve_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32_sve, *c = simsimd_cap_sve_k; return; - case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f32_sve, *c = simsimd_cap_sve_k; return; - case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f32_sve, *c = simsimd_cap_sve_k; return; - case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_f32_sve, *c = simsimd_cap_sve_k; return; - default: break; - } + if (v & simsimd_cap_sve_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32_sve, *c = simsimd_cap_sve_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f32_sve, *c = simsimd_cap_sve_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f32_sve, *c = simsimd_cap_sve_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_f32_sve, *c = simsimd_cap_sve_k; return; + default: break; + } #endif #if SIMSIMD_TARGET_NEON - if (viable & simsimd_cap_neon_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32_neon, *c = simsimd_cap_neon_k; return; - case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f32_neon, *c = simsimd_cap_neon_k; return; - case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f32_neon, *c = simsimd_cap_neon_k; return; - case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_f32_neon, *c = simsimd_cap_neon_k; return; - case simsimd_metric_js_k: *m = (m_t)&simsimd_js_f32_neon, *c = simsimd_cap_neon_k; return; - case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_f32_neon, *c = simsimd_cap_neon_k; return; - default: break; - } + if (v & simsimd_cap_neon_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32_neon, *c = simsimd_cap_neon_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f32_neon, *c = simsimd_cap_neon_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f32_neon, *c = simsimd_cap_neon_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_f32_neon, *c = simsimd_cap_neon_k; return; + case simsimd_metric_js_k: *m = (m_t)&simsimd_js_f32_neon, *c = simsimd_cap_neon_k; return; + case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_f32_neon, *c = simsimd_cap_neon_k; return; + default: break; + } #endif #if SIMSIMD_TARGET_SKYLAKE - if (viable & simsimd_cap_skylake_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32_skylake, *c = simsimd_cap_skylake_k; return; - case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f32_skylake, *c = simsimd_cap_skylake_k; return; - case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f32_skylake, *c = simsimd_cap_skylake_k; return; - case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_f32_skylake, *c = simsimd_cap_skylake_k; return; - case simsimd_metric_js_k: *m = (m_t)&simsimd_js_f32_skylake, *c = simsimd_cap_skylake_k; return; - case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_f32_skylake, *c = simsimd_cap_skylake_k; return; - case simsimd_metric_bilinear_k: *m = (m_t)&simsimd_bilinear_f32_skylake, *c = simsimd_cap_skylake_k; return; - case simsimd_metric_mahalanobis_k: - *m = (m_t)&simsimd_mahalanobis_f32_skylake, *c = simsimd_cap_skylake_k; - return; - default: break; - } + if (v & simsimd_cap_skylake_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32_skylake, *c = simsimd_cap_skylake_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f32_skylake, *c = simsimd_cap_skylake_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f32_skylake, *c = simsimd_cap_skylake_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_f32_skylake, *c = simsimd_cap_skylake_k; return; + case simsimd_metric_js_k: *m = (m_t)&simsimd_js_f32_skylake, *c = simsimd_cap_skylake_k; return; + case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_f32_skylake, *c = simsimd_cap_skylake_k; return; + case simsimd_metric_bilinear_k: *m = (m_t)&simsimd_bilinear_f32_skylake, *c = simsimd_cap_skylake_k; return; + case simsimd_metric_mahalanobis_k: + *m = (m_t)&simsimd_mahalanobis_f32_skylake, *c = simsimd_cap_skylake_k; + return; + case simsimd_metric_fma_k: *m = (m_t)&simsimd_fma_f32_skylake, *c = simsimd_cap_skylake_k; return; + case simsimd_metric_wsum_k: *m = (m_t)&simsimd_wsum_f32_skylake, *c = simsimd_cap_skylake_k; return; + default: break; + } #endif #if SIMSIMD_TARGET_HASWELL - if (viable & simsimd_cap_haswell_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32_haswell, *c = simsimd_cap_haswell_k; return; - case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f32_haswell, *c = simsimd_cap_haswell_k; return; - case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f32_haswell, *c = simsimd_cap_haswell_k; return; - case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_f32_haswell, *c = simsimd_cap_haswell_k; return; - default: break; - } -#endif - - if (viable & simsimd_cap_serial_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f32_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f32_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_f32_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_js_k: *m = (m_t)&simsimd_js_f32_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_f32_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_bilinear_k: *m = (m_t)&simsimd_bilinear_f32_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_mahalanobis_k: - *m = (m_t)&simsimd_mahalanobis_f32_serial, *c = simsimd_cap_serial_k; - return; - default: break; - } - - break; - } - // Half-precision floating-point vectors - case simsimd_datatype_f16_k: { + if (v & simsimd_cap_haswell_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f32_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f32_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_f32_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_fma_k: *m = (m_t)&simsimd_fma_f32_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_wsum_k: *m = (m_t)&simsimd_wsum_f32_haswell, *c = simsimd_cap_haswell_k; return; + default: break; + } +#endif + if (v & simsimd_cap_serial_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f32_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f32_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_f32_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_js_k: *m = (m_t)&simsimd_js_f32_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_f32_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_bilinear_k: *m = (m_t)&simsimd_bilinear_f32_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_mahalanobis_k: *m = (m_t)&simsimd_mahalanobis_f32_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_fma_k: *m = (m_t)&simsimd_fma_f32_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_wsum_k: *m = (m_t)&simsimd_wsum_f32_serial, *c = simsimd_cap_serial_k; return; + default: break; + } +} +SIMSIMD_INTERNAL void _simsimd_find_metric_punned_f16(simsimd_capability_t v, simsimd_metric_kind_t k, + simsimd_metric_punned_t *m, simsimd_capability_t *c) { + typedef simsimd_metric_punned_t m_t; #if SIMSIMD_TARGET_SVE_F16 - if (viable & simsimd_cap_sve_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16_sve, *c = simsimd_cap_sve_f16_k; return; - case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f16_sve, *c = simsimd_cap_sve_f16_k; return; - case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f16_sve, *c = simsimd_cap_sve_f16_k; return; - case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_f16_sve, *c = simsimd_cap_sve_f16_k; return; - default: break; - } + if (v & simsimd_cap_sve_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16_sve, *c = simsimd_cap_sve_f16_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f16_sve, *c = simsimd_cap_sve_f16_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f16_sve, *c = simsimd_cap_sve_f16_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_f16_sve, *c = simsimd_cap_sve_f16_k; return; + default: break; + } #endif #if SIMSIMD_TARGET_NEON_F16 - if (viable & simsimd_cap_neon_f16_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16_neon, *c = simsimd_cap_neon_f16_k; return; - case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f16_neon, *c = simsimd_cap_neon_f16_k; return; - case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f16_neon, *c = simsimd_cap_neon_f16_k; return; - case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_f16_neon, *c = simsimd_cap_neon_f16_k; return; - case simsimd_metric_js_k: *m = (m_t)&simsimd_js_f16_neon, *c = simsimd_cap_neon_f16_k; return; - case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_f16_neon, *c = simsimd_cap_neon_f16_k; return; - case simsimd_metric_bilinear_k: *m = (m_t)&simsimd_bilinear_f16_neon, *c = simsimd_cap_neon_f16_k; return; - case simsimd_metric_mahalanobis_k: - *m = (m_t)&simsimd_mahalanobis_f16_neon, *c = simsimd_cap_neon_f16_k; - return; - default: break; - } + if (v & simsimd_cap_neon_f16_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16_neon, *c = simsimd_cap_neon_f16_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f16_neon, *c = simsimd_cap_neon_f16_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f16_neon, *c = simsimd_cap_neon_f16_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_f16_neon, *c = simsimd_cap_neon_f16_k; return; + case simsimd_metric_js_k: *m = (m_t)&simsimd_js_f16_neon, *c = simsimd_cap_neon_f16_k; return; + case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_f16_neon, *c = simsimd_cap_neon_f16_k; return; + case simsimd_metric_bilinear_k: *m = (m_t)&simsimd_bilinear_f16_neon, *c = simsimd_cap_neon_f16_k; return; + case simsimd_metric_mahalanobis_k: *m = (m_t)&simsimd_mahalanobis_f16_neon, *c = simsimd_cap_neon_f16_k; return; + default: break; + } #endif #if SIMSIMD_TARGET_SAPPHIRE - if (viable & simsimd_cap_sapphire_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16_sapphire, *c = simsimd_cap_sapphire_k; return; - case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f16_sapphire, *c = simsimd_cap_sapphire_k; return; - case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f16_sapphire, *c = simsimd_cap_sapphire_k; return; - case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_f16_sapphire, *c = simsimd_cap_sapphire_k; return; - case simsimd_metric_js_k: *m = (m_t)&simsimd_js_f16_sapphire, *c = simsimd_cap_sapphire_k; return; - case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_f16_sapphire, *c = simsimd_cap_sapphire_k; return; - case simsimd_metric_bilinear_k: - *m = (m_t)&simsimd_bilinear_f16_sapphire, *c = simsimd_cap_sapphire_k; - return; - case simsimd_metric_mahalanobis_k: - *m = (m_t)&simsimd_mahalanobis_f16_sapphire, *c = simsimd_cap_sapphire_k; - return; - default: break; - } + if (v & simsimd_cap_sapphire_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16_sapphire, *c = simsimd_cap_sapphire_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f16_sapphire, *c = simsimd_cap_sapphire_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f16_sapphire, *c = simsimd_cap_sapphire_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_f16_sapphire, *c = simsimd_cap_sapphire_k; return; + case simsimd_metric_js_k: *m = (m_t)&simsimd_js_f16_sapphire, *c = simsimd_cap_sapphire_k; return; + case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_f16_sapphire, *c = simsimd_cap_sapphire_k; return; + case simsimd_metric_bilinear_k: *m = (m_t)&simsimd_bilinear_f16_sapphire, *c = simsimd_cap_sapphire_k; return; + case simsimd_metric_mahalanobis_k: + *m = (m_t)&simsimd_mahalanobis_f16_sapphire, *c = simsimd_cap_sapphire_k; + return; + case simsimd_metric_fma_k: *m = (m_t)&simsimd_fma_f16_sapphire, *c = simsimd_cap_sapphire_k; return; + case simsimd_metric_wsum_k: *m = (m_t)&simsimd_wsum_f16_sapphire, *c = simsimd_cap_sapphire_k; return; + default: break; + } #endif #if SIMSIMD_TARGET_HASWELL - if (viable & simsimd_cap_haswell_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16_haswell, *c = simsimd_cap_haswell_k; return; - case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f16_haswell, *c = simsimd_cap_haswell_k; return; - case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f16_haswell, *c = simsimd_cap_haswell_k; return; - case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_f16_haswell, *c = simsimd_cap_haswell_k; return; - case simsimd_metric_js_k: *m = (m_t)&simsimd_js_f16_haswell, *c = simsimd_cap_haswell_k; return; - case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_f16_haswell, *c = simsimd_cap_haswell_k; return; - case simsimd_metric_bilinear_k: *m = (m_t)&simsimd_bilinear_f16_haswell, *c = simsimd_cap_haswell_k; return; - case simsimd_metric_mahalanobis_k: - *m = (m_t)&simsimd_mahalanobis_f16_haswell, *c = simsimd_cap_haswell_k; - return; - default: break; - } -#endif - - if (viable & simsimd_cap_serial_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f16_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f16_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_f16_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_js_k: *m = (m_t)&simsimd_js_f16_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_f16_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_bilinear_k: *m = (m_t)&simsimd_bilinear_f16_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_mahalanobis_k: - *m = (m_t)&simsimd_mahalanobis_f16_serial, *c = simsimd_cap_serial_k; - return; - default: break; - } - - break; - } - // Brain floating-point vectors - case simsimd_datatype_bf16_k: { + if (v & simsimd_cap_haswell_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f16_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f16_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_f16_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_js_k: *m = (m_t)&simsimd_js_f16_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_f16_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_bilinear_k: *m = (m_t)&simsimd_bilinear_f16_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_mahalanobis_k: + *m = (m_t)&simsimd_mahalanobis_f16_haswell, *c = simsimd_cap_haswell_k; + return; + case simsimd_metric_fma_k: *m = (m_t)&simsimd_fma_f16_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_wsum_k: *m = (m_t)&simsimd_wsum_f16_haswell, *c = simsimd_cap_haswell_k; return; + default: break; + } +#endif + if (v & simsimd_cap_serial_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f16_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f16_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_f16_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_js_k: *m = (m_t)&simsimd_js_f16_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_f16_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_bilinear_k: *m = (m_t)&simsimd_bilinear_f16_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_mahalanobis_k: *m = (m_t)&simsimd_mahalanobis_f16_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_fma_k: *m = (m_t)&simsimd_fma_f16_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_wsum_k: *m = (m_t)&simsimd_wsum_f16_serial, *c = simsimd_cap_serial_k; return; + default: break; + } +} + +SIMSIMD_INTERNAL void _simsimd_find_metric_punned_bf16(simsimd_capability_t v, simsimd_metric_kind_t k, + simsimd_metric_punned_t *m, simsimd_capability_t *c) { + typedef simsimd_metric_punned_t m_t; #if SIMSIMD_TARGET_SVE_BF16 - if (viable & simsimd_cap_sve_bf16_k) - switch (kind) { - case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_bf16_sve, *c = simsimd_cap_sve_bf16_k; return; - case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_bf16_sve, *c = simsimd_cap_sve_bf16_k; return; - case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_bf16_sve, *c = simsimd_cap_sve_bf16_k; return; - default: break; - } + if (v & simsimd_cap_sve_bf16_k) switch (k) { + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_bf16_sve, *c = simsimd_cap_sve_bf16_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_bf16_sve, *c = simsimd_cap_sve_bf16_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_bf16_sve, *c = simsimd_cap_sve_bf16_k; return; + default: break; + } #endif #if SIMSIMD_TARGET_NEON_BF16 - if (viable & simsimd_cap_neon_bf16_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_bf16_neon, *c = simsimd_cap_neon_bf16_k; return; - case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_bf16_neon, *c = simsimd_cap_neon_bf16_k; return; - case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_bf16_neon, *c = simsimd_cap_neon_bf16_k; return; - case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_bf16_neon, *c = simsimd_cap_neon_bf16_k; return; - default: break; - } + if (v & simsimd_cap_neon_bf16_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_bf16_neon, *c = simsimd_cap_neon_bf16_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_bf16_neon, *c = simsimd_cap_neon_bf16_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_bf16_neon, *c = simsimd_cap_neon_bf16_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_bf16_neon, *c = simsimd_cap_neon_bf16_k; return; + default: break; + } #endif #if SIMSIMD_TARGET_GENOA - if (viable & simsimd_cap_genoa_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_bf16_genoa, *c = simsimd_cap_genoa_k; return; - case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_bf16_genoa, *c = simsimd_cap_genoa_k; return; - case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_bf16_genoa, *c = simsimd_cap_genoa_k; return; - case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_bf16_genoa, *c = simsimd_cap_genoa_k; return; - case simsimd_metric_bilinear_k: *m = (m_t)&simsimd_bilinear_bf16_genoa, *c = simsimd_cap_genoa_k; return; - case simsimd_metric_mahalanobis_k: - *m = (m_t)&simsimd_mahalanobis_bf16_genoa, *c = simsimd_cap_genoa_k; - return; - default: break; - } + if (v & simsimd_cap_genoa_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_bf16_genoa, *c = simsimd_cap_genoa_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_bf16_genoa, *c = simsimd_cap_genoa_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_bf16_genoa, *c = simsimd_cap_genoa_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_bf16_genoa, *c = simsimd_cap_genoa_k; return; + case simsimd_metric_bilinear_k: *m = (m_t)&simsimd_bilinear_bf16_genoa, *c = simsimd_cap_genoa_k; return; + case simsimd_metric_mahalanobis_k: *m = (m_t)&simsimd_mahalanobis_bf16_genoa, *c = simsimd_cap_genoa_k; return; + default: break; + } +#endif +#if SIMSIMD_TARGET_SKYLAKE + if (v & simsimd_cap_skylake_k) switch (k) { + case simsimd_metric_fma_k: *m = (m_t)&simsimd_fma_bf16_skylake, *c = simsimd_cap_skylake_k; return; + case simsimd_metric_wsum_k: *m = (m_t)&simsimd_wsum_bf16_skylake, *c = simsimd_cap_skylake_k; return; + default: break; + } #endif #if SIMSIMD_TARGET_HASWELL - if (viable & simsimd_cap_haswell_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_bf16_haswell, *c = simsimd_cap_haswell_k; return; - case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_bf16_haswell, *c = simsimd_cap_haswell_k; return; - case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_bf16_haswell, *c = simsimd_cap_haswell_k; return; - case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_bf16_haswell, *c = simsimd_cap_haswell_k; return; - case simsimd_metric_bilinear_k: - *m = (m_t)&simsimd_bilinear_bf16_haswell, *c = simsimd_cap_haswell_k; - return; - case simsimd_metric_mahalanobis_k: - *m = (m_t)&simsimd_mahalanobis_bf16_haswell, *c = simsimd_cap_haswell_k; - return; - default: break; - } -#endif - if (viable & simsimd_cap_serial_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_bf16_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_bf16_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_bf16_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_bf16_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_js_k: *m = (m_t)&simsimd_js_bf16_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_bf16_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_bilinear_k: *m = (m_t)&simsimd_bilinear_bf16_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_mahalanobis_k: - *m = (m_t)&simsimd_mahalanobis_bf16_serial, *c = simsimd_cap_serial_k; - return; - default: break; - } - - break; - } - // Single-byte signed integer vectors - case simsimd_datatype_i8_k: { + if (v & simsimd_cap_haswell_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_bf16_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_bf16_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_bf16_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_bf16_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_bilinear_k: *m = (m_t)&simsimd_bilinear_bf16_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_mahalanobis_k: + *m = (m_t)&simsimd_mahalanobis_bf16_haswell, *c = simsimd_cap_haswell_k; + return; + case simsimd_metric_fma_k: *m = (m_t)&simsimd_fma_bf16_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_wsum_k: *m = (m_t)&simsimd_wsum_bf16_haswell, *c = simsimd_cap_haswell_k; return; + default: break; + } +#endif + if (v & simsimd_cap_serial_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_bf16_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_bf16_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_bf16_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_bf16_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_js_k: *m = (m_t)&simsimd_js_bf16_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_bf16_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_bilinear_k: *m = (m_t)&simsimd_bilinear_bf16_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_mahalanobis_k: + *m = (m_t)&simsimd_mahalanobis_bf16_serial, *c = simsimd_cap_serial_k; + return; + case simsimd_metric_fma_k: *m = (m_t)&simsimd_fma_bf16_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_wsum_k: *m = (m_t)&simsimd_wsum_bf16_serial, *c = simsimd_cap_serial_k; return; + default: break; + } +} + +SIMSIMD_INTERNAL void _simsimd_find_metric_punned_i8(simsimd_capability_t v, simsimd_metric_kind_t k, + simsimd_metric_punned_t *m, simsimd_capability_t *c) { + typedef simsimd_metric_punned_t m_t; #if SIMSIMD_TARGET_NEON_I8 - if (viable & simsimd_cap_neon_i8_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_i8_neon, *c = simsimd_cap_neon_i8_k; return; - case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_i8_neon, *c = simsimd_cap_neon_i8_k; return; - case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_i8_neon, *c = simsimd_cap_neon_i8_k; return; - case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_i8_neon, *c = simsimd_cap_neon_i8_k; return; - default: break; - } + if (v & simsimd_cap_neon_i8_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_i8_neon, *c = simsimd_cap_neon_i8_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_i8_neon, *c = simsimd_cap_neon_i8_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_i8_neon, *c = simsimd_cap_neon_i8_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_i8_neon, *c = simsimd_cap_neon_i8_k; return; + default: break; + } +#endif +#if SIMSIMD_TARGET_SAPPHIRE + if (v & simsimd_cap_sapphire_k) switch (k) { + case simsimd_metric_fma_k: *m = (m_t)&simsimd_fma_i8_sapphire, *c = simsimd_cap_sapphire_k; return; + case simsimd_metric_wsum_k: *m = (m_t)&simsimd_wsum_i8_sapphire, *c = simsimd_cap_sapphire_k; return; + default: break; + } #endif #if SIMSIMD_TARGET_ICE - if (viable & simsimd_cap_ice_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_i8_ice, *c = simsimd_cap_ice_k; return; - case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_i8_ice, *c = simsimd_cap_ice_k; return; - case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_i8_ice, *c = simsimd_cap_ice_k; return; - case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_i8_ice, *c = simsimd_cap_ice_k; return; - default: break; - } + if (v & simsimd_cap_ice_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_i8_ice, *c = simsimd_cap_ice_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_i8_ice, *c = simsimd_cap_ice_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_i8_ice, *c = simsimd_cap_ice_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_i8_ice, *c = simsimd_cap_ice_k; return; + default: break; + } #endif #if SIMSIMD_TARGET_HASWELL - if (viable & simsimd_cap_haswell_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_i8_haswell, *c = simsimd_cap_haswell_k; return; - case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_i8_haswell, *c = simsimd_cap_haswell_k; return; - case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_i8_haswell, *c = simsimd_cap_haswell_k; return; - case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_i8_haswell, *c = simsimd_cap_haswell_k; return; - default: break; - } -#endif - - if (viable & simsimd_cap_serial_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_i8_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_i8_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_i8_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_i8_serial, *c = simsimd_cap_serial_k; return; - default: break; - } - - break; - } - // Single-byte unsigned integer vectors - case simsimd_datatype_u8_k: { + if (v & simsimd_cap_haswell_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_i8_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_i8_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_i8_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_i8_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_fma_k: *m = (m_t)&simsimd_fma_i8_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_wsum_k: *m = (m_t)&simsimd_wsum_i8_haswell, *c = simsimd_cap_haswell_k; return; + default: break; + } +#endif + if (v & simsimd_cap_serial_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_i8_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_i8_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_i8_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_i8_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_fma_k: *m = (m_t)&simsimd_fma_i8_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_wsum_k: *m = (m_t)&simsimd_wsum_i8_serial, *c = simsimd_cap_serial_k; return; + default: break; + } +} +SIMSIMD_INTERNAL void _simsimd_find_metric_punned_u8(simsimd_capability_t v, simsimd_metric_kind_t k, + simsimd_metric_punned_t *m, simsimd_capability_t *c) { + typedef simsimd_metric_punned_t m_t; #if SIMSIMD_TARGET_NEON_I8 - if (viable & simsimd_cap_neon_i8_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_u8_neon, *c = simsimd_cap_neon_i8_k; return; - case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_u8_neon, *c = simsimd_cap_neon_i8_k; return; - case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_u8_neon, *c = simsimd_cap_neon_i8_k; return; - case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_u8_neon, *c = simsimd_cap_neon_i8_k; return; - default: break; - } + if (v & simsimd_cap_neon_i8_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_u8_neon, *c = simsimd_cap_neon_i8_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_u8_neon, *c = simsimd_cap_neon_i8_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_u8_neon, *c = simsimd_cap_neon_i8_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_u8_neon, *c = simsimd_cap_neon_i8_k; return; + default: break; + } +#endif +#if SIMSIMD_TARGET_SAPPHIRE + if (v & simsimd_cap_sapphire_k) switch (k) { + case simsimd_metric_fma_k: *m = (m_t)&simsimd_fma_u8_sapphire, *c = simsimd_cap_sapphire_k; return; + case simsimd_metric_wsum_k: *m = (m_t)&simsimd_wsum_u8_sapphire, *c = simsimd_cap_sapphire_k; return; + default: break; + } #endif #if SIMSIMD_TARGET_ICE - if (viable & simsimd_cap_ice_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_u8_ice, *c = simsimd_cap_ice_k; return; - case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_u8_ice, *c = simsimd_cap_ice_k; return; - case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_u8_ice, *c = simsimd_cap_ice_k; return; - case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_u8_ice, *c = simsimd_cap_ice_k; return; - default: break; - } + if (v & simsimd_cap_ice_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_u8_ice, *c = simsimd_cap_ice_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_u8_ice, *c = simsimd_cap_ice_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_u8_ice, *c = simsimd_cap_ice_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_u8_ice, *c = simsimd_cap_ice_k; return; + default: break; + } #endif #if SIMSIMD_TARGET_HASWELL - if (viable & simsimd_cap_haswell_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_u8_haswell, *c = simsimd_cap_haswell_k; return; - case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_u8_haswell, *c = simsimd_cap_haswell_k; return; - case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_u8_haswell, *c = simsimd_cap_haswell_k; return; - case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_u8_haswell, *c = simsimd_cap_haswell_k; return; - default: break; - } -#endif - - if (viable & simsimd_cap_serial_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_u8_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_u8_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_u8_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_u8_serial, *c = simsimd_cap_serial_k; return; - default: break; - } - - break; - } - // Binary vectors - case simsimd_datatype_b8_k: { + if (v & simsimd_cap_haswell_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_u8_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_u8_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_u8_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_u8_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_fma_k: *m = (m_t)&simsimd_fma_u8_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_wsum_k: *m = (m_t)&simsimd_wsum_u8_haswell, *c = simsimd_cap_haswell_k; return; + default: break; + } +#endif + if (v & simsimd_cap_serial_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_u8_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_u8_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_u8_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_u8_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_fma_k: *m = (m_t)&simsimd_fma_u8_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_wsum_k: *m = (m_t)&simsimd_wsum_u8_serial, *c = simsimd_cap_serial_k; return; + default: break; + } +} +SIMSIMD_INTERNAL void _simsimd_find_metric_punned_b8(simsimd_capability_t v, simsimd_metric_kind_t k, + simsimd_metric_punned_t *m, simsimd_capability_t *c) { + typedef simsimd_metric_punned_t m_t; #if SIMSIMD_TARGET_SVE - if (viable & simsimd_cap_sve_k) - switch (kind) { - case simsimd_metric_hamming_k: *m = (m_t)&simsimd_hamming_b8_sve, *c = simsimd_cap_sve_k; return; - case simsimd_metric_jaccard_k: *m = (m_t)&simsimd_jaccard_b8_sve, *c = simsimd_cap_sve_k; return; - default: break; - } + if (v & simsimd_cap_sve_k) switch (k) { + case simsimd_metric_hamming_k: *m = (m_t)&simsimd_hamming_b8_sve, *c = simsimd_cap_sve_k; return; + case simsimd_metric_jaccard_k: *m = (m_t)&simsimd_jaccard_b8_sve, *c = simsimd_cap_sve_k; return; + default: break; + } #endif #if SIMSIMD_TARGET_NEON - if (viable & simsimd_cap_neon_k) - switch (kind) { - case simsimd_metric_hamming_k: *m = (m_t)&simsimd_hamming_b8_neon, *c = simsimd_cap_neon_k; return; - case simsimd_metric_jaccard_k: *m = (m_t)&simsimd_jaccard_b8_neon, *c = simsimd_cap_neon_k; return; - default: break; - } + if (v & simsimd_cap_neon_k) switch (k) { + case simsimd_metric_hamming_k: *m = (m_t)&simsimd_hamming_b8_neon, *c = simsimd_cap_neon_k; return; + case simsimd_metric_jaccard_k: *m = (m_t)&simsimd_jaccard_b8_neon, *c = simsimd_cap_neon_k; return; + default: break; + } #endif #if SIMSIMD_TARGET_ICE - if (viable & simsimd_cap_ice_k) - switch (kind) { - case simsimd_metric_hamming_k: *m = (m_t)&simsimd_hamming_b8_ice, *c = simsimd_cap_ice_k; return; - case simsimd_metric_jaccard_k: *m = (m_t)&simsimd_jaccard_b8_ice, *c = simsimd_cap_ice_k; return; - default: break; - } + if (v & simsimd_cap_ice_k) switch (k) { + case simsimd_metric_hamming_k: *m = (m_t)&simsimd_hamming_b8_ice, *c = simsimd_cap_ice_k; return; + case simsimd_metric_jaccard_k: *m = (m_t)&simsimd_jaccard_b8_ice, *c = simsimd_cap_ice_k; return; + default: break; + } #endif #if SIMSIMD_TARGET_HASWELL - if (viable & simsimd_cap_haswell_k) - switch (kind) { - case simsimd_metric_hamming_k: *m = (m_t)&simsimd_hamming_b8_haswell, *c = simsimd_cap_haswell_k; return; - case simsimd_metric_jaccard_k: *m = (m_t)&simsimd_jaccard_b8_haswell, *c = simsimd_cap_haswell_k; return; - default: break; - } -#endif - - if (viable & simsimd_cap_serial_k) - switch (kind) { - case simsimd_metric_hamming_k: *m = (m_t)&simsimd_hamming_b8_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_jaccard_k: *m = (m_t)&simsimd_jaccard_b8_serial, *c = simsimd_cap_serial_k; return; - default: break; - } - - break; - } - // Complex floating-point vectors - case simsimd_datatype_f32c_k: { + if (v & simsimd_cap_haswell_k) switch (k) { + case simsimd_metric_hamming_k: *m = (m_t)&simsimd_hamming_b8_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_jaccard_k: *m = (m_t)&simsimd_jaccard_b8_haswell, *c = simsimd_cap_haswell_k; return; + default: break; + } +#endif + if (v & simsimd_cap_serial_k) switch (k) { + case simsimd_metric_hamming_k: *m = (m_t)&simsimd_hamming_b8_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_jaccard_k: *m = (m_t)&simsimd_jaccard_b8_serial, *c = simsimd_cap_serial_k; return; + default: break; + } +} +SIMSIMD_INTERNAL void _simsimd_find_metric_punned_f64c(simsimd_capability_t v, simsimd_metric_kind_t k, + simsimd_metric_punned_t *m, simsimd_capability_t *c) { + typedef simsimd_metric_punned_t m_t; #if SIMSIMD_TARGET_SVE - if (viable & simsimd_cap_sve_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32c_sve, *c = simsimd_cap_sve_k; return; - case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f32c_sve, *c = simsimd_cap_sve_k; return; - default: break; - } -#endif -#if SIMSIMD_TARGET_NEON - if (viable & simsimd_cap_neon_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32c_neon, *c = simsimd_cap_neon_k; return; - case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f32c_neon, *c = simsimd_cap_neon_k; return; - default: break; - } + if (v & simsimd_cap_sve_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f64c_sve, *c = simsimd_cap_sve_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f64c_sve, *c = simsimd_cap_sve_k; return; + default: break; + } #endif #if SIMSIMD_TARGET_SKYLAKE - if (viable & simsimd_cap_skylake_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32c_skylake, *c = simsimd_cap_skylake_k; return; - case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f32c_skylake, *c = simsimd_cap_skylake_k; return; - default: break; - } -#endif -#if SIMSIMD_TARGET_HASWELL - if (viable & simsimd_cap_haswell_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32c_haswell, *c = simsimd_cap_haswell_k; return; - case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f32c_haswell, *c = simsimd_cap_haswell_k; return; - default: break; - } -#endif - - if (viable & simsimd_cap_serial_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32c_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f32c_serial, *c = simsimd_cap_serial_k; return; - default: break; - } - - break; - } - // Complex double-precision floating-point vectors - case simsimd_datatype_f64c_k: { + if (v & simsimd_cap_skylake_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f64c_skylake, *c = simsimd_cap_skylake_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f64c_skylake, *c = simsimd_cap_skylake_k; return; + default: break; + } +#endif + if (v & simsimd_cap_serial_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f64c_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f64c_serial, *c = simsimd_cap_serial_k; return; + default: break; + } +} +SIMSIMD_INTERNAL void _simsimd_find_metric_punned_f32c(simsimd_capability_t v, simsimd_metric_kind_t k, + simsimd_metric_punned_t *m, simsimd_capability_t *c) { + typedef simsimd_metric_punned_t m_t; #if SIMSIMD_TARGET_SVE - if (viable & simsimd_cap_sve_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f64c_sve, *c = simsimd_cap_sve_k; return; - case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f64c_sve, *c = simsimd_cap_sve_k; return; - default: break; - } + if (v & simsimd_cap_sve_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32c_sve, *c = simsimd_cap_sve_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f32c_sve, *c = simsimd_cap_sve_k; return; + default: break; + } +#endif +#if SIMSIMD_TARGET_NEON + if (v & simsimd_cap_neon_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32c_neon, *c = simsimd_cap_neon_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f32c_neon, *c = simsimd_cap_neon_k; return; + default: break; + } #endif #if SIMSIMD_TARGET_SKYLAKE - if (viable & simsimd_cap_skylake_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f64c_skylake, *c = simsimd_cap_skylake_k; return; - case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f64c_skylake, *c = simsimd_cap_skylake_k; return; - default: break; - } -#endif - - if (viable & simsimd_cap_serial_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f64c_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f64c_serial, *c = simsimd_cap_serial_k; return; - default: break; - } - - break; - } - // Complex half-precision floating-point vectors - case simsimd_datatype_f16c_k: { + if (v & simsimd_cap_skylake_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32c_skylake, *c = simsimd_cap_skylake_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f32c_skylake, *c = simsimd_cap_skylake_k; return; + default: break; + } +#endif +#if SIMSIMD_TARGET_HASWELL + if (v & simsimd_cap_haswell_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32c_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f32c_haswell, *c = simsimd_cap_haswell_k; return; + default: break; + } +#endif + if (v & simsimd_cap_serial_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32c_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f32c_serial, *c = simsimd_cap_serial_k; return; + default: break; + } +} +SIMSIMD_INTERNAL void _simsimd_find_metric_punned_f16c(simsimd_capability_t v, simsimd_metric_kind_t k, + simsimd_metric_punned_t *m, simsimd_capability_t *c) { + typedef simsimd_metric_punned_t m_t; #if SIMSIMD_TARGET_SVE_F16 - if (viable & simsimd_cap_sve_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16c_sve, *c = simsimd_cap_sve_f16_k; return; - case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f16c_sve, *c = simsimd_cap_sve_f16_k; return; - default: break; - } + if (v & simsimd_cap_sve_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16c_sve, *c = simsimd_cap_sve_f16_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f16c_sve, *c = simsimd_cap_sve_f16_k; return; + default: break; + } #endif #if SIMSIMD_TARGET_NEON_F16 - if (viable & simsimd_cap_neon_f16_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16c_neon, *c = simsimd_cap_neon_f16_k; return; - case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f16c_neon, *c = simsimd_cap_neon_f16_k; return; - default: break; - } + if (v & simsimd_cap_neon_f16_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16c_neon, *c = simsimd_cap_neon_f16_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f16c_neon, *c = simsimd_cap_neon_f16_k; return; + default: break; + } #endif #if SIMSIMD_TARGET_SAPPHIRE - if (viable & simsimd_cap_sapphire_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16c_sapphire, *c = simsimd_cap_sapphire_k; return; - case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f16c_sapphire, *c = simsimd_cap_sapphire_k; return; - default: break; - } + if (v & simsimd_cap_sapphire_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16c_sapphire, *c = simsimd_cap_sapphire_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f16c_sapphire, *c = simsimd_cap_sapphire_k; return; + default: break; + } #endif #if SIMSIMD_TARGET_HASWELL - if (viable & simsimd_cap_haswell_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16c_haswell, *c = simsimd_cap_haswell_k; return; - case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f16c_haswell, *c = simsimd_cap_haswell_k; return; - default: break; - } -#endif - - if (viable & simsimd_cap_serial_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16c_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f16c_serial, *c = simsimd_cap_serial_k; return; - default: break; - } - - break; - } - // Complex Brain floating-point vectors - case simsimd_datatype_bf16c_k: { + if (v & simsimd_cap_haswell_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16c_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f16c_haswell, *c = simsimd_cap_haswell_k; return; + default: break; + } +#endif + if (v & simsimd_cap_serial_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16c_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f16c_serial, *c = simsimd_cap_serial_k; return; + default: break; + } +} +SIMSIMD_INTERNAL void _simsimd_find_metric_punned_bf16c(simsimd_capability_t v, simsimd_metric_kind_t k, + simsimd_metric_punned_t *m, simsimd_capability_t *c) { + typedef simsimd_metric_punned_t m_t; #if SIMSIMD_TARGET_NEON_BF16 - if (viable & simsimd_cap_neon_bf16_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_bf16c_neon, *c = simsimd_cap_neon_bf16_k; return; - case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_bf16c_neon, *c = simsimd_cap_neon_bf16_k; return; - default: break; - } + if (v & simsimd_cap_neon_bf16_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_bf16c_neon, *c = simsimd_cap_neon_bf16_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_bf16c_neon, *c = simsimd_cap_neon_bf16_k; return; + default: break; + } #endif #if SIMSIMD_TARGET_GENOA - if (viable & simsimd_cap_genoa_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_bf16c_genoa, *c = simsimd_cap_genoa_k; return; - case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_bf16c_genoa, *c = simsimd_cap_genoa_k; return; - default: break; - } -#endif - - if (viable & simsimd_cap_serial_k) - switch (kind) { - case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_bf16c_serial, *c = simsimd_cap_serial_k; return; - case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_bf16c_serial, *c = simsimd_cap_serial_k; return; - default: break; - } - - break; - } - - // Unsigned 16-bit integer vectors - case simsimd_datatype_u16_k: { + if (v & simsimd_cap_genoa_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_bf16c_genoa, *c = simsimd_cap_genoa_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_bf16c_genoa, *c = simsimd_cap_genoa_k; return; + default: break; + } +#endif + if (v & simsimd_cap_serial_k) switch (k) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_bf16c_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_bf16c_serial, *c = simsimd_cap_serial_k; return; + default: break; + } +} +SIMSIMD_INTERNAL void _simsimd_find_metric_punned_u16(simsimd_capability_t v, simsimd_metric_kind_t k, + simsimd_metric_punned_t *m, simsimd_capability_t *c) { + typedef simsimd_metric_punned_t m_t; #if SIMSIMD_TARGET_SVE2 - if (viable & simsimd_cap_sve2_k) - switch (kind) { - case simsimd_metric_intersect_k: *m = (m_t)&simsimd_intersect_u16_sve2, *c = simsimd_cap_sve2_k; return; - case simsimd_metric_spdot_counts_k: - *m = (m_t)&simsimd_spdot_counts_u16_sve2, *c = simsimd_cap_sve2_k; - return; - case simsimd_metric_spdot_weights_k: - *m = (m_t)&simsimd_spdot_weights_u16_sve2, *c = simsimd_cap_sve2_k; - return; - default: break; - } + if (v & simsimd_cap_sve2_k) switch (k) { + case simsimd_metric_intersect_k: *m = (m_t)&simsimd_intersect_u16_sve2, *c = simsimd_cap_sve2_k; return; + case simsimd_metric_spdot_counts_k: *m = (m_t)&simsimd_spdot_counts_u16_sve2, *c = simsimd_cap_sve2_k; return; + case simsimd_metric_spdot_weights_k: *m = (m_t)&simsimd_spdot_weights_u16_sve2, *c = simsimd_cap_sve2_k; return; + default: break; + } #endif #if SIMSIMD_TARGET_NEON - if (viable & simsimd_cap_neon_k) - switch (kind) { - case simsimd_metric_intersect_k: *m = (m_t)&simsimd_intersect_u16_neon, *c = simsimd_cap_neon_k; return; - default: break; - } + if (v & simsimd_cap_neon_k) switch (k) { + case simsimd_metric_intersect_k: *m = (m_t)&simsimd_intersect_u16_neon, *c = simsimd_cap_neon_k; return; + default: break; + } #endif #if SIMSIMD_TARGET_TURIN - if (viable & simsimd_cap_turin_k) - switch (kind) { - case simsimd_metric_intersect_k: *m = (m_t)&simsimd_intersect_u16_turin, *c = simsimd_cap_turin_k; return; - case simsimd_metric_spdot_counts_k: - *m = (m_t)&simsimd_spdot_counts_u16_turin, *c = simsimd_cap_turin_k; - return; - case simsimd_metric_spdot_weights_k: - *m = (m_t)&simsimd_spdot_weights_u16_turin, *c = simsimd_cap_turin_k; - return; - default: break; - } + if (v & simsimd_cap_turin_k) switch (k) { + case simsimd_metric_intersect_k: *m = (m_t)&simsimd_intersect_u16_turin, *c = simsimd_cap_turin_k; return; + case simsimd_metric_spdot_counts_k: *m = (m_t)&simsimd_spdot_counts_u16_turin, *c = simsimd_cap_turin_k; return; + case simsimd_metric_spdot_weights_k: + *m = (m_t)&simsimd_spdot_weights_u16_turin, *c = simsimd_cap_turin_k; + return; + default: break; + } #endif #if SIMSIMD_TARGET_ICE - if (viable & simsimd_cap_ice_k) - switch (kind) { - case simsimd_metric_intersect_k: *m = (m_t)&simsimd_intersect_u16_ice, *c = simsimd_cap_skylake_k; return; - default: break; - } -#endif - if (viable & simsimd_cap_serial_k) - switch (kind) { - case simsimd_metric_intersect_k: *m = (m_t)&simsimd_intersect_u16_serial, *c = simsimd_cap_serial_k; return; - default: break; - } - - break; - } - // Unsigned 32-bit integer vectors - case simsimd_datatype_u32_k: { + if (v & simsimd_cap_ice_k) switch (k) { + case simsimd_metric_intersect_k: *m = (m_t)&simsimd_intersect_u16_ice, *c = simsimd_cap_skylake_k; return; + default: break; + } +#endif + if (v & simsimd_cap_serial_k) switch (k) { + case simsimd_metric_intersect_k: *m = (m_t)&simsimd_intersect_u16_serial, *c = simsimd_cap_serial_k; return; + default: break; + } +} +SIMSIMD_INTERNAL void _simsimd_find_metric_punned_u32(simsimd_capability_t v, simsimd_metric_kind_t k, + simsimd_metric_punned_t *m, simsimd_capability_t *c) { + typedef simsimd_metric_punned_t m_t; #if SIMSIMD_TARGET_SVE2 - if (viable & simsimd_cap_sve2_k) - switch (kind) { - case simsimd_metric_intersect_k: *m = (m_t)&simsimd_intersect_u32_sve2, *c = simsimd_cap_sve2_k; return; - default: break; - } + if (v & simsimd_cap_sve2_k) switch (k) { + case simsimd_metric_intersect_k: *m = (m_t)&simsimd_intersect_u32_sve2, *c = simsimd_cap_sve2_k; return; + default: break; + } #endif #if SIMSIMD_TARGET_NEON - if (viable & simsimd_cap_neon_k) - switch (kind) { - case simsimd_metric_intersect_k: *m = (m_t)&simsimd_intersect_u32_neon, *c = simsimd_cap_neon_k; return; - default: break; - } + if (v & simsimd_cap_neon_k) switch (k) { + case simsimd_metric_intersect_k: *m = (m_t)&simsimd_intersect_u32_neon, *c = simsimd_cap_neon_k; return; + default: break; + } #endif #if SIMSIMD_TARGET_TURIN - if (viable & simsimd_cap_turin_k) - switch (kind) { - case simsimd_metric_intersect_k: *m = (m_t)&simsimd_intersect_u32_turin, *c = simsimd_cap_skylake_k; return; - default: break; - } + if (v & simsimd_cap_turin_k) switch (k) { + case simsimd_metric_intersect_k: *m = (m_t)&simsimd_intersect_u32_turin, *c = simsimd_cap_skylake_k; return; + default: break; + } #endif #if SIMSIMD_TARGET_ICE - if (viable & simsimd_cap_ice_k) - switch (kind) { - case simsimd_metric_intersect_k: *m = (m_t)&simsimd_intersect_u32_ice, *c = simsimd_cap_skylake_k; return; - default: break; - } -#endif - if (viable & simsimd_cap_serial_k) - switch (kind) { - case simsimd_metric_intersect_k: *m = (m_t)&simsimd_intersect_u32_serial, *c = simsimd_cap_serial_k; return; - default: break; - } - - break; - } + if (v & simsimd_cap_ice_k) switch (k) { + case simsimd_metric_intersect_k: *m = (m_t)&simsimd_intersect_u32_ice, *c = simsimd_cap_skylake_k; return; + default: break; + } +#endif + if (v & simsimd_cap_serial_k) switch (k) { + case simsimd_metric_intersect_k: *m = (m_t)&simsimd_intersect_u32_serial, *c = simsimd_cap_serial_k; return; + default: break; + } +} + +/** + * @brief Determines the best suited metric implementation based on the given datatype, + * supported and allowed by hardware capabilities. + * + * @param kind The kind of metric to be evaluated. + * @param datatype The data type for which the metric needs to be evaluated. + * @param supported The hardware capabilities supported by the CPU. + * @param allowed The hardware capabilities allowed for use. + * @param metric_output Output variable for the selected similarity function. + * @param capability_output Output variable for the utilized hardware capabilities. + */ +SIMSIMD_INTERNAL void _simsimd_find_metric_punned_implementation( // + simsimd_metric_kind_t kind, // + simsimd_datatype_t datatype, // + simsimd_capability_t supported, // + simsimd_capability_t allowed, // + simsimd_metric_punned_t *metric_output, // + simsimd_capability_t *capability_output) { + + // Modern compilers abso-freaking-lutely love optimizing-out my logic! + // Just marking the variables as `volatile` is not enough, so we have + // to add inline assembly to further discourage them! +#if defined(_MSC_VER) + _ReadWriteBarrier(); +#else + __asm__ __volatile__("" ::: "memory"); +#endif + + volatile simsimd_metric_punned_t *m = metric_output; + volatile simsimd_capability_t *c = capability_output; + volatile simsimd_capability_t viable = (simsimd_capability_t)(supported & allowed); + + switch (datatype) { + + case simsimd_datatype_f64_k: _simsimd_find_metric_punned_f64(viable, kind, m, c); return; + case simsimd_datatype_f32_k: _simsimd_find_metric_punned_f32(viable, kind, m, c); return; + case simsimd_datatype_f16_k: _simsimd_find_metric_punned_f16(viable, kind, m, c); return; + case simsimd_datatype_bf16_k: _simsimd_find_metric_punned_bf16(viable, kind, m, c); return; + case simsimd_datatype_i8_k: _simsimd_find_metric_punned_i8(viable, kind, m, c); return; + case simsimd_datatype_u8_k: _simsimd_find_metric_punned_u8(viable, kind, m, c); return; + case simsimd_datatype_b8_k: _simsimd_find_metric_punned_b8(viable, kind, m, c); return; + case simsimd_datatype_f32c_k: _simsimd_find_metric_punned_f32c(viable, kind, m, c); return; + case simsimd_datatype_f64c_k: _simsimd_find_metric_punned_f64c(viable, kind, m, c); return; + case simsimd_datatype_f16c_k: _simsimd_find_metric_punned_f16c(viable, kind, m, c); return; + case simsimd_datatype_bf16c_k: _simsimd_find_metric_punned_bf16c(viable, kind, m, c); return; + case simsimd_datatype_u16_k: _simsimd_find_metric_punned_u16(viable, kind, m, c); return; + case simsimd_datatype_u32_k: _simsimd_find_metric_punned_u32(viable, kind, m, c); return; + + // These data-types are not supported yet + case simsimd_datatype_i4x2_k: break; + case simsimd_datatype_i16_k: break; + case simsimd_datatype_i32_k: break; + case simsimd_datatype_i64_k: break; + case simsimd_datatype_u64_k: break; + case simsimd_datatype_unknown_k: break; + default: break; } + // Replace with zeros if no suitable implementation was found + *m = (simsimd_metric_punned_t)0; + *c = (simsimd_capability_t)0; + // Modern compilers abso-freaking-lutely love optimizing-out my logic! // Just marking the variables as `volatile` is not enough, so we have // to add inline assembly to further discourage them! @@ -1274,30 +1279,30 @@ SIMSIMD_DYNAMIC int simsimd_uses_sierra(void); * @note The dot product is zero if and only if the two vectors are orthogonal. * @note Defined only for floating-point and integer data types. */ -SIMSIMD_DYNAMIC void simsimd_dot_f16(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_dot_bf16(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_dot_f32(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_dot_f64(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_dot_f16c(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_dot_bf16c(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_dot_f32c(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_dot_f64c(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_vdot_f16c(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_vdot_bf16c(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_vdot_f32c(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_vdot_f64c(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_dot_f16(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_dot_bf16(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_dot_f32(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_dot_f64(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_dot_f16c(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_dot_bf16c(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_dot_f32c(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_dot_f64c(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_vdot_f16c(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_vdot_bf16c(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_vdot_f32c(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_vdot_f64c(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *d); /* Spatial distances * - Cosine distance: the cosine of the angle between two vectors. @@ -1312,42 +1317,42 @@ SIMSIMD_DYNAMIC void simsimd_vdot_f64c(simsimd_f64_t const* a, simsimd_f64_t con * @note The output distance value is zero if and only if the two vectors are identical. * @note Defined only for floating-point and integer data types. */ -SIMSIMD_DYNAMIC void simsimd_cos_i8(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_cos_u8(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_cos_f16(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_cos_bf16(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_cos_f32(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_cos_f64(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_l2sq_i8(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_l2sq_u8(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_l2sq_f16(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_l2sq_bf16(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_l2sq_f32(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_l2sq_f64(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_l2_i8(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_l2_u8(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_l2_f16(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_l2_bf16(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_l2_f32(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_l2_f64(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_cos_i8(simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_cos_u8(simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_cos_f16(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_cos_bf16(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_cos_f32(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_cos_f64(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_l2sq_i8(simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_l2sq_u8(simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_l2sq_f16(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_l2sq_bf16(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_l2sq_f32(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_l2sq_f64(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_l2_i8(simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_l2_u8(simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_l2_f16(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_l2_bf16(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_l2_f32(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_l2_f64(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *d); /* Binary distances * - Hamming distance: the number of positions at which the corresponding bits are different. @@ -1362,10 +1367,10 @@ SIMSIMD_DYNAMIC void simsimd_l2_f64(simsimd_f64_t const* a, simsimd_f64_t const* * @note The output distance value is zero if and only if the two vectors are identical. * @note Defined only for binary data. */ -SIMSIMD_DYNAMIC void simsimd_hamming_b8(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_jaccard_b8(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n, - simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_hamming_b8(simsimd_b8_t const *a, simsimd_b8_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_jaccard_b8(simsimd_b8_t const *a, simsimd_b8_t const *b, simsimd_size_t n, + simsimd_distance_t *d); /* Probability distributions * - Jensen-Shannon divergence: a measure of similarity between two probability distributions. @@ -1381,22 +1386,22 @@ SIMSIMD_DYNAMIC void simsimd_jaccard_b8(simsimd_b8_t const* a, simsimd_b8_t cons * @note The output divergence value is zero if and only if the two distributions are identical. * @note Defined only for floating-point data types. */ -SIMSIMD_DYNAMIC void simsimd_kl_f16(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_kl_bf16(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_kl_f32(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_kl_f64(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_js_f16(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_js_bf16(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_js_f32(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* d); -SIMSIMD_DYNAMIC void simsimd_js_f64(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_kl_f16(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_kl_bf16(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_kl_f32(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_kl_f64(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_js_f16(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_js_bf16(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_js_f32(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *d); +SIMSIMD_DYNAMIC void simsimd_js_f64(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *d); #else @@ -1413,22 +1418,22 @@ SIMSIMD_DYNAMIC void simsimd_js_f64(simsimd_f64_t const* a, simsimd_f64_t const* */ // clang-format off -SIMSIMD_PUBLIC int simsimd_uses_neon(void) { return SIMSIMD_TARGET_ARM && SIMSIMD_TARGET_NEON; } -SIMSIMD_PUBLIC int simsimd_uses_neon_f16(void) { return SIMSIMD_TARGET_ARM && SIMSIMD_TARGET_NEON_F16 ; } -SIMSIMD_PUBLIC int simsimd_uses_neon_bf16(void) { return SIMSIMD_TARGET_ARM && SIMSIMD_TARGET_NEON_BF16; } -SIMSIMD_PUBLIC int simsimd_uses_neon_i8(void) { return SIMSIMD_TARGET_ARM && SIMSIMD_TARGET_NEON_I8; } -SIMSIMD_PUBLIC int simsimd_uses_sve(void) { return SIMSIMD_TARGET_ARM && SIMSIMD_TARGET_SVE; } -SIMSIMD_PUBLIC int simsimd_uses_sve_f16(void) { return SIMSIMD_TARGET_ARM && SIMSIMD_TARGET_SVE_F16; } -SIMSIMD_PUBLIC int simsimd_uses_sve_bf16(void) { return SIMSIMD_TARGET_ARM && SIMSIMD_TARGET_SVE_BF16; } -SIMSIMD_PUBLIC int simsimd_uses_sve_i8(void) { return SIMSIMD_TARGET_ARM && SIMSIMD_TARGET_SVE_I8; } -SIMSIMD_PUBLIC int simsimd_uses_sve2(void) { return SIMSIMD_TARGET_ARM && SIMSIMD_TARGET_SVE2; } -SIMSIMD_PUBLIC int simsimd_uses_haswell(void) { return SIMSIMD_TARGET_X86 && SIMSIMD_TARGET_HASWELL; } -SIMSIMD_PUBLIC int simsimd_uses_skylake(void) { return SIMSIMD_TARGET_X86 && SIMSIMD_TARGET_SKYLAKE; } -SIMSIMD_PUBLIC int simsimd_uses_ice(void) { return SIMSIMD_TARGET_X86 && SIMSIMD_TARGET_ICE; } -SIMSIMD_PUBLIC int simsimd_uses_genoa(void) { return SIMSIMD_TARGET_X86 && SIMSIMD_TARGET_GENOA; } -SIMSIMD_PUBLIC int simsimd_uses_sapphire(void) { return SIMSIMD_TARGET_X86 && SIMSIMD_TARGET_SAPPHIRE; } -SIMSIMD_PUBLIC int simsimd_uses_turin(void) { return SIMSIMD_TARGET_X86 && SIMSIMD_TARGET_TURIN; } -SIMSIMD_PUBLIC int simsimd_uses_sierra(void) { return SIMSIMD_TARGET_X86 && SIMSIMD_TARGET_SIERRA; } +SIMSIMD_PUBLIC int simsimd_uses_neon(void) { return _SIMSIMD_TARGET_ARM && SIMSIMD_TARGET_NEON; } +SIMSIMD_PUBLIC int simsimd_uses_neon_f16(void) { return _SIMSIMD_TARGET_ARM && SIMSIMD_TARGET_NEON_F16 ; } +SIMSIMD_PUBLIC int simsimd_uses_neon_bf16(void) { return _SIMSIMD_TARGET_ARM && SIMSIMD_TARGET_NEON_BF16; } +SIMSIMD_PUBLIC int simsimd_uses_neon_i8(void) { return _SIMSIMD_TARGET_ARM && SIMSIMD_TARGET_NEON_I8; } +SIMSIMD_PUBLIC int simsimd_uses_sve(void) { return _SIMSIMD_TARGET_ARM && SIMSIMD_TARGET_SVE; } +SIMSIMD_PUBLIC int simsimd_uses_sve_f16(void) { return _SIMSIMD_TARGET_ARM && SIMSIMD_TARGET_SVE_F16; } +SIMSIMD_PUBLIC int simsimd_uses_sve_bf16(void) { return _SIMSIMD_TARGET_ARM && SIMSIMD_TARGET_SVE_BF16; } +SIMSIMD_PUBLIC int simsimd_uses_sve_i8(void) { return _SIMSIMD_TARGET_ARM && SIMSIMD_TARGET_SVE_I8; } +SIMSIMD_PUBLIC int simsimd_uses_sve2(void) { return _SIMSIMD_TARGET_ARM && SIMSIMD_TARGET_SVE2; } +SIMSIMD_PUBLIC int simsimd_uses_haswell(void) { return _SIMSIMD_TARGET_X86 && SIMSIMD_TARGET_HASWELL; } +SIMSIMD_PUBLIC int simsimd_uses_skylake(void) { return _SIMSIMD_TARGET_X86 && SIMSIMD_TARGET_SKYLAKE; } +SIMSIMD_PUBLIC int simsimd_uses_ice(void) { return _SIMSIMD_TARGET_X86 && SIMSIMD_TARGET_ICE; } +SIMSIMD_PUBLIC int simsimd_uses_genoa(void) { return _SIMSIMD_TARGET_X86 && SIMSIMD_TARGET_GENOA; } +SIMSIMD_PUBLIC int simsimd_uses_sapphire(void) { return _SIMSIMD_TARGET_X86 && SIMSIMD_TARGET_SAPPHIRE; } +SIMSIMD_PUBLIC int simsimd_uses_turin(void) { return _SIMSIMD_TARGET_X86 && SIMSIMD_TARGET_TURIN; } +SIMSIMD_PUBLIC int simsimd_uses_sierra(void) { return _SIMSIMD_TARGET_X86 && SIMSIMD_TARGET_SIERRA; } SIMSIMD_PUBLIC int simsimd_uses_dynamic_dispatch(void) { return 0; } SIMSIMD_PUBLIC simsimd_capability_t simsimd_capabilities(void) { return _simsimd_capabilities_implementation(); } SIMSIMD_PUBLIC void simsimd_find_metric_punned( // @@ -1456,8 +1461,8 @@ SIMSIMD_PUBLIC void simsimd_find_metric_punned( // * @note The dot product is zero if and only if the two vectors are orthogonal. * @note Defined only for floating-point and integer data types. */ -SIMSIMD_PUBLIC void simsimd_dot_i8(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_dot_i8(simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_NEON_F16 simsimd_dot_i8_neon(a, b, n, d); #elif SIMSIMD_TARGET_ICE @@ -1468,8 +1473,8 @@ SIMSIMD_PUBLIC void simsimd_dot_i8(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_dot_i8_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_dot_u8(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_dot_u8(simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_NEON_F16 simsimd_dot_u8_neon(a, b, n, d); #elif SIMSIMD_TARGET_ICE @@ -1480,8 +1485,8 @@ SIMSIMD_PUBLIC void simsimd_dot_u8(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_dot_u8_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_dot_f16(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_dot_f16(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_SVE_F16 simsimd_dot_f16_sve(a, b, n, d); #elif SIMSIMD_TARGET_NEON_F16 @@ -1494,8 +1499,8 @@ SIMSIMD_PUBLIC void simsimd_dot_f16(simsimd_f16_t const* a, simsimd_f16_t const* simsimd_dot_f16_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_dot_bf16(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_dot_bf16(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_GENOA simsimd_dot_bf16_genoa(a, b, n, d); #elif SIMSIMD_TARGET_HASWELL @@ -1506,8 +1511,8 @@ SIMSIMD_PUBLIC void simsimd_dot_bf16(simsimd_bf16_t const* a, simsimd_bf16_t con simsimd_dot_bf16_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_dot_f32(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_dot_f32(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_SVE simsimd_dot_f32_sve(a, b, n, d); #elif SIMSIMD_TARGET_NEON @@ -1520,8 +1525,8 @@ SIMSIMD_PUBLIC void simsimd_dot_f32(simsimd_f32_t const* a, simsimd_f32_t const* simsimd_dot_f32_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_dot_f64(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_dot_f64(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_SVE simsimd_dot_f64_sve(a, b, n, d); #elif SIMSIMD_TARGET_SKYLAKE @@ -1530,8 +1535,8 @@ SIMSIMD_PUBLIC void simsimd_dot_f64(simsimd_f64_t const* a, simsimd_f64_t const* simsimd_dot_f64_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_dot_f16c(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_dot_f16c(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_SVE_F16 simsimd_dot_f16c_sve(a, b, n, d); #elif SIMSIMD_TARGET_NEON_F16 @@ -1544,8 +1549,8 @@ SIMSIMD_PUBLIC void simsimd_dot_f16c(simsimd_f16_t const* a, simsimd_f16_t const simsimd_dot_f16c_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_dot_bf16c(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_dot_bf16c(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_GENOA simsimd_dot_bf16c_genoa(a, b, n, d); #elif SIMSIMD_TARGET_NEON_BF16 @@ -1554,8 +1559,8 @@ SIMSIMD_PUBLIC void simsimd_dot_bf16c(simsimd_bf16_t const* a, simsimd_bf16_t co simsimd_dot_bf16c_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_dot_f32c(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_dot_f32c(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_SVE simsimd_dot_f32c_sve(a, b, n, d); #elif SIMSIMD_TARGET_NEON @@ -1568,8 +1573,8 @@ SIMSIMD_PUBLIC void simsimd_dot_f32c(simsimd_f32_t const* a, simsimd_f32_t const simsimd_dot_f32c_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_dot_f64c(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_dot_f64c(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_SVE simsimd_dot_f64c_sve(a, b, n, d); #elif SIMSIMD_TARGET_SKYLAKE @@ -1578,8 +1583,8 @@ SIMSIMD_PUBLIC void simsimd_dot_f64c(simsimd_f64_t const* a, simsimd_f64_t const simsimd_dot_f64c_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_vdot_f16c(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_vdot_f16c(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_SVE simsimd_vdot_f16c_sve(a, b, n, d); #elif SIMSIMD_TARGET_NEON @@ -1592,8 +1597,8 @@ SIMSIMD_PUBLIC void simsimd_vdot_f16c(simsimd_f16_t const* a, simsimd_f16_t cons simsimd_vdot_f16c_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_vdot_bf16c(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_vdot_bf16c(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_GENOA simsimd_vdot_bf16c_genoa(a, b, n, d); #elif SIMSIMD_TARGET_NEON_BF16 @@ -1602,8 +1607,8 @@ SIMSIMD_PUBLIC void simsimd_vdot_bf16c(simsimd_bf16_t const* a, simsimd_bf16_t c simsimd_vdot_bf16c_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_vdot_f32c(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_vdot_f32c(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_SVE simsimd_vdot_f32c_sve(a, b, n, d); #elif SIMSIMD_TARGET_NEON @@ -1616,8 +1621,8 @@ SIMSIMD_PUBLIC void simsimd_vdot_f32c(simsimd_f32_t const* a, simsimd_f32_t cons simsimd_vdot_f32c_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_vdot_f64c(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_vdot_f64c(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_SVE simsimd_vdot_f64c_sve(a, b, n, d); #elif SIMSIMD_TARGET_SKYLAKE @@ -1640,8 +1645,8 @@ SIMSIMD_PUBLIC void simsimd_vdot_f64c(simsimd_f64_t const* a, simsimd_f64_t cons * @note The output distance value is zero if and only if the two vectors are identical. * @note Defined only for floating-point and integer data types. */ -SIMSIMD_PUBLIC void simsimd_cos_i8(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_cos_i8(simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_NEON simsimd_cos_i8_neon(a, b, n, d); #elif SIMSIMD_TARGET_ICE @@ -1652,8 +1657,8 @@ SIMSIMD_PUBLIC void simsimd_cos_i8(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_cos_i8_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_cos_u8(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_cos_u8(simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_NEON simsimd_cos_u8_neon(a, b, n, d); #elif SIMSIMD_TARGET_ICE @@ -1664,8 +1669,8 @@ SIMSIMD_PUBLIC void simsimd_cos_u8(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_cos_u8_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_cos_f16(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_cos_f16(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_SVE_F16 simsimd_cos_f16_sve(a, b, n, d); #elif SIMSIMD_TARGET_NEON_F16 @@ -1678,8 +1683,8 @@ SIMSIMD_PUBLIC void simsimd_cos_f16(simsimd_f16_t const* a, simsimd_f16_t const* simsimd_cos_f16_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_cos_bf16(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_cos_bf16(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_GENOA simsimd_cos_bf16_genoa(a, b, n, d); #elif SIMSIMD_TARGET_HASWELL @@ -1692,8 +1697,8 @@ SIMSIMD_PUBLIC void simsimd_cos_bf16(simsimd_bf16_t const* a, simsimd_bf16_t con simsimd_cos_bf16_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_cos_f32(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_cos_f32(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_SVE simsimd_cos_f32_sve(a, b, n, d); #elif SIMSIMD_TARGET_NEON @@ -1706,8 +1711,8 @@ SIMSIMD_PUBLIC void simsimd_cos_f32(simsimd_f32_t const* a, simsimd_f32_t const* simsimd_cos_f32_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_cos_f64(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_cos_f64(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_SVE simsimd_cos_f64_sve(a, b, n, d); #elif SIMSIMD_TARGET_NEON @@ -1718,8 +1723,8 @@ SIMSIMD_PUBLIC void simsimd_cos_f64(simsimd_f64_t const* a, simsimd_f64_t const* simsimd_cos_f64_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_l2sq_i8(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_l2sq_i8(simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_NEON simsimd_l2sq_i8_neon(a, b, n, d); #elif SIMSIMD_TARGET_ICE @@ -1730,8 +1735,8 @@ SIMSIMD_PUBLIC void simsimd_l2sq_i8(simsimd_i8_t const* a, simsimd_i8_t const* b simsimd_l2sq_i8_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_l2sq_u8(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_l2sq_u8(simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_NEON simsimd_l2sq_u8_neon(a, b, n, d); #elif SIMSIMD_TARGET_ICE @@ -1742,8 +1747,8 @@ SIMSIMD_PUBLIC void simsimd_l2sq_u8(simsimd_u8_t const* a, simsimd_u8_t const* b simsimd_l2sq_u8_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_l2sq_f16(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_l2sq_f16(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_SVE_F16 simsimd_l2sq_f16_sve(a, b, n, d); #elif SIMSIMD_TARGET_NEON_F16 @@ -1756,8 +1761,8 @@ SIMSIMD_PUBLIC void simsimd_l2sq_f16(simsimd_f16_t const* a, simsimd_f16_t const simsimd_l2sq_f16_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_l2sq_bf16(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_l2sq_bf16(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_GENOA simsimd_l2sq_bf16_genoa(a, b, n, d); #elif SIMSIMD_TARGET_HASWELL @@ -1770,8 +1775,8 @@ SIMSIMD_PUBLIC void simsimd_l2sq_bf16(simsimd_bf16_t const* a, simsimd_bf16_t co simsimd_l2sq_bf16_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_l2sq_f32(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_l2sq_f32(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_SVE simsimd_l2sq_f32_sve(a, b, n, d); #elif SIMSIMD_TARGET_NEON @@ -1784,8 +1789,8 @@ SIMSIMD_PUBLIC void simsimd_l2sq_f32(simsimd_f32_t const* a, simsimd_f32_t const simsimd_l2sq_f32_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_l2sq_f64(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_l2sq_f64(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_SVE simsimd_l2sq_f64_sve(a, b, n, d); #elif SIMSIMD_TARGET_NEON @@ -1796,8 +1801,8 @@ SIMSIMD_PUBLIC void simsimd_l2sq_f64(simsimd_f64_t const* a, simsimd_f64_t const simsimd_l2sq_f64_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_l2_i8(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_l2_i8(simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_NEON simsimd_l2_i8_neon(a, b, n, d); #elif SIMSIMD_TARGET_ICE @@ -1808,8 +1813,8 @@ SIMSIMD_PUBLIC void simsimd_l2_i8(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_l2_i8_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_l2_u8(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_l2_u8(simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_NEON simsimd_l2_u8_neon(a, b, n, d); #elif SIMSIMD_TARGET_ICE @@ -1820,8 +1825,8 @@ SIMSIMD_PUBLIC void simsimd_l2_u8(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_l2_u8_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_l2_f16(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_l2_f16(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_SVE_F16 simsimd_l2_f16_sve(a, b, n, d); #elif SIMSIMD_TARGET_NEON_F16 @@ -1834,8 +1839,8 @@ SIMSIMD_PUBLIC void simsimd_l2_f16(simsimd_f16_t const* a, simsimd_f16_t const* simsimd_l2_f16_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_l2_bf16(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_l2_bf16(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_GENOA simsimd_l2_bf16_genoa(a, b, n, d); #elif SIMSIMD_TARGET_HASWELL @@ -1848,8 +1853,8 @@ SIMSIMD_PUBLIC void simsimd_l2_bf16(simsimd_bf16_t const* a, simsimd_bf16_t cons simsimd_l2_bf16_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_l2_f32(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_l2_f32(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_SVE simsimd_l2_f32_sve(a, b, n, d); #elif SIMSIMD_TARGET_NEON @@ -1862,8 +1867,8 @@ SIMSIMD_PUBLIC void simsimd_l2_f32(simsimd_f32_t const* a, simsimd_f32_t const* simsimd_l2_f32_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_l2_f64(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_l2_f64(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_SVE simsimd_l2_f64_sve(a, b, n, d); #elif SIMSIMD_TARGET_NEON @@ -1888,8 +1893,8 @@ SIMSIMD_PUBLIC void simsimd_l2_f64(simsimd_f64_t const* a, simsimd_f64_t const* * @note The output distance value is zero if and only if the two vectors are identical. * @note Defined only for binary data. */ -SIMSIMD_PUBLIC void simsimd_hamming_b8(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_hamming_b8(simsimd_b8_t const *a, simsimd_b8_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_SVE simsimd_hamming_b8_sve(a, b, n, d); #elif SIMSIMD_TARGET_NEON @@ -1902,8 +1907,8 @@ SIMSIMD_PUBLIC void simsimd_hamming_b8(simsimd_b8_t const* a, simsimd_b8_t const simsimd_hamming_b8_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_jaccard_b8(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_jaccard_b8(simsimd_b8_t const *a, simsimd_b8_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_SVE simsimd_jaccard_b8_sve(a, b, n, d); #elif SIMSIMD_TARGET_NEON @@ -1931,8 +1936,8 @@ SIMSIMD_PUBLIC void simsimd_jaccard_b8(simsimd_b8_t const* a, simsimd_b8_t const * @note The output divergence value is zero if and only if the two distributions are identical. * @note Defined only for floating-point data types. */ -SIMSIMD_PUBLIC void simsimd_kl_f16(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_kl_f16(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_NEON simsimd_kl_f16_neon(a, b, n, d); #elif SIMSIMD_TARGET_HASWELL @@ -1941,12 +1946,12 @@ SIMSIMD_PUBLIC void simsimd_kl_f16(simsimd_f16_t const* a, simsimd_f16_t const* simsimd_kl_f16_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_kl_bf16(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_kl_bf16(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { simsimd_kl_bf16_serial(a, b, n, d); } -SIMSIMD_PUBLIC void simsimd_kl_f32(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_kl_f32(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_NEON simsimd_kl_f32_neon(a, b, n, d); #elif SIMSIMD_TARGET_SKYLAKE @@ -1955,12 +1960,12 @@ SIMSIMD_PUBLIC void simsimd_kl_f32(simsimd_f32_t const* a, simsimd_f32_t const* simsimd_kl_f32_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_kl_f64(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_kl_f64(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { simsimd_kl_f64_serial(a, b, n, d); } -SIMSIMD_PUBLIC void simsimd_js_f16(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_js_f16(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_NEON simsimd_js_f16_neon(a, b, n, d); #elif SIMSIMD_TARGET_HASWELL @@ -1969,12 +1974,12 @@ SIMSIMD_PUBLIC void simsimd_js_f16(simsimd_f16_t const* a, simsimd_f16_t const* simsimd_js_f16_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_js_bf16(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_js_bf16(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { simsimd_js_bf16_serial(a, b, n, d); } -SIMSIMD_PUBLIC void simsimd_js_f32(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_js_f32(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { #if SIMSIMD_TARGET_NEON simsimd_js_f32_neon(a, b, n, d); #elif SIMSIMD_TARGET_SKYLAKE @@ -1983,8 +1988,8 @@ SIMSIMD_PUBLIC void simsimd_js_f32(simsimd_f32_t const* a, simsimd_f32_t const* simsimd_js_f32_serial(a, b, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_js_f64(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_js_f64(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *d) { simsimd_js_f64_serial(a, b, n, d); } @@ -1996,8 +2001,8 @@ SIMSIMD_PUBLIC void simsimd_js_f64(simsimd_f64_t const* a, simsimd_f64_t const* * @param b_length The number of elements in the second array. * @param d The output for the number of elements in the intersection. */ -SIMSIMD_PUBLIC void simsimd_intersect_u16(simsimd_u16_t const* a, simsimd_u16_t const* b, simsimd_size_t a_length, - simsimd_size_t b_length, simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_intersect_u16(simsimd_u16_t const *a, simsimd_u16_t const *b, simsimd_size_t a_length, + simsimd_size_t b_length, simsimd_distance_t *d) { #if SIMSIMD_TARGET_SVE2 simsimd_intersect_u16_sve2(a, b, a_length, b_length, d); #elif SIMSIMD_TARGET_NEON @@ -2009,8 +2014,8 @@ SIMSIMD_PUBLIC void simsimd_intersect_u16(simsimd_u16_t const* a, simsimd_u16_t #endif } -SIMSIMD_PUBLIC void simsimd_intersect_u32(simsimd_u32_t const* a, simsimd_u32_t const* b, simsimd_size_t a_length, - simsimd_size_t b_length, simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_intersect_u32(simsimd_u32_t const *a, simsimd_u32_t const *b, simsimd_size_t a_length, + simsimd_size_t b_length, simsimd_distance_t *d) { #if SIMSIMD_TARGET_SVE2 simsimd_intersect_u32_sve2(a, b, a_length, b_length, d); #elif SIMSIMD_TARGET_NEON @@ -2032,9 +2037,9 @@ SIMSIMD_PUBLIC void simsimd_intersect_u32(simsimd_u32_t const* a, simsimd_u32_t * @param b_length The number of elements in the second array. * @param d The output for the number of elements in the intersection. */ -SIMSIMD_PUBLIC void simsimd_spdot_counts_u16(simsimd_u16_t const* a, simsimd_u16_t const* b, - simsimd_i16_t const* a_weights, simsimd_i16_t const* b_weights, - simsimd_size_t a_length, simsimd_size_t b_length, simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_spdot_counts_u16(simsimd_u16_t const *a, simsimd_u16_t const *b, + simsimd_i16_t const *a_weights, simsimd_i16_t const *b_weights, + simsimd_size_t a_length, simsimd_size_t b_length, simsimd_distance_t *d) { #if SIMSIMD_TARGET_SVE2 simsimd_spdot_counts_u16_sve2(a, b, a_weights, b_weights, a_length, b_length, d); #elif SIMSIMD_TARGET_TURIN @@ -2044,9 +2049,9 @@ SIMSIMD_PUBLIC void simsimd_spdot_counts_u16(simsimd_u16_t const* a, simsimd_u16 #endif } -SIMSIMD_PUBLIC void simsimd_spdot_weights_u16(simsimd_u16_t const* a, simsimd_u16_t const* b, - simsimd_bf16_t const* a_weights, simsimd_bf16_t const* b_weights, - simsimd_size_t a_length, simsimd_size_t b_length, simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_spdot_weights_u16(simsimd_u16_t const *a, simsimd_u16_t const *b, + simsimd_bf16_t const *a_weights, simsimd_bf16_t const *b_weights, + simsimd_size_t a_length, simsimd_size_t b_length, simsimd_distance_t *d) { #if SIMSIMD_TARGET_SVE2 simsimd_spdot_weights_u16_sve2(a, b, a_weights, b_weights, a_length, b_length, d); #elif SIMSIMD_TARGET_TURIN @@ -2064,17 +2069,17 @@ SIMSIMD_PUBLIC void simsimd_spdot_weights_u16(simsimd_u16_t const* a, simsimd_u1 * @param n The number of dimensions in the vectors. * @param d The output for the number of elements in the intersection. */ -SIMSIMD_PUBLIC void simsimd_bilinear_f64(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_f64_t const* c, - simsimd_size_t n, simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_bilinear_f64(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_f64_t const *c, + simsimd_size_t n, simsimd_distance_t *d) { simsimd_bilinear_f64_serial(a, b, c, n, d); } -SIMSIMD_PUBLIC void simsimd_mahalanobis_f64(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_f64_t const* c, - simsimd_size_t n, simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_mahalanobis_f64(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_f64_t const *c, + simsimd_size_t n, simsimd_distance_t *d) { simsimd_mahalanobis_f64_serial(a, b, c, n, d); } -SIMSIMD_PUBLIC void simsimd_bilinear_f32(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_f32_t const* c, - simsimd_size_t n, simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_bilinear_f32(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_f32_t const *c, + simsimd_size_t n, simsimd_distance_t *d) { #if SIMSIMD_TARGET_SKYLAKE simsimd_bilinear_f32_skylake(a, b, c, n, d); #elif SIMSIMD_TARGET_NEON @@ -2083,8 +2088,8 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f32(simsimd_f32_t const* a, simsimd_f32_t c simsimd_bilinear_f32_serial(a, b, c, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_mahalanobis_f32(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_f32_t const* c, - simsimd_size_t n, simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_mahalanobis_f32(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_f32_t const *c, + simsimd_size_t n, simsimd_distance_t *d) { #if SIMSIMD_TARGET_SKYLAKE simsimd_mahalanobis_f32_skylake(a, b, c, n, d); #elif SIMSIMD_TARGET_NEON @@ -2093,8 +2098,8 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f32(simsimd_f32_t const* a, simsimd_f32_ simsimd_mahalanobis_f32_serial(a, b, c, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_bilinear_f16(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_f16_t const* c, - simsimd_size_t n, simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_bilinear_f16(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_f16_t const *c, + simsimd_size_t n, simsimd_distance_t *d) { #if SIMSIMD_TARGET_SAPPHIRE simsimd_bilinear_f16_sapphire(a, b, c, n, d); #elif SIMSIMD_TARGET_HASWELL @@ -2105,8 +2110,8 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f16(simsimd_f16_t const* a, simsimd_f16_t c simsimd_bilinear_f16_serial(a, b, c, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_mahalanobis_f16(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_f16_t const* c, - simsimd_size_t n, simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_mahalanobis_f16(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_f16_t const *c, + simsimd_size_t n, simsimd_distance_t *d) { #if SIMSIMD_TARGET_SAPPHIRE simsimd_mahalanobis_f16_sapphire(a, b, c, n, d); #elif SIMSIMD_TARGET_HASWELL @@ -2117,8 +2122,8 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f16(simsimd_f16_t const* a, simsimd_f16_ simsimd_mahalanobis_f16_serial(a, b, c, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_bilinear_bf16(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_bf16_t const* c, - simsimd_size_t n, simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_bilinear_bf16(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_bf16_t const *c, + simsimd_size_t n, simsimd_distance_t *d) { #if SIMSIMD_TARGET_GENOA simsimd_bilinear_bf16_genoa(a, b, c, n, d); #elif SIMSIMD_TARGET_HASWELL @@ -2129,8 +2134,8 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16(simsimd_bf16_t const* a, simsimd_bf16_ simsimd_bilinear_bf16_serial(a, b, c, n, d); #endif } -SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_bf16_t const* c, - simsimd_size_t n, simsimd_distance_t* d) { +SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_bf16_t const *c, + simsimd_size_t n, simsimd_distance_t *d) { #if SIMSIMD_TARGET_GENOA simsimd_mahalanobis_bf16_genoa(a, b, c, n, d); #elif SIMSIMD_TARGET_HASWELL diff --git a/include/simsimd/sparse.h b/include/simsimd/sparse.h index 450abe9d..fb5c849e 100644 --- a/include/simsimd/sparse.h +++ b/include/simsimd/sparse.h @@ -58,209 +58,204 @@ extern "C" { * but uses clever galloping logic, if the arrays significantly differ in size. */ SIMSIMD_PUBLIC void simsimd_intersect_u16_serial( // - simsimd_u16_t const* a, simsimd_u16_t const* b, // + simsimd_u16_t const *a, simsimd_u16_t const *b, // simsimd_size_t a_length, simsimd_size_t b_length, // - simsimd_distance_t* results); + simsimd_distance_t *results); SIMSIMD_PUBLIC void simsimd_intersect_u32_serial( // - simsimd_u32_t const* a, simsimd_u32_t const* b, // + simsimd_u32_t const *a, simsimd_u32_t const *b, // simsimd_size_t a_length, simsimd_size_t b_length, // - simsimd_distance_t* results); + simsimd_distance_t *results); SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_serial( // - simsimd_u16_t const* a, simsimd_u16_t const* b, // - simsimd_i16_t const* a_weights, simsimd_i16_t const* b_weights, // + simsimd_u16_t const *a, simsimd_u16_t const *b, // + simsimd_i16_t const *a_weights, simsimd_i16_t const *b_weights, // simsimd_size_t a_length, simsimd_size_t b_length, // - simsimd_distance_t* results); + simsimd_distance_t *results); SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_serial( // - simsimd_u16_t const* a, simsimd_u16_t const* b, // - simsimd_bf16_t const* a_weights, simsimd_bf16_t const* b_weights, // + simsimd_u16_t const *a, simsimd_u16_t const *b, // + simsimd_bf16_t const *a_weights, simsimd_bf16_t const *b_weights, // simsimd_size_t a_length, simsimd_size_t b_length, // - simsimd_distance_t* results); + simsimd_distance_t *results); /* Implements the most naive set intersection algorithm, similar to `std::set_intersection in C++ STL`, * naively enumerating the elements of two arrays. */ SIMSIMD_PUBLIC void simsimd_intersect_u16_accurate( // - simsimd_u16_t const* a, simsimd_u16_t const* b, // + simsimd_u16_t const *a, simsimd_u16_t const *b, // simsimd_size_t a_length, simsimd_size_t b_length, // - simsimd_distance_t* results); + simsimd_distance_t *results); SIMSIMD_PUBLIC void simsimd_intersect_u32_accurate( // - simsimd_u32_t const* a, simsimd_u32_t const* b, // + simsimd_u32_t const *a, simsimd_u32_t const *b, // simsimd_size_t a_length, simsimd_size_t b_length, // - simsimd_distance_t* results); + simsimd_distance_t *results); SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_accurate( // - simsimd_u16_t const* a, simsimd_u16_t const* b, // - simsimd_i16_t const* a_weights, simsimd_i16_t const* b_weights, // + simsimd_u16_t const *a, simsimd_u16_t const *b, // + simsimd_i16_t const *a_weights, simsimd_i16_t const *b_weights, // simsimd_size_t a_length, simsimd_size_t b_length, // - simsimd_distance_t* results); + simsimd_distance_t *results); SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_accurate( // - simsimd_u16_t const* a, simsimd_u16_t const* b, // - simsimd_bf16_t const* a_weights, simsimd_bf16_t const* b_weights, // + simsimd_u16_t const *a, simsimd_u16_t const *b, // + simsimd_bf16_t const *a_weights, simsimd_bf16_t const *b_weights, // simsimd_size_t a_length, simsimd_size_t b_length, // - simsimd_distance_t* results); + simsimd_distance_t *results); /* SIMD-powered backends for Arm SVE, mostly using 32-bit arithmetic over variable-length platform-defined word sizes. * Designed for Arm Graviton 3, Microsoft Cobalt, as well as Nvidia Grace and newer Ampere Altra CPUs. */ SIMSIMD_PUBLIC void simsimd_intersect_u16_sve2( // - simsimd_u16_t const* a, simsimd_u16_t const* b, // + simsimd_u16_t const *a, simsimd_u16_t const *b, // simsimd_size_t a_length, simsimd_size_t b_length, // - simsimd_distance_t* results); + simsimd_distance_t *results); SIMSIMD_PUBLIC void simsimd_intersect_u32_sve2( // - simsimd_u32_t const* a, simsimd_u32_t const* b, // + simsimd_u32_t const *a, simsimd_u32_t const *b, // simsimd_size_t a_length, simsimd_size_t b_length, // - simsimd_distance_t* results); + simsimd_distance_t *results); SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_sve2( // - simsimd_u16_t const* a, simsimd_u16_t const* b, // - simsimd_i16_t const* a_weights, simsimd_i16_t const* b_weights, // + simsimd_u16_t const *a, simsimd_u16_t const *b, // + simsimd_i16_t const *a_weights, simsimd_i16_t const *b_weights, // simsimd_size_t a_length, simsimd_size_t b_length, // - simsimd_distance_t* results); + simsimd_distance_t *results); SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_sve2( // - simsimd_u16_t const* a, simsimd_u16_t const* b, // - simsimd_bf16_t const* a_weights, simsimd_bf16_t const* b_weights, // + simsimd_u16_t const *a, simsimd_u16_t const *b, // + simsimd_bf16_t const *a_weights, simsimd_bf16_t const *b_weights, // simsimd_size_t a_length, simsimd_size_t b_length, // - simsimd_distance_t* results); + simsimd_distance_t *results); /* SIMD-powered backends for various generations of AVX512 CPUs. * Skylake is handy, as it supports masked loads and other operations, avoiding the need for the tail loop. * Ice Lake, however, is needed even for the most basic kernels to perform integer matching. */ SIMSIMD_PUBLIC void simsimd_intersect_u16_ice( // - simsimd_u16_t const* a, simsimd_u16_t const* b, // + simsimd_u16_t const *a, simsimd_u16_t const *b, // simsimd_size_t a_length, simsimd_size_t b_length, // - simsimd_distance_t* results); + simsimd_distance_t *results); SIMSIMD_PUBLIC void simsimd_intersect_u32_ice( // - simsimd_u32_t const* a, simsimd_u32_t const* b, // + simsimd_u32_t const *a, simsimd_u32_t const *b, // simsimd_size_t a_length, simsimd_size_t b_length, // - simsimd_distance_t* results); + simsimd_distance_t *results); /* SIMD-powered backends for AMD Turin CPUs with cheap VP2INTERSECT instructions. * On the Intel side, only mobile Tiger Lake support them, but have prohibitively high latency. */ SIMSIMD_PUBLIC void simsimd_intersect_u16_turin( // - simsimd_u16_t const* a, simsimd_u16_t const* b, // + simsimd_u16_t const *a, simsimd_u16_t const *b, // simsimd_size_t a_length, simsimd_size_t b_length, // - simsimd_distance_t* results); + simsimd_distance_t *results); SIMSIMD_PUBLIC void simsimd_intersect_u32_turin( // - simsimd_u32_t const* a, simsimd_u32_t const* b, // + simsimd_u32_t const *a, simsimd_u32_t const *b, // simsimd_size_t a_length, simsimd_size_t b_length, // - simsimd_distance_t* results); + simsimd_distance_t *results); SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_turin( // - simsimd_u16_t const* a, simsimd_u16_t const* b, // - simsimd_i16_t const* a_weights, simsimd_i16_t const* b_weights, // + simsimd_u16_t const *a, simsimd_u16_t const *b, // + simsimd_i16_t const *a_weights, simsimd_i16_t const *b_weights, // simsimd_size_t a_length, simsimd_size_t b_length, // - simsimd_distance_t* results); + simsimd_distance_t *results); SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_turin( // - simsimd_u16_t const* a, simsimd_u16_t const* b, // - simsimd_bf16_t const* a_weights, simsimd_bf16_t const* b_weights, // + simsimd_u16_t const *a, simsimd_u16_t const *b, // + simsimd_bf16_t const *a_weights, simsimd_bf16_t const *b_weights, // simsimd_size_t a_length, simsimd_size_t b_length, // - simsimd_distance_t* results); - -#define SIMSIMD_MAKE_INTERSECT_LINEAR(name, input_type, counter_type) \ - SIMSIMD_PUBLIC void simsimd_intersect_##input_type##_##name( \ - simsimd_##input_type##_t const* a, simsimd_##input_type##_t const* b, simsimd_size_t a_length, \ - simsimd_size_t b_length, simsimd_distance_t* result) { \ - simsimd_##counter_type##_t intersection_size = 0; \ - simsimd_size_t i = 0, j = 0; \ - while (i != a_length && j != b_length) { \ - simsimd_##input_type##_t ai = a[i]; \ - simsimd_##input_type##_t bj = b[j]; \ - intersection_size += ai == bj; \ - i += ai < bj; \ - j += ai >= bj; \ - } \ - *result = intersection_size; \ + simsimd_distance_t *results); + +#define SIMSIMD_MAKE_INTERSECT_LINEAR(name, input_type, counter_type) \ + SIMSIMD_PUBLIC void simsimd_intersect_##input_type##_##name( \ + simsimd_##input_type##_t const *a, simsimd_##input_type##_t const *b, simsimd_size_t a_length, \ + simsimd_size_t b_length, simsimd_distance_t *result) { \ + simsimd_##counter_type##_t intersection_size = 0; \ + simsimd_size_t i = 0, j = 0; \ + while (i != a_length && j != b_length) { \ + simsimd_##input_type##_t ai = a[i]; \ + simsimd_##input_type##_t bj = b[j]; \ + intersection_size += ai == bj; \ + i += ai < bj; \ + j += ai >= bj; \ + } \ + *result = intersection_size; \ } SIMSIMD_MAKE_INTERSECT_LINEAR(accurate, u16, size) // simsimd_intersect_u16_accurate SIMSIMD_MAKE_INTERSECT_LINEAR(accurate, u32, size) // simsimd_intersect_u32_accurate -#define SIMSIMD_MAKE_INTERSECT_WEIGHTED(name, input_type, counter_type, weight_type, accumulator_type, \ - load_and_convert) \ - SIMSIMD_PUBLIC void simsimd_intersect_##input_type##weight_type##_##name( \ - simsimd_##input_type##_t const* a, simsimd_##input_type##_t const* b, \ - simsimd_##weight_type##_t const* a_weights, simsimd_##weight_type##_t const* b_weights, \ - simsimd_size_t a_length, simsimd_size_t b_length, simsimd_distance_t* results) { \ - simsimd_##counter_type##_t intersection_size = 0; \ - simsimd_##accumulator_type##_t weights_product = 0; \ - simsimd_size_t i = 0, j = 0; \ - while (i != a_length && j != b_length) { \ - simsimd_##input_type##_t ai = a[i]; \ - simsimd_##input_type##_t bj = b[j]; \ - int matches = ai == bj; \ - simsimd_##counter_type##_t awi = load_and_convert(a_weights + i); \ - simsimd_##counter_type##_t bwi = load_and_convert(b_weights + i); \ - weights_product += matches * awi * bwi; \ - intersection_size += matches; \ - i += ai < bj; \ - j += ai >= bj; \ - } \ - results[0] = intersection_size; \ - results[1] = weights_product; \ +#define SIMSIMD_MAKE_INTERSECT_WEIGHTED(name, variation, input_type, counter_type, weight_type, accumulator_type, \ + load_and_convert) \ + SIMSIMD_PUBLIC void simsimd_##variation##_##input_type##_##name( \ + simsimd_##input_type##_t const *a, simsimd_##input_type##_t const *b, \ + simsimd_##weight_type##_t const *a_weights, simsimd_##weight_type##_t const *b_weights, \ + simsimd_size_t a_length, simsimd_size_t b_length, simsimd_distance_t *results) { \ + simsimd_##counter_type##_t intersection_size = 0; \ + simsimd_##accumulator_type##_t weights_product = 0; \ + simsimd_size_t i = 0, j = 0; \ + while (i != a_length && j != b_length) { \ + simsimd_##input_type##_t ai = a[i]; \ + simsimd_##input_type##_t bj = b[j]; \ + int matches = ai == bj; \ + simsimd_##counter_type##_t awi = load_and_convert(a_weights + i); \ + simsimd_##counter_type##_t bwi = load_and_convert(b_weights + i); \ + weights_product += matches * awi * bwi; \ + intersection_size += matches; \ + i += ai < bj; \ + j += ai >= bj; \ + } \ + results[0] = intersection_size; \ + results[1] = weights_product; \ } -SIMSIMD_MAKE_INTERSECT_WEIGHTED(accurate, u16, size, i16, i64, +SIMSIMD_MAKE_INTERSECT_WEIGHTED(accurate, spdot_counts, u16, size, i16, i64, SIMSIMD_DEREFERENCE) // simsimd_spdot_counts_u16_accurate -SIMSIMD_MAKE_INTERSECT_WEIGHTED(accurate, u16, size, bf16, f64, +SIMSIMD_MAKE_INTERSECT_WEIGHTED(accurate, spdot_weights, u16, size, bf16, f64, SIMSIMD_BF16_TO_F32) // simsimd_spdot_weights_u16_accurate -#define SIMSIMD_MAKE_INTERSECT_GALLOPING(name, input_type, counter_type) \ - SIMSIMD_PUBLIC simsimd_size_t simsimd_galloping_search_##input_type(simsimd_##input_type##_t const* array, \ - simsimd_size_t start, simsimd_size_t length, \ - simsimd_##input_type##_t val) { \ - simsimd_size_t low = start; \ - simsimd_size_t high = start + 1; \ - while (high < length && array[high] < val) { \ - low = high; \ - high = (2 * high < length) ? 2 * high : length; \ - } \ - while (low < high) { \ - simsimd_size_t mid = low + (high - low) / 2; \ - if (array[mid] < val) { \ - low = mid + 1; \ - } else { \ - high = mid; \ - } \ - } \ - return low; \ - } \ - \ - SIMSIMD_PUBLIC void simsimd_intersect_##input_type##_##name( \ - simsimd_##input_type##_t const* shorter, simsimd_##input_type##_t const* longer, \ - simsimd_size_t shorter_length, simsimd_size_t longer_length, simsimd_distance_t* result) { \ - /* Swap arrays if necessary, as we want "longer" to be larger than "shorter" */ \ - if (longer_length < shorter_length) { \ - simsimd_##input_type##_t const* temp = shorter; \ - shorter = longer; \ - longer = temp; \ - simsimd_size_t temp_length = shorter_length; \ - shorter_length = longer_length; \ - longer_length = temp_length; \ - } \ - \ - /* Use the accurate implementation if galloping is not beneficial */ \ - if (longer_length < 64 * shorter_length) { \ - simsimd_intersect_##input_type##_accurate(shorter, longer, shorter_length, longer_length, result); \ - return; \ - } \ - \ - /* Perform galloping, shrinking the target range */ \ - simsimd_##counter_type##_t intersection_size = 0; \ - simsimd_size_t j = 0; \ - for (simsimd_size_t i = 0; i < shorter_length; ++i) { \ - simsimd_##input_type##_t shorter_i = shorter[i]; \ - j = simsimd_galloping_search_##input_type(longer, j, longer_length, shorter_i); \ - if (j < longer_length && longer[j] == shorter_i) { \ - intersection_size++; \ - } \ - } \ - *result = intersection_size; \ +#define SIMSIMD_MAKE_INTERSECT_GALLOPING(name, input_type, counter_type) \ + SIMSIMD_PUBLIC simsimd_size_t simsimd_galloping_search_##input_type(simsimd_##input_type##_t const *array, \ + simsimd_size_t start, simsimd_size_t length, \ + simsimd_##input_type##_t val) { \ + simsimd_size_t low = start; \ + simsimd_size_t high = start + 1; \ + while (high < length && array[high] < val) { \ + low = high; \ + high = (2 * high < length) ? 2 * high : length; \ + } \ + while (low < high) { \ + simsimd_size_t mid = low + (high - low) / 2; \ + if (array[mid] < val) { low = mid + 1; } \ + else { high = mid; } \ + } \ + return low; \ + } \ + \ + SIMSIMD_PUBLIC void simsimd_intersect_##input_type##_##name( \ + simsimd_##input_type##_t const *shorter, simsimd_##input_type##_t const *longer, \ + simsimd_size_t shorter_length, simsimd_size_t longer_length, simsimd_distance_t *result) { \ + /* Swap arrays if necessary, as we want "longer" to be larger than "shorter" */ \ + if (longer_length < shorter_length) { \ + simsimd_##input_type##_t const *temp = shorter; \ + shorter = longer; \ + longer = temp; \ + simsimd_size_t temp_length = shorter_length; \ + shorter_length = longer_length; \ + longer_length = temp_length; \ + } \ + \ + /* Use the accurate implementation if galloping is not beneficial */ \ + if (longer_length < 64 * shorter_length) { \ + simsimd_intersect_##input_type##_accurate(shorter, longer, shorter_length, longer_length, result); \ + return; \ + } \ + \ + /* Perform galloping, shrinking the target range */ \ + simsimd_##counter_type##_t intersection_size = 0; \ + simsimd_size_t j = 0; \ + for (simsimd_size_t i = 0; i < shorter_length; ++i) { \ + simsimd_##input_type##_t shorter_i = shorter[i]; \ + j = simsimd_galloping_search_##input_type(longer, j, longer_length, shorter_i); \ + if (j < longer_length && longer[j] == shorter_i) { intersection_size++; } \ + } \ + *result = intersection_size; \ } SIMSIMD_MAKE_INTERSECT_GALLOPING(serial, u16, size) // simsimd_intersect_u16_serial SIMSIMD_MAKE_INTERSECT_GALLOPING(serial, u32, size) // simsimd_intersect_u32_serial -SIMSIMD_MAKE_INTERSECT_WEIGHTED(serial, u16, size, i16, i32, +SIMSIMD_MAKE_INTERSECT_WEIGHTED(serial, spdot_counts, u16, size, i16, i32, SIMSIMD_DEREFERENCE) // simsimd_spdot_counts_u16_serial -SIMSIMD_MAKE_INTERSECT_WEIGHTED(serial, u16, size, bf16, f32, +SIMSIMD_MAKE_INTERSECT_WEIGHTED(serial, spdot_weights, u16, size, bf16, f32, SIMSIMD_BF16_TO_F32) // simsimd_spdot_weights_u16_serial /* The AVX-512 implementations are inspired by the "Faster-Than-Native Alternatives @@ -276,11 +271,11 @@ SIMSIMD_MAKE_INTERSECT_WEIGHTED(serial, u16, size, bf16, f32, * - `_mm512_permutexvar_epi16` - needs BW - 4-6 cycles latency * - `_mm512_permutexvar_epi8` - needs VBMI - 3 cycles latency */ -#if SIMSIMD_TARGET_X86 +#if _SIMSIMD_TARGET_X86 #if SIMSIMD_TARGET_ICE #pragma GCC push_options #pragma GCC target("avx2", "avx512f", "avx512vl", "bmi2", "lzcnt", "popcnt", "avx512bw", "avx512vbmi2") -#pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,bmi2,lzcnt,popcnt,avx512bw,avx512vbmi2"))), \ +#pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,bmi2,lzcnt,popcnt,avx512bw,avx512vbmi2"))), \ apply_to = function) /** @@ -396,9 +391,9 @@ SIMSIMD_INTERNAL simsimd_u16_t _simsimd_intersect_u32x16_ice(__m512i a, __m512i } SIMSIMD_PUBLIC void simsimd_intersect_u16_ice( // - simsimd_u16_t const* a, simsimd_u16_t const* b, // + simsimd_u16_t const *a, simsimd_u16_t const *b, // simsimd_size_t a_length, simsimd_size_t b_length, // - simsimd_distance_t* results) { + simsimd_distance_t *results) { // The baseline implementation for very small arrays (2 registers or less) can be quite simple: if (a_length < 64 && b_length < 64) { @@ -406,8 +401,8 @@ SIMSIMD_PUBLIC void simsimd_intersect_u16_ice( // return; } - simsimd_u16_t const* const a_end = a + a_length; - simsimd_u16_t const* const b_end = b + b_length; + simsimd_u16_t const *const a_end = a + a_length; + simsimd_u16_t const *const b_end = b + b_length; simsimd_size_t c = 0; union vec_t { __m512i zmm; @@ -416,8 +411,8 @@ SIMSIMD_PUBLIC void simsimd_intersect_u16_ice( // } a_vec, b_vec; while (a + 32 < a_end && b + 32 < b_end) { - a_vec.zmm = _mm512_loadu_si512((__m512i const*)a); - b_vec.zmm = _mm512_loadu_si512((__m512i const*)b); + a_vec.zmm = _mm512_loadu_si512((__m512i const *)a); + b_vec.zmm = _mm512_loadu_si512((__m512i const *)b); // Intersecting registers with `_simsimd_intersect_u16x32_ice` involves a lot of shuffling // and comparisons, so we want to avoid it if the slices don't overlap at all.. @@ -429,13 +424,13 @@ SIMSIMD_PUBLIC void simsimd_intersect_u16_ice( // // If the slices don't overlap, advance the appropriate pointer while (a_max < b_min && a + 64 < a_end) { a += 32; - a_vec.zmm = _mm512_loadu_si512((__m512i const*)a); + a_vec.zmm = _mm512_loadu_si512((__m512i const *)a); a_max = a_vec.u16[31]; } a_min = a_vec.u16[0]; while (b_max < a_min && b + 64 < b_end) { b += 32; - b_vec.zmm = _mm512_loadu_si512((__m512i const*)b); + b_vec.zmm = _mm512_loadu_si512((__m512i const *)b); b_max = b_vec.u16[31]; } b_min = b_vec.u16[0]; @@ -448,8 +443,8 @@ SIMSIMD_PUBLIC void simsimd_intersect_u16_ice( // // _mm512_mask_compressstoreu_epi16(c, a_matches, a_vec); c += _mm_popcnt_u32(a_matches); // MSVC has no `_popcnt32` - __m512i a_last_broadcasted = _mm512_set1_epi16(*(short const*)&a_max); - __m512i b_last_broadcasted = _mm512_set1_epi16(*(short const*)&b_max); + __m512i a_last_broadcasted = _mm512_set1_epi16(*(short const *)&a_max); + __m512i b_last_broadcasted = _mm512_set1_epi16(*(short const *)&b_max); __mmask32 a_step_mask = _mm512_cmple_epu16_mask(a_vec.zmm, b_last_broadcasted); __mmask32 b_step_mask = _mm512_cmple_epu16_mask(b_vec.zmm, a_last_broadcasted); a += 32 - _lzcnt_u32((simsimd_u32_t)a_step_mask); @@ -461,9 +456,9 @@ SIMSIMD_PUBLIC void simsimd_intersect_u16_ice( // } SIMSIMD_PUBLIC void simsimd_intersect_u32_ice( // - simsimd_u32_t const* a, simsimd_u32_t const* b, // + simsimd_u32_t const *a, simsimd_u32_t const *b, // simsimd_size_t a_length, simsimd_size_t b_length, // - simsimd_distance_t* results) { + simsimd_distance_t *results) { // The baseline implementation for very small arrays (2 registers or less) can be quite simple: if (a_length < 32 && b_length < 32) { @@ -471,8 +466,8 @@ SIMSIMD_PUBLIC void simsimd_intersect_u32_ice( // return; } - simsimd_u32_t const* const a_end = a + a_length; - simsimd_u32_t const* const b_end = b + b_length; + simsimd_u32_t const *const a_end = a + a_length; + simsimd_u32_t const *const b_end = b + b_length; simsimd_size_t c = 0; union vec_t { __m512i zmm; @@ -481,8 +476,8 @@ SIMSIMD_PUBLIC void simsimd_intersect_u32_ice( // } a_vec, b_vec; while (a + 16 < a_end && b + 16 < b_end) { - a_vec.zmm = _mm512_loadu_si512((__m512i const*)a); - b_vec.zmm = _mm512_loadu_si512((__m512i const*)b); + a_vec.zmm = _mm512_loadu_si512((__m512i const *)a); + b_vec.zmm = _mm512_loadu_si512((__m512i const *)b); // Intersecting registers with `_simsimd_intersect_u32x16_ice` involves a lot of shuffling // and comparisons, so we want to avoid it if the slices don't overlap at all.. @@ -494,13 +489,13 @@ SIMSIMD_PUBLIC void simsimd_intersect_u32_ice( // // If the slices don't overlap, advance the appropriate pointer while (a_max < b_min && a + 32 < a_end) { a += 16; - a_vec.zmm = _mm512_loadu_si512((__m512i const*)a); + a_vec.zmm = _mm512_loadu_si512((__m512i const *)a); a_max = a_vec.u32[15]; } a_min = a_vec.u32[0]; while (b_max < a_min && b + 32 < b_end) { b += 16; - b_vec.zmm = _mm512_loadu_si512((__m512i const*)b); + b_vec.zmm = _mm512_loadu_si512((__m512i const *)b); b_max = b_vec.u32[15]; } b_min = b_vec.u32[0]; @@ -513,8 +508,8 @@ SIMSIMD_PUBLIC void simsimd_intersect_u32_ice( // // _mm512_mask_compressstoreu_epi32(c, a_matches, a_vec); c += _mm_popcnt_u32(a_matches); // MSVC has no `_popcnt32` - __m512i a_last_broadcasted = _mm512_set1_epi32(*(int const*)&a_max); - __m512i b_last_broadcasted = _mm512_set1_epi32(*(int const*)&b_max); + __m512i a_last_broadcasted = _mm512_set1_epi32(*(int const *)&a_max); + __m512i b_last_broadcasted = _mm512_set1_epi32(*(int const *)&b_max); __mmask16 a_step_mask = _mm512_cmple_epu32_mask(a_vec.zmm, b_last_broadcasted); __mmask16 b_step_mask = _mm512_cmple_epu32_mask(b_vec.zmm, a_last_broadcasted); a += 32 - _lzcnt_u32((simsimd_u32_t)a_step_mask); @@ -531,17 +526,17 @@ SIMSIMD_PUBLIC void simsimd_intersect_u32_ice( // #if SIMSIMD_TARGET_TURIN #pragma GCC push_options -#pragma GCC target("avx2", "avx512f", "avx512vl", "bmi2", "lzcnt", "popcnt", "avx512bw", "avx512vbmi2", "avx512bf16", \ +#pragma GCC target("avx2", "avx512f", "avx512vl", "bmi2", "lzcnt", "popcnt", "avx512bw", "avx512vbmi2", "avx512bf16", \ "avx512vp2intersect") -#pragma clang attribute push( \ - __attribute__(( \ - target("avx2,avx512f,avx512vl,bmi2,lzcnt,popcnt,avx512bw,avx512vbmi2,avx512bf16,avx512vp2intersect"))), \ +#pragma clang attribute push( \ + __attribute__(( \ + target("avx2,avx512f,avx512vl,bmi2,lzcnt,popcnt,avx512bw,avx512vbmi2,avx512bf16,avx512vp2intersect"))), \ apply_to = function) SIMSIMD_PUBLIC void simsimd_intersect_u16_turin( // - simsimd_u16_t const* a, simsimd_u16_t const* b, // + simsimd_u16_t const *a, simsimd_u16_t const *b, // simsimd_size_t a_length, simsimd_size_t b_length, // - simsimd_distance_t* results) { + simsimd_distance_t *results) { // The baseline implementation for very small arrays (2 registers or less) can be quite simple: if (a_length < 64 && b_length < 64) { @@ -552,8 +547,8 @@ SIMSIMD_PUBLIC void simsimd_intersect_u16_turin( // //! There is no such thing as `_mm512_2intersect_epi16`, only the 32-bit variant! //! So instead of jumping through 32 entries at a time, like on Ice Lake, we will //! step through 16 entries at a time. - simsimd_u16_t const* const a_end = a + a_length; - simsimd_u16_t const* const b_end = b + b_length; + simsimd_u16_t const *const a_end = a + a_length; + simsimd_u16_t const *const b_end = b + b_length; simsimd_size_t c = 0; union vec_t { __m256i ymm; @@ -562,8 +557,8 @@ SIMSIMD_PUBLIC void simsimd_intersect_u16_turin( // } a_vec, b_vec; while (a + 16 < a_end && b + 16 < b_end) { - a_vec.ymm = _mm256_lddqu_si256((__m256i const*)a); - b_vec.ymm = _mm256_lddqu_si256((__m256i const*)b); + a_vec.ymm = _mm256_lddqu_si256((__m256i const *)a); + b_vec.ymm = _mm256_lddqu_si256((__m256i const *)b); // Intersecting registers with `_mm512_2intersect_epi16_mask` involves a lot of shuffling // and comparisons, so we want to avoid it if the slices don't overlap at all.. @@ -575,13 +570,13 @@ SIMSIMD_PUBLIC void simsimd_intersect_u16_turin( // // If the slices don't overlap, advance the appropriate pointer while (a_max < b_min && a + 32 < a_end) { a += 16; - a_vec.ymm = _mm256_lddqu_si256((__m256i const*)a); + a_vec.ymm = _mm256_lddqu_si256((__m256i const *)a); a_max = a_vec.u16[15]; } a_min = a_vec.u16[0]; while (b_max < a_min && b + 32 < b_end) { b += 16; - b_vec.ymm = _mm256_lddqu_si256((__m256i const*)b); + b_vec.ymm = _mm256_lddqu_si256((__m256i const *)b); b_max = b_vec.u16[15]; } b_min = b_vec.u16[0]; @@ -597,8 +592,8 @@ SIMSIMD_PUBLIC void simsimd_intersect_u16_turin( // // _mm512_mask_compressstoreu_epi16(c, a_matches_any_in_b, a_vec); c += _mm_popcnt_u32(a_matches_any_in_b); // MSVC has no `_popcnt32` - __m256i a_last_broadcasted = _mm256_set1_epi16(*(short const*)&a_max); - __m256i b_last_broadcasted = _mm256_set1_epi16(*(short const*)&b_max); + __m256i a_last_broadcasted = _mm256_set1_epi16(*(short const *)&a_max); + __m256i b_last_broadcasted = _mm256_set1_epi16(*(short const *)&b_max); __mmask16 a_step_mask = _mm256_cmple_epu16_mask(a_vec.ymm, b_last_broadcasted); __mmask16 b_step_mask = _mm256_cmple_epu16_mask(b_vec.ymm, a_last_broadcasted); a += 32 - _lzcnt_u32((simsimd_u32_t)a_step_mask); //? Is this correct? Needs testing! @@ -610,9 +605,9 @@ SIMSIMD_PUBLIC void simsimd_intersect_u16_turin( // } SIMSIMD_PUBLIC void simsimd_intersect_u32_turin( // - simsimd_u32_t const* a, simsimd_u32_t const* b, // + simsimd_u32_t const *a, simsimd_u32_t const *b, // simsimd_size_t a_length, simsimd_size_t b_length, // - simsimd_distance_t* results) { + simsimd_distance_t *results) { // The baseline implementation for very small arrays (2 registers or less) can be quite simple: if (a_length < 32 && b_length < 32) { @@ -620,8 +615,8 @@ SIMSIMD_PUBLIC void simsimd_intersect_u32_turin( // return; } - simsimd_u32_t const* const a_end = a + a_length; - simsimd_u32_t const* const b_end = b + b_length; + simsimd_u32_t const *const a_end = a + a_length; + simsimd_u32_t const *const b_end = b + b_length; simsimd_size_t c = 0; union vec_t { __m512i zmm; @@ -630,8 +625,8 @@ SIMSIMD_PUBLIC void simsimd_intersect_u32_turin( // } a_vec, b_vec; while (a + 16 < a_end && b + 16 < b_end) { - a_vec.zmm = _mm512_loadu_si512((__m512i const*)a); - b_vec.zmm = _mm512_loadu_si512((__m512i const*)b); + a_vec.zmm = _mm512_loadu_si512((__m512i const *)a); + b_vec.zmm = _mm512_loadu_si512((__m512i const *)b); // Intersecting registers with `_mm512_2intersect_epi32` involves a lot of shuffling // and comparisons, so we want to avoid it if the slices don't overlap at all.. @@ -643,13 +638,13 @@ SIMSIMD_PUBLIC void simsimd_intersect_u32_turin( // // If the slices don't overlap, advance the appropriate pointer while (a_max < b_min && a + 32 < a_end) { a += 16; - a_vec.zmm = _mm512_loadu_si512((__m512i const*)a); + a_vec.zmm = _mm512_loadu_si512((__m512i const *)a); a_max = a_vec.u32[15]; } a_min = a_vec.u32[0]; while (b_max < a_min && b + 32 < b_end) { b += 16; - b_vec.zmm = _mm512_loadu_si512((__m512i const*)b); + b_vec.zmm = _mm512_loadu_si512((__m512i const *)b); b_max = b_vec.u32[15]; } b_min = b_vec.u32[0]; @@ -663,8 +658,8 @@ SIMSIMD_PUBLIC void simsimd_intersect_u32_turin( // // _mm512_mask_compressstoreu_epi32(c, a_matches_any_in_b, a_vec); c += _mm_popcnt_u32(a_matches_any_in_b); // MSVC has no `_popcnt32` - __m512i a_last_broadcasted = _mm512_set1_epi32(*(int const*)&a_max); - __m512i b_last_broadcasted = _mm512_set1_epi32(*(int const*)&b_max); + __m512i a_last_broadcasted = _mm512_set1_epi32(*(int const *)&a_max); + __m512i b_last_broadcasted = _mm512_set1_epi32(*(int const *)&b_max); __mmask16 a_step_mask = _mm512_cmple_epu32_mask(a_vec.zmm, b_last_broadcasted); __mmask16 b_step_mask = _mm512_cmple_epu32_mask(b_vec.zmm, a_last_broadcasted); a += 32 - _lzcnt_u32((simsimd_u32_t)a_step_mask); @@ -676,10 +671,10 @@ SIMSIMD_PUBLIC void simsimd_intersect_u32_turin( // } SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_turin( // - simsimd_u16_t const* a, simsimd_u16_t const* b, // - simsimd_bf16_t const* a_weights, simsimd_bf16_t const* b_weights, // + simsimd_u16_t const *a, simsimd_u16_t const *b, // + simsimd_bf16_t const *a_weights, simsimd_bf16_t const *b_weights, // simsimd_size_t a_length, simsimd_size_t b_length, // - simsimd_distance_t* results) { + simsimd_distance_t *results) { // The baseline implementation for very small arrays (2 registers or less) can be quite simple: if (a_length < 64 && b_length < 64) { @@ -690,8 +685,8 @@ SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_turin( // //! There is no such thing as `_mm512_2intersect_epi16`, only the 32-bit variant! //! So instead of jumping through 32 entries at a time, like on Ice Lake, we will //! step through 16 entries at a time. - simsimd_u16_t const* const a_end = a + a_length; - simsimd_u16_t const* const b_end = b + b_length; + simsimd_u16_t const *const a_end = a + a_length; + simsimd_u16_t const *const b_end = b + b_length; simsimd_size_t intersection_size = 0; union vec_t { __m256i ymm; @@ -702,8 +697,8 @@ SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_turin( // product_vec.ymmps = _mm256_setzero_ps(); while (a + 16 < a_end && b + 16 < b_end) { - a_vec.ymm = _mm256_lddqu_si256((__m256i const*)a); - b_vec.ymm = _mm256_lddqu_si256((__m256i const*)b); + a_vec.ymm = _mm256_lddqu_si256((__m256i const *)a); + b_vec.ymm = _mm256_lddqu_si256((__m256i const *)b); // Intersecting registers with `_mm512_2intersect_epi16_mask` involves a lot of shuffling // and comparisons, so we want to avoid it if the slices don't overlap at all.. @@ -715,13 +710,13 @@ SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_turin( // // If the slices don't overlap, advance the appropriate pointer while (a_max < b_min && a + 32 < a_end) { a += 16, a_weights += 16; - a_vec.ymm = _mm256_lddqu_si256((__m256i const*)a); + a_vec.ymm = _mm256_lddqu_si256((__m256i const *)a); a_max = a_vec.u16[15]; } a_min = a_vec.u16[0]; while (b_max < a_min && b + 32 < b_end) { b += 16, b_weights += 16; - b_vec.ymm = _mm256_lddqu_si256((__m256i const*)b); + b_vec.ymm = _mm256_lddqu_si256((__m256i const *)b); b_max = b_vec.u16[15]; } b_min = b_vec.u16[0]; @@ -740,15 +735,15 @@ SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_turin( // // Load and shift all the relevant weights to the start of the vector before doing the dot product if (a_matches_count_in_b) { - __m256i a_weights_vec = _mm256_lddqu_si256((__m256i const*)a_weights); + __m256i a_weights_vec = _mm256_lddqu_si256((__m256i const *)a_weights); a_weights_vec = _mm256_maskz_compress_epi16(a_matches_any_in_b, a_weights_vec); - __m256i b_weights_vec = _mm256_lddqu_si256((__m256i const*)b_weights); + __m256i b_weights_vec = _mm256_lddqu_si256((__m256i const *)b_weights); b_weights_vec = _mm256_maskz_compress_epi16(b_matches_any_in_a, b_weights_vec); product_vec.ymmps = _mm256_dpbf16_ps(product_vec.ymmps, (__m256bh)a_weights_vec, (__m256bh)b_weights_vec); } - __m256i a_last_broadcasted = _mm256_set1_epi16(*(short const*)&a_max); - __m256i b_last_broadcasted = _mm256_set1_epi16(*(short const*)&b_max); + __m256i a_last_broadcasted = _mm256_set1_epi16(*(short const *)&a_max); + __m256i b_last_broadcasted = _mm256_set1_epi16(*(short const *)&b_max); __mmask16 a_step_mask = _mm256_cmple_epu16_mask(a_vec.ymm, b_last_broadcasted); __mmask16 b_step_mask = _mm256_cmple_epu16_mask(b_vec.ymm, a_last_broadcasted); int a_step = 32 - _lzcnt_u32((simsimd_u32_t)a_step_mask); //? Is this correct? Needs testing! @@ -762,10 +757,10 @@ SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_turin( // } SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_turin( // - simsimd_u16_t const* a, simsimd_u16_t const* b, // - simsimd_i16_t const* a_weights, simsimd_i16_t const* b_weights, // + simsimd_u16_t const *a, simsimd_u16_t const *b, // + simsimd_i16_t const *a_weights, simsimd_i16_t const *b_weights, // simsimd_size_t a_length, simsimd_size_t b_length, // - simsimd_distance_t* results) { + simsimd_distance_t *results) { // The baseline implementation for very small arrays (2 registers or less) can be quite simple: if (a_length < 64 && b_length < 64) { @@ -776,8 +771,8 @@ SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_turin( // //! There is no such thing as `_mm512_2intersect_epi16`, only the 32-bit variant! //! So instead of jumping through 32 entries at a time, like on Ice Lake, we will //! step through 16 entries at a time. - simsimd_u16_t const* const a_end = a + a_length; - simsimd_u16_t const* const b_end = b + b_length; + simsimd_u16_t const *const a_end = a + a_length; + simsimd_u16_t const *const b_end = b + b_length; simsimd_size_t intersection_size = 0; union vec_t { __m256i ymm; @@ -787,8 +782,8 @@ SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_turin( // product_vec.ymm = _mm256_setzero_si256(); while (a + 16 < a_end && b + 16 < b_end) { - a_vec.ymm = _mm256_lddqu_si256((__m256i const*)a); - b_vec.ymm = _mm256_lddqu_si256((__m256i const*)b); + a_vec.ymm = _mm256_lddqu_si256((__m256i const *)a); + b_vec.ymm = _mm256_lddqu_si256((__m256i const *)b); // Intersecting registers with `_mm512_2intersect_epi16_mask` involves a lot of shuffling // and comparisons, so we want to avoid it if the slices don't overlap at all.. @@ -800,13 +795,13 @@ SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_turin( // // If the slices don't overlap, advance the appropriate pointer while (a_max < b_min && a + 32 < a_end) { a += 16, a_weights += 16; - a_vec.ymm = _mm256_lddqu_si256((__m256i const*)a); + a_vec.ymm = _mm256_lddqu_si256((__m256i const *)a); a_max = a_vec.u16[15]; } a_min = a_vec.u16[0]; while (b_max < a_min && b + 32 < b_end) { b += 16, b_weights += 16; - b_vec.ymm = _mm256_lddqu_si256((__m256i const*)b); + b_vec.ymm = _mm256_lddqu_si256((__m256i const *)b); b_max = b_vec.u16[15]; } b_min = b_vec.u16[0]; @@ -825,15 +820,15 @@ SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_turin( // // Load and shift all the relevant weights to the start of the vector before doing the dot product if (a_matches_count_in_b) { - __m256i a_weights_vec = _mm256_lddqu_si256((__m256i const*)a_weights); + __m256i a_weights_vec = _mm256_lddqu_si256((__m256i const *)a_weights); a_weights_vec = _mm256_maskz_compress_epi16(a_matches_any_in_b, a_weights_vec); - __m256i b_weights_vec = _mm256_lddqu_si256((__m256i const*)b_weights); + __m256i b_weights_vec = _mm256_lddqu_si256((__m256i const *)b_weights); b_weights_vec = _mm256_maskz_compress_epi16(b_matches_any_in_a, b_weights_vec); product_vec.ymm = _mm256_dpwssds_epi32(product_vec.ymm, a_weights_vec, b_weights_vec); } - __m256i a_last_broadcasted = _mm256_set1_epi16(*(short const*)&a_max); - __m256i b_last_broadcasted = _mm256_set1_epi16(*(short const*)&b_max); + __m256i a_last_broadcasted = _mm256_set1_epi16(*(short const *)&a_max); + __m256i b_last_broadcasted = _mm256_set1_epi16(*(short const *)&b_max); __mmask16 a_step_mask = _mm256_cmple_epu16_mask(a_vec.ymm, b_last_broadcasted); __mmask16 b_step_mask = _mm256_cmple_epu16_mask(b_vec.ymm, a_last_broadcasted); int a_step = 32 - _lzcnt_u32((simsimd_u32_t)a_step_mask); //? Is this correct? Needs testing! @@ -849,9 +844,9 @@ SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_turin( // #pragma clang attribute pop #pragma GCC pop_options #endif // SIMSIMD_TARGET_TURIN -#endif // SIMSIMD_TARGET_X86 +#endif // _SIMSIMD_TARGET_X86 -#if SIMSIMD_TARGET_ARM +#if _SIMSIMD_TARGET_ARM #if SIMSIMD_TARGET_NEON #pragma GCC push_options #pragma GCC target("arch=armv8.2-a") @@ -871,8 +866,7 @@ SIMSIMD_INTERNAL int _simsimd_clz_u64(simsimd_u64_t x) { return __builtin_clzll(x); #else int n = 0; - while ((x & 0x8000000000000000ull) == 0) - n++, x <<= 1; + while ((x & 0x8000000000000000ull) == 0) n++, x <<= 1; return n; #endif } @@ -911,9 +905,9 @@ SIMSIMD_INTERNAL uint16x8_t _simsimd_intersect_u16x8_neon(uint16x8_t a, uint16x8 } SIMSIMD_PUBLIC void simsimd_intersect_u16_neon( // - simsimd_u16_t const* a, simsimd_u16_t const* b, // + simsimd_u16_t const *a, simsimd_u16_t const *b, // simsimd_size_t a_length, simsimd_size_t b_length, // - simsimd_distance_t* results) { + simsimd_distance_t *results) { // The baseline implementation for very small arrays (2 registers or less) can be quite simple: if (a_length < 32 && b_length < 32) { @@ -921,8 +915,8 @@ SIMSIMD_PUBLIC void simsimd_intersect_u16_neon( // return; } - simsimd_u16_t const* const a_end = a + a_length; - simsimd_u16_t const* const b_end = b + b_length; + simsimd_u16_t const *const a_end = a + a_length; + simsimd_u16_t const *const b_end = b + b_length; union vec_t { uint16x8_t u16x8; simsimd_u16_t u16[8]; @@ -994,9 +988,9 @@ SIMSIMD_PUBLIC void simsimd_intersect_u16_neon( // } SIMSIMD_PUBLIC void simsimd_intersect_u32_neon( // - simsimd_u32_t const* a, simsimd_u32_t const* b, // + simsimd_u32_t const *a, simsimd_u32_t const *b, // simsimd_size_t a_length, simsimd_size_t b_length, // - simsimd_distance_t* results) { + simsimd_distance_t *results) { // The baseline implementation for very small arrays (2 registers or less) can be quite simple: if (a_length < 32 && b_length < 32) { @@ -1004,8 +998,8 @@ SIMSIMD_PUBLIC void simsimd_intersect_u32_neon( // return; } - simsimd_u32_t const* const a_end = a + a_length; - simsimd_u32_t const* const b_end = b + b_length; + simsimd_u32_t const *const a_end = a + a_length; + simsimd_u32_t const *const b_end = b + b_length; union vec_t { uint32x4_t u32x4; simsimd_u32_t u32[4]; @@ -1099,10 +1093,10 @@ SIMSIMD_PUBLIC void simsimd_intersect_u32_neon( // #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve+sve2"))), apply_to = function) SIMSIMD_PUBLIC void simsimd_intersect_u16_sve2( // - simsimd_u16_t const* a, simsimd_u16_t const* b, // + simsimd_u16_t const *a, simsimd_u16_t const *b, // simsimd_size_t a_length, simsimd_size_t b_length, // - simsimd_distance_t* results) { + simsimd_distance_t *results) { // A single SVE lane is 128 bits wide, so one lane fits 8 values. simsimd_size_t const register_size = svcnth(); @@ -1168,8 +1162,8 @@ SIMSIMD_PUBLIC void simsimd_intersect_u16_sve2( // *results = c; } -SIMSIMD_PUBLIC void simsimd_intersect_u32_sve2(simsimd_u32_t const* a, simsimd_u32_t const* b, simsimd_size_t a_length, - simsimd_size_t b_length, simsimd_distance_t* results) { +SIMSIMD_PUBLIC void simsimd_intersect_u32_sve2(simsimd_u32_t const *a, simsimd_u32_t const *b, simsimd_size_t a_length, + simsimd_size_t b_length, simsimd_distance_t *results) { // A single SVE lane is 128 bits wide, so one lane fits 4 values. simsimd_size_t const register_size = svcntw(); @@ -1264,10 +1258,10 @@ SIMSIMD_PUBLIC void simsimd_intersect_u32_sve2(simsimd_u32_t const* a, simsimd_u } SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_sve2( // - simsimd_u16_t const* a, simsimd_u16_t const* b, // - simsimd_i16_t const* a_weights, simsimd_i16_t const* b_weights, // + simsimd_u16_t const *a, simsimd_u16_t const *b, // + simsimd_i16_t const *a_weights, simsimd_i16_t const *b_weights, // simsimd_size_t a_length, simsimd_size_t b_length, // - simsimd_distance_t* results) { + simsimd_distance_t *results) { // A single SVE lane is 128 bits wide, so one lane fits 8 values. simsimd_size_t const register_size = svcnth(); @@ -1347,10 +1341,10 @@ SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_sve2( // #pragma clang attribute push(__attribute__((target("arch=armv8.6-a+sve+sve2+bf16"))), apply_to = function) SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_sve2( // - simsimd_u16_t const* a, simsimd_u16_t const* b, // - simsimd_bf16_t const* a_weights, simsimd_bf16_t const* b_weights, // + simsimd_u16_t const *a, simsimd_u16_t const *b, // + simsimd_bf16_t const *a_weights, simsimd_bf16_t const *b_weights, // simsimd_size_t a_length, simsimd_size_t b_length, // - simsimd_distance_t* results) { + simsimd_distance_t *results) { // A single SVE lane is 128 bits wide, so one lane fits 8 values. simsimd_size_t const register_size = svcnth(); @@ -1402,8 +1396,8 @@ SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_sve2( // simsimd_u64_t b_step = svcntp_b16(b_progress, b_mask); // Compare `a_vec` with each lane of `b_vec` - svbfloat16_t a_weights_vec = svld1_bf16(a_progress, (__bf16 const*)a_weights + a_idx); - svbfloat16_t b_weights_vec = svld1_bf16(b_progress, (__bf16 const*)b_weights + b_idx); + svbfloat16_t a_weights_vec = svld1_bf16(a_progress, (__bf16 const *)a_weights + a_idx); + svbfloat16_t b_weights_vec = svld1_bf16(b_progress, (__bf16 const *)b_weights + b_idx); for (simsimd_size_t i = 0; i < lanes_count; i++) { svbool_t equal_mask = svmatch_u16(a_progress, a_vec, b_vec); //! The `svsel_bf16` intrinsic is broken in many compilers, not returning the correct type. @@ -1427,7 +1421,7 @@ SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_sve2( // #pragma clang attribute pop #pragma GCC pop_options #endif // SIMSIMD_TARGET_SVE2 && SIMSIMD_TARGET_SVE_BF16 -#endif // SIMSIMD_TARGET_ARM +#endif // _SIMSIMD_TARGET_ARM #ifdef __cplusplus } diff --git a/include/simsimd/spatial.h b/include/simsimd/spatial.h index e669de4b..e03e5a78 100644 --- a/include/simsimd/spatial.h +++ b/include/simsimd/spatial.h @@ -132,6 +132,9 @@ SIMSIMD_PUBLIC void simsimd_cos_bf16_haswell(simsimd_bf16_t const* a, simsimd_bf SIMSIMD_PUBLIC void simsimd_l2_f32_haswell(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* d); SIMSIMD_PUBLIC void simsimd_l2sq_f32_haswell(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* d); SIMSIMD_PUBLIC void simsimd_cos_f32_haswell(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_l2_f64_haswell(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_l2sq_f64_haswell(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_cos_f64_haswell(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, simsimd_distance_t* d); /* SIMD-powered backends for AVX512 CPUs of Skylake generation and newer, using 32-bit arithmetic over 512-bit words. * Skylake was launched in 2015, and discontinued in 2019. Skylake had support for F, CD, VL, DQ, and BW extensions, @@ -170,47 +173,45 @@ SIMSIMD_PUBLIC void simsimd_cos_f16_sapphire(simsimd_f16_t const* a, simsimd_f16 SIMSIMD_PUBLIC void simsimd_cos_i8_sierra(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t* d); // clang-format on -#define SIMSIMD_MAKE_L2SQ(name, input_type, accumulator_type, load_and_convert) \ - SIMSIMD_PUBLIC void simsimd_l2sq_##input_type##_##name(simsimd_##input_type##_t const* a, \ - simsimd_##input_type##_t const* b, simsimd_size_t n, \ - simsimd_distance_t* result) { \ - simsimd_##accumulator_type##_t d2 = 0; \ - for (simsimd_size_t i = 0; i != n; ++i) { \ - simsimd_##accumulator_type##_t ai = load_and_convert(a + i); \ - simsimd_##accumulator_type##_t bi = load_and_convert(b + i); \ - d2 += (ai - bi) * (ai - bi); \ - } \ - *result = d2; \ - } - -#define SIMSIMD_MAKE_L2(name, input_type, accumulator_type, load_and_convert) \ - SIMSIMD_PUBLIC void simsimd_l2_##input_type##_##name(simsimd_##input_type##_t const* a, \ - simsimd_##input_type##_t const* b, simsimd_size_t n, \ - simsimd_distance_t* result) { \ - simsimd_l2sq_##input_type##_##name(a, b, n, result); \ - *result = SIMSIMD_SQRT(*result); \ - } - -#define SIMSIMD_MAKE_COS(name, input_type, accumulator_type, load_and_convert) \ - SIMSIMD_PUBLIC void simsimd_cos_##input_type##_##name(simsimd_##input_type##_t const* a, \ - simsimd_##input_type##_t const* b, simsimd_size_t n, \ - simsimd_distance_t* result) { \ - simsimd_##accumulator_type##_t ab = 0, a2 = 0, b2 = 0; \ - for (simsimd_size_t i = 0; i != n; ++i) { \ - simsimd_##accumulator_type##_t ai = load_and_convert(a + i); \ - simsimd_##accumulator_type##_t bi = load_and_convert(b + i); \ - ab += ai * bi; \ - a2 += ai * ai; \ - b2 += bi * bi; \ - } \ - if (a2 == 0 && b2 == 0) { \ - *result = 0; \ - } else if (ab == 0) { \ - *result = 1; \ - } else { \ - simsimd_distance_t unclipped_result = 1 - ab * SIMSIMD_RSQRT(a2) * SIMSIMD_RSQRT(b2); \ - *result = unclipped_result > 0 ? unclipped_result : 0; \ - } \ +#define SIMSIMD_MAKE_L2SQ(name, input_type, accumulator_type, load_and_convert) \ + SIMSIMD_PUBLIC void simsimd_l2sq_##input_type##_##name(simsimd_##input_type##_t const *a, \ + simsimd_##input_type##_t const *b, simsimd_size_t n, \ + simsimd_distance_t *result) { \ + simsimd_##accumulator_type##_t d2 = 0; \ + for (simsimd_size_t i = 0; i != n; ++i) { \ + simsimd_##accumulator_type##_t ai = load_and_convert(a + i); \ + simsimd_##accumulator_type##_t bi = load_and_convert(b + i); \ + d2 += (ai - bi) * (ai - bi); \ + } \ + *result = d2; \ + } + +#define SIMSIMD_MAKE_L2(name, input_type, accumulator_type, load_and_convert) \ + SIMSIMD_PUBLIC void simsimd_l2_##input_type##_##name(simsimd_##input_type##_t const *a, \ + simsimd_##input_type##_t const *b, simsimd_size_t n, \ + simsimd_distance_t *result) { \ + simsimd_l2sq_##input_type##_##name(a, b, n, result); \ + *result = SIMSIMD_SQRT(*result); \ + } + +#define SIMSIMD_MAKE_COS(name, input_type, accumulator_type, load_and_convert) \ + SIMSIMD_PUBLIC void simsimd_cos_##input_type##_##name(simsimd_##input_type##_t const *a, \ + simsimd_##input_type##_t const *b, simsimd_size_t n, \ + simsimd_distance_t *result) { \ + simsimd_##accumulator_type##_t ab = 0, a2 = 0, b2 = 0; \ + for (simsimd_size_t i = 0; i != n; ++i) { \ + simsimd_##accumulator_type##_t ai = load_and_convert(a + i); \ + simsimd_##accumulator_type##_t bi = load_and_convert(b + i); \ + ab += ai * bi; \ + a2 += ai * ai; \ + b2 += bi * bi; \ + } \ + if (a2 == 0 && b2 == 0) { *result = 0; } \ + else if (ab == 0) { *result = 1; } \ + else { \ + simsimd_distance_t unclipped_result = 1 - ab * SIMSIMD_RSQRT(a2) * SIMSIMD_RSQRT(b2); \ + *result = unclipped_result > 0 ? unclipped_result : 0; \ + } \ } SIMSIMD_MAKE_COS(serial, f64, f64, SIMSIMD_DEREFERENCE) // simsimd_cos_f64_serial @@ -249,7 +250,7 @@ SIMSIMD_MAKE_COS(accurate, bf16, f64, SIMSIMD_BF16_TO_F32) // simsimd_cos_bf16_ SIMSIMD_MAKE_L2SQ(accurate, bf16, f64, SIMSIMD_BF16_TO_F32) // simsimd_l2sq_bf16_accurate SIMSIMD_MAKE_L2(accurate, bf16, f64, SIMSIMD_BF16_TO_F32) // simsimd_l2_bf16_accurate -#if SIMSIMD_TARGET_ARM +#if _SIMSIMD_TARGET_ARM #if SIMSIMD_TARGET_NEON #pragma GCC push_options #pragma GCC target("arch=armv8.2-a+simd") @@ -263,10 +264,8 @@ SIMSIMD_INTERNAL simsimd_f64_t _simsimd_sqrt_f64_neon(simsimd_f64_t x) { } SIMSIMD_INTERNAL simsimd_distance_t _simsimd_cos_normalize_f32_neon(simsimd_f32_t ab, simsimd_f32_t a2, simsimd_f32_t b2) { - if (a2 == 0 && b2 == 0) - return 0; - if (ab == 0) - return 1; + if (a2 == 0 && b2 == 0) return 0; + if (ab == 0) return 1; simsimd_f32_t squares_arr[2] = {a2, b2}; float32x2_t squares = vld1_f32(squares_arr); // Unlike x86, Arm NEON manuals don't explicitly mention the accuracy of their `rsqrt` approximation. @@ -286,10 +285,8 @@ SIMSIMD_INTERNAL simsimd_distance_t _simsimd_cos_normalize_f32_neon(simsimd_f32_ SIMSIMD_INTERNAL simsimd_distance_t _simsimd_cos_normalize_f64_neon(simsimd_f64_t ab, simsimd_f64_t a2, simsimd_f64_t b2) { - if (a2 == 0 && b2 == 0) - return 0; - if (ab == 0) - return 1; + if (a2 == 0 && b2 == 0) return 0; + if (ab == 0) return 1; simsimd_f64_t squares_arr[2] = {a2, b2}; float64x2_t squares = vld1q_f64(squares_arr); @@ -308,13 +305,13 @@ SIMSIMD_INTERNAL simsimd_distance_t _simsimd_cos_normalize_f64_neon(simsimd_f64_ return result > 0 ? result : 0; } -SIMSIMD_PUBLIC void simsimd_l2_f32_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2_f32_neon(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_l2sq_f32_neon(a, b, n, result); *result = _simsimd_sqrt_f64_neon(*result); } -SIMSIMD_PUBLIC void simsimd_l2sq_f32_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2sq_f32_neon(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { float32x4_t sum_vec = vdupq_n_f32(0); simsimd_size_t i = 0; for (; i + 4 <= n; i += 4) { @@ -331,8 +328,8 @@ SIMSIMD_PUBLIC void simsimd_l2sq_f32_neon(simsimd_f32_t const* a, simsimd_f32_t *result = sum; } -SIMSIMD_PUBLIC void simsimd_cos_f32_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_cos_f32_neon(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { float32x4_t ab_vec = vdupq_n_f32(0), a2_vec = vdupq_n_f32(0), b2_vec = vdupq_n_f32(0); simsimd_size_t i = 0; for (; i + 4 <= n; i += 4) { @@ -351,13 +348,13 @@ SIMSIMD_PUBLIC void simsimd_cos_f32_neon(simsimd_f32_t const* a, simsimd_f32_t c *result = _simsimd_cos_normalize_f64_neon(ab, a2, b2); } -SIMSIMD_PUBLIC void simsimd_l2_f64_neon(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2_f64_neon(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_l2sq_f64_neon(a, b, n, result); *result = _simsimd_sqrt_f64_neon(*result); } -SIMSIMD_PUBLIC void simsimd_l2sq_f64_neon(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2sq_f64_neon(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { float64x2_t sum_vec = vdupq_n_f64(0); simsimd_size_t i = 0; for (; i + 2 <= n; i += 2) { @@ -374,8 +371,8 @@ SIMSIMD_PUBLIC void simsimd_l2sq_f64_neon(simsimd_f64_t const* a, simsimd_f64_t *result = sum; } -SIMSIMD_PUBLIC void simsimd_cos_f64_neon(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_cos_f64_neon(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { float64x2_t ab_vec = vdupq_n_f64(0), a2_vec = vdupq_n_f64(0), b2_vec = vdupq_n_f64(0); simsimd_size_t i = 0; for (; i + 2 <= n; i += 2) { @@ -403,13 +400,13 @@ SIMSIMD_PUBLIC void simsimd_cos_f64_neon(simsimd_f64_t const* a, simsimd_f64_t c #pragma GCC target("arch=armv8.2-a+simd+fp16") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function) -SIMSIMD_PUBLIC void simsimd_l2_f16_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2_f16_neon(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_l2sq_f16_neon(a, b, n, result); *result = _simsimd_sqrt_f32_neon(*result); } -SIMSIMD_PUBLIC void simsimd_l2sq_f16_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2sq_f16_neon(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { float32x4_t a_vec, b_vec; float32x4_t sum_vec = vdupq_n_f32(0); @@ -418,21 +415,21 @@ SIMSIMD_PUBLIC void simsimd_l2sq_f16_neon(simsimd_f16_t const* a, simsimd_f16_t a_vec = vcvt_f32_f16(_simsimd_partial_load_f16x4_neon(a, n)); b_vec = vcvt_f32_f16(_simsimd_partial_load_f16x4_neon(b, n)); n = 0; - } else { - a_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const*)a)); - b_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const*)b)); + } + else { + a_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const *)a)); + b_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const *)b)); n -= 4, a += 4, b += 4; } float32x4_t diff_vec = vsubq_f32(a_vec, b_vec); sum_vec = vfmaq_f32(sum_vec, diff_vec, diff_vec); - if (n) - goto simsimd_l2sq_f16_neon_cycle; + if (n) goto simsimd_l2sq_f16_neon_cycle; *result = vaddvq_f32(sum_vec); } -SIMSIMD_PUBLIC void simsimd_cos_f16_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_cos_f16_neon(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { float32x4_t ab_vec = vdupq_n_f32(0), a2_vec = vdupq_n_f32(0), b2_vec = vdupq_n_f32(0); float32x4_t a_vec, b_vec; @@ -441,16 +438,16 @@ SIMSIMD_PUBLIC void simsimd_cos_f16_neon(simsimd_f16_t const* a, simsimd_f16_t c a_vec = vcvt_f32_f16(_simsimd_partial_load_f16x4_neon(a, n)); b_vec = vcvt_f32_f16(_simsimd_partial_load_f16x4_neon(b, n)); n = 0; - } else { - a_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const*)a)); - b_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const*)b)); + } + else { + a_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const *)a)); + b_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const *)b)); n -= 4, a += 4, b += 4; } ab_vec = vfmaq_f32(ab_vec, a_vec, b_vec); a2_vec = vfmaq_f32(a2_vec, a_vec, a_vec); b2_vec = vfmaq_f32(b2_vec, b_vec, b_vec); - if (n) - goto simsimd_cos_f16_neon_cycle; + if (n) goto simsimd_cos_f16_neon_cycle; simsimd_f32_t ab = vaddvq_f32(ab_vec), a2 = vaddvq_f32(a2_vec), b2 = vaddvq_f32(b2_vec); *result = _simsimd_cos_normalize_f32_neon(ab, a2, b2); @@ -465,8 +462,8 @@ SIMSIMD_PUBLIC void simsimd_cos_f16_neon(simsimd_f16_t const* a, simsimd_f16_t c #pragma GCC target("arch=armv8.6-a+simd+bf16") #pragma clang attribute push(__attribute__((target("arch=armv8.6-a+simd+bf16"))), apply_to = function) -SIMSIMD_PUBLIC void simsimd_cos_bf16_neon(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_cos_bf16_neon(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { // Similar to `simsimd_cos_i8_neon`, we can use the `BFMMLA` instruction through // the `vbfmmlaq_f32` intrinsic to compute matrix products and later drop 1/4 of values. @@ -517,29 +514,29 @@ SIMSIMD_PUBLIC void simsimd_cos_bf16_neon(simsimd_bf16_t const* a, simsimd_bf16_ a_vec = _simsimd_partial_load_bf16x8_neon(a, n); b_vec = _simsimd_partial_load_bf16x8_neon(b, n); n = 0; - } else { - a_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const*)a); - b_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const*)b); + } + else { + a_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const *)a); + b_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const *)b); n -= 8, a += 8, b += 8; } ab_vec = vbfdotq_f32(ab_vec, a_vec, b_vec); a2_vec = vbfdotq_f32(a2_vec, a_vec, a_vec); b2_vec = vbfdotq_f32(b2_vec, b_vec, b_vec); - if (n) - goto simsimd_cos_bf16_neon_cycle; + if (n) goto simsimd_cos_bf16_neon_cycle; // Avoid `simsimd_approximate_inverse_square_root` on Arm NEON simsimd_f32_t ab = vaddvq_f32(ab_vec), a2 = vaddvq_f32(a2_vec), b2 = vaddvq_f32(b2_vec); *result = _simsimd_cos_normalize_f32_neon(ab, a2, b2); } -SIMSIMD_PUBLIC void simsimd_l2_bf16_neon(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2_bf16_neon(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_l2sq_bf16_neon(a, b, n, result); *result = _simsimd_sqrt_f64_neon(*result); } -SIMSIMD_PUBLIC void simsimd_l2sq_bf16_neon(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2sq_bf16_neon(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { float32x4_t diff_high_vec, diff_low_vec; float32x4_t sum_high_vec = vdupq_n_f32(0), sum_low_vec = vdupq_n_f32(0); @@ -550,17 +547,17 @@ SIMSIMD_PUBLIC void simsimd_l2sq_bf16_neon(simsimd_bf16_t const* a, simsimd_bf16 diff_high_vec = vsubq_f32(vcvt_f32_bf16(vget_high_bf16(a_vec)), vcvt_f32_bf16(vget_high_bf16(b_vec))); diff_low_vec = vsubq_f32(vcvt_f32_bf16(vget_low_bf16(a_vec)), vcvt_f32_bf16(vget_low_bf16(b_vec))); n = 0; - } else { - bfloat16x8_t a_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const*)a); - bfloat16x8_t b_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const*)b); + } + else { + bfloat16x8_t a_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const *)a); + bfloat16x8_t b_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const *)b); diff_high_vec = vsubq_f32(vcvt_f32_bf16(vget_high_bf16(a_vec)), vcvt_f32_bf16(vget_high_bf16(b_vec))); diff_low_vec = vsubq_f32(vcvt_f32_bf16(vget_low_bf16(a_vec)), vcvt_f32_bf16(vget_low_bf16(b_vec))); n -= 8, a += 8, b += 8; } sum_high_vec = vfmaq_f32(sum_high_vec, diff_high_vec, diff_high_vec); sum_low_vec = vfmaq_f32(sum_low_vec, diff_low_vec, diff_low_vec); - if (n) - goto simsimd_l2sq_bf16_neon_cycle; + if (n) goto simsimd_l2sq_bf16_neon_cycle; *result = vaddvq_f32(vaddq_f32(sum_high_vec, sum_low_vec)); } @@ -574,13 +571,13 @@ SIMSIMD_PUBLIC void simsimd_l2sq_bf16_neon(simsimd_bf16_t const* a, simsimd_bf16 #pragma GCC target("arch=armv8.2-a+dotprod+i8mm") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+dotprod+i8mm"))), apply_to = function) -SIMSIMD_PUBLIC void simsimd_l2_i8_neon(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2_i8_neon(simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_l2sq_i8_neon(a, b, n, result); *result = _simsimd_sqrt_f32_neon(*result); } -SIMSIMD_PUBLIC void simsimd_l2sq_i8_neon(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2sq_i8_neon(simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { // The naive approach is to upcast 8-bit signed integers into 16-bit signed integers // for subtraction, then multiply within 16-bit integers and accumulate the results @@ -606,8 +603,8 @@ SIMSIMD_PUBLIC void simsimd_l2sq_i8_neon(simsimd_i8_t const* a, simsimd_i8_t con *result = d2; } -SIMSIMD_PUBLIC void simsimd_cos_i8_neon(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_cos_i8_neon(simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_size_t i = 0; @@ -725,13 +722,13 @@ SIMSIMD_PUBLIC void simsimd_cos_i8_neon(simsimd_i8_t const* a, simsimd_i8_t cons *result = _simsimd_cos_normalize_f32_neon(ab, a2, b2); } -SIMSIMD_PUBLIC void simsimd_l2_u8_neon(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2_u8_neon(simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_l2sq_u8_neon(a, b, n, result); *result = _simsimd_sqrt_f32_neon(*result); } -SIMSIMD_PUBLIC void simsimd_l2sq_u8_neon(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2sq_u8_neon(simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { uint32x4_t d2_vec = vdupq_n_u32(0); simsimd_size_t i = 0; for (; i + 16 <= n; i += 16) { @@ -748,8 +745,8 @@ SIMSIMD_PUBLIC void simsimd_l2sq_u8_neon(simsimd_u8_t const* a, simsimd_u8_t con *result = d2; } -SIMSIMD_PUBLIC void simsimd_cos_u8_neon(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_cos_u8_neon(simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_size_t i = 0; uint32x4_t ab_vec = vdupq_n_u32(0); @@ -784,13 +781,13 @@ SIMSIMD_PUBLIC void simsimd_cos_u8_neon(simsimd_u8_t const* a, simsimd_u8_t cons #pragma GCC target("arch=armv8.2-a+sve") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve"))), apply_to = function) -SIMSIMD_PUBLIC void simsimd_l2_f32_sve(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2_f32_sve(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_l2sq_f32_sve(a, b, n, result); *result = _simsimd_sqrt_f64_neon(*result); } -SIMSIMD_PUBLIC void simsimd_l2sq_f32_sve(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2sq_f32_sve(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_size_t i = 0; svfloat32_t d2_vec = svdupq_n_f32(0.f, 0.f, 0.f, 0.f); do { @@ -805,8 +802,8 @@ SIMSIMD_PUBLIC void simsimd_l2sq_f32_sve(simsimd_f32_t const* a, simsimd_f32_t c *result = d2; } -SIMSIMD_PUBLIC void simsimd_cos_f32_sve(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_cos_f32_sve(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_size_t i = 0; svfloat32_t ab_vec = svdupq_n_f32(0.f, 0.f, 0.f, 0.f); svfloat32_t a2_vec = svdupq_n_f32(0.f, 0.f, 0.f, 0.f); @@ -827,13 +824,13 @@ SIMSIMD_PUBLIC void simsimd_cos_f32_sve(simsimd_f32_t const* a, simsimd_f32_t co *result = _simsimd_cos_normalize_f64_neon(ab, a2, b2); } -SIMSIMD_PUBLIC void simsimd_l2_f64_sve(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2_f64_sve(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_l2sq_f64_sve(a, b, n, result); *result = _simsimd_sqrt_f64_neon(*result); } -SIMSIMD_PUBLIC void simsimd_l2sq_f64_sve(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2sq_f64_sve(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_size_t i = 0; svfloat64_t d2_vec = svdupq_n_f64(0.0, 0.0); do { @@ -848,8 +845,8 @@ SIMSIMD_PUBLIC void simsimd_l2sq_f64_sve(simsimd_f64_t const* a, simsimd_f64_t c *result = d2; } -SIMSIMD_PUBLIC void simsimd_cos_f64_sve(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_cos_f64_sve(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_size_t i = 0; svfloat64_t ab_vec = svdupq_n_f64(0.0, 0.0); svfloat64_t a2_vec = svdupq_n_f64(0.0, 0.0); @@ -879,17 +876,17 @@ SIMSIMD_PUBLIC void simsimd_cos_f64_sve(simsimd_f64_t const* a, simsimd_f64_t co #pragma GCC target("arch=armv8.2-a+sve+fp16") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve+fp16"))), apply_to = function) -SIMSIMD_PUBLIC void simsimd_l2_f16_sve(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2_f16_sve(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_l2sq_f16_sve(a, b, n, result); *result = _simsimd_sqrt_f32_neon(*result); } -SIMSIMD_PUBLIC void simsimd_l2sq_f16_sve(simsimd_f16_t const* a_enum, simsimd_f16_t const* b_enum, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2sq_f16_sve(simsimd_f16_t const *a_enum, simsimd_f16_t const *b_enum, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_size_t i = 0; svfloat16_t d2_vec = svdupq_n_f16(0, 0, 0, 0, 0, 0, 0, 0); - simsimd_f16_for_arm_simd_t const* a = (simsimd_f16_for_arm_simd_t const*)(a_enum); - simsimd_f16_for_arm_simd_t const* b = (simsimd_f16_for_arm_simd_t const*)(b_enum); + simsimd_f16_for_arm_simd_t const *a = (simsimd_f16_for_arm_simd_t const *)(a_enum); + simsimd_f16_for_arm_simd_t const *b = (simsimd_f16_for_arm_simd_t const *)(b_enum); do { svbool_t pg_vec = svwhilelt_b16((unsigned int)i, (unsigned int)n); svfloat16_t a_vec = svld1_f16(pg_vec, a + i); @@ -902,14 +899,14 @@ SIMSIMD_PUBLIC void simsimd_l2sq_f16_sve(simsimd_f16_t const* a_enum, simsimd_f1 *result = d2_f16; } -SIMSIMD_PUBLIC void simsimd_cos_f16_sve(simsimd_f16_t const* a_enum, simsimd_f16_t const* b_enum, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_cos_f16_sve(simsimd_f16_t const *a_enum, simsimd_f16_t const *b_enum, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_size_t i = 0; svfloat16_t ab_vec = svdupq_n_f16(0, 0, 0, 0, 0, 0, 0, 0); svfloat16_t a2_vec = svdupq_n_f16(0, 0, 0, 0, 0, 0, 0, 0); svfloat16_t b2_vec = svdupq_n_f16(0, 0, 0, 0, 0, 0, 0, 0); - simsimd_f16_for_arm_simd_t const* a = (simsimd_f16_for_arm_simd_t const*)(a_enum); - simsimd_f16_for_arm_simd_t const* b = (simsimd_f16_for_arm_simd_t const*)(b_enum); + simsimd_f16_for_arm_simd_t const *a = (simsimd_f16_for_arm_simd_t const *)(a_enum); + simsimd_f16_for_arm_simd_t const *b = (simsimd_f16_for_arm_simd_t const *)(b_enum); do { svbool_t pg_vec = svwhilelt_b16((unsigned int)i, (unsigned int)n); svfloat16_t a_vec = svld1_f16(pg_vec, a + i); @@ -935,18 +932,18 @@ SIMSIMD_PUBLIC void simsimd_cos_f16_sve(simsimd_f16_t const* a_enum, simsimd_f16 #pragma GCC target("arch=armv8.2-a+sve+bf16") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve+bf16"))), apply_to = function) -SIMSIMD_PUBLIC void simsimd_l2_bf16_sve(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2_bf16_sve(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_l2sq_bf16_sve(a, b, n, result); *result = _simsimd_sqrt_f32_neon(*result); } -SIMSIMD_PUBLIC void simsimd_l2sq_bf16_sve(simsimd_bf16_t const* a_enum, simsimd_bf16_t const* b_enum, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2sq_bf16_sve(simsimd_bf16_t const *a_enum, simsimd_bf16_t const *b_enum, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_size_t i = 0; svfloat32_t d2_low_vec = svdupq_n_f32(0.f, 0.f, 0.f, 0.f); svfloat32_t d2_high_vec = svdupq_n_f32(0.f, 0.f, 0.f, 0.f); - simsimd_u16_t const* a = (simsimd_u16_t const*)(a_enum); - simsimd_u16_t const* b = (simsimd_u16_t const*)(b_enum); + simsimd_u16_t const *a = (simsimd_u16_t const *)(a_enum); + simsimd_u16_t const *b = (simsimd_u16_t const *)(b_enum); do { svbool_t pg_vec = svwhilelt_b16((unsigned int)i, (unsigned int)n); svuint16_t a_vec = svld1_u16(pg_vec, a + i); @@ -970,14 +967,14 @@ SIMSIMD_PUBLIC void simsimd_l2sq_bf16_sve(simsimd_bf16_t const* a_enum, simsimd_ *result = d2; } -SIMSIMD_PUBLIC void simsimd_cos_bf16_sve(simsimd_bf16_t const* a_enum, simsimd_bf16_t const* b_enum, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_cos_bf16_sve(simsimd_bf16_t const *a_enum, simsimd_bf16_t const *b_enum, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_size_t i = 0; svfloat32_t ab_vec = svdupq_n_f32(0.f, 0.f, 0.f, 0.f); svfloat32_t a2_vec = svdupq_n_f32(0.f, 0.f, 0.f, 0.f); svfloat32_t b2_vec = svdupq_n_f32(0.f, 0.f, 0.f, 0.f); - simsimd_bf16_for_arm_simd_t const* a = (simsimd_bf16_for_arm_simd_t const*)(a_enum); - simsimd_bf16_for_arm_simd_t const* b = (simsimd_bf16_for_arm_simd_t const*)(b_enum); + simsimd_bf16_for_arm_simd_t const *a = (simsimd_bf16_for_arm_simd_t const *)(a_enum); + simsimd_bf16_for_arm_simd_t const *b = (simsimd_bf16_for_arm_simd_t const *)(b_enum); do { svbool_t pg_vec = svwhilelt_b16((unsigned int)i, (unsigned int)n); svbfloat16_t a_vec = svld1_bf16(pg_vec, a + i); @@ -997,9 +994,9 @@ SIMSIMD_PUBLIC void simsimd_cos_bf16_sve(simsimd_bf16_t const* a_enum, simsimd_b #pragma clang attribute pop #pragma GCC pop_options #endif // SIMSIMD_TARGET_SVE_BF16 -#endif // SIMSIMD_TARGET_ARM +#endif // _SIMSIMD_TARGET_ARM -#if SIMSIMD_TARGET_X86 +#if _SIMSIMD_TARGET_X86 #if SIMSIMD_TARGET_HASWELL #pragma GCC push_options #pragma GCC target("avx2") @@ -1016,8 +1013,7 @@ SIMSIMD_INTERNAL simsimd_distance_t _simsimd_cos_normalize_f64_haswell(simsimd_f simsimd_f64_t b2) { // If both vectors have magnitude 0, the distance is 0. - if (a2 == 0 && b2 == 0) - return 0; + if (a2 == 0 && b2 == 0) return 0; // If any one of the vectors is 0, the square root of the product is 0, // the division is illformed, and the result is 1. else if (ab == 0) @@ -1045,8 +1041,7 @@ SIMSIMD_INTERNAL simsimd_distance_t _simsimd_cos_normalize_f32_haswell(simsimd_f simsimd_f32_t b2) { // If both vectors have magnitude 0, the distance is 0. - if (a2 == 0.0f && b2 == 0.0f) - return 0.0f; + if (a2 == 0.0f && b2 == 0.0f) return 0.0f; // If any one of the vectors is 0, the square root of the product is 0, // the division is illformed, and the result is 1. else if (ab == 0.0f) @@ -1077,21 +1072,21 @@ SIMSIMD_INTERNAL simsimd_distance_t _simsimd_cos_normalize_f32_haswell(simsimd_f #pragma clang attribute pop #pragma GCC pop_options #endif // SIMSIMD_TARGET_HASWELL -#endif // SIMSIMD_TARGET_X86 +#endif // _SIMSIMD_TARGET_X86 -#if SIMSIMD_TARGET_X86 +#if _SIMSIMD_TARGET_X86 #if SIMSIMD_TARGET_HASWELL #pragma GCC push_options #pragma GCC target("avx2", "f16c", "fma") #pragma clang attribute push(__attribute__((target("avx2,f16c,fma"))), apply_to = function) -SIMSIMD_PUBLIC void simsimd_l2_f16_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2_f16_haswell(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_l2sq_f16_haswell(a, b, n, result); *result = _simsimd_sqrt_f32_haswell(*result); } -SIMSIMD_PUBLIC void simsimd_l2sq_f16_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2sq_f16_haswell(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m256 a_vec, b_vec; __m256 d2_vec = _mm256_setzero_ps(); @@ -1100,21 +1095,21 @@ SIMSIMD_PUBLIC void simsimd_l2sq_f16_haswell(simsimd_f16_t const* a, simsimd_f16 a_vec = _simsimd_partial_load_f16x8_haswell(a, n); b_vec = _simsimd_partial_load_f16x8_haswell(b, n); n = 0; - } else { - a_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const*)a)); - b_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const*)b)); + } + else { + a_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)a)); + b_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)b)); n -= 8, a += 8, b += 8; } __m256 d_vec = _mm256_sub_ps(a_vec, b_vec); d2_vec = _mm256_fmadd_ps(d_vec, d_vec, d2_vec); - if (n) - goto simsimd_l2sq_f16_haswell_cycle; + if (n) goto simsimd_l2sq_f16_haswell_cycle; *result = _simsimd_reduce_f32x8_haswell(d2_vec); } -SIMSIMD_PUBLIC void simsimd_cos_f16_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_cos_f16_haswell(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m256 a_vec, b_vec; __m256 ab_vec = _mm256_setzero_ps(), a2_vec = _mm256_setzero_ps(), b2_vec = _mm256_setzero_ps(); @@ -1123,16 +1118,16 @@ SIMSIMD_PUBLIC void simsimd_cos_f16_haswell(simsimd_f16_t const* a, simsimd_f16_ a_vec = _simsimd_partial_load_f16x8_haswell(a, n); b_vec = _simsimd_partial_load_f16x8_haswell(b, n); n = 0; - } else { - a_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const*)a)); - b_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const*)b)); + } + else { + a_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)a)); + b_vec = _mm256_cvtph_ps(_mm_lddqu_si128((__m128i const *)b)); n -= 8, a += 8, b += 8; } ab_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_vec); a2_vec = _mm256_fmadd_ps(a_vec, a_vec, a2_vec); b2_vec = _mm256_fmadd_ps(b_vec, b_vec, b2_vec); - if (n) - goto simsimd_cos_f16_haswell_cycle; + if (n) goto simsimd_cos_f16_haswell_cycle; simsimd_f32_t ab = _simsimd_reduce_f32x8_haswell(ab_vec); simsimd_f32_t a2 = _simsimd_reduce_f32x8_haswell(a2_vec); @@ -1140,13 +1135,13 @@ SIMSIMD_PUBLIC void simsimd_cos_f16_haswell(simsimd_f16_t const* a, simsimd_f16_ *result = _simsimd_cos_normalize_f32_haswell(ab, a2, b2); } -SIMSIMD_PUBLIC void simsimd_l2_bf16_haswell(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2_bf16_haswell(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_l2sq_bf16_haswell(a, b, n, result); *result = _simsimd_sqrt_f32_haswell(*result); } -SIMSIMD_PUBLIC void simsimd_l2sq_bf16_haswell(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2sq_bf16_haswell(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m256 a_vec, b_vec; __m256 d2_vec = _mm256_setzero_ps(); @@ -1155,21 +1150,21 @@ SIMSIMD_PUBLIC void simsimd_l2sq_bf16_haswell(simsimd_bf16_t const* a, simsimd_b a_vec = _simsimd_bf16x8_to_f32x8_haswell(_simsimd_partial_load_bf16x8_haswell(a, n)); b_vec = _simsimd_bf16x8_to_f32x8_haswell(_simsimd_partial_load_bf16x8_haswell(b, n)); n = 0; - } else { - a_vec = _simsimd_bf16x8_to_f32x8_haswell(_mm_lddqu_si128((__m128i const*)a)); - b_vec = _simsimd_bf16x8_to_f32x8_haswell(_mm_lddqu_si128((__m128i const*)b)); + } + else { + a_vec = _simsimd_bf16x8_to_f32x8_haswell(_mm_lddqu_si128((__m128i const *)a)); + b_vec = _simsimd_bf16x8_to_f32x8_haswell(_mm_lddqu_si128((__m128i const *)b)); n -= 8, a += 8, b += 8; } __m256 d_vec = _mm256_sub_ps(a_vec, b_vec); d2_vec = _mm256_fmadd_ps(d_vec, d_vec, d2_vec); - if (n) - goto simsimd_l2sq_bf16_haswell_cycle; + if (n) goto simsimd_l2sq_bf16_haswell_cycle; *result = _simsimd_reduce_f32x8_haswell(d2_vec); } -SIMSIMD_PUBLIC void simsimd_cos_bf16_haswell(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_cos_bf16_haswell(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m256 a_vec, b_vec; __m256 ab_vec = _mm256_setzero_ps(), a2_vec = _mm256_setzero_ps(), b2_vec = _mm256_setzero_ps(); @@ -1178,16 +1173,16 @@ SIMSIMD_PUBLIC void simsimd_cos_bf16_haswell(simsimd_bf16_t const* a, simsimd_bf a_vec = _simsimd_bf16x8_to_f32x8_haswell(_simsimd_partial_load_bf16x8_haswell(a, n)); b_vec = _simsimd_bf16x8_to_f32x8_haswell(_simsimd_partial_load_bf16x8_haswell(b, n)); n = 0; - } else { - a_vec = _simsimd_bf16x8_to_f32x8_haswell(_mm_lddqu_si128((__m128i const*)a)); - b_vec = _simsimd_bf16x8_to_f32x8_haswell(_mm_lddqu_si128((__m128i const*)b)); + } + else { + a_vec = _simsimd_bf16x8_to_f32x8_haswell(_mm_lddqu_si128((__m128i const *)a)); + b_vec = _simsimd_bf16x8_to_f32x8_haswell(_mm_lddqu_si128((__m128i const *)b)); n -= 8, a += 8, b += 8; } ab_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_vec); a2_vec = _mm256_fmadd_ps(a_vec, a_vec, a2_vec); b2_vec = _mm256_fmadd_ps(b_vec, b_vec, b2_vec); - if (n) - goto simsimd_cos_bf16_haswell_cycle; + if (n) goto simsimd_cos_bf16_haswell_cycle; simsimd_f32_t ab = _simsimd_reduce_f32x8_haswell(ab_vec); simsimd_f32_t a2 = _simsimd_reduce_f32x8_haswell(a2_vec); @@ -1195,21 +1190,21 @@ SIMSIMD_PUBLIC void simsimd_cos_bf16_haswell(simsimd_bf16_t const* a, simsimd_bf *result = _simsimd_cos_normalize_f32_haswell(ab, a2, b2); } -SIMSIMD_PUBLIC void simsimd_l2_i8_haswell(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2_i8_haswell(simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_l2sq_i8_haswell(a, b, n, result); *result = _simsimd_sqrt_f32_haswell(*result); } -SIMSIMD_PUBLIC void simsimd_l2sq_i8_haswell(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2sq_i8_haswell(simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m256i d2_i32_low_vec = _mm256_setzero_si256(); __m256i d2_i32_high_vec = _mm256_setzero_si256(); simsimd_size_t i = 0; for (; i + 32 <= n; i += 32) { - __m256i a_i8_vec = _mm256_lddqu_si256((__m256i const*)(a + i)); - __m256i b_i8_vec = _mm256_lddqu_si256((__m256i const*)(b + i)); + __m256i a_i8_vec = _mm256_lddqu_si256((__m256i const *)(a + i)); + __m256i b_i8_vec = _mm256_lddqu_si256((__m256i const *)(b + i)); // Sign extend `i8` to `i16` __m256i a_i16_low_vec = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(a_i8_vec)); @@ -1240,8 +1235,8 @@ SIMSIMD_PUBLIC void simsimd_l2sq_i8_haswell(simsimd_i8_t const* a, simsimd_i8_t *result = (simsimd_f64_t)d2; } -SIMSIMD_PUBLIC void simsimd_cos_i8_haswell(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_cos_i8_haswell(simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m256i ab_i32_low_vec = _mm256_setzero_si256(); __m256i ab_i32_high_vec = _mm256_setzero_si256(); @@ -1264,8 +1259,8 @@ SIMSIMD_PUBLIC void simsimd_cos_i8_haswell(simsimd_i8_t const* a, simsimd_i8_t c // This can easily lead to noticeable numerical errors in the final result. simsimd_size_t i = 0; for (; i + 32 <= n; i += 32) { - __m256i a_i8_vec = _mm256_lddqu_si256((__m256i const*)(a + i)); - __m256i b_i8_vec = _mm256_lddqu_si256((__m256i const*)(b + i)); + __m256i a_i8_vec = _mm256_lddqu_si256((__m256i const *)(a + i)); + __m256i b_i8_vec = _mm256_lddqu_si256((__m256i const *)(b + i)); // Unpack `int8` to `int16` __m256i a_i16_low_vec = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(a_i8_vec, 0)); @@ -1296,13 +1291,13 @@ SIMSIMD_PUBLIC void simsimd_cos_i8_haswell(simsimd_i8_t const* a, simsimd_i8_t c *result = _simsimd_cos_normalize_f32_haswell(ab, a2, b2); } -SIMSIMD_PUBLIC void simsimd_l2_u8_haswell(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2_u8_haswell(simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_l2sq_u8_haswell(a, b, n, result); *result = _simsimd_sqrt_f32_haswell(*result); } -SIMSIMD_PUBLIC void simsimd_l2sq_u8_haswell(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2sq_u8_haswell(simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m256i d2_i32_low_vec = _mm256_setzero_si256(); __m256i d2_i32_high_vec = _mm256_setzero_si256(); @@ -1310,8 +1305,8 @@ SIMSIMD_PUBLIC void simsimd_l2sq_u8_haswell(simsimd_u8_t const* a, simsimd_u8_t simsimd_size_t i = 0; for (; i + 32 <= n; i += 32) { - __m256i a_u8_vec = _mm256_lddqu_si256((__m256i const*)(a + i)); - __m256i b_u8_vec = _mm256_lddqu_si256((__m256i const*)(b + i)); + __m256i a_u8_vec = _mm256_lddqu_si256((__m256i const *)(a + i)); + __m256i b_u8_vec = _mm256_lddqu_si256((__m256i const *)(b + i)); // Substracting unsigned vectors in AVX2 is done by saturating subtraction: __m256i d_u8_vec = _mm256_or_si256(_mm256_subs_epu8(a_u8_vec, b_u8_vec), _mm256_subs_epu8(b_u8_vec, a_u8_vec)); @@ -1338,8 +1333,8 @@ SIMSIMD_PUBLIC void simsimd_l2sq_u8_haswell(simsimd_u8_t const* a, simsimd_u8_t *result = (simsimd_f64_t)d2; } -SIMSIMD_PUBLIC void simsimd_cos_u8_haswell(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_cos_u8_haswell(simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m256i ab_i32_low_vec = _mm256_setzero_si256(); __m256i ab_i32_high_vec = _mm256_setzero_si256(); @@ -1363,8 +1358,8 @@ SIMSIMD_PUBLIC void simsimd_cos_u8_haswell(simsimd_u8_t const* a, simsimd_u8_t c // This can easily lead to noticeable numerical errors in the final result. simsimd_size_t i = 0; for (; i + 32 <= n; i += 32) { - __m256i a_u8_vec = _mm256_lddqu_si256((__m256i const*)(a + i)); - __m256i b_u8_vec = _mm256_lddqu_si256((__m256i const*)(b + i)); + __m256i a_u8_vec = _mm256_lddqu_si256((__m256i const *)(a + i)); + __m256i b_u8_vec = _mm256_lddqu_si256((__m256i const *)(b + i)); // Upcast `uint8` to `int16`. Unlike the signed version, we can use the unpacking // instructions instead of extracts, as they are much faster and more efficient. @@ -1396,13 +1391,13 @@ SIMSIMD_PUBLIC void simsimd_cos_u8_haswell(simsimd_u8_t const* a, simsimd_u8_t c *result = _simsimd_cos_normalize_f32_haswell(ab, a2, b2); } -SIMSIMD_PUBLIC void simsimd_l2_f32_haswell(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2_f32_haswell(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_l2sq_f32_haswell(a, b, n, result); *result = _simsimd_sqrt_f32_haswell(*result); } -SIMSIMD_PUBLIC void simsimd_l2sq_f32_haswell(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2sq_f32_haswell(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m256 d2_vec = _mm256_setzero_ps(); simsimd_size_t i = 0; @@ -1422,8 +1417,8 @@ SIMSIMD_PUBLIC void simsimd_l2sq_f32_haswell(simsimd_f32_t const* a, simsimd_f32 *result = d2; } -SIMSIMD_PUBLIC void simsimd_cos_f32_haswell(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_cos_f32_haswell(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m256 ab_vec = _mm256_setzero_ps(); __m256 a2_vec = _mm256_setzero_ps(); @@ -1447,6 +1442,57 @@ SIMSIMD_PUBLIC void simsimd_cos_f32_haswell(simsimd_f32_t const* a, simsimd_f32_ *result = _simsimd_cos_normalize_f64_haswell(ab, a2, b2); } +SIMSIMD_PUBLIC void simsimd_l2_f64_haswell(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { + simsimd_l2sq_f64_haswell(a, b, n, result); + *result = _simsimd_sqrt_f64_haswell(*result); +} +SIMSIMD_PUBLIC void simsimd_l2sq_f64_haswell(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { + + __m256d d2_vec = _mm256_setzero_pd(); + simsimd_size_t i = 0; + for (; i + 4 <= n; i += 4) { + __m256d a_vec = _mm256_loadu_pd(a + i); + __m256d b_vec = _mm256_loadu_pd(b + i); + __m256d d_vec = _mm256_sub_pd(a_vec, b_vec); + d2_vec = _mm256_fmadd_pd(d_vec, d_vec, d2_vec); + } + + simsimd_f64_t d2 = _simsimd_reduce_f64x4_haswell(d2_vec); + for (; i < n; ++i) { + simsimd_f64_t d = a[i] - b[i]; + d2 += d * d; + } + + *result = d2; +} + +SIMSIMD_PUBLIC void simsimd_cos_f64_haswell(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { + + __m256d ab_vec = _mm256_setzero_pd(); + __m256d a2_vec = _mm256_setzero_pd(); + __m256d b2_vec = _mm256_setzero_pd(); + simsimd_size_t i = 0; + for (; i + 4 <= n; i += 4) { + __m256d a_vec = _mm256_loadu_pd(a + i); + __m256d b_vec = _mm256_loadu_pd(b + i); + ab_vec = _mm256_fmadd_pd(a_vec, b_vec, ab_vec); + a2_vec = _mm256_fmadd_pd(a_vec, a_vec, a2_vec); + b2_vec = _mm256_fmadd_pd(b_vec, b_vec, b2_vec); + } + + simsimd_f64_t ab = _simsimd_reduce_f64x4_haswell(ab_vec); + simsimd_f64_t a2 = _simsimd_reduce_f64x4_haswell(a2_vec); + simsimd_f64_t b2 = _simsimd_reduce_f64x4_haswell(b2_vec); + for (; i < n; ++i) { + simsimd_f64_t ai = a[i], bi = b[i]; + ab += ai * bi, a2 += ai * ai, b2 += bi * bi; + } + *result = _simsimd_cos_normalize_f64_haswell(ab, a2, b2); +} + #pragma clang attribute pop #pragma GCC pop_options #endif // SIMSIMD_TARGET_HASWELL @@ -1456,13 +1502,13 @@ SIMSIMD_PUBLIC void simsimd_cos_f32_haswell(simsimd_f32_t const* a, simsimd_f32_ #pragma GCC target("avx2", "avx512f", "avx512bw", "avx512vl", "bmi2") #pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512bw,avx512vl,bmi2"))), apply_to = function) -SIMSIMD_PUBLIC void simsimd_l2_f32_skylake(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2_f32_skylake(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_l2sq_f32_skylake(a, b, n, result); *result = _simsimd_sqrt_f64_haswell(*result); } -SIMSIMD_PUBLIC void simsimd_l2sq_f32_skylake(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2sq_f32_skylake(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m512 d2_vec = _mm512_setzero(); __m512 a_vec, b_vec; @@ -1472,15 +1518,15 @@ SIMSIMD_PUBLIC void simsimd_l2sq_f32_skylake(simsimd_f32_t const* a, simsimd_f32 a_vec = _mm512_maskz_loadu_ps(mask, a); b_vec = _mm512_maskz_loadu_ps(mask, b); n = 0; - } else { + } + else { a_vec = _mm512_loadu_ps(a); b_vec = _mm512_loadu_ps(b); a += 16, b += 16, n -= 16; } __m512 d_vec = _mm512_sub_ps(a_vec, b_vec); d2_vec = _mm512_fmadd_ps(d_vec, d_vec, d2_vec); - if (n) - goto simsimd_l2sq_f32_skylake_cycle; + if (n) goto simsimd_l2sq_f32_skylake_cycle; *result = _simsimd_reduce_f32x16_skylake(d2_vec); } @@ -1489,8 +1535,7 @@ SIMSIMD_INTERNAL simsimd_distance_t _simsimd_cos_normalize_f64_skylake(simsimd_f simsimd_f64_t b2) { // If both vectors have magnitude 0, the distance is 0. - if (a2 == 0 && b2 == 0) - return 0; + if (a2 == 0 && b2 == 0) return 0; // If any one of the vectors is 0, the square root of the product is 0, // the division is illformed, and the result is 1. else if (ab == 0) @@ -1529,8 +1574,8 @@ SIMSIMD_INTERNAL simsimd_distance_t _simsimd_cos_normalize_f64_skylake(simsimd_f return result > 0 ? result : 0; } -SIMSIMD_PUBLIC void simsimd_cos_f32_skylake(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_cos_f32_skylake(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m512 ab_vec = _mm512_setzero(); __m512 a2_vec = _mm512_setzero(); __m512 b2_vec = _mm512_setzero(); @@ -1542,7 +1587,8 @@ SIMSIMD_PUBLIC void simsimd_cos_f32_skylake(simsimd_f32_t const* a, simsimd_f32_ a_vec = _mm512_maskz_loadu_ps(mask, a); b_vec = _mm512_maskz_loadu_ps(mask, b); n = 0; - } else { + } + else { a_vec = _mm512_loadu_ps(a); b_vec = _mm512_loadu_ps(b); a += 16, b += 16, n -= 16; @@ -1550,8 +1596,7 @@ SIMSIMD_PUBLIC void simsimd_cos_f32_skylake(simsimd_f32_t const* a, simsimd_f32_ ab_vec = _mm512_fmadd_ps(a_vec, b_vec, ab_vec); a2_vec = _mm512_fmadd_ps(a_vec, a_vec, a2_vec); b2_vec = _mm512_fmadd_ps(b_vec, b_vec, b2_vec); - if (n) - goto simsimd_cos_f32_skylake_cycle; + if (n) goto simsimd_cos_f32_skylake_cycle; simsimd_f64_t ab = _simsimd_reduce_f32x16_skylake(ab_vec); simsimd_f64_t a2 = _simsimd_reduce_f32x16_skylake(a2_vec); @@ -1559,13 +1604,13 @@ SIMSIMD_PUBLIC void simsimd_cos_f32_skylake(simsimd_f32_t const* a, simsimd_f32_ *result = _simsimd_cos_normalize_f64_skylake(ab, a2, b2); } -SIMSIMD_PUBLIC void simsimd_l2_f64_skylake(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2_f64_skylake(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_l2sq_f64_skylake(a, b, n, result); *result = _simsimd_sqrt_f64_haswell(*result); } -SIMSIMD_PUBLIC void simsimd_l2sq_f64_skylake(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2sq_f64_skylake(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m512d d2_vec = _mm512_setzero_pd(); __m512d a_vec, b_vec; @@ -1575,21 +1620,21 @@ SIMSIMD_PUBLIC void simsimd_l2sq_f64_skylake(simsimd_f64_t const* a, simsimd_f64 a_vec = _mm512_maskz_loadu_pd(mask, a); b_vec = _mm512_maskz_loadu_pd(mask, b); n = 0; - } else { + } + else { a_vec = _mm512_loadu_pd(a); b_vec = _mm512_loadu_pd(b); a += 8, b += 8, n -= 8; } __m512d d_vec = _mm512_sub_pd(a_vec, b_vec); d2_vec = _mm512_fmadd_pd(d_vec, d_vec, d2_vec); - if (n) - goto simsimd_l2sq_f64_skylake_cycle; + if (n) goto simsimd_l2sq_f64_skylake_cycle; *result = _mm512_reduce_add_pd(d2_vec); } -SIMSIMD_PUBLIC void simsimd_cos_f64_skylake(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_cos_f64_skylake(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m512d ab_vec = _mm512_setzero_pd(); __m512d a2_vec = _mm512_setzero_pd(); __m512d b2_vec = _mm512_setzero_pd(); @@ -1601,7 +1646,8 @@ SIMSIMD_PUBLIC void simsimd_cos_f64_skylake(simsimd_f64_t const* a, simsimd_f64_ a_vec = _mm512_maskz_loadu_pd(mask, a); b_vec = _mm512_maskz_loadu_pd(mask, b); n = 0; - } else { + } + else { a_vec = _mm512_loadu_pd(a); b_vec = _mm512_loadu_pd(b); a += 8, b += 8, n -= 8; @@ -1609,8 +1655,7 @@ SIMSIMD_PUBLIC void simsimd_cos_f64_skylake(simsimd_f64_t const* a, simsimd_f64_ ab_vec = _mm512_fmadd_pd(a_vec, b_vec, ab_vec); a2_vec = _mm512_fmadd_pd(a_vec, a_vec, a2_vec); b2_vec = _mm512_fmadd_pd(b_vec, b_vec, b2_vec); - if (n) - goto simsimd_cos_f64_skylake_cycle; + if (n) goto simsimd_cos_f64_skylake_cycle; simsimd_f64_t ab = _mm512_reduce_add_pd(ab_vec); simsimd_f64_t a2 = _mm512_reduce_add_pd(a2_vec); @@ -1625,7 +1670,7 @@ SIMSIMD_PUBLIC void simsimd_cos_f64_skylake(simsimd_f64_t const* a, simsimd_f64_ #if SIMSIMD_TARGET_GENOA #pragma GCC push_options #pragma GCC target("avx2", "avx512f", "avx512vl", "bmi2", "avx512bw", "avx512bf16") -#pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,bmi2,avx512bw,avx512bf16"))), \ +#pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,bmi2,avx512bw,avx512bf16"))), \ apply_to = function) SIMSIMD_INTERNAL __m512i _simsimd_substract_bf16x32_genoa(__m512i a_i16, __m512i b_i16) { @@ -1680,13 +1725,13 @@ SIMSIMD_INTERNAL __m512i _simsimd_substract_bf16x32_genoa(__m512i a_i16, __m512i return d.ivec; } -SIMSIMD_PUBLIC void simsimd_l2_bf16_genoa(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2_bf16_genoa(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_l2sq_bf16_genoa(a, b, n, result); *result = _simsimd_sqrt_f32_haswell(*result); } -SIMSIMD_PUBLIC void simsimd_l2sq_bf16_genoa(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2sq_bf16_genoa(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m512 d2_vec = _mm512_setzero_ps(); __m512i a_i16_vec, b_i16_vec, d_i16_vec; @@ -1696,21 +1741,21 @@ SIMSIMD_PUBLIC void simsimd_l2sq_bf16_genoa(simsimd_bf16_t const* a, simsimd_bf1 a_i16_vec = _mm512_maskz_loadu_epi16(mask, a); b_i16_vec = _mm512_maskz_loadu_epi16(mask, b); n = 0; - } else { + } + else { a_i16_vec = _mm512_loadu_epi16(a); b_i16_vec = _mm512_loadu_epi16(b); a += 32, b += 32, n -= 32; } d_i16_vec = _simsimd_substract_bf16x32_genoa(a_i16_vec, b_i16_vec); d2_vec = _mm512_dpbf16_ps(d2_vec, (__m512bh)(d_i16_vec), (__m512bh)(d_i16_vec)); - if (n) - goto simsimd_l2sq_bf16_genoa_cycle; + if (n) goto simsimd_l2sq_bf16_genoa_cycle; *result = _simsimd_reduce_f32x16_skylake(d2_vec); } -SIMSIMD_PUBLIC void simsimd_cos_bf16_genoa(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_cos_bf16_genoa(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m512 ab_vec = _mm512_setzero_ps(); __m512 a2_vec = _mm512_setzero_ps(); __m512 b2_vec = _mm512_setzero_ps(); @@ -1722,7 +1767,8 @@ SIMSIMD_PUBLIC void simsimd_cos_bf16_genoa(simsimd_bf16_t const* a, simsimd_bf16 a_i16_vec = _mm512_maskz_loadu_epi16(mask, a); b_i16_vec = _mm512_maskz_loadu_epi16(mask, b); n = 0; - } else { + } + else { a_i16_vec = _mm512_loadu_epi16(a); b_i16_vec = _mm512_loadu_epi16(b); a += 32, b += 32, n -= 32; @@ -1730,8 +1776,7 @@ SIMSIMD_PUBLIC void simsimd_cos_bf16_genoa(simsimd_bf16_t const* a, simsimd_bf16 ab_vec = _mm512_dpbf16_ps(ab_vec, (__m512bh)(a_i16_vec), (__m512bh)(b_i16_vec)); a2_vec = _mm512_dpbf16_ps(a2_vec, (__m512bh)(a_i16_vec), (__m512bh)(a_i16_vec)); b2_vec = _mm512_dpbf16_ps(b2_vec, (__m512bh)(b_i16_vec), (__m512bh)(b_i16_vec)); - if (n) - goto simsimd_cos_bf16_genoa_cycle; + if (n) goto simsimd_cos_bf16_genoa_cycle; simsimd_f32_t ab = _simsimd_reduce_f32x16_skylake(ab_vec); simsimd_f32_t a2 = _simsimd_reduce_f32x16_skylake(a2_vec); @@ -1748,13 +1793,13 @@ SIMSIMD_PUBLIC void simsimd_cos_bf16_genoa(simsimd_bf16_t const* a, simsimd_bf16 #pragma GCC target("avx2", "avx512f", "avx512vl", "bmi2", "avx512fp16") #pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,bmi2,avx512fp16"))), apply_to = function) -SIMSIMD_PUBLIC void simsimd_l2_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2_f16_sapphire(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_l2sq_f16_sapphire(a, b, n, result); *result = _simsimd_sqrt_f32_haswell(*result); } -SIMSIMD_PUBLIC void simsimd_l2sq_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2sq_f16_sapphire(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m512h d2_vec = _mm512_setzero_ph(); __m512i a_i16_vec, b_i16_vec; @@ -1764,21 +1809,21 @@ SIMSIMD_PUBLIC void simsimd_l2sq_f16_sapphire(simsimd_f16_t const* a, simsimd_f1 a_i16_vec = _mm512_maskz_loadu_epi16(mask, a); b_i16_vec = _mm512_maskz_loadu_epi16(mask, b); n = 0; - } else { + } + else { a_i16_vec = _mm512_loadu_epi16(a); b_i16_vec = _mm512_loadu_epi16(b); a += 32, b += 32, n -= 32; } __m512h d_vec = _mm512_sub_ph(_mm512_castsi512_ph(a_i16_vec), _mm512_castsi512_ph(b_i16_vec)); d2_vec = _mm512_fmadd_ph(d_vec, d_vec, d2_vec); - if (n) - goto simsimd_l2sq_f16_sapphire_cycle; + if (n) goto simsimd_l2sq_f16_sapphire_cycle; *result = _mm512_reduce_add_ph(d2_vec); } -SIMSIMD_PUBLIC void simsimd_cos_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_cos_f16_sapphire(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m512h ab_vec = _mm512_setzero_ph(); __m512h a2_vec = _mm512_setzero_ph(); __m512h b2_vec = _mm512_setzero_ph(); @@ -1790,7 +1835,8 @@ SIMSIMD_PUBLIC void simsimd_cos_f16_sapphire(simsimd_f16_t const* a, simsimd_f16 a_i16_vec = _mm512_maskz_loadu_epi16(mask, a); b_i16_vec = _mm512_maskz_loadu_epi16(mask, b); n = 0; - } else { + } + else { a_i16_vec = _mm512_loadu_epi16(a); b_i16_vec = _mm512_loadu_epi16(b); a += 32, b += 32, n -= 32; @@ -1798,8 +1844,7 @@ SIMSIMD_PUBLIC void simsimd_cos_f16_sapphire(simsimd_f16_t const* a, simsimd_f16 ab_vec = _mm512_fmadd_ph(_mm512_castsi512_ph(a_i16_vec), _mm512_castsi512_ph(b_i16_vec), ab_vec); a2_vec = _mm512_fmadd_ph(_mm512_castsi512_ph(a_i16_vec), _mm512_castsi512_ph(a_i16_vec), a2_vec); b2_vec = _mm512_fmadd_ph(_mm512_castsi512_ph(b_i16_vec), _mm512_castsi512_ph(b_i16_vec), b2_vec); - if (n) - goto simsimd_cos_f16_sapphire_cycle; + if (n) goto simsimd_cos_f16_sapphire_cycle; simsimd_f32_t ab = _mm512_reduce_add_ph(ab_vec); simsimd_f32_t a2 = _mm512_reduce_add_ph(a2_vec); @@ -1814,16 +1859,16 @@ SIMSIMD_PUBLIC void simsimd_cos_f16_sapphire(simsimd_f16_t const* a, simsimd_f16 #if SIMSIMD_TARGET_ICE #pragma GCC push_options #pragma GCC target("avx2", "avx512f", "avx512vl", "bmi2", "avx512bw", "avx512vnni") -#pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,bmi2,avx512bw,avx512vnni"))), \ +#pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,bmi2,avx512bw,avx512vnni"))), \ apply_to = function) -SIMSIMD_PUBLIC void simsimd_l2_i8_ice(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2_i8_ice(simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_l2sq_i8_ice(a, b, n, result); *result = _simsimd_sqrt_f32_haswell(*result); } -SIMSIMD_PUBLIC void simsimd_l2sq_i8_ice(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2sq_i8_ice(simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m512i d2_i32_vec = _mm512_setzero_si512(); __m512i a_i16_vec, b_i16_vec, d_i16s_vec; @@ -1833,21 +1878,21 @@ SIMSIMD_PUBLIC void simsimd_l2sq_i8_ice(simsimd_i8_t const* a, simsimd_i8_t cons a_i16_vec = _mm512_cvtepi8_epi16(_mm256_maskz_loadu_epi8(mask, a)); b_i16_vec = _mm512_cvtepi8_epi16(_mm256_maskz_loadu_epi8(mask, b)); n = 0; - } else { - a_i16_vec = _mm512_cvtepi8_epi16(_mm256_lddqu_si256((__m256i const*)a)); - b_i16_vec = _mm512_cvtepi8_epi16(_mm256_lddqu_si256((__m256i const*)b)); + } + else { + a_i16_vec = _mm512_cvtepi8_epi16(_mm256_lddqu_si256((__m256i const *)a)); + b_i16_vec = _mm512_cvtepi8_epi16(_mm256_lddqu_si256((__m256i const *)b)); a += 32, b += 32, n -= 32; } d_i16s_vec = _mm512_sub_epi16(a_i16_vec, b_i16_vec); d2_i32_vec = _mm512_dpwssd_epi32(d2_i32_vec, d_i16s_vec, d_i16s_vec); - if (n) - goto simsimd_l2sq_i8_ice_cycle; + if (n) goto simsimd_l2sq_i8_ice_cycle; *result = _mm512_reduce_add_epi32(d2_i32_vec); } -SIMSIMD_PUBLIC void simsimd_cos_i8_ice(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_cos_i8_ice(simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m512i ab_i32_vec = _mm512_setzero_si512(); __m512i a2_i32_vec = _mm512_setzero_si512(); @@ -1859,9 +1904,10 @@ SIMSIMD_PUBLIC void simsimd_cos_i8_ice(simsimd_i8_t const* a, simsimd_i8_t const a_i16_vec = _mm512_cvtepi8_epi16(_mm256_maskz_loadu_epi8(mask, a)); b_i16_vec = _mm512_cvtepi8_epi16(_mm256_maskz_loadu_epi8(mask, b)); n = 0; - } else { - a_i16_vec = _mm512_cvtepi8_epi16(_mm256_lddqu_si256((__m256i const*)a)); - b_i16_vec = _mm512_cvtepi8_epi16(_mm256_lddqu_si256((__m256i const*)b)); + } + else { + a_i16_vec = _mm512_cvtepi8_epi16(_mm256_lddqu_si256((__m256i const *)a)); + b_i16_vec = _mm512_cvtepi8_epi16(_mm256_lddqu_si256((__m256i const *)b)); a += 32, b += 32, n -= 32; } @@ -1907,21 +1953,20 @@ SIMSIMD_PUBLIC void simsimd_cos_i8_ice(simsimd_i8_t const* a, simsimd_i8_t const ab_i32_vec = _mm512_add_epi32(ab_i32_vec, _mm512_madd_epi16(a_i16_vec, b_i16_vec)); a2_i32_vec = _mm512_add_epi32(a2_i32_vec, _mm512_madd_epi16(a_i16_vec, a_i16_vec)); b2_i32_vec = _mm512_add_epi32(b2_i32_vec, _mm512_madd_epi16(b_i16_vec, b_i16_vec)); - if (n) - goto simsimd_cos_i8_ice_cycle; + if (n) goto simsimd_cos_i8_ice_cycle; int ab = _mm512_reduce_add_epi32(ab_i32_vec); int a2 = _mm512_reduce_add_epi32(a2_i32_vec); int b2 = _mm512_reduce_add_epi32(b2_i32_vec); *result = _simsimd_cos_normalize_f32_haswell(ab, a2, b2); } -SIMSIMD_PUBLIC void simsimd_l2_u8_ice(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2_u8_ice(simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { simsimd_l2sq_u8_ice(a, b, n, result); *result = _simsimd_sqrt_f32_haswell(*result); } -SIMSIMD_PUBLIC void simsimd_l2sq_u8_ice(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2sq_u8_ice(simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m512i d2_i32_low_vec = _mm512_setzero_si512(); __m512i d2_i32_high_vec = _mm512_setzero_si512(); __m512i const zeros_vec = _mm512_setzero_si512(); @@ -1934,7 +1979,8 @@ SIMSIMD_PUBLIC void simsimd_l2sq_u8_ice(simsimd_u8_t const* a, simsimd_u8_t cons a_u8_vec = _mm512_maskz_loadu_epi8(mask, a); b_u8_vec = _mm512_maskz_loadu_epi8(mask, b); n = 0; - } else { + } + else { a_u8_vec = _mm512_loadu_si512(a); b_u8_vec = _mm512_loadu_si512(b); a += 64, b += 64, n -= 64; @@ -1948,14 +1994,13 @@ SIMSIMD_PUBLIC void simsimd_l2sq_u8_ice(simsimd_u8_t const* a, simsimd_u8_t cons // Multiply and accumulate at `int16` level, accumulate at `int32` level: d2_i32_low_vec = _mm512_dpwssd_epi32(d2_i32_low_vec, d_i16_low_vec, d_i16_low_vec); d2_i32_high_vec = _mm512_dpwssd_epi32(d2_i32_high_vec, d_i16_high_vec, d_i16_high_vec); - if (n) - goto simsimd_l2sq_u8_ice_cycle; + if (n) goto simsimd_l2sq_u8_ice_cycle; *result = _mm512_reduce_add_epi32(_mm512_add_epi32(d2_i32_low_vec, d2_i32_high_vec)); } -SIMSIMD_PUBLIC void simsimd_cos_u8_ice(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_cos_u8_ice(simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m512i ab_i32_low_vec = _mm512_setzero_si512(); __m512i ab_i32_high_vec = _mm512_setzero_si512(); @@ -1973,7 +2018,8 @@ SIMSIMD_PUBLIC void simsimd_cos_u8_ice(simsimd_u8_t const* a, simsimd_u8_t const a_u8_vec = _mm512_maskz_loadu_epi8(mask, a); b_u8_vec = _mm512_maskz_loadu_epi8(mask, b); n = 0; - } else { + } + else { a_u8_vec = _mm512_loadu_si512(a); b_u8_vec = _mm512_loadu_si512(b); a += 64, b += 64, n -= 64; @@ -1993,8 +2039,7 @@ SIMSIMD_PUBLIC void simsimd_cos_u8_ice(simsimd_u8_t const* a, simsimd_u8_t const a2_i32_high_vec = _mm512_dpwssds_epi32(a2_i32_high_vec, a_i16_high_vec, a_i16_high_vec); b2_i32_low_vec = _mm512_dpwssds_epi32(b2_i32_low_vec, b_i16_low_vec, b_i16_low_vec); b2_i32_high_vec = _mm512_dpwssds_epi32(b2_i32_high_vec, b_i16_high_vec, b_i16_high_vec); - if (n) - goto simsimd_cos_u8_ice_cycle; + if (n) goto simsimd_cos_u8_ice_cycle; int ab = _mm512_reduce_add_epi32(_mm512_add_epi32(ab_i32_low_vec, ab_i32_high_vec)); int a2 = _mm512_reduce_add_epi32(_mm512_add_epi32(a2_i32_low_vec, a2_i32_high_vec)); @@ -2002,13 +2047,13 @@ SIMSIMD_PUBLIC void simsimd_cos_u8_ice(simsimd_u8_t const* a, simsimd_u8_t const *result = _simsimd_cos_normalize_f32_haswell(ab, a2, b2); } -SIMSIMD_PUBLIC void simsimd_l2_i4x2_ice(simsimd_i4x2_t const* a, simsimd_i4x2_t const* b, simsimd_size_t n_words, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2_i4x2_ice(simsimd_i4x2_t const *a, simsimd_i4x2_t const *b, simsimd_size_t n_words, + simsimd_distance_t *result) { simsimd_l2sq_i4x2_ice(a, b, n_words, result); *result = _simsimd_sqrt_f32_haswell(*result); } -SIMSIMD_PUBLIC void simsimd_l2sq_i4x2_ice(simsimd_i4x2_t const* a, simsimd_i4x2_t const* b, simsimd_size_t n_words, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_l2sq_i4x2_ice(simsimd_i4x2_t const *a, simsimd_i4x2_t const *b, simsimd_size_t n_words, + simsimd_distance_t *result) { // While `int8_t` covers the range [-128, 127], `int4_t` covers only [-8, 7]. // The absolute difference between two 4-bit integers is at most 15 and it is always a `uint4_t` value! @@ -2044,7 +2089,8 @@ SIMSIMD_PUBLIC void simsimd_l2sq_i4x2_ice(simsimd_i4x2_t const* a, simsimd_i4x2_ a_i4x2_vec = _mm512_maskz_loadu_epi8(mask, a); b_i4x2_vec = _mm512_maskz_loadu_epi8(mask, b); n_words = 0; - } else { + } + else { a_i4x2_vec = _mm512_loadu_epi8(a); b_i4x2_vec = _mm512_loadu_epi8(b); a += 64, b += 64, n_words -= 64; @@ -2081,15 +2127,14 @@ SIMSIMD_PUBLIC void simsimd_l2sq_i4x2_ice(simsimd_i4x2_t const* a, simsimd_i4x2_ _mm512_unpackhi_epi8(d2_u8_high_vec, _mm512_setzero_si512())); d2_u32_vec = _mm512_add_epi32(d2_u32_vec, _mm512_unpacklo_epi16(d2_u16_low_vec, _mm512_setzero_si512())); d2_u32_vec = _mm512_add_epi32(d2_u32_vec, _mm512_unpacklo_epi16(d2_u16_high_vec, _mm512_setzero_si512())); - if (n_words) - goto simsimd_l2sq_i4x2_ice_cycle; + if (n_words) goto simsimd_l2sq_i4x2_ice_cycle; // Finally, we can reduce the 16-bit integers to 32-bit integers and sum them up. int d2 = _mm512_reduce_add_epi32(d2_u32_vec); *result = d2; } -SIMSIMD_PUBLIC void simsimd_cos_i4x2_ice(simsimd_i4x2_t const* a, simsimd_i4x2_t const* b, simsimd_size_t n_words, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_cos_i4x2_ice(simsimd_i4x2_t const *a, simsimd_i4x2_t const *b, simsimd_size_t n_words, + simsimd_distance_t *result) { // We need to compose a lookup table for all the scalar products of 4-bit integers. // While `int8_t` covers the range [-128, 127], `int4_t` covers only [-8, 7]. @@ -2158,7 +2203,8 @@ SIMSIMD_PUBLIC void simsimd_cos_i4x2_ice(simsimd_i4x2_t const* a, simsimd_i4x2_t a_i4x2_vec = _mm512_maskz_loadu_epi8(mask, a); b_i4x2_vec = _mm512_maskz_loadu_epi8(mask, b); n_words = 0; - } else { + } + else { a_i4x2_vec = _mm512_loadu_epi8(a); b_i4x2_vec = _mm512_loadu_epi8(b); a += 64, b += 64, n_words -= 64; @@ -2217,15 +2263,13 @@ SIMSIMD_PUBLIC void simsimd_cos_i4x2_ice(simsimd_i4x2_t const* a, simsimd_i4x2_t ab_i32_high_vec, // _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(a_i8_high_vec, 1)), // _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(b_i8_high_vec, 1))); - if (n_words) - goto simsimd_cos_i4x2_ice_cycle; + if (n_words) goto simsimd_cos_i4x2_ice_cycle; int ab = _mm512_reduce_add_epi32(_mm512_add_epi32(ab_i32_low_vec, ab_i32_high_vec)); unsigned short a2_u16[32], b2_u16[32]; _mm512_storeu_si512(a2_u16, _mm512_add_epi16(a2_u16_low_vec, a2_u16_high_vec)); unsigned int a2 = 0, b2 = 0; - for (int i = 0; i < 32; ++i) - a2 += a2_u16[i], b2 += b2_u16[i]; + for (int i = 0; i < 32; ++i) a2 += a2_u16[i], b2 += b2_u16[i]; *result = _simsimd_cos_normalize_f32_haswell(ab, a2, b2); } @@ -2238,8 +2282,8 @@ SIMSIMD_PUBLIC void simsimd_cos_i4x2_ice(simsimd_i4x2_t const* a, simsimd_i4x2_t #pragma GCC target("avx2", "bmi2", "avx2vnni") #pragma clang attribute push(__attribute__((target("avx2,bmi2,avx2vnni"))), apply_to = function) -SIMSIMD_PUBLIC void simsimd_cos_i8_sierra(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, - simsimd_distance_t* result) { +SIMSIMD_PUBLIC void simsimd_cos_i8_sierra(simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, + simsimd_distance_t *result) { __m256i ab_i32_vec = _mm256_setzero_si256(); __m256i a2_i32_vec = _mm256_setzero_si256(); @@ -2247,8 +2291,8 @@ SIMSIMD_PUBLIC void simsimd_cos_i8_sierra(simsimd_i8_t const* a, simsimd_i8_t co simsimd_size_t i = 0; for (; i + 32 <= n; i += 32) { - __m256i a_i8_vec = _mm256_lddqu_si256((__m256i const*)(a + i)); - __m256i b_i8_vec = _mm256_lddqu_si256((__m256i const*)(b + i)); + __m256i a_i8_vec = _mm256_lddqu_si256((__m256i const *)(a + i)); + __m256i b_i8_vec = _mm256_lddqu_si256((__m256i const *)(b + i)); ab_i32_vec = _mm256_dpbssds_epi32(ab_i32_vec, a_i8_vec, b_i8_vec); a2_i32_vec = _mm256_dpbssds_epi32(a2_i32_vec, a_i8_vec, a_i8_vec); b2_i32_vec = _mm256_dpbssds_epi32(b2_i32_vec, b_i8_vec, b_i8_vec); @@ -2271,7 +2315,7 @@ SIMSIMD_PUBLIC void simsimd_cos_i8_sierra(simsimd_i8_t const* a, simsimd_i8_t co #pragma clang attribute pop #pragma GCC pop_options #endif // SIMSIMD_TARGET_SIERRA -#endif // SIMSIMD_TARGET_X86 +#endif // _SIMSIMD_TARGET_X86 #ifdef __cplusplus } diff --git a/include/simsimd/types.h b/include/simsimd/types.h index 0447ccbc..8c2fc5cb 100644 --- a/include/simsimd/types.h +++ b/include/simsimd/types.h @@ -6,18 +6,19 @@ * * Defines: * - Sized aliases for numeric types, like: `simsimd_i32_t` and `simsimd_f64_t`. - * - Macros for compiler/hardware checks, like: `SIMSIMD_TARGET_NEON` + * - Macros for internal compiler/hardware checks, like: `_SIMSIMD_TARGET_ARM`. + * - Macros for feature controls, like: `SIMSIMD_TARGET_NEON` */ #ifndef SIMSIMD_TYPES_H #define SIMSIMD_TYPES_H // Inferring target OS: Windows, MacOS, or Linux #if defined(WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(__NT__) -#define SIMSIMD_DEFINED_WINDOWS +#define _SIMSIMD_DEFINED_WINDOWS 1 #elif defined(__APPLE__) && defined(__MACH__) -#define SIMSIMD_DEFINED_APPLE +#define _SIMSIMD_DEFINED_APPLE 1 #elif defined(__linux__) -#define SIMSIMD_DEFINED_LINUX +#define _SIMSIMD_DEFINED_LINUX 1 #endif // Annotation for the public API symbols: @@ -40,73 +41,113 @@ #define SIMSIMD_INTERNAL inline static #endif -// Compiling for Arm: SIMSIMD_TARGET_ARM -#if !defined(SIMSIMD_TARGET_ARM) +// Compiling for Arm: _SIMSIMD_TARGET_ARM +#if !defined(_SIMSIMD_TARGET_ARM) #if defined(__aarch64__) || defined(_M_ARM64) -#define SIMSIMD_TARGET_ARM 1 +#define _SIMSIMD_TARGET_ARM 1 #else -#define SIMSIMD_TARGET_ARM 0 +#define _SIMSIMD_TARGET_ARM 0 #endif // defined(__aarch64__) || defined(_M_ARM64) -#endif // !defined(SIMSIMD_TARGET_ARM) +#endif // !defined(_SIMSIMD_TARGET_ARM) -// Compiling for x86: SIMSIMD_TARGET_X86 -#if !defined(SIMSIMD_TARGET_X86) +// Compiling for x86: _SIMSIMD_TARGET_X86 +#if !defined(_SIMSIMD_TARGET_X86) #if defined(__x86_64__) || defined(_M_X64) -#define SIMSIMD_TARGET_X86 1 +#define _SIMSIMD_TARGET_X86 1 #else -#define SIMSIMD_TARGET_X86 0 +#define _SIMSIMD_TARGET_X86 0 #endif // defined(__x86_64__) || defined(_M_X64) -#endif // !defined(SIMSIMD_TARGET_X86) +#endif // !defined(_SIMSIMD_TARGET_X86) // Compiling for Arm: SIMSIMD_TARGET_NEON -#if !defined(SIMSIMD_TARGET_NEON) || (SIMSIMD_TARGET_NEON && !SIMSIMD_TARGET_ARM) +#if !defined(SIMSIMD_TARGET_NEON) || (SIMSIMD_TARGET_NEON && !_SIMSIMD_TARGET_ARM) #if defined(__ARM_NEON) -#define SIMSIMD_TARGET_NEON SIMSIMD_TARGET_ARM +#define SIMSIMD_TARGET_NEON _SIMSIMD_TARGET_ARM #else #undef SIMSIMD_TARGET_NEON #define SIMSIMD_TARGET_NEON 0 #endif // defined(__ARM_NEON) -#endif // !defined(SIMSIMD_TARGET_NEON) - -#if !defined(SIMSIMD_TARGET_NEON_I8) -#define SIMSIMD_TARGET_NEON_I8 SIMSIMD_TARGET_NEON -#endif // !defined(SIMSIMD_TARGET_NEON_I8) -#if !defined(SIMSIMD_TARGET_NEON_F16) -#define SIMSIMD_TARGET_NEON_F16 SIMSIMD_TARGET_NEON -#endif // !defined(SIMSIMD_TARGET_NEON_F16) -#if !defined(SIMSIMD_TARGET_NEON_BF16) -#define SIMSIMD_TARGET_NEON_BF16 SIMSIMD_TARGET_NEON -#endif // !defined(SIMSIMD_TARGET_NEON_BF16) +#endif // !defined(SIMSIMD_TARGET_NEON) || ... + +// Compiling for Arm: SIMSIMD_TARGET_NEON_I8 +#if !defined(SIMSIMD_TARGET_NEON_I8) || (SIMSIMD_TARGET_NEON_I8 && !_SIMSIMD_TARGET_ARM) +#if defined(__ARM_NEON) +#define SIMSIMD_TARGET_NEON_I8 _SIMSIMD_TARGET_ARM +#else +#undef SIMSIMD_TARGET_NEON_I8 +#define SIMSIMD_TARGET_NEON_I8 0 +#endif // defined(__ARM_NEON) +#endif // !defined(SIMSIMD_TARGET_NEON_I8) || ... + +// Compiling for Arm: SIMSIMD_TARGET_NEON_F16 +#if !defined(SIMSIMD_TARGET_NEON_F16) || (SIMSIMD_TARGET_NEON_F16 && !_SIMSIMD_TARGET_ARM) +#if defined(__ARM_NEON) +#define SIMSIMD_TARGET_NEON_F16 _SIMSIMD_TARGET_ARM +#else +#undef SIMSIMD_TARGET_NEON_F16 +#define SIMSIMD_TARGET_NEON_F16 0 +#endif // defined(__ARM_NEON) +#endif // !defined(SIMSIMD_TARGET_NEON_F16) || ... + +// Compiling for Arm: SIMSIMD_TARGET_NEON_BF16 +#if !defined(SIMSIMD_TARGET_NEON_BF16) || (SIMSIMD_TARGET_NEON_BF16 && !_SIMSIMD_TARGET_ARM) +#if defined(__ARM_NEON) +#define SIMSIMD_TARGET_NEON_BF16 _SIMSIMD_TARGET_ARM +#else +#undef SIMSIMD_TARGET_NEON_BF16 +#define SIMSIMD_TARGET_NEON_BF16 0 +#endif // defined(__ARM_NEON) +#endif // !defined(SIMSIMD_TARGET_NEON_BF16) || ... // Compiling for Arm: SIMSIMD_TARGET_SVE -#if !defined(SIMSIMD_TARGET_SVE) || (SIMSIMD_TARGET_SVE && !SIMSIMD_TARGET_ARM) +#if !defined(SIMSIMD_TARGET_SVE) || (SIMSIMD_TARGET_SVE && !_SIMSIMD_TARGET_ARM) #if defined(__ARM_FEATURE_SVE) -#define SIMSIMD_TARGET_SVE SIMSIMD_TARGET_ARM +#define SIMSIMD_TARGET_SVE _SIMSIMD_TARGET_ARM #else #undef SIMSIMD_TARGET_SVE #define SIMSIMD_TARGET_SVE 0 #endif // defined(__ARM_FEATURE_SVE) -#endif // !defined(SIMSIMD_TARGET_SVE) - -#if !defined(SIMSIMD_TARGET_SVE_I8) -#define SIMSIMD_TARGET_SVE_I8 SIMSIMD_TARGET_SVE -#endif // !defined(SIMSIMD_TARGET_SVE_I8) -#if !defined(SIMSIMD_TARGET_SVE_F16) -#define SIMSIMD_TARGET_SVE_F16 SIMSIMD_TARGET_SVE -#endif // !defined(SIMSIMD_TARGET_SVE_F16) -#if !defined(SIMSIMD_TARGET_SVE_BF16) -#define SIMSIMD_TARGET_SVE_BF16 SIMSIMD_TARGET_SVE -#endif // !defined(SIMSIMD_TARGET_SVE_BF16) +#endif // !defined(SIMSIMD_TARGET_SVE) || ... + +// Compiling for Arm: SIMSIMD_TARGET_SVE_I8 +#if !defined(SIMSIMD_TARGET_SVE_I8) || (SIMSIMD_TARGET_SVE_I8 && !_SIMSIMD_TARGET_ARM) +#if defined(__ARM_FEATURE_SVE) +#define SIMSIMD_TARGET_SVE_I8 _SIMSIMD_TARGET_ARM +#else +#undef SIMSIMD_TARGET_SVE_I8 +#define SIMSIMD_TARGET_SVE_I8 0 +#endif // defined(__ARM_FEATURE_SVE) +#endif // !defined(SIMSIMD_TARGET_SVE_I8) || ... + +// Compiling for Arm: SIMSIMD_TARGET_SVE_F16 +#if !defined(SIMSIMD_TARGET_SVE_F16) || (SIMSIMD_TARGET_SVE_F16 && !_SIMSIMD_TARGET_ARM) +#if defined(__ARM_FEATURE_SVE) +#define SIMSIMD_TARGET_SVE_F16 _SIMSIMD_TARGET_ARM +#else +#undef SIMSIMD_TARGET_SVE_F16 +#define SIMSIMD_TARGET_SVE_F16 0 +#endif // defined(__ARM_FEATURE_SVE) +#endif // !defined(SIMSIMD_TARGET_SVE_F16) || ... + +// Compiling for Arm: SIMSIMD_TARGET_SVE_BF16 +#if !defined(SIMSIMD_TARGET_SVE_BF16) || (SIMSIMD_TARGET_SVE_BF16 && !_SIMSIMD_TARGET_ARM) +#if defined(__ARM_FEATURE_SVE) +#define SIMSIMD_TARGET_SVE_BF16 _SIMSIMD_TARGET_ARM +#else +#undef SIMSIMD_TARGET_SVE_BF16 +#define SIMSIMD_TARGET_SVE_BF16 0 +#endif // defined(__ARM_FEATURE_SVE) +#endif // !defined(SIMSIMD_TARGET_SVE_BF16) || ... // Compiling for Arm: SIMSIMD_TARGET_SVE2 -#if !defined(SIMSIMD_TARGET_SVE2) || (SIMSIMD_TARGET_SVE2 && !SIMSIMD_TARGET_ARM) +#if !defined(SIMSIMD_TARGET_SVE2) || (SIMSIMD_TARGET_SVE2 && !_SIMSIMD_TARGET_ARM) #if defined(__ARM_FEATURE_SVE) -#define SIMSIMD_TARGET_SVE2 SIMSIMD_TARGET_ARM +#define SIMSIMD_TARGET_SVE2 _SIMSIMD_TARGET_ARM #else #undef SIMSIMD_TARGET_SVE2 #define SIMSIMD_TARGET_SVE2 0 #endif // defined(__ARM_FEATURE_SVE) -#endif // !defined(SIMSIMD_TARGET_SVE2) +#endif // !defined(SIMSIMD_TARGET_SVE2) || ... // Compiling for x86: SIMSIMD_TARGET_HASWELL // @@ -115,14 +156,14 @@ // are supported on all CPUs starting with Jaguar 2009. // Starting with Sandy Bridge, Intel adds basic AVX support in their CPUs and in 2013 // extends it with AVX2 in the Haswell generation. Moreover, Haswell adds FMA support. -#if !defined(SIMSIMD_TARGET_HASWELL) || (SIMSIMD_TARGET_HASWELL && !SIMSIMD_TARGET_X86) +#if !defined(SIMSIMD_TARGET_HASWELL) || (SIMSIMD_TARGET_HASWELL && !_SIMSIMD_TARGET_X86) #if defined(__AVX2__) && defined(__FMA__) && defined(__F16C__) #define SIMSIMD_TARGET_HASWELL 1 #else #undef SIMSIMD_TARGET_HASWELL #define SIMSIMD_TARGET_HASWELL 0 #endif // defined(__AVX2__) -#endif // !defined(SIMSIMD_TARGET_HASWELL) +#endif // !defined(SIMSIMD_TARGET_HASWELL) || ... // Compiling for x86: SIMSIMD_TARGET_SKYLAKE, SIMSIMD_TARGET_ICE, SIMSIMD_TARGET_GENOA, // SIMSIMD_TARGET_SAPPHIRE, SIMSIMD_TARGET_TURIN, SIMSIMD_TARGET_SIERRA @@ -131,56 +172,56 @@ // gcc-12 -march=sapphirerapids -dM -E - < /dev/null | egrep "SSE|AVX" | sort // On Arm machines you may want to check for other flags: // gcc-12 -march=native -dM -E - < /dev/null | egrep "NEON|SVE|FP16|FMA" | sort -#if !defined(SIMSIMD_TARGET_SKYLAKE) || (SIMSIMD_TARGET_SKYLAKE && !SIMSIMD_TARGET_X86) -#if defined(__AVX512F__) && defined(__AVX512CD__) && defined(__AVX512VL__) && defined(__AVX512DQ__) && \ +#if !defined(SIMSIMD_TARGET_SKYLAKE) || (SIMSIMD_TARGET_SKYLAKE && !_SIMSIMD_TARGET_X86) +#if defined(__AVX512F__) && defined(__AVX512CD__) && defined(__AVX512VL__) && defined(__AVX512DQ__) && \ defined(__AVX512BW__) #define SIMSIMD_TARGET_SKYLAKE 1 #else #undef SIMSIMD_TARGET_SKYLAKE #define SIMSIMD_TARGET_SKYLAKE 0 #endif -#endif // !defined(SIMSIMD_TARGET_SKYLAKE) -#if !defined(SIMSIMD_TARGET_ICE) || (SIMSIMD_TARGET_ICE && !SIMSIMD_TARGET_X86) -#if defined(__AVX512VNNI__) && defined(__AVX512IFMA__) && defined(__AVX512BITALG__) && defined(__AVX512VBMI2__) && \ +#endif // !defined(SIMSIMD_TARGET_SKYLAKE) || ... +#if !defined(SIMSIMD_TARGET_ICE) || (SIMSIMD_TARGET_ICE && !_SIMSIMD_TARGET_X86) +#if defined(__AVX512VNNI__) && defined(__AVX512IFMA__) && defined(__AVX512BITALG__) && defined(__AVX512VBMI2__) && \ defined(__AVX512VPOPCNTDQ__) #define SIMSIMD_TARGET_ICE 1 #else #undef SIMSIMD_TARGET_ICE #define SIMSIMD_TARGET_ICE 0 #endif -#endif // !defined(SIMSIMD_TARGET_ICE) -#if !defined(SIMSIMD_TARGET_GENOA) || (SIMSIMD_TARGET_GENOA && !SIMSIMD_TARGET_X86) +#endif // !defined(SIMSIMD_TARGET_ICE) || ... +#if !defined(SIMSIMD_TARGET_GENOA) || (SIMSIMD_TARGET_GENOA && !_SIMSIMD_TARGET_X86) #if defined(__AVX512BF16__) #define SIMSIMD_TARGET_GENOA 1 #else #undef SIMSIMD_TARGET_GENOA #define SIMSIMD_TARGET_GENOA 0 #endif -#endif // !defined(SIMSIMD_TARGET_GENOA) -#if !defined(SIMSIMD_TARGET_SAPPHIRE) || (SIMSIMD_TARGET_SAPPHIRE && !SIMSIMD_TARGET_X86) +#endif // !defined(SIMSIMD_TARGET_GENOA) || ... +#if !defined(SIMSIMD_TARGET_SAPPHIRE) || (SIMSIMD_TARGET_SAPPHIRE && !_SIMSIMD_TARGET_X86) #if defined(__AVX512FP16__) #define SIMSIMD_TARGET_SAPPHIRE 1 #else #undef SIMSIMD_TARGET_SAPPHIRE #define SIMSIMD_TARGET_SAPPHIRE 0 #endif -#endif // !defined(SIMSIMD_TARGET_SAPPHIRE) -#if !defined(SIMSIMD_TARGET_TURIN) || (SIMSIMD_TARGET_TURIN && !SIMSIMD_TARGET_X86) +#endif // !defined(SIMSIMD_TARGET_SAPPHIRE) || ... +#if !defined(SIMSIMD_TARGET_TURIN) || (SIMSIMD_TARGET_TURIN && !_SIMSIMD_TARGET_X86) #if defined(__AVX512VP2INTERSECT__) #define SIMSIMD_TARGET_TURIN 1 #else #undef SIMSIMD_TARGET_TURIN #define SIMSIMD_TARGET_TURIN 0 #endif -#endif // !defined(SIMSIMD_TARGET_TURIN) -#if !defined(SIMSIMD_TARGET_SIERRA) || (SIMSIMD_TARGET_SIERRA && !SIMSIMD_TARGET_X86) +#endif // !defined(SIMSIMD_TARGET_TURIN) || ... +#if !defined(SIMSIMD_TARGET_SIERRA) || (SIMSIMD_TARGET_SIERRA && !_SIMSIMD_TARGET_X86) #if defined(__AVX2_VNNI__) #define SIMSIMD_TARGET_SIERRA 1 #else #undef SIMSIMD_TARGET_SIERRA #define SIMSIMD_TARGET_SIERRA 0 #endif -#endif // !defined(SIMSIMD_TARGET_SIERRA) +#endif // !defined(SIMSIMD_TARGET_SIERRA) || ... #ifdef _MSC_VER #include @@ -194,7 +235,7 @@ #include #endif -#if SIMSIMD_TARGET_HASWELL || SIMSIMD_TARGET_SKYLAKE || SIMSIMD_TARGET_ICE || SIMSIMD_TARGET_GENOA || \ +#if SIMSIMD_TARGET_HASWELL || SIMSIMD_TARGET_SKYLAKE || SIMSIMD_TARGET_ICE || SIMSIMD_TARGET_GENOA || \ SIMSIMD_TARGET_SAPPHIRE || SIMSIMD_TARGET_TURIN #include #endif @@ -253,12 +294,12 @@ typedef simsimd_f64_t simsimd_distance_t; * - Default: `unsigned short`. */ #if !defined(SIMSIMD_NATIVE_F16) || SIMSIMD_NATIVE_F16 -#if (defined(__GNUC__) || defined(__clang__)) && (defined(__ARM_ARCH) || defined(__aarch64__)) && \ +#if (defined(__GNUC__) || defined(__clang__)) && (defined(__ARM_ARCH) || defined(__aarch64__)) && \ (defined(__ARM_FP16_FORMAT_IEEE)) #undef SIMSIMD_NATIVE_F16 #define SIMSIMD_NATIVE_F16 1 typedef __fp16 simsimd_f16_t; -#elif ((defined(__GNUC__) || defined(__clang__)) && (defined(__x86_64__) || defined(__i386__)) && \ +#elif ((defined(__GNUC__) || defined(__clang__)) && (defined(__x86_64__) || defined(__i386__)) && \ (defined(__AVX512FP16__))) typedef _Float16 simsimd_f16_t; #undef SIMSIMD_NATIVE_F16 @@ -301,12 +342,12 @@ typedef unsigned short simsimd_f16_t; * https://forums.developer.apple.com/forums/thread/726201 * https://www.phoronix.com/news/GCC-LLVM-bf16-BFloat16-Type */ -#if (defined(__GNUC__) || defined(__clang__)) && (defined(__ARM_ARCH) || defined(__aarch64__)) && \ +#if (defined(__GNUC__) || defined(__clang__)) && (defined(__ARM_ARCH) || defined(__aarch64__)) && \ (defined(__ARM_BF16_FORMAT_ALTERNATIVE)) #undef SIMSIMD_NATIVE_BF16 #define SIMSIMD_NATIVE_BF16 1 typedef __bf16 simsimd_bf16_t; -#elif ((defined(__GNUC__) || defined(__clang__)) && (defined(__x86_64__) || defined(__i386__)) && \ +#elif ((defined(__GNUC__) || defined(__clang__)) && (defined(__x86_64__) || defined(__i386__)) && \ (defined(__AVX512BF16__))) typedef __bfloat16 simsimd_bf16_t; #undef SIMSIMD_NATIVE_BF16 @@ -333,7 +374,7 @@ typedef unsigned short simsimd_bf16_t; * Some of those are defined as aliases, so we use `#define` preprocessor * directives instead of `typedef` to avoid errors. */ -#if SIMSIMD_TARGET_ARM +#if _SIMSIMD_TARGET_ARM #if defined(_MSC_VER) #define simsimd_f16_for_arm_simd_t simsimd_f16_t #define simsimd_bf16_for_arm_simd_t simsimd_bf16_t @@ -394,6 +435,19 @@ SIMSIMD_STATIC_ASSERT(sizeof(simsimd_bf16_t) == 2, simsimd_bf16_t_must_be_2_byte #endif #endif +#if !defined(SIMSIMD_F32_TO_I8) +#define SIMSIMD_F32_TO_I8(x, y) *(y) = (simsimd_i8_t)fminf(fmaxf(roundf(x), -128), 127) +#endif +#if !defined(SIMSIMD_F32_TO_U8) +#define SIMSIMD_F32_TO_U8(x, y) *(y) = (simsimd_u8_t)fminf(fmaxf(roundf(x), 0), 255) +#endif +#if !defined(SIMSIMD_F64_TO_I8) +#define SIMSIMD_F64_TO_I8(x, y) *(y) = (simsimd_i8_t)fmin(fmax(round(x), -128), 127) +#endif +#if !defined(SIMSIMD_F64_TO_U8) +#define SIMSIMD_F64_TO_U8(x, y) *(y) = (simsimd_u8_t)fmin(fmax(round(x), 0), 255) +#endif + /** @brief Convenience type for half-precision floating-point type conversions. */ typedef union { unsigned i; @@ -443,8 +497,8 @@ SIMSIMD_PUBLIC simsimd_f32_t simsimd_approximate_log(simsimd_f32_t number) { * https://gist.github.com/milhidaka/95863906fe828198f47991c813dbe233 * https://github.com/OpenCyphal/libcanard/blob/636795f4bc395f56af8d2c61d3757b5e762bb9e5/canard.c#L811-L834 */ -SIMSIMD_PUBLIC simsimd_f32_t simsimd_f16_to_f32(simsimd_f16_t const* x_ptr) { - unsigned short x = *(unsigned short const*)x_ptr; +SIMSIMD_PUBLIC simsimd_f32_t simsimd_f16_to_f32(simsimd_f16_t const *x_ptr) { + unsigned short x = *(unsigned short const *)x_ptr; unsigned int exponent = (x & 0x7C00) >> 10; unsigned int mantissa = (x & 0x03FF) << 13; simsimd_f32i32_t mantissa_conv; @@ -465,7 +519,7 @@ SIMSIMD_PUBLIC simsimd_f32_t simsimd_f16_to_f32(simsimd_f16_t const* x_ptr) { * https://gist.github.com/milhidaka/95863906fe828198f47991c813dbe233 * https://github.com/OpenCyphal/libcanard/blob/636795f4bc395f56af8d2c61d3757b5e762bb9e5/canard.c#L811-L834 */ -SIMSIMD_PUBLIC void simsimd_f32_to_f16(simsimd_f32_t x, simsimd_f16_t* result_ptr) { +SIMSIMD_PUBLIC void simsimd_f32_to_f16(simsimd_f32_t x, simsimd_f16_t *result_ptr) { simsimd_f32i32_t conv; conv.f = x; unsigned int b = conv.i + 0x00001000; @@ -474,7 +528,7 @@ SIMSIMD_PUBLIC void simsimd_f32_to_f16(simsimd_f32_t x, simsimd_f16_t* result_pt unsigned short result = ((b & 0x80000000) >> 16) | (e > 112) * ((((e - 112) << 10) & 0x7C00) | (m >> 13)) | ((e < 113) & (e > 101)) * ((((0x007FF000 + m) >> (125 - e)) + 1) >> 1) | ((e > 143) * 0x7FFF); - *(unsigned short*)result_ptr = result; + *(unsigned short *)result_ptr = result; } /** @@ -484,8 +538,8 @@ SIMSIMD_PUBLIC void simsimd_f32_to_f16(simsimd_f32_t x, simsimd_f16_t* result_pt * https://stackoverflow.com/questions/55253233/convert-fp32-to-bfloat16-in-c/55254307#55254307 * https://cloud.google.com/blog/products/ai-machine-learning/bfloat16-the-secret-to-high-performance-on-cloud-tpus */ -SIMSIMD_PUBLIC simsimd_f32_t simsimd_bf16_to_f32(simsimd_bf16_t const* x_ptr) { - unsigned short x = *(unsigned short const*)x_ptr; +SIMSIMD_PUBLIC simsimd_f32_t simsimd_bf16_to_f32(simsimd_bf16_t const *x_ptr) { + unsigned short x = *(unsigned short const *)x_ptr; simsimd_f32i32_t conv; conv.i = x << 16; // Zero extends the mantissa return conv.f; @@ -497,14 +551,14 @@ SIMSIMD_PUBLIC simsimd_f32_t simsimd_bf16_to_f32(simsimd_bf16_t const* x_ptr) { * https://stackoverflow.com/questions/55253233/convert-fp32-to-bfloat16-in-c/55254307#55254307 * https://cloud.google.com/blog/products/ai-machine-learning/bfloat16-the-secret-to-high-performance-on-cloud-tpus */ -SIMSIMD_PUBLIC void simsimd_f32_to_bf16(simsimd_f32_t x, simsimd_bf16_t* result_ptr) { +SIMSIMD_PUBLIC void simsimd_f32_to_bf16(simsimd_f32_t x, simsimd_bf16_t *result_ptr) { simsimd_f32i32_t conv; conv.f = x; conv.i += 0x8000; // Rounding is optional conv.i >>= 16; // The top 16 bits will be zeroed out anyways // conv.i &= 0xFFFF; - *(unsigned short*)result_ptr = (unsigned short)conv.i; + *(unsigned short *)result_ptr = (unsigned short)conv.i; } SIMSIMD_PUBLIC simsimd_u32_t simsimd_u32_rol(simsimd_u32_t x, int n) { return (x << n) | (x >> (32 - n)); } diff --git a/javascript/lib.c b/javascript/lib.c index eacf8abd..1a461957 100644 --- a/javascript/lib.c +++ b/javascript/lib.c @@ -67,8 +67,7 @@ napi_value runAPI(napi_env env, napi_callback_info info, simsimd_metric_kind_t m // Convert the result to a JavaScript number napi_value js_result; status = napi_create_double(env, result, &js_result); - if (status != napi_ok) - return NULL; + if (status != napi_ok) return NULL; return js_result; } diff --git a/package.json b/package.json index e3f44bab..c114d028 100644 --- a/package.json +++ b/package.json @@ -1,7 +1,7 @@ { "name": "simsimd", "version": "5.8.0", - "description": "Fastest SIMD-Accelerated Vector Similarity Functions for x86 and Arm", + "description": "Portable mixed-precision BLAS-like vector math library for x86 and ARM", "homepage": "https://github.com/ashvardanian/simsimd", "author": "Ash Vardanian", "license": "Apache 2.0", diff --git a/python/lib.c b/python/lib.c index 9e2e87d7..debe289e 100644 --- a/python/lib.c +++ b/python/lib.c @@ -4,6 +4,64 @@ * @author Ash Vardanian * @date January 1, 2023 * @copyright Copyright (c) 2023 + * + * @section Latency, Quality, and Arguments Parsing + * + * The complexity of implementing high-quality CPython bindings is often underestimated. + * You can't use high-level wrappers like PyBind11 and NanoBind, and you shouldn't use + * SWIG-like messy toolchains. Most of them use expensive dynamic data-structures to map + * your callbacks to object/module properties, not taking advantage of the CPython API. + * They are prohibitively slow for low-latency operations like checking the length of a + * container, handling vectors, or strings. + * + * Once you are down to the CPython API, there is a lot of boilerplate code to write and + * it's understandable that most people lazily use the `PyArg_ParseTupleAndKeywords` and + * `PyArg_ParseTuple` functions. Those, however, need to dynamically parse format specifier + * strings at runtime, which @b can't be fast by design! Moreover, they are not suitable + * for the Python's "Fast Calling Convention". In a typical scenario, a function is defined + * with `METH_VARARGS | METH_KEYWORDS` and has a signature like: + * + * @code {.c} + * static PyObject* cdist( + * PyObject * self, + * PyObject * positional_args_tuple, + * PyObject * named_args_dict) { + * PyObject * a_obj, b_obj, metric_obj, out_obj, dtype_obj, out_dtype_obj, threads_obj; + * static char* names[] = {"a", "b", "metric", "threads", "dtype", "out_dtype", NULL}; + * if (!PyArg_ParseTupleAndKeywords( + * positional_args_tuple, named_args_dict, "OO|s$Kss", names, + * &a_obj, &b_obj, &metric_str, &threads, &dtype_str, &out_dtype_str)) + * return NULL; + * ... + * @endcode + * + * This `cdist` example takes 2 positional, 1 postional or named, 3 named-only arguments. + * The alternative using the `METH_FASTCALL` is to use a function signature like: + * + * @code {.c} + * static PyObject* cdist( + * PyObject * self, + * PyObject * const * args_c_array, //! C array of `args_count` pointers + * Py_ssize_t const positional_args_count, //! The `args_c_array` may be larger than this + * PyObject * args_names_tuple) { //! May be smaller than `args_count` + * Py_ssize_t args_names_count = args_names_tuple ? PyTuple_Size(args_names_tuple) : 0; + * Py_ssize_t args_count = positional_args_count + args_names_count; + * ... + * @endcode + * + * The positional elements are easy to access in that C array, but parsing the named arguments is tricky. + * There may be a case, when the call is ill-formed and more positional arguments are provided than needed. + * + * @code {.py} + * cdist(a, b, "cos", "dos"): //! positional_args_count == 4, args_names_count == 0 + * cdist(a, b, "cos", metric="dos"): //! positional_args_count == 3, args_names_count == 1 + * cdist(a, b, metric="cos", metric="dos"): //! positional_args_count == 2, args_names_count == 2 + * @endcode + * + * If the same argument is provided twice, a @b `TypeError` is raised. + * If the argument is not found, a @b `KeyError` is raised. + * + * https://ashvardanian.com/posts/discount-on-keyword-arguments-in-python/ */ #include @@ -19,7 +77,7 @@ #include typedef struct TensorArgument { - char* start; + char *start; size_t dimensions; size_t count; size_t stride; @@ -36,8 +94,8 @@ typedef struct DistancesTensor { simsimd_distance_t start[]; // Variable length data aligned to 64-bit scalars } DistancesTensor; -static int DistancesTensor_getbuffer(PyObject* export_from, Py_buffer* view, int flags); -static void DistancesTensor_releasebuffer(PyObject* export_from, Py_buffer* view); +static int DistancesTensor_getbuffer(PyObject *export_from, Py_buffer *view, int flags); +static void DistancesTensor_releasebuffer(PyObject *export_from, Py_buffer *view); static PyBufferProcs DistancesTensor_as_buffer = { .bf_getbuffer = DistancesTensor_getbuffer, @@ -59,7 +117,7 @@ simsimd_capability_t static_capabilities = simsimd_cap_serial_k; /// @brief Helper method to check for string equality. /// @return 1 if the strings are equal, 0 otherwise. -int same_string(char const* a, char const* b) { return strcmp(a, b) == 0; } +int same_string(char const *a, char const *b) { return strcmp(a, b) == 0; } /// @brief Helper method to check if a logical datatype is complex and should be represented as two scalars. /// @return 1 if the datatype is complex, 0 otherwise. @@ -71,7 +129,7 @@ int is_complex(simsimd_datatype_t datatype) { /// @brief Converts a numpy datatype string to a logical datatype, normalizing the format. /// @return `simsimd_datatype_unknown_k` if the datatype is not supported, otherwise the logical datatype. /// @see https://docs.python.org/3/library/struct.html#format-characters -simsimd_datatype_t numpy_string_to_datatype(char const* name) { +simsimd_datatype_t numpy_string_to_datatype(char const *name) { // Floating-point numbers: if (same_string(name, "f") || same_string(name, "format); @@ -430,41 +491,41 @@ int parse_tensor(PyObject* tensor, Py_buffer* buffer, TensorArgument* parsed) { if (buffer->strides[0] > buffer->itemsize) { PyErr_SetString(PyExc_ValueError, "Input vectors must be contiguous, check with `X.__array_interface__`"); PyBuffer_Release(buffer); - return -1; + return 0; } parsed->dimensions = buffer->shape[0]; parsed->count = 1; parsed->stride = 0; - } else if (buffer->ndim == 2) { + } + else if (buffer->ndim == 2) { if (buffer->strides[1] > buffer->itemsize) { PyErr_SetString(PyExc_ValueError, "Input vectors must be contiguous, check with `X.__array_interface__`"); PyBuffer_Release(buffer); - return -1; + return 0; } parsed->dimensions = buffer->shape[1]; parsed->count = buffer->shape[0]; parsed->stride = buffer->strides[0]; - } else { + } + else { PyErr_SetString(PyExc_ValueError, "Input tensors must be 1D or 2D"); PyBuffer_Release(buffer); - return -1; + return 0; } // We handle complex numbers differently - if (is_complex(parsed->datatype)) { - parsed->dimensions *= 2; - } + if (is_complex(parsed->datatype)) { parsed->dimensions *= 2; } - return 0; + return 1; } -static int DistancesTensor_getbuffer(PyObject* export_from, Py_buffer* view, int flags) { - DistancesTensor* tensor = (DistancesTensor*)export_from; +static int DistancesTensor_getbuffer(PyObject *export_from, Py_buffer *view, int flags) { + DistancesTensor *tensor = (DistancesTensor *)export_from; size_t const total_items = tensor->shape[0] * tensor->shape[1]; size_t const item_size = bytes_per_datatype(tensor->datatype); view->buf = &tensor->start[0]; - view->obj = (PyObject*)tensor; + view->obj = (PyObject *)tensor; view->len = item_size * total_items; view->readonly = 0; view->itemsize = (Py_ssize_t)item_size; @@ -479,459 +540,638 @@ static int DistancesTensor_getbuffer(PyObject* export_from, Py_buffer* view, int return 0; } -static void DistancesTensor_releasebuffer(PyObject* export_from, Py_buffer* view) { +static void DistancesTensor_releasebuffer(PyObject *export_from, Py_buffer *view) { // This function MUST NOT decrement view->obj, since that is done automatically in PyBuffer_Release(). // https://docs.python.org/3/c-api/typeobj.html#c.PyBufferProcs.bf_releasebuffer } -static PyObject* implement_dense_metric(simsimd_metric_kind_t metric_kind, PyObject* const* args, Py_ssize_t nargs) { - // Function now accepts up to 3 arguments, the third being optional - if (nargs < 2 || nargs > 3) { - PyErr_SetString(PyExc_TypeError, "Function expects 2 or 3 arguments"); +static PyObject *implement_dense_metric( // + simsimd_metric_kind_t metric_kind, // + PyObject *const *args, Py_ssize_t const positional_args_count, PyObject *args_names_tuple) { + + PyObject *return_obj = NULL; + + // This function accepts up to 5 arguments: + PyObject *a_obj = NULL; // Required object, positional-only + PyObject *b_obj = NULL; // Required object, positional-only + PyObject *dtype_obj = NULL; // Optional object, "dtype" keyword or positional + PyObject *out_obj = NULL; // Optional object, "out" keyword-only + PyObject *out_dtype_obj = NULL; // Optional object, "out_dtype" keyword-only + + // Once parsed, the arguments will be stored in these variables: + char const *dtype_str = NULL, *out_dtype_str = NULL; + simsimd_datatype_t dtype = simsimd_datatype_unknown_k, out_dtype = simsimd_datatype_unknown_k; + Py_buffer a_buffer, b_buffer, out_buffer; + TensorArgument a_parsed, b_parsed, out_parsed; + memset(&a_buffer, 0, sizeof(Py_buffer)); + memset(&b_buffer, 0, sizeof(Py_buffer)); + memset(&out_buffer, 0, sizeof(Py_buffer)); + + // Parse the arguments + Py_ssize_t const args_names_count = args_names_tuple ? PyTuple_Size(args_names_tuple) : 0; + Py_ssize_t const args_count = positional_args_count + args_names_count; + if (args_count < 2 || args_count > 5) { + PyErr_Format(PyExc_TypeError, "Function expects 2-5 arguments, got %zd", args_count); + return NULL; + } + if (positional_args_count > 3) { + PyErr_Format(PyExc_TypeError, "Only first 3 arguments can be positional, received %zd", positional_args_count); return NULL; } - PyObject* output = NULL; - PyObject* input_tensor_a = args[0]; - PyObject* input_tensor_b = args[1]; - PyObject* input_datatype_desc = nargs == 3 ? args[2] : NULL; + // Positional-only arguments (first and second matrix) + a_obj = args[0]; + b_obj = args[1]; + + // Positional or keyword arguments (dtype) + if (positional_args_count == 3) dtype_obj = args[2]; + + // The rest of the arguments must be checked in the keyword dictionary: + for (Py_ssize_t args_names_tuple_progress = 0, args_progress = positional_args_count; + args_names_tuple_progress < args_names_count; ++args_progress, ++args_names_tuple_progress) { + PyObject *const key = PyTuple_GetItem(args_names_tuple, args_names_tuple_progress); + PyObject *const value = args[args_progress]; + if (PyUnicode_CompareWithASCIIString(key, "dtype") == 0 && !dtype_obj) { dtype_obj = value; } + else if (PyUnicode_CompareWithASCIIString(key, "out") == 0 && !out_obj) { out_obj = value; } + else if (PyUnicode_CompareWithASCIIString(key, "out_dtype") == 0 && !out_dtype_obj) { out_dtype_obj = value; } + else { + PyErr_Format(PyExc_TypeError, "Got unexpected keyword argument: %S", key); + return NULL; + } + } + + // Convert `dtype_obj` to `dtype_str` and to `dtype` + if (dtype_obj) { + dtype_str = PyUnicode_AsUTF8(dtype_obj); + if (!dtype_str && PyErr_Occurred()) { + PyErr_SetString(PyExc_TypeError, "Expected 'dtype' to be a string"); + return NULL; + } + dtype = python_string_to_datatype(dtype_str); + if (dtype == simsimd_datatype_unknown_k) { + PyErr_SetString(PyExc_ValueError, "Unsupported 'dtype'"); + return NULL; + } + } - Py_buffer buffer_a, buffer_b; - TensorArgument parsed_a, parsed_b; - if (parse_tensor(input_tensor_a, &buffer_a, &parsed_a) != 0 || - parse_tensor(input_tensor_b, &buffer_b, &parsed_b) != 0) { - return NULL; // Error already set by parse_tensor + // Convert `out_dtype_obj` to `out_dtype_str` and to `out_dtype` + if (out_dtype_obj) { + out_dtype_str = PyUnicode_AsUTF8(out_dtype_obj); + if (!out_dtype_str && PyErr_Occurred()) { + PyErr_SetString(PyExc_TypeError, "Expected 'out_dtype' to be a string"); + return NULL; + } + out_dtype = python_string_to_datatype(out_dtype_str); + if (out_dtype == simsimd_datatype_unknown_k) { + PyErr_SetString(PyExc_ValueError, "Unsupported 'out_dtype'"); + return NULL; + } } + // Convert `a_obj` to `a_buffer` and to `a_parsed`. Same for `b_obj` and `out_obj`. + if (!parse_tensor(a_obj, &a_buffer, &a_parsed) || !parse_tensor(b_obj, &b_buffer, &b_parsed)) return NULL; + if (out_obj && !parse_tensor(out_obj, &out_buffer, &out_parsed)) return NULL; + // Check dimensions - if (parsed_a.dimensions != parsed_b.dimensions) { + if (a_parsed.dimensions != b_parsed.dimensions) { PyErr_SetString(PyExc_ValueError, "Vector dimensions don't match"); goto cleanup; } - if (parsed_a.count == 0 || parsed_b.count == 0) { + if (a_parsed.count == 0 || b_parsed.count == 0) { PyErr_SetString(PyExc_ValueError, "Collections can't be empty"); goto cleanup; } - if (parsed_a.count > 1 && parsed_b.count > 1 && parsed_a.count != parsed_b.count) { + if (a_parsed.count > 1 && b_parsed.count > 1 && a_parsed.count != b_parsed.count) { PyErr_SetString(PyExc_ValueError, "Collections must have the same number of elements or just one element"); goto cleanup; } // Check data types - if (parsed_a.datatype != parsed_b.datatype && parsed_a.datatype != simsimd_datatype_unknown_k && - parsed_b.datatype != simsimd_datatype_unknown_k) { + if (a_parsed.datatype != b_parsed.datatype || // + a_parsed.datatype == simsimd_datatype_unknown_k || b_parsed.datatype == simsimd_datatype_unknown_k) { PyErr_SetString(PyExc_TypeError, "Input tensors must have matching datatypes, check with `X.__array_interface__`"); goto cleanup; } + if (dtype == simsimd_datatype_unknown_k) dtype = a_parsed.datatype; + + // Inference order for the output type: + // 1. `out_dtype` named argument, if defined + // 2. `out.dtype` attribute, if `out` is passed + // 3. double precision float (or its complex variant) + if (out_dtype == simsimd_datatype_unknown_k) { + if (out_obj) { out_dtype = out_parsed.datatype; } + else { out_dtype = is_complex(dtype) ? simsimd_datatype_f64c_k : simsimd_datatype_f64_k; } + } - // Process the third argument, `input_datatype_desc`, if provided - simsimd_datatype_t input_datatype = parsed_a.datatype; - char const* input_datatype_str = ""; - if (input_datatype_desc != NULL) { - // Ensure it is a string (or convert it to one if possible) - if (!PyUnicode_Check(input_datatype_desc)) { - PyErr_SetString(PyExc_TypeError, "third argument must be a string describing the value type"); - goto cleanup; - } - // Convert Python string to C string - input_datatype_str = PyUnicode_AsUTF8(input_datatype_desc); - if (!input_datatype_str) { - PyErr_SetString(PyExc_ValueError, "Could not convert value type description to string"); + // Make sure the return datatype is complex if the input datatype is complex, and the same for real numbers + if (out_dtype != simsimd_datatype_unknown_k) { + if (is_complex(dtype) != is_complex(out_dtype)) { + PyErr_SetString( + PyExc_ValueError, + "If the input datatype is complex, the return datatype must be complex, and same for real."); goto cleanup; } - input_datatype = python_string_to_datatype(input_datatype_str); - if (input_datatype == simsimd_datatype_unknown_k) { - PyErr_Format(PyExc_ValueError, "Unsupported datatype '%s'", input_datatype_str); + } + + // Check if the downcasting to provided datatype is supported + { + char returned_buffer_example[8]; + if (!cast_distance(0, out_dtype, &returned_buffer_example, 0)) { + PyErr_SetString(PyExc_ValueError, "Exporting to the provided datatype is not supported"); goto cleanup; } } + // Look up the metric and the capability simsimd_metric_punned_t metric = NULL; simsimd_capability_t capability = simsimd_cap_serial_k; - simsimd_find_metric_punned(metric_kind, input_datatype, static_capabilities, simsimd_cap_any_k, &metric, - &capability); + simsimd_find_metric_punned(metric_kind, dtype, static_capabilities, simsimd_cap_any_k, &metric, &capability); if (!metric) { - PyErr_Format(PyExc_LookupError, - "Unsupported metric '%c' and datatype combination across vectors ('%s'/'%s' and '%s'/'%s') and " - "`dtype` override ('%s'/'%s')", - metric_kind, // - buffer_a.format, datatype_to_python_string(parsed_a.datatype), // - buffer_b.format, datatype_to_python_string(parsed_b.datatype), // - input_datatype_str, datatype_to_python_string(input_datatype)); + PyErr_Format( // + PyExc_LookupError, + "Unsupported metric '%c' and datatype combination across vectors ('%s'/'%s' and '%s'/'%s') and " + "`dtype` override ('%s'/'%s')", + metric_kind, // + a_buffer.format ? a_buffer.format : "nil", datatype_to_python_string(a_parsed.datatype), // + b_buffer.format ? b_buffer.format : "nil", datatype_to_python_string(b_parsed.datatype), // + dtype_str ? dtype_str : "nil", datatype_to_python_string(dtype)); goto cleanup; } // If the distance is computed between two vectors, rather than matrices, return a scalar - int datatype_is_complex = is_complex(input_datatype); - simsimd_datatype_t return_datatype = datatype_is_complex ? simsimd_datatype_f64c_k : simsimd_datatype_f64_k; - if (parsed_a.rank == 1 && parsed_b.rank == 1) { + int const dtype_is_complex = is_complex(dtype); + if (a_parsed.rank == 1 && b_parsed.rank == 1) { // For complex numbers we are going to use `PyComplex_FromDoubles`. - if (datatype_is_complex) { + if (dtype_is_complex) { simsimd_distance_t distances[2]; - metric(parsed_a.start, parsed_b.start, parsed_a.dimensions, distances); - output = PyComplex_FromDoubles(distances[0], distances[1]); - } else { + metric(a_parsed.start, b_parsed.start, a_parsed.dimensions, distances); + return_obj = PyComplex_FromDoubles(distances[0], distances[1]); + } + else { simsimd_distance_t distance; - metric(parsed_a.start, parsed_b.start, parsed_a.dimensions, &distance); - output = PyFloat_FromDouble(distance); + metric(a_parsed.start, b_parsed.start, a_parsed.dimensions, &distance); + return_obj = PyFloat_FromDouble(distance); } - } else { - - // In some batch requests we may be computing the distance from multiple vectors to one, - // so the stride must be set to zero avoid illegal memory access - if (parsed_a.count == 1) - parsed_a.stride = 0; - if (parsed_b.count == 1) - parsed_b.stride = 0; - - // We take the maximum of the two counts, because if only one entry is present in one of the arrays, - // all distances will be computed against that single entry. - size_t const count_pairs = parsed_a.count > parsed_b.count ? parsed_a.count : parsed_b.count; - size_t const components_per_pair = datatype_is_complex ? 2 : 1; - size_t const count_components = count_pairs * components_per_pair; - DistancesTensor* distances_obj = PyObject_NewVar(DistancesTensor, &DistancesTensorType, - count_components * bytes_per_datatype(return_datatype)); + goto cleanup; + } + + // In some batch requests we may be computing the distance from multiple vectors to one, + // so the stride must be set to zero avoid illegal memory access + if (a_parsed.count == 1) a_parsed.stride = 0; + if (b_parsed.count == 1) b_parsed.stride = 0; + + // We take the maximum of the two counts, because if only one entry is present in one of the arrays, + // all distances will be computed against that single entry. + size_t const count_pairs = a_parsed.count > b_parsed.count ? a_parsed.count : b_parsed.count; + size_t const components_per_pair = dtype_is_complex ? 2 : 1; + size_t const count_components = count_pairs * components_per_pair; + char *distances_start = NULL; + size_t distances_stride_bytes = 0; + + // Allocate the output matrix if it wasn't provided + if (!out_obj) { + DistancesTensor *distances_obj = + PyObject_NewVar(DistancesTensor, &DistancesTensorType, count_components * bytes_per_datatype(out_dtype)); if (!distances_obj) { PyErr_NoMemory(); goto cleanup; } // Initialize the object - distances_obj->datatype = return_datatype; + distances_obj->datatype = out_dtype; distances_obj->dimensions = 1; distances_obj->shape[0] = count_pairs; distances_obj->shape[1] = 1; - distances_obj->strides[0] = bytes_per_datatype(return_datatype); + distances_obj->strides[0] = bytes_per_datatype(out_dtype); distances_obj->strides[1] = 0; - output = (PyObject*)distances_obj; - - // Compute the distances - simsimd_distance_t* distances = (simsimd_distance_t*)&distances_obj->start[0]; - for (size_t i = 0; i < count_pairs; ++i) { - simsimd_distance_t result[2]; - metric( // - parsed_a.start + i * parsed_a.stride, // - parsed_b.start + i * parsed_b.stride, // - parsed_a.dimensions, // - (simsimd_distance_t*)&result); - - // Export out: - if (!cast_distance(result[0], return_datatype, distances, i * components_per_pair)) { - PyObject_Del(distances_obj); - output = NULL; - PyErr_SetString(PyExc_ValueError, "Unsupported datatype"); - goto cleanup; - } - if (datatype_is_complex) - cast_distance(result[1], return_datatype, distances, i * components_per_pair + 1); + return_obj = (PyObject *)distances_obj; + distances_start = (char *)&distances_obj->start[0]; + distances_stride_bytes = distances_obj->strides[0]; + } + else { + if (bytes_per_datatype(out_parsed.datatype) != bytes_per_datatype(out_dtype)) { + PyErr_Format( // + PyExc_LookupError, + "Output tensor scalar type must be compatible with the output type ('%s' and '%s'/'%s')", + datatype_to_python_string(out_dtype), out_buffer.format ? out_buffer.format : "nil", + datatype_to_python_string(out_parsed.datatype)); + goto cleanup; } + distances_start = (char *)&out_parsed.start[0]; + distances_stride_bytes = out_buffer.strides[0]; + //? Logic suggests to return `None` in in-place mode... + //? SciPy decided differently. + return_obj = Py_None; + } + + // Compute the distances + for (size_t i = 0; i < count_pairs; ++i) { + simsimd_distance_t result[2]; + metric( // + a_parsed.start + i * a_parsed.stride, // + b_parsed.start + i * b_parsed.stride, // + a_parsed.dimensions, // + (simsimd_distance_t *)&result); + + // Export out: + cast_distance(result[0], out_dtype, distances_start + i * distances_stride_bytes, 0); + if (dtype_is_complex) cast_distance(result[1], out_dtype, distances_start + i * distances_stride_bytes, 1); } cleanup: - PyBuffer_Release(&buffer_a); - PyBuffer_Release(&buffer_b); - return output; + PyBuffer_Release(&a_buffer); + PyBuffer_Release(&b_buffer); + PyBuffer_Release(&out_buffer); + return return_obj; } -static PyObject* implement_curved_metric(simsimd_metric_kind_t metric_kind, PyObject* const* args, Py_ssize_t nargs) { - // Function now accepts up to 4 arguments, the fourth being optional - if (nargs < 3 || nargs > 4) { - PyErr_SetString(PyExc_TypeError, "Function expects 4 or 5 arguments"); +static PyObject *implement_curved_metric( // + simsimd_metric_kind_t metric_kind, // + PyObject *const *args, Py_ssize_t const positional_args_count, PyObject *args_names_tuple) { + + PyObject *return_obj = NULL; + + // This function accepts up to 6 arguments: + PyObject *a_obj = NULL; // Required object, positional-only + PyObject *b_obj = NULL; // Required object, positional-only + PyObject *c_obj = NULL; // Required object, positional-only + PyObject *dtype_obj = NULL; // Optional object, "dtype" keyword or positional + + // Once parsed, the arguments will be stored in these variables: + char const *dtype_str = NULL; + simsimd_datatype_t dtype = simsimd_datatype_unknown_k; + Py_buffer a_buffer, b_buffer, c_buffer; + TensorArgument a_parsed, b_parsed, c_parsed; + memset(&a_buffer, 0, sizeof(Py_buffer)); + memset(&b_buffer, 0, sizeof(Py_buffer)); + memset(&c_buffer, 0, sizeof(Py_buffer)); + + // Parse the arguments + Py_ssize_t const args_names_count = args_names_tuple ? PyTuple_Size(args_names_tuple) : 0; + Py_ssize_t const args_count = positional_args_count + args_names_count; + if (args_count < 3 || args_count > 6) { + PyErr_Format(PyExc_TypeError, "Function expects 2-6 arguments, got %zd", args_count); + return NULL; + } + if (positional_args_count > 4) { + PyErr_Format(PyExc_TypeError, "Only first 4 arguments can be positional, received %zd", positional_args_count); return NULL; } - PyObject* output = NULL; - PyObject* input_tensor_a = args[0]; - PyObject* input_tensor_b = args[1]; - PyObject* input_tensor_c = args[2]; - PyObject* input_datatype_desc = nargs == 4 ? args[3] : NULL; + // Positional-only arguments (first, second, and third matrix) + a_obj = args[0]; + b_obj = args[1]; + c_obj = args[2]; + + // Positional or keyword arguments (dtype) + if (positional_args_count == 4) dtype_obj = args[3]; + + // The rest of the arguments must be checked in the keyword dictionary: + for (Py_ssize_t args_names_tuple_progress = 0, args_progress = positional_args_count; + args_names_tuple_progress < args_names_count; ++args_progress, ++args_names_tuple_progress) { + PyObject *const key = PyTuple_GetItem(args_names_tuple, args_names_tuple_progress); + PyObject *const value = args[args_progress]; + if (PyUnicode_CompareWithASCIIString(key, "dtype") == 0 && !dtype_obj) { dtype_obj = value; } + else { + PyErr_Format(PyExc_TypeError, "Got unexpected keyword argument: %S", key); + return NULL; + } + } - Py_buffer buffer_a, buffer_b, buffer_c; - TensorArgument parsed_a, parsed_b, parsed_c; - if (parse_tensor(input_tensor_a, &buffer_a, &parsed_a) != 0 || - parse_tensor(input_tensor_b, &buffer_b, &parsed_b) != 0 || - parse_tensor(input_tensor_c, &buffer_c, &parsed_c) != 0) { - return NULL; // Error already set by parse_tensor + // Convert `dtype_obj` to `dtype_str` and to `dtype` + if (dtype_obj) { + dtype_str = PyUnicode_AsUTF8(dtype_obj); + if (!dtype_str && PyErr_Occurred()) { + PyErr_SetString(PyExc_TypeError, "Expected 'dtype' to be a string"); + return NULL; + } + dtype = python_string_to_datatype(dtype_str); + if (dtype == simsimd_datatype_unknown_k) { + PyErr_SetString(PyExc_ValueError, "Unsupported 'dtype'"); + return NULL; + } } + // Convert `a_obj` to `a_buffer` and to `a_parsed`. Same for `b_obj` and `out_obj`. + if (!parse_tensor(a_obj, &a_buffer, &a_parsed) || !parse_tensor(b_obj, &b_buffer, &b_parsed) || + !parse_tensor(c_obj, &c_buffer, &c_parsed)) + return NULL; + // Check dimensions - if (parsed_a.rank != 1 || parsed_b.rank != 1) { + if (a_parsed.rank != 1 || b_parsed.rank != 1) { PyErr_SetString(PyExc_ValueError, "First and second argument must be vectors"); goto cleanup; } - if (parsed_c.rank != 2) { + if (c_parsed.rank != 2) { PyErr_SetString(PyExc_ValueError, "Third argument must be a matrix (rank-2 tensor)"); goto cleanup; } - if (parsed_a.count == 0 || parsed_b.count == 0) { + if (a_parsed.count == 0 || b_parsed.count == 0) { PyErr_SetString(PyExc_ValueError, "Collections can't be empty"); goto cleanup; } - if (parsed_a.count > 1 && parsed_b.count > 1 && parsed_a.count != parsed_b.count) { + if (a_parsed.count > 1 && b_parsed.count > 1 && a_parsed.count != b_parsed.count) { PyErr_SetString(PyExc_ValueError, "Collections must have the same number of elements or just one element"); goto cleanup; } // Check data types - if (parsed_a.datatype != parsed_b.datatype && parsed_a.datatype != simsimd_datatype_unknown_k && - parsed_b.datatype != simsimd_datatype_unknown_k) { + if (a_parsed.datatype != b_parsed.datatype || a_parsed.datatype != c_parsed.datatype || + a_parsed.datatype == simsimd_datatype_unknown_k || b_parsed.datatype == simsimd_datatype_unknown_k || + c_parsed.datatype == simsimd_datatype_unknown_k) { PyErr_SetString(PyExc_TypeError, "Input tensors must have matching datatypes, check with `X.__array_interface__`"); goto cleanup; } + if (dtype == simsimd_datatype_unknown_k) dtype = a_parsed.datatype; - // Process the third argument, `input_datatype_desc`, if provided - simsimd_datatype_t input_datatype = parsed_a.datatype; - char const* input_datatype_str = ""; - if (input_datatype_desc != NULL) { - // Ensure it is a string (or convert it to one if possible) - if (!PyUnicode_Check(input_datatype_desc)) { - PyErr_SetString(PyExc_TypeError, "Third argument must be a string describing the value type"); - goto cleanup; - } - // Convert Python string to C string - input_datatype_str = PyUnicode_AsUTF8(input_datatype_desc); - if (!input_datatype_str) { - PyErr_SetString(PyExc_ValueError, "Could not convert value type description to string"); - goto cleanup; - } - input_datatype = python_string_to_datatype(input_datatype_str); - if (input_datatype == simsimd_datatype_unknown_k) { - PyErr_Format(PyExc_ValueError, "Unsupported datatype '%s'", input_datatype_str); - goto cleanup; - } - } - + // Look up the metric and the capability simsimd_metric_curved_punned_t metric = NULL; simsimd_capability_t capability = simsimd_cap_serial_k; - simsimd_find_metric_punned(metric_kind, input_datatype, static_capabilities, simsimd_cap_any_k, - (simsimd_metric_punned_t*)&metric, &capability); + simsimd_find_metric_punned(metric_kind, dtype, static_capabilities, simsimd_cap_any_k, + (simsimd_metric_punned_t *)&metric, &capability); if (!metric) { - PyErr_Format(PyExc_LookupError, - "Unsupported metric '%c' and datatype combination across vectors ('%s'/'%s' and '%s'/'%s'), " - "tensor ('%s'/'%s'), and `dtype` override ('%s'/'%s')", - metric_kind, // - buffer_a.format, datatype_to_python_string(parsed_a.datatype), // - buffer_b.format, datatype_to_python_string(parsed_b.datatype), // - buffer_c.format, datatype_to_python_string(parsed_c.datatype), // - input_datatype_str, datatype_to_python_string(input_datatype)); + PyErr_Format( // + PyExc_LookupError, + "Unsupported metric '%c' and datatype combination across vectors ('%s'/'%s' and '%s'/'%s'), " + "tensor ('%s'/'%s'), and `dtype` override ('%s'/'%s')", + metric_kind, // + a_buffer.format ? a_buffer.format : "nil", datatype_to_python_string(a_parsed.datatype), // + b_buffer.format ? b_buffer.format : "nil", datatype_to_python_string(b_parsed.datatype), // + c_buffer.format ? c_buffer.format : "nil", datatype_to_python_string(c_parsed.datatype), // + dtype_str ? dtype_str : "nil", datatype_to_python_string(dtype)); goto cleanup; } simsimd_distance_t distance; - metric(parsed_a.start, parsed_b.start, parsed_c.start, parsed_a.dimensions, &distance); - output = PyFloat_FromDouble(distance); + metric(a_parsed.start, b_parsed.start, c_parsed.start, a_parsed.dimensions, &distance); + return_obj = PyFloat_FromDouble(distance); cleanup: - PyBuffer_Release(&buffer_a); - PyBuffer_Release(&buffer_b); - PyBuffer_Release(&buffer_c); - return output; + PyBuffer_Release(&a_buffer); + PyBuffer_Release(&b_buffer); + PyBuffer_Release(&c_buffer); + return return_obj; } -static PyObject* implement_sparse_metric(simsimd_metric_kind_t metric_kind, PyObject* const* args, Py_ssize_t nargs) { +static PyObject *implement_sparse_metric( // + simsimd_metric_kind_t metric_kind, // + PyObject *const *args, Py_ssize_t nargs) { if (nargs != 2) { PyErr_SetString(PyExc_TypeError, "Function expects only 2 arguments"); return NULL; } - PyObject* output = NULL; - PyObject* input_tensor_a = args[0]; - PyObject* input_tensor_b = args[1]; + PyObject *return_obj = NULL; + PyObject *a_obj = args[0]; + PyObject *b_obj = args[1]; - Py_buffer buffer_a, buffer_b; - TensorArgument parsed_a, parsed_b; - if (parse_tensor(input_tensor_a, &buffer_a, &parsed_a) != 0 || - parse_tensor(input_tensor_b, &buffer_b, &parsed_b) != 0) { - return NULL; // Error already set by parse_tensor - } + Py_buffer a_buffer, b_buffer; + TensorArgument a_parsed, b_parsed; + if (!parse_tensor(a_obj, &a_buffer, &a_parsed) || !parse_tensor(b_obj, &b_buffer, &b_parsed)) return NULL; // Check dimensions - if (parsed_a.rank != 1 || parsed_b.rank != 1) { + if (a_parsed.rank != 1 || b_parsed.rank != 1) { PyErr_SetString(PyExc_ValueError, "First and second argument must be vectors"); goto cleanup; } // Check data types - if (parsed_a.datatype != parsed_b.datatype && parsed_a.datatype != simsimd_datatype_unknown_k && - parsed_b.datatype != simsimd_datatype_unknown_k) { + if (a_parsed.datatype != b_parsed.datatype && a_parsed.datatype != simsimd_datatype_unknown_k && + b_parsed.datatype != simsimd_datatype_unknown_k) { PyErr_SetString(PyExc_TypeError, "Input tensors must have matching datatypes, check with `X.__array_interface__`"); goto cleanup; } - simsimd_datatype_t input_datatype = parsed_a.datatype; + simsimd_datatype_t dtype = a_parsed.datatype; simsimd_metric_sparse_punned_t metric = NULL; simsimd_capability_t capability = simsimd_cap_serial_k; - simsimd_find_metric_punned(metric_kind, input_datatype, static_capabilities, simsimd_cap_any_k, - (simsimd_metric_punned_t*)&metric, &capability); + simsimd_find_metric_punned(metric_kind, dtype, static_capabilities, simsimd_cap_any_k, + (simsimd_metric_punned_t *)&metric, &capability); if (!metric) { - PyErr_Format(PyExc_LookupError, "Unsupported metric '%c' and datatype combination ('%s'/'%s' and '%s'/'%s')", - metric_kind, // - buffer_a.format, datatype_to_python_string(parsed_a.datatype), // - buffer_b.format, datatype_to_python_string(parsed_b.datatype)); + PyErr_Format( // + PyExc_LookupError, "Unsupported metric '%c' and datatype combination ('%s'/'%s' and '%s'/'%s')", + metric_kind, // + a_buffer.format ? a_buffer.format : "nil", datatype_to_python_string(a_parsed.datatype), // + b_buffer.format ? b_buffer.format : "nil", datatype_to_python_string(b_parsed.datatype)); goto cleanup; } simsimd_distance_t distance; - metric(parsed_a.start, parsed_b.start, parsed_a.dimensions, parsed_b.dimensions, &distance); - output = PyFloat_FromDouble(distance); + metric(a_parsed.start, b_parsed.start, a_parsed.dimensions, b_parsed.dimensions, &distance); + return_obj = PyFloat_FromDouble(distance); cleanup: - PyBuffer_Release(&buffer_a); - PyBuffer_Release(&buffer_b); - return output; + PyBuffer_Release(&a_buffer); + PyBuffer_Release(&b_buffer); + return return_obj; } -static PyObject* impl_cdist( // - PyObject* input_tensor_a, PyObject* input_tensor_b, // - simsimd_metric_kind_t metric_kind, size_t threads, simsimd_datatype_t input_datatype, - simsimd_datatype_t return_datatype) { +static PyObject *implement_cdist( // + PyObject *a_obj, PyObject *b_obj, PyObject *out_obj, // + simsimd_metric_kind_t metric_kind, size_t threads, // + simsimd_datatype_t dtype, simsimd_datatype_t out_dtype) { - PyObject* output = NULL; - Py_buffer buffer_a, buffer_b; - TensorArgument parsed_a, parsed_b; - if (parse_tensor(input_tensor_a, &buffer_a, &parsed_a) != 0 || - parse_tensor(input_tensor_b, &buffer_b, &parsed_b) != 0) { - return NULL; // Error already set by parse_tensor - } + PyObject *return_obj = NULL; + + Py_buffer a_buffer, b_buffer, out_buffer; + TensorArgument a_parsed, b_parsed, out_parsed; + memset(&a_buffer, 0, sizeof(Py_buffer)); + memset(&b_buffer, 0, sizeof(Py_buffer)); + memset(&out_buffer, 0, sizeof(Py_buffer)); + + // Error will be set by `parse_tensor` if the input is invalid + if (!parse_tensor(a_obj, &a_buffer, &a_parsed) || !parse_tensor(b_obj, &b_buffer, &b_parsed)) return NULL; + if (out_obj && !parse_tensor(out_obj, &out_buffer, &out_parsed)) return NULL; // Check dimensions - if (parsed_a.dimensions != parsed_b.dimensions) { - PyErr_Format(PyExc_ValueError, "Vector dimensions don't match (%d != %d)", parsed_a.dimensions, - parsed_b.dimensions); + if (a_parsed.dimensions != b_parsed.dimensions) { + PyErr_Format(PyExc_ValueError, "Vector dimensions don't match (%z != %z)", a_parsed.dimensions, + b_parsed.dimensions); goto cleanup; } - if (parsed_a.count == 0 || parsed_b.count == 0) { + if (a_parsed.count == 0 || b_parsed.count == 0) { PyErr_SetString(PyExc_ValueError, "Collections can't be empty"); goto cleanup; } + if (out_obj && + (out_parsed.rank != 2 || out_buffer.shape[0] != a_parsed.count || out_buffer.shape[1] != b_parsed.count)) { + PyErr_Format(PyExc_ValueError, "Output tensor must have shape (%z, %z)", a_parsed.count, b_parsed.count); + goto cleanup; + } // Check data types - if (parsed_a.datatype != parsed_b.datatype && parsed_a.datatype != simsimd_datatype_unknown_k && - parsed_b.datatype != simsimd_datatype_unknown_k) { + if (a_parsed.datatype != b_parsed.datatype || // + a_parsed.datatype == simsimd_datatype_unknown_k || b_parsed.datatype == simsimd_datatype_unknown_k) { PyErr_SetString(PyExc_TypeError, "Input tensors must have matching datatypes, check with `X.__array_interface__`"); goto cleanup; } - if (input_datatype == simsimd_datatype_unknown_k) - input_datatype = parsed_a.datatype; - - simsimd_metric_punned_t metric = NULL; - simsimd_capability_t capability = simsimd_cap_serial_k; - simsimd_find_metric_punned(metric_kind, input_datatype, static_capabilities, simsimd_cap_any_k, &metric, - &capability); - if (!metric) { - PyErr_Format(PyExc_LookupError, "Unsupported metric '%c' and datatype combination ('%s'/'%s' and '%s'/'%s')", - metric_kind, // - buffer_a.format, datatype_to_python_string(parsed_a.datatype), // - buffer_b.format, datatype_to_python_string(parsed_b.datatype)); - goto cleanup; + if (dtype == simsimd_datatype_unknown_k) dtype = a_parsed.datatype; + + // Inference order for the output type: + // 1. `out_dtype` named argument, if defined + // 2. `out.dtype` attribute, if `out` is passed + // 3. double precision float (or its complex variant) + if (out_dtype == simsimd_datatype_unknown_k) { + if (out_obj) { out_dtype = out_parsed.datatype; } + else { out_dtype = is_complex(dtype) ? simsimd_datatype_f64c_k : simsimd_datatype_f64_k; } } - // Make sure the return datatype is complex if the input datatype is complex, - // and the same for real numbers - if (return_datatype != simsimd_datatype_unknown_k) { - if (is_complex(input_datatype) != is_complex(return_datatype)) { + // Make sure the return datatype is complex if the input datatype is complex, and the same for real numbers + if (out_dtype != simsimd_datatype_unknown_k) { + if (is_complex(dtype) != is_complex(out_dtype)) { PyErr_SetString( PyExc_ValueError, "If the input datatype is complex, the return datatype must be complex, and same for real."); goto cleanup; } - } else { - return_datatype = is_complex(input_datatype) ? simsimd_datatype_f64c_k : simsimd_datatype_f64_k; + } + + // Check if the downcasting to provided datatype is supported + { + char returned_buffer_example[8]; + if (!cast_distance(0, out_dtype, &returned_buffer_example, 0)) { + PyErr_SetString(PyExc_ValueError, "Exporting to the provided datatype is not supported"); + goto cleanup; + } + } + + // Look up the metric and the capability + simsimd_metric_punned_t metric = NULL; + simsimd_capability_t capability = simsimd_cap_serial_k; + simsimd_find_metric_punned(metric_kind, dtype, static_capabilities, simsimd_cap_any_k, &metric, &capability); + if (!metric) { + PyErr_Format( // + PyExc_LookupError, "Unsupported metric '%c' and datatype combination ('%s'/'%s' and '%s'/'%s')", + metric_kind, // + a_buffer.format ? a_buffer.format : "nil", datatype_to_python_string(a_parsed.datatype), // + b_buffer.format ? b_buffer.format : "nil", datatype_to_python_string(b_parsed.datatype)); + goto cleanup; } // If the distance is computed between two vectors, rather than matrices, return a scalar - int datatype_is_complex = is_complex(input_datatype); - if (parsed_a.rank == 1 && parsed_b.rank == 1) { + int const dtype_is_complex = is_complex(dtype); + if (a_parsed.rank == 1 && b_parsed.rank == 1) { // For complex numbers we are going to use `PyComplex_FromDoubles`. - if (datatype_is_complex) { + if (dtype_is_complex) { simsimd_distance_t distances[2]; - metric(parsed_a.start, parsed_b.start, parsed_a.dimensions, distances); - output = PyComplex_FromDoubles(distances[0], distances[1]); - } else { + metric(a_parsed.start, b_parsed.start, a_parsed.dimensions, distances); + return_obj = PyComplex_FromDoubles(distances[0], distances[1]); + } + else { simsimd_distance_t distance; - metric(parsed_a.start, parsed_b.start, parsed_a.dimensions, &distance); - output = PyFloat_FromDouble(distance); + metric(a_parsed.start, b_parsed.start, a_parsed.dimensions, &distance); + return_obj = PyFloat_FromDouble(distance); } - } else { + goto cleanup; + } #ifdef __linux__ #ifdef _OPENMP - if (threads == 0) - threads = omp_get_num_procs(); - omp_set_num_threads(threads); + if (threads == 0) threads = omp_get_num_procs(); + omp_set_num_threads(threads); #endif #endif - // Check if the downcasting to provided datatype is supported - { + size_t const count_pairs = a_parsed.count * b_parsed.count; + size_t const components_per_pair = dtype_is_complex ? 2 : 1; + size_t const count_components = count_pairs * components_per_pair; + char *distances_start = NULL; + size_t distances_rows_stride_bytes = 0; + size_t distances_cols_stride_bytes = 0; - char returned_buffer_example[8]; - if (!cast_distance(0, return_datatype, &returned_buffer_example, 0)) { - PyErr_SetString(PyExc_ValueError, "Unsupported datatype"); - goto cleanup; - } - } + // Allocate the output matrix if it wasn't provided + if (!out_obj) { - size_t const count_pairs = parsed_a.count * parsed_b.count; - size_t const components_per_pair = datatype_is_complex ? 2 : 1; - size_t const count_components = count_pairs * components_per_pair; - DistancesTensor* distances_obj = PyObject_NewVar(DistancesTensor, &DistancesTensorType, - count_components * bytes_per_datatype(return_datatype)); + DistancesTensor *distances_obj = + PyObject_NewVar(DistancesTensor, &DistancesTensorType, count_components * bytes_per_datatype(out_dtype)); if (!distances_obj) { PyErr_NoMemory(); goto cleanup; } // Initialize the object - distances_obj->datatype = return_datatype; + distances_obj->datatype = out_dtype; distances_obj->dimensions = 2; - distances_obj->shape[0] = parsed_a.count; - distances_obj->shape[1] = parsed_b.count; - distances_obj->strides[0] = parsed_b.count * bytes_per_datatype(distances_obj->datatype); + distances_obj->shape[0] = a_parsed.count; + distances_obj->shape[1] = b_parsed.count; + distances_obj->strides[0] = b_parsed.count * bytes_per_datatype(distances_obj->datatype); distances_obj->strides[1] = bytes_per_datatype(distances_obj->datatype); - output = (PyObject*)distances_obj; + return_obj = (PyObject *)distances_obj; + distances_start = (char *)&distances_obj->start[0]; + distances_rows_stride_bytes = distances_obj->strides[0]; + distances_cols_stride_bytes = distances_obj->strides[1]; + } + else { + if (bytes_per_datatype(out_parsed.datatype) != bytes_per_datatype(out_dtype)) { + PyErr_Format( // + PyExc_LookupError, + "Output tensor scalar type must be compatible with the output type ('%s' and '%s'/'%s')", + datatype_to_python_string(out_dtype), out_buffer.format ? out_buffer.format : "nil", + datatype_to_python_string(out_parsed.datatype)); + goto cleanup; + } + distances_start = (char *)&out_parsed.start[0]; + distances_rows_stride_bytes = out_buffer.strides[0]; + distances_cols_stride_bytes = out_buffer.strides[1]; + //? Logic suggests to return `None` in in-place mode... + //? SciPy decided differently. + return_obj = Py_None; + } - // Compute the distances - simsimd_distance_t* distances = (simsimd_distance_t*)&distances_obj->start[0]; + // Assuming most of our kernels are symmetric, we only need to compute the upper triangle + // if we are computing all pairwise distances within the same set. + int const is_symmetric = kernel_is_commutative(metric_kind) && a_parsed.start == b_parsed.start && + a_parsed.stride == b_parsed.stride && a_parsed.count == b_parsed.count; #pragma omp parallel for collapse(2) - for (size_t i = 0; i < parsed_a.count; ++i) - for (size_t j = 0; j < parsed_b.count; ++j) { - simsimd_distance_t result[2]; - metric( // - parsed_a.start + i * parsed_a.stride, // - parsed_b.start + j * parsed_b.stride, // - parsed_a.dimensions, // - (simsimd_distance_t*)&result // - ); - // Export out: - cast_distance(result[0], return_datatype, distances, - i * components_per_pair * parsed_b.count + j * components_per_pair); - if (datatype_is_complex) - cast_distance(result[1], return_datatype, distances, - i * components_per_pair * parsed_b.count + j * components_per_pair + 1); - } - } + for (size_t i = 0; i < a_parsed.count; ++i) + for (size_t j = 0; j < b_parsed.count; ++j) { + if (is_symmetric && i > j) continue; + + // Export into an on-stack buffer and then copy to the output + simsimd_distance_t result[2]; + metric( // + a_parsed.start + i * a_parsed.stride, // + b_parsed.start + j * b_parsed.stride, // + a_parsed.dimensions, // + (simsimd_distance_t *)&result // + ); + + // Export into both the lower and upper triangle + if (1) + cast_distance(result[0], out_dtype, + distances_start + i * distances_rows_stride_bytes + j * distances_cols_stride_bytes, 0); + if (dtype_is_complex) + cast_distance(result[1], out_dtype, + distances_start + i * distances_rows_stride_bytes + j * distances_cols_stride_bytes, 1); + if (is_symmetric) + cast_distance(result[0], out_dtype, + distances_start + j * distances_rows_stride_bytes + i * distances_cols_stride_bytes, 0); + if (is_symmetric && dtype_is_complex) + cast_distance(result[1], out_dtype, + distances_start + j * distances_rows_stride_bytes + i * distances_cols_stride_bytes, 1); + } cleanup: - PyBuffer_Release(&buffer_a); - PyBuffer_Release(&buffer_b); - return output; + PyBuffer_Release(&a_buffer); + PyBuffer_Release(&b_buffer); + PyBuffer_Release(&out_buffer); + return return_obj; } -static PyObject* implement_pointer_access(simsimd_metric_kind_t metric_kind, PyObject* args) { - char const* type_name = PyUnicode_AsUTF8(PyTuple_GetItem(args, 0)); - if (!type_name) { - PyErr_SetString(PyExc_TypeError, "Invalid type name"); +static PyObject *implement_pointer_access(simsimd_metric_kind_t metric_kind, PyObject *dtype_obj) { + char const *dtype_name = PyUnicode_AsUTF8(dtype_obj); + if (!dtype_name) { + PyErr_SetString(PyExc_TypeError, "Data-type name must be a string"); return NULL; } - simsimd_datatype_t datatype = python_string_to_datatype(type_name); - if (!datatype) { // Check the actual variable here instead of type_name - PyErr_SetString(PyExc_TypeError, "Unsupported type"); + simsimd_datatype_t datatype = python_string_to_datatype(dtype_name); + if (!datatype) { // Check the actual variable here instead of dtype_name + PyErr_SetString(PyExc_ValueError, "Unsupported type"); return NULL; } @@ -946,96 +1186,83 @@ static PyObject* implement_pointer_access(simsimd_metric_kind_t metric_kind, PyO return PyLong_FromUnsignedLongLong((unsigned long long)metric); } -static PyObject* api_cdist(PyObject* self, PyObject* const* args, Py_ssize_t positional_args_count, PyObject* kwnames) { - - // This function accepts up to 6 arguments: - PyObject* input_tensor_a = NULL; // Required object, positional-only - PyObject* input_tensor_b = NULL; // Required object, positional-only - PyObject* metric_obj = NULL; // Optional string, positional or keyword - PyObject* threads_obj = NULL; // Optional integer, keyword-only - PyObject* dtype_obj = NULL; // Optional string, keyword-only - PyObject* out_dtype_obj = NULL; // Optional string, keyword-only +static char const doc_cdist[] = // + "Compute pairwise distances between two sets of input matrices.\n\n" + "Args:\n" + " a (NDArray): First matrix.\n" + " b (NDArray): Second matrix.\n" + " metric (str, optional): Distance metric to use (e.g., 'sqeuclidean', 'cosine').\n" + " out (NDArray, optional): Output matrix to store the result.\n" + " dtype (Union[IntegralType, FloatType, ComplexType], optional): Override the presumed input type.\n" + " out_dtype (Union[FloatType, ComplexType], optional): Result type, default is 'float64'.\n\n" + " threads (int, optional): Number of threads to use (default is 1).\n" + "Returns:\n" + " DistancesTensor: Pairwise distances between all inputs.\n\n" + "Equivalent to: `scipy.spatial.distance.cdist`.\n" + "Notes:\n" + " * `a` and `b` are positional-only arguments.\n" + " * `metric` can be positional or keyword.\n" + " * `out`, `threads`, `dtype`, and `out_dtype` are keyword-only arguments."; + +static PyObject *api_cdist( // + PyObject *self, PyObject *const *args, Py_ssize_t const positional_args_count, PyObject *args_names_tuple) { + + // This function accepts up to 7 arguments - more than SciPy: + // https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.cdist.html + PyObject *a_obj = NULL; // Required object, positional-only + PyObject *b_obj = NULL; // Required object, positional-only + PyObject *metric_obj = NULL; // Optional string, "metric" keyword or positional + PyObject *out_obj = NULL; // Optional object, "out" keyword-only + PyObject *dtype_obj = NULL; // Optional string, "dtype" keyword-only + PyObject *out_dtype_obj = NULL; // Optional string, "out_dtype" keyword-only + PyObject *threads_obj = NULL; // Optional integer, "threads" keyword-only // Once parsed, the arguments will be stored in these variables: - char const* metric_str = NULL; unsigned long long threads = 1; - char const* dtype_str = NULL; - char const* out_dtype_str = NULL; - - // The lazy implementation would be to use `PyArg_ParseTupleAndKeywords` for a `kwnames` dictionary: - // static char* kwlist[] = {"input_tensor_a", "input_tensor_b", "metric", "threads", "dtype", "out_dtype", NULL}; - // if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OO|s$Kss", kwlist, &input_tensor_a, &input_tensor_b, &metric_str, - // &threads, &dtype_str, &out_dtype_str)) - // return NULL; - Py_ssize_t kwnames_count = kwnames ? PyTuple_Size(kwnames) : 0; - Py_ssize_t args_count = positional_args_count + kwnames_count; - if (args_count < 2 || args_count > 6) { - PyErr_Format(PyExc_TypeError, "Function expects 2-6 arguments, got %d", args_count); + char const *dtype_str = NULL, *out_dtype_str = NULL; + simsimd_datatype_t dtype = simsimd_datatype_unknown_k, out_dtype = simsimd_datatype_unknown_k; + + /// Same default as in SciPy: + /// https://docs.scipy.org/doc/scipy-1.11.4/reference/generated/scipy.spatial.distance.cdist.html + simsimd_metric_kind_t metric_kind = simsimd_metric_euclidean_k; + char const *metric_str = NULL; + + // Parse the arguments + Py_ssize_t const args_names_count = args_names_tuple ? PyTuple_Size(args_names_tuple) : 0; + Py_ssize_t const args_count = positional_args_count + args_names_count; + if (args_count < 2 || args_count > 7) { + PyErr_Format(PyExc_TypeError, "Function expects 2-7 arguments, got %zd", args_count); + return NULL; + } + if (positional_args_count > 3) { + PyErr_Format(PyExc_TypeError, "Only first 3 arguments can be positional, received %zd", positional_args_count); return NULL; } // Positional-only arguments (first and second matrix) - input_tensor_a = args[0]; - input_tensor_b = args[1]; + a_obj = args[0]; + b_obj = args[1]; // Positional or keyword arguments (metric) - Py_ssize_t args_progress = 2; - Py_ssize_t remaining_positional_args_count = args_count - args_progress - kwnames_count; - if (remaining_positional_args_count == 1) { - metric_obj = args[2]; - args_progress = 3; - } else if (remaining_positional_args_count > 1) { - PyErr_Format(PyExc_TypeError, "Only first 3 arguments can be positional, received %zd", - remaining_positional_args_count); - return NULL; - } - - // The rest of the arguments must be checked in the keyword dictionary. - // There may be a case, when the call is ill-formed and more positional arguments are provided than needed. - // For a call like: - // - // cdist(a, b, "cos", "dos"): positional_args_count == 4, kwnames_count == 0 - // cdist(a, b, "cos", metric="dos"): positional_args_count == 3, kwnames_count == 1 - // cdist(a, b, metric="cos", metric="dos"): positional_args_count == 2, kwnames_count == 2 - // - // https://ashvardanian.com/posts/discount-on-keyword-arguments-in-python/ - for (Py_ssize_t kwnames_progress = 0; kwnames_progress < kwnames_count; ++args_progress, ++kwnames_progress) { - PyObject* key = PyTuple_GetItem(kwnames, kwnames_progress); - PyObject* value = args[args_progress]; - if (PyUnicode_CompareWithASCIIString(key, "threads") == 0) { - if (threads_obj != NULL) { - PyErr_SetString(PyExc_TypeError, "Got multiple values for argument 'threads'"); - return NULL; - } - threads_obj = value; - } else if (PyUnicode_CompareWithASCIIString(key, "dtype") == 0) { - if (dtype_obj != NULL) { - PyErr_SetString(PyExc_TypeError, "Got multiple values for argument 'dtype'"); - return NULL; - } - dtype_obj = value; - } else if (PyUnicode_CompareWithASCIIString(key, "out_dtype") == 0) { - if (out_dtype_obj != NULL) { - PyErr_SetString(PyExc_TypeError, "Got multiple values for argument 'out_dtype'"); - return NULL; - } - out_dtype_obj = value; - } else if (PyUnicode_CompareWithASCIIString(key, "metric") == 0) { - if (metric_obj != NULL) { - PyErr_SetString(PyExc_TypeError, "Got multiple values for argument 'metric'"); - return NULL; - } - metric_obj = value; - } else { + if (positional_args_count == 3) metric_obj = args[2]; + + // The rest of the arguments must be checked in the keyword dictionary: + for (Py_ssize_t args_names_tuple_progress = 0, args_progress = positional_args_count; + args_names_tuple_progress < args_names_count; ++args_progress, ++args_names_tuple_progress) { + PyObject *const key = PyTuple_GetItem(args_names_tuple, args_names_tuple_progress); + PyObject *const value = args[args_progress]; + if (PyUnicode_CompareWithASCIIString(key, "dtype") == 0 && !dtype_obj) { dtype_obj = value; } + else if (PyUnicode_CompareWithASCIIString(key, "out") == 0 && !out_obj) { out_obj = value; } + else if (PyUnicode_CompareWithASCIIString(key, "out_dtype") == 0 && !out_dtype_obj) { out_dtype_obj = value; } + else if (PyUnicode_CompareWithASCIIString(key, "threads") == 0 && !threads_obj) { threads_obj = value; } + else if (PyUnicode_CompareWithASCIIString(key, "metric") == 0 && !metric_obj) { metric_obj = value; } + else { PyErr_Format(PyExc_TypeError, "Got unexpected keyword argument: %S", key); return NULL; } } - // Process the PyObject values - /// Same default as in SciPy: - /// https://docs.scipy.org/doc/scipy-1.11.4/reference/generated/scipy.spatial.distance.cdist.html - simsimd_metric_kind_t metric_kind = simsimd_metric_euclidean_k; + // Convert `metric_obj` to `metric_str` and to `metric_kind` if (metric_obj) { metric_str = PyUnicode_AsUTF8(metric_obj); if (!metric_str && PyErr_Occurred()) { @@ -1049,15 +1276,14 @@ static PyObject* api_cdist(PyObject* self, PyObject* const* args, Py_ssize_t pos } } - threads = 1; - if (threads_obj) - threads = PyLong_AsSize_t(threads_obj); + // Convert `threads_obj` to `threads` integer + if (threads_obj) threads = PyLong_AsSize_t(threads_obj); if (PyErr_Occurred()) { PyErr_SetString(PyExc_TypeError, "Expected 'threads' to be an unsigned integer"); return NULL; } - simsimd_datatype_t dtype = simsimd_datatype_unknown_k; + // Convert `dtype_obj` to `dtype_str` and to `dtype` if (dtype_obj) { dtype_str = PyUnicode_AsUTF8(dtype_obj); if (!dtype_str && PyErr_Occurred()) { @@ -1071,7 +1297,7 @@ static PyObject* api_cdist(PyObject* self, PyObject* const* args, Py_ssize_t pos } } - simsimd_datatype_t out_dtype = simsimd_datatype_unknown_k; + // Convert `out_dtype_obj` to `out_dtype_str` and to `out_dtype` if (out_dtype_obj) { out_dtype_str = PyUnicode_AsUTF8(out_dtype_obj); if (!out_dtype_str && PyErr_Occurred()) { @@ -1085,400 +1311,741 @@ static PyObject* api_cdist(PyObject* self, PyObject* const* args, Py_ssize_t pos } } - return impl_cdist(input_tensor_a, input_tensor_b, metric_kind, threads, dtype, out_dtype); + return implement_cdist(a_obj, b_obj, out_obj, metric_kind, threads, dtype, out_dtype); } -static PyObject* api_l2_pointer(PyObject* self, PyObject* args) { - return implement_pointer_access(simsimd_metric_l2_k, args); +static char const doc_l2_pointer[] = "Get (int) pointer to the `simsimd.l2` kernel."; +static PyObject *api_l2_pointer(PyObject *self, PyObject *dtype_obj) { + return implement_pointer_access(simsimd_metric_l2_k, dtype_obj); } -static PyObject* api_l2sq_pointer(PyObject* self, PyObject* args) { - return implement_pointer_access(simsimd_metric_l2sq_k, args); +static char const doc_l2sq_pointer[] = "Get (int) pointer to the `simsimd.l2sq` kernel."; +static PyObject *api_l2sq_pointer(PyObject *self, PyObject *dtype_obj) { + return implement_pointer_access(simsimd_metric_l2sq_k, dtype_obj); } -static PyObject* api_cos_pointer(PyObject* self, PyObject* args) { - return implement_pointer_access(simsimd_metric_cos_k, args); +static char const doc_cos_pointer[] = "Get (int) pointer to the `simsimd.cos` kernel."; +static PyObject *api_cos_pointer(PyObject *self, PyObject *dtype_obj) { + return implement_pointer_access(simsimd_metric_cos_k, dtype_obj); } -static PyObject* api_dot_pointer(PyObject* self, PyObject* args) { - return implement_pointer_access(simsimd_metric_dot_k, args); +static char const doc_dot_pointer[] = "Get (int) pointer to the `simsimd.dot` kernel."; +static PyObject *api_dot_pointer(PyObject *self, PyObject *dtype_obj) { + return implement_pointer_access(simsimd_metric_dot_k, dtype_obj); } -static PyObject* api_vdot_pointer(PyObject* self, PyObject* args) { - return implement_pointer_access(simsimd_metric_vdot_k, args); +static char const doc_vdot_pointer[] = "Get (int) pointer to the `simsimd.vdot` kernel."; +static PyObject *api_vdot_pointer(PyObject *self, PyObject *dtype_obj) { + return implement_pointer_access(simsimd_metric_vdot_k, dtype_obj); } -static PyObject* api_kl_pointer(PyObject* self, PyObject* args) { - return implement_pointer_access(simsimd_metric_kl_k, args); +static char const doc_kl_pointer[] = "Get (int) pointer to the `simsimd.kl` kernel."; +static PyObject *api_kl_pointer(PyObject *self, PyObject *dtype_obj) { + return implement_pointer_access(simsimd_metric_kl_k, dtype_obj); } -static PyObject* api_js_pointer(PyObject* self, PyObject* args) { - return implement_pointer_access(simsimd_metric_js_k, args); +static char const doc_js_pointer[] = "Get (int) pointer to the `simsimd.js` kernel."; +static PyObject *api_js_pointer(PyObject *self, PyObject *dtype_obj) { + return implement_pointer_access(simsimd_metric_js_k, dtype_obj); } -static PyObject* api_hamming_pointer(PyObject* self, PyObject* args) { - return implement_pointer_access(simsimd_metric_hamming_k, args); +static char const doc_hamming_pointer[] = "Get (int) pointer to the `simsimd.hamming` kernel."; +static PyObject *api_hamming_pointer(PyObject *self, PyObject *dtype_obj) { + return implement_pointer_access(simsimd_metric_hamming_k, dtype_obj); } -static PyObject* api_jaccard_pointer(PyObject* self, PyObject* args) { - return implement_pointer_access(simsimd_metric_jaccard_k, args); +static char const doc_jaccard_pointer[] = "Get (int) pointer to the `simsimd.jaccard` kernel."; +static PyObject *api_jaccard_pointer(PyObject *self, PyObject *dtype_obj) { + return implement_pointer_access(simsimd_metric_jaccard_k, dtype_obj); } -static PyObject* api_l2(PyObject* self, PyObject* const* args, Py_ssize_t nargs) { - return implement_dense_metric(simsimd_metric_l2_k, args, nargs); + +static char const doc_l2[] = // + "Compute Euclidean (L2) distances between two matrices.\n\n" + "Args:\n" + " a (NDArray): First matrix or vector.\n" + " b (NDArray): Second matrix or vector.\n" + " dtype (Union[IntegralType, FloatType], optional): Override the presumed input type.\n" + " out (NDArray, optional): Vector for resulting distances. Allocates a new tensor by default.\n" + " out_dtype (FloatType, optional): Result type, default is 'float64'.\n\n" + "Returns:\n" + " DistancesTensor: The distances if `out` is not provided.\n" + " None: If `out` is provided. Operation will per performed in-place.\n\n" + "Equivalent to: `scipy.spatial.distance.euclidean`.\n" + "Signature:\n" + " >>> def euclidean(a, b, /, dtype, *, out, out_dtype) -> Optional[DistancesTensor]: ..."; + +static PyObject *api_l2(PyObject *self, PyObject *const *args, Py_ssize_t const positional_args_count, + PyObject *args_names_tuple) { + return implement_dense_metric(simsimd_metric_l2_k, args, positional_args_count, args_names_tuple); } -static PyObject* api_l2sq(PyObject* self, PyObject* const* args, Py_ssize_t nargs) { - return implement_dense_metric(simsimd_metric_l2sq_k, args, nargs); + +static char const doc_l2sq[] = // + "Compute squared Euclidean (L2) distances between two matrices.\n\n" + "Args:\n" + " a (NDArray): First matrix or vector.\n" + " b (NDArray): Second matrix or vector.\n" + " dtype (Union[IntegralType, FloatType], optional): Override the presumed input type.\n" + " out (NDArray, optional): Vector for resulting distances. Allocates a new tensor by default.\n" + " out_dtype (FloatType, optional): Result type, default is 'float64'.\n\n" + "Returns:\n" + " DistancesTensor: The distances if `out` is not provided.\n" + " None: If `out` is provided. Operation will per performed in-place.\n\n" + "Equivalent to: `scipy.spatial.distance.sqeuclidean`.\n" + "Signature:\n" + " >>> def sqeuclidean(a, b, /, dtype, *, out, out_dtype) -> Optional[DistancesTensor]: ..."; + +static PyObject *api_l2sq(PyObject *self, PyObject *const *args, Py_ssize_t const positional_args_count, + PyObject *args_names_tuple) { + return implement_dense_metric(simsimd_metric_l2sq_k, args, positional_args_count, args_names_tuple); } -static PyObject* api_cos(PyObject* self, PyObject* const* args, Py_ssize_t nargs) { - return implement_dense_metric(simsimd_metric_cos_k, args, nargs); + +static char const doc_cos[] = // + "Compute cosine (angular) distances between two matrices.\n\n" + "Args:\n" + " a (NDArray): First matrix or vector.\n" + " b (NDArray): Second matrix or vector.\n" + " dtype (Union[IntegralType, FloatType], optional): Override the presumed input type.\n" + " out (NDArray, optional): Vector for resulting distances. Allocates a new tensor by default.\n" + " out_dtype (FloatType, optional): Result type, default is 'float64'.\n\n" + "Returns:\n" + " DistancesTensor: The distances if `out` is not provided.\n" + " None: If `out` is provided. Operation will per performed in-place.\n\n" + "Equivalent to: `scipy.spatial.distance.cosine`.\n" + "Signature:\n" + " >>> def cosine(a, b, /, dtype, *, out, out_dtype) -> Optional[DistancesTensor]: ..."; + +static PyObject *api_cos(PyObject *self, PyObject *const *args, Py_ssize_t const positional_args_count, + PyObject *args_names_tuple) { + return implement_dense_metric(simsimd_metric_cos_k, args, positional_args_count, args_names_tuple); } -static PyObject* api_dot(PyObject* self, PyObject* const* args, Py_ssize_t nargs) { - return implement_dense_metric(simsimd_metric_dot_k, args, nargs); + +static char const doc_dot[] = // + "Compute the inner (dot) product between two matrices (real or complex).\n\n" + "Args:\n" + " a (NDArray): First matrix or vector.\n" + " b (NDArray): Second matrix or vector.\n" + " dtype (Union[IntegralType, FloatType, ComplexType], optional): Override the presumed input type.\n" + " out (NDArray, optional): Vector for resulting distances. Allocates a new tensor by default.\n" + " out_dtype (Union[FloatType, ComplexType], optional): Result type, default is 'float64'.\n\n" + "Returns:\n" + " DistancesTensor: The distances if `out` is not provided.\n" + " None: If `out` is provided. Operation will per performed in-place.\n\n" + "Equivalent to: `numpy.inner`.\n" + "Signature:\n" + " >>> def dot(a, b, /, dtype, *, out, out_dtype) -> Optional[DistancesTensor]: ..."; + +static PyObject *api_dot(PyObject *self, PyObject *const *args, Py_ssize_t const positional_args_count, + PyObject *args_names_tuple) { + return implement_dense_metric(simsimd_metric_dot_k, args, positional_args_count, args_names_tuple); } -static PyObject* api_vdot(PyObject* self, PyObject* const* args, Py_ssize_t nargs) { - return implement_dense_metric(simsimd_metric_vdot_k, args, nargs); + +static char const doc_vdot[] = // + "Compute the conjugate dot product between two complex matrices.\n\n" + "Args:\n" + " a (NDArray): First complex matrix or vector.\n" + " b (NDArray): Second complex matrix or vector.\n" + " dtype (ComplexType, optional): Override the presumed input type.\n" + " out (NDArray, optional): Vector for resulting distances. Allocates a new tensor by default.\n" + " out_dtype (Union[ComplexType], optional): Result type, default is 'float64'.\n\n" + "Returns:\n" + " DistancesTensor: The distances if `out` is not provided.\n" + " None: If `out` is provided. Operation will per performed in-place.\n\n" + "Equivalent to: `numpy.vdot`.\n" + "Signature:\n" + " >>> def vdot(a, b, /, dtype, *, out, out_dtype) -> Optional[DistancesTensor]: ..."; + +static PyObject *api_vdot(PyObject *self, PyObject *const *args, Py_ssize_t const positional_args_count, + PyObject *args_names_tuple) { + return implement_dense_metric(simsimd_metric_vdot_k, args, positional_args_count, args_names_tuple); } -static PyObject* api_kl(PyObject* self, PyObject* const* args, Py_ssize_t nargs) { - return implement_dense_metric(simsimd_metric_kl_k, args, nargs); + +static char const doc_kl[] = // + "Compute Kullback-Leibler divergences between two matrices.\n\n" + "Args:\n" + " a (NDArray): First floating-point matrix or vector.\n" + " b (NDArray): Second floating-point matrix or vector.\n" + " dtype (FloatType, optional): Override the presumed input type.\n" + " out (NDArray, optional): Vector for resulting distances. Allocates a new tensor by default.\n" + " out_dtype (FloatType, optional): Result type, default is 'float64'.\n\n" + "Returns:\n" + " DistancesTensor: The distances if `out` is not provided.\n" + " None: If `out` is provided. Operation will per performed in-place.\n\n" + "Equivalent to: `scipy.special.kl_div`.\n" + "Signature:\n" + " >>> def kl(a, b, /, dtype, *, out, out_dtype) -> Optional[DistancesTensor]: ..."; + +static PyObject *api_kl(PyObject *self, PyObject *const *args, Py_ssize_t const positional_args_count, + PyObject *args_names_tuple) { + return implement_dense_metric(simsimd_metric_kl_k, args, positional_args_count, args_names_tuple); } -static PyObject* api_js(PyObject* self, PyObject* const* args, Py_ssize_t nargs) { - return implement_dense_metric(simsimd_metric_js_k, args, nargs); + +static char const doc_js[] = // + "Compute Jensen-Shannon divergences between two matrices.\n\n" + "Args:\n" + " a (NDArray): First floating-point matrix or vector.\n" + " b (NDArray): Second floating-point matrix or vector.\n" + " dtype (Union[IntegralType, FloatType], optional): Override the presumed input type.\n" + " out (NDArray, optional): Vector for resulting distances. Allocates a new tensor by default.\n" + " out_dtype (FloatType, optional): Result type, default is 'float64'.\n\n" + "Returns:\n" + " DistancesTensor: The distances if `out` is not provided.\n" + " None: If `out` is provided. Operation will per performed in-place.\n\n" + "Equivalent to: `scipy.spatial.distance.jensenshannon`.\n" + "Signature:\n" + " >>> def kl(a, b, /, dtype, *, out, out_dtype) -> Optional[DistancesTensor]: ..."; + +static PyObject *api_js(PyObject *self, PyObject *const *args, Py_ssize_t const positional_args_count, + PyObject *args_names_tuple) { + return implement_dense_metric(simsimd_metric_js_k, args, positional_args_count, args_names_tuple); } -static PyObject* api_hamming(PyObject* self, PyObject* const* args, Py_ssize_t nargs) { - return implement_dense_metric(simsimd_metric_hamming_k, args, nargs); + +static char const doc_hamming[] = // + "Compute Hamming distances between two matrices.\n\n" + "Args:\n" + " a (NDArray): First binary matrix or vector.\n" + " b (NDArray): Second binary matrix or vector.\n" + " dtype (IntegralType, optional): Override the presumed input type.\n" + " out (NDArray, optional): Vector for resulting distances. Allocates a new tensor by default.\n" + " out_dtype (FloatType, optional): Result type, default is 'float64'.\n\n" + "Returns:\n" + " DistancesTensor: The distances if `out` is not provided.\n" + " None: If `out` is provided. Operation will per performed in-place.\n\n" + "Similar to: `scipy.spatial.distance.hamming`.\n" + "Signature:\n" + " >>> def hamming(a, b, /, dtype, *, out, out_dtype) -> Optional[DistancesTensor]: ..."; + +static PyObject *api_hamming(PyObject *self, PyObject *const *args, Py_ssize_t const positional_args_count, + PyObject *args_names_tuple) { + return implement_dense_metric(simsimd_metric_hamming_k, args, positional_args_count, args_names_tuple); } -static PyObject* api_jaccard(PyObject* self, PyObject* const* args, Py_ssize_t nargs) { - return implement_dense_metric(simsimd_metric_jaccard_k, args, nargs); + +static char const doc_jaccard[] = // + "Compute Jaccard distances (bitwise Tanimoto) between two matrices.\n\n" + "Args:\n" + " a (NDArray): First binary matrix or vector.\n" + " b (NDArray): Second binary matrix or vector.\n" + " dtype (IntegralType, optional): Override the presumed input type.\n" + " out (NDArray, optional): Vector for resulting distances. Allocates a new tensor by default.\n" + " out_dtype (FloatType, optional): Result type, default is 'float64'.\n\n" + "Returns:\n" + " DistancesTensor: The distances if `out` is not provided.\n" + " None: If `out` is provided. Operation will per performed in-place.\n\n" + "Similar to: `scipy.spatial.distance.jaccard`.\n" + "Signature:\n" + " >>> def jaccard(a, b, /, dtype, *, out, out_dtype) -> Optional[DistancesTensor]: ..."; + +static PyObject *api_jaccard(PyObject *self, PyObject *const *args, Py_ssize_t const positional_args_count, + PyObject *args_names_tuple) { + return implement_dense_metric(simsimd_metric_jaccard_k, args, positional_args_count, args_names_tuple); } -static PyObject* api_bilinear(PyObject* self, PyObject* const* args, Py_ssize_t nargs) { - return implement_curved_metric(simsimd_metric_bilinear_k, args, nargs); + +static char const doc_bilinear[] = // + "Compute the bilinear form between two vectors given a metric tensor.\n\n" + "Args:\n" + " a (NDArray): First vector.\n" + " b (NDArray): Second vector.\n" + " metric_tensor (NDArray): The metric tensor defining the bilinear form.\n" + " dtype (FloatType, optional): Override the presumed input type.\n\n" + "Returns:\n" + " float: The bilinear form.\n\n" + "Equivalent to: `numpy.dot` with a metric tensor.\n" + "Signature:\n" + " >>> def bilinear(a, b, metric_tensor, /, dtype) -> float: ..."; + +static PyObject *api_bilinear(PyObject *self, PyObject *const *args, Py_ssize_t const positional_args_count, + PyObject *args_names_tuple) { + return implement_curved_metric(simsimd_metric_bilinear_k, args, positional_args_count, args_names_tuple); } -static PyObject* api_mahalanobis(PyObject* self, PyObject* const* args, Py_ssize_t nargs) { - return implement_curved_metric(simsimd_metric_mahalanobis_k, args, nargs); + +static char const doc_mahalanobis[] = // + "Compute the Mahalanobis distance between two vectors given an inverse covariance matrix.\n\n" + "Args:\n" + " a (NDArray): First vector.\n" + " b (NDArray): Second vector.\n" + " inverse_covariance (NDArray): The inverse of the covariance matrix.\n" + " dtype (FloatType, optional): Override the presumed input type.\n\n" + "Returns:\n" + " float: The Mahalanobis distance.\n\n" + "Equivalent to: `scipy.spatial.distance.mahalanobis`.\n" + "Signature:\n" + " >>> def mahalanobis(a, b, inverse_covariance, /, dtype) -> float: ..."; + +static PyObject *api_mahalanobis(PyObject *self, PyObject *const *args, Py_ssize_t const positional_args_count, + PyObject *args_names_tuple) { + return implement_curved_metric(simsimd_metric_mahalanobis_k, args, positional_args_count, args_names_tuple); } -static PyObject* api_intersect(PyObject* self, PyObject* const* args, Py_ssize_t nargs) { + +static char const doc_intersect[] = // + "Compute the intersection of two sorted integer arrays.\n\n" + "Args:\n" + " a (NDArray): First sorted integer array.\n" + " b (NDArray): Second sorted integer array.\n\n" + "Returns:\n" + " float: The number of intersecting elements.\n\n" + "Similar to: `numpy.intersect1d`." + "Signature:\n" + " >>> def intersect(a, b, /) -> float: ..."; + +static PyObject *api_intersect(PyObject *self, PyObject *const *args, Py_ssize_t nargs) { return implement_sparse_metric(simsimd_metric_intersect_k, args, nargs); } +static char const doc_fma[] = // + "Fused-Multiply-Add between 3 input vectors.\n\n" + "Args:\n" + " a (NDArray): First vector.\n" + " b (NDArray): Second vector.\n" + " c (NDArray): Third vector.\n" + " dtype (Union[IntegralType, FloatType], optional): Override the presumed numeric type.\n" + " alpha (float, optional): First scale, 1.0 by default.\n" + " beta (float, optional): Second scale, 1.0 by default.\n" + " out (NDArray, optional): Vector for resulting distances.\n\n" + "Returns:\n" + " DistancesTensor: The distances if `out` is not provided.\n" + " None: If `out` is provided. Operation will per performed in-place.\n\n" + "Equivalent to: `alpha * a * b + beta * c`.\n" + "Signature:\n" + " >>> def fma(a, b, c, /, dtype, *, alpha, beta, out) -> Optional[DistancesTensor]: ..."; + +static PyObject *api_fma(PyObject *self, PyObject *const *args, Py_ssize_t const positional_args_count, + PyObject *args_names_tuple) { + + PyObject *return_obj = NULL; + + // This function accepts up to 5 arguments: + PyObject *a_obj = NULL; // Required object, positional-only + PyObject *b_obj = NULL; // Required object, positional-only + PyObject *c_obj = NULL; // Required object, positional-only + PyObject *dtype_obj = NULL; // Optional object, "dtype" keyword or positional + PyObject *out_obj = NULL; // Optional object, "out" keyword-only + PyObject *alpha_obj = NULL; // Optional object, "alpha" keyword-only + PyObject *beta_obj = NULL; // Optional object, "beta" keyword-only + + // Once parsed, the arguments will be stored in these variables: + char const *dtype_str = NULL; + simsimd_datatype_t dtype = simsimd_datatype_unknown_k; + simsimd_distance_t alpha = 1, beta = 1; + + Py_buffer a_buffer, b_buffer, c_buffer, out_buffer; + TensorArgument a_parsed, b_parsed, c_parsed, out_parsed; + memset(&a_buffer, 0, sizeof(Py_buffer)); + memset(&b_buffer, 0, sizeof(Py_buffer)); + memset(&c_buffer, 0, sizeof(Py_buffer)); + memset(&out_buffer, 0, sizeof(Py_buffer)); + + Py_ssize_t const args_names_count = args_names_tuple ? PyTuple_Size(args_names_tuple) : 0; + Py_ssize_t const args_count = positional_args_count + args_names_count; + if (args_count < 3 || args_count > 7) { + PyErr_Format(PyExc_TypeError, "Function expects 3-7 arguments, got %zd", args_count); + return NULL; + } + if (positional_args_count > 4) { + PyErr_Format(PyExc_TypeError, "Only first 4 arguments can be positional, received %zd", positional_args_count); + return NULL; + } + + // Positional-only arguments (first and second matrix) + a_obj = args[0]; + b_obj = args[1]; + c_obj = args[2]; + + // Positional or keyword arguments (dtype) + if (positional_args_count == 4) dtype_obj = args[3]; + + // The rest of the arguments must be checked in the keyword dictionary: + for (Py_ssize_t args_names_tuple_progress = 0, args_progress = positional_args_count; + args_names_tuple_progress < args_names_count; ++args_progress, ++args_names_tuple_progress) { + PyObject *const key = PyTuple_GetItem(args_names_tuple, args_names_tuple_progress); + PyObject *const value = args[args_progress]; + if (PyUnicode_CompareWithASCIIString(key, "dtype") == 0 && !dtype_obj) { dtype_obj = value; } + else if (PyUnicode_CompareWithASCIIString(key, "out") == 0 && !out_obj) { out_obj = value; } + else if (PyUnicode_CompareWithASCIIString(key, "alpha") == 0 && !alpha_obj) { alpha_obj = value; } + else if (PyUnicode_CompareWithASCIIString(key, "beta") == 0 && !beta_obj) { beta_obj = value; } + else { + PyErr_Format(PyExc_TypeError, "Got unexpected keyword argument: %S", key); + return NULL; + } + } + + // Convert `dtype_obj` to `dtype_str` and to `dtype` + if (dtype_obj) { + dtype_str = PyUnicode_AsUTF8(dtype_obj); + if (!dtype_str && PyErr_Occurred()) { + PyErr_SetString(PyExc_TypeError, "Expected 'dtype' to be a string"); + return NULL; + } + dtype = python_string_to_datatype(dtype_str); + if (dtype == simsimd_datatype_unknown_k) { + PyErr_SetString(PyExc_ValueError, "Unsupported 'dtype'"); + return NULL; + } + } + + // Convert `alpha_obj` to `alpha` and `beta_obj` to `beta` + if (alpha_obj) alpha = PyFloat_AsDouble(alpha_obj); + if (beta_obj) beta = PyFloat_AsDouble(beta_obj); + if (PyErr_Occurred()) { + PyErr_SetString(PyExc_TypeError, "Expected 'alpha' and 'beta' to be a float"); + return NULL; + } + + // Convert `a_obj` to `a_buffer` and to `a_parsed`. Same for `b_obj` and `out_obj`. + if (!parse_tensor(a_obj, &a_buffer, &a_parsed) || !parse_tensor(b_obj, &b_buffer, &b_parsed) || + !parse_tensor(c_obj, &c_buffer, &c_parsed)) + return NULL; + if (out_obj && !parse_tensor(out_obj, &out_buffer, &out_parsed)) return NULL; + + // Check dimensions + if (a_parsed.rank != 1 || b_parsed.rank != 1 || c_parsed.rank != 1 || (out_obj && out_parsed.rank != 1)) { + PyErr_SetString(PyExc_ValueError, "All tensors must be vectors"); + goto cleanup; + } + if (a_parsed.dimensions != b_parsed.dimensions || a_parsed.dimensions != c_parsed.dimensions || + (out_obj && a_parsed.dimensions != out_parsed.dimensions)) { + PyErr_SetString(PyExc_ValueError, "Vector dimensions don't match"); + goto cleanup; + } + + // Check data types + if (a_parsed.datatype != b_parsed.datatype || a_parsed.datatype == simsimd_datatype_unknown_k || + b_parsed.datatype == simsimd_datatype_unknown_k || c_parsed.datatype == simsimd_datatype_unknown_k || + (out_obj && out_parsed.datatype == simsimd_datatype_unknown_k)) { + PyErr_SetString(PyExc_TypeError, + "Input tensors must have matching datatypes, check with `X.__array_interface__`"); + goto cleanup; + } + if (dtype == simsimd_datatype_unknown_k) dtype = a_parsed.datatype; + + // Look up the metric and the capability + simsimd_kernel_fma_punned_t metric = NULL; + simsimd_capability_t capability = simsimd_cap_serial_k; + simsimd_metric_kind_t const metric_kind = simsimd_metric_fma_k; + simsimd_find_metric_punned(metric_kind, dtype, static_capabilities, simsimd_cap_any_k, + (simsimd_metric_punned_t *)&metric, &capability); + if (!metric) { + PyErr_Format( // + PyExc_LookupError, + "Unsupported metric '%c' and datatype combination across vectors ('%s'/'%s') and " + "`dtype` override ('%s'/'%s')", + metric_kind, // + a_buffer.format ? a_buffer.format : "nil", datatype_to_python_string(a_parsed.datatype), // + dtype_str ? dtype_str : "nil", datatype_to_python_string(dtype)); + goto cleanup; + } + + char *distances_start = NULL; + size_t distances_stride_bytes = 0; + + // Allocate the output matrix if it wasn't provided + if (!out_obj) { + DistancesTensor *distances_obj = + PyObject_NewVar(DistancesTensor, &DistancesTensorType, a_parsed.dimensions * bytes_per_datatype(dtype)); + if (!distances_obj) { + PyErr_NoMemory(); + goto cleanup; + } + + // Initialize the object + distances_obj->datatype = dtype; + distances_obj->dimensions = 1; + distances_obj->shape[0] = a_parsed.dimensions; + distances_obj->shape[1] = 1; + distances_obj->strides[0] = bytes_per_datatype(dtype); + distances_obj->strides[1] = 0; + return_obj = (PyObject *)distances_obj; + distances_start = (char *)&distances_obj->start[0]; + distances_stride_bytes = distances_obj->strides[0]; + } + else { + distances_start = (char *)&out_parsed.start[0]; + distances_stride_bytes = out_buffer.strides[0]; + //? Logic suggests to return `None` in in-place mode... + //? SciPy decided differently. + return_obj = Py_None; + } + + metric(a_parsed.start, b_parsed.start, c_parsed.start, a_parsed.dimensions, alpha, beta, distances_start); +cleanup: + PyBuffer_Release(&a_buffer); + PyBuffer_Release(&b_buffer); + PyBuffer_Release(&c_buffer); + PyBuffer_Release(&out_buffer); + return return_obj; +} + +static char const doc_wsum[] = // + "Weighted Sum of 2 input vectors.\n\n" + "Args:\n" + " a (NDArray): First vector.\n" + " b (NDArray): Second vector.\n" + " dtype (Union[IntegralType, FloatType], optional): Override the presumed numeric type.\n" + " alpha (float, optional): First scale, 1.0 by default.\n" + " beta (float, optional): Second scale, 1.0 by default.\n" + " out (NDArray, optional): Vector for resulting distances.\n\n" + "Returns:\n" + " DistancesTensor: The distances if `out` is not provided.\n" + " None: If `out` is provided. Operation will per performed in-place.\n\n" + "Equivalent to: `alpha * a + beta * b`.\n" + "Signature:\n" + " >>> def wsum(a, b, /, dtype, *, alpha, beta, out) -> Optional[DistancesTensor]: ..."; + +static PyObject *api_wsum(PyObject *self, PyObject *const *args, Py_ssize_t const positional_args_count, + PyObject *args_names_tuple) { + + PyObject *return_obj = NULL; + + // This function accepts up to 5 arguments: + PyObject *a_obj = NULL; // Required object, positional-only + PyObject *b_obj = NULL; // Required object, positional-only + PyObject *dtype_obj = NULL; // Optional object, "dtype" keyword or positional + PyObject *out_obj = NULL; // Optional object, "out" keyword-only + PyObject *alpha_obj = NULL; // Optional object, "alpha" keyword-only + PyObject *beta_obj = NULL; // Optional object, "beta" keyword-only + + // Once parsed, the arguments will be stored in these variables: + char const *dtype_str = NULL; + simsimd_datatype_t dtype = simsimd_datatype_unknown_k; + simsimd_distance_t alpha = 1, beta = 1; + + Py_buffer a_buffer, b_buffer, out_buffer; + TensorArgument a_parsed, b_parsed, out_parsed; + memset(&a_buffer, 0, sizeof(Py_buffer)); + memset(&b_buffer, 0, sizeof(Py_buffer)); + memset(&out_buffer, 0, sizeof(Py_buffer)); + + Py_ssize_t const args_names_count = args_names_tuple ? PyTuple_Size(args_names_tuple) : 0; + Py_ssize_t const args_count = positional_args_count + args_names_count; + if (args_count < 2 || args_count > 6) { + PyErr_Format(PyExc_TypeError, "Function expects 2-6 arguments, got %zd", args_count); + return NULL; + } + if (positional_args_count > 3) { + PyErr_Format(PyExc_TypeError, "Only first 3 arguments can be positional, received %zd", positional_args_count); + return NULL; + } + + // Positional-only arguments (first and second matrix) + a_obj = args[0]; + b_obj = args[1]; + + // Positional or keyword arguments (dtype) + if (positional_args_count == 3) dtype_obj = args[2]; + + // The rest of the arguments must be checked in the keyword dictionary: + for (Py_ssize_t args_names_tuple_progress = 0, args_progress = positional_args_count; + args_names_tuple_progress < args_names_count; ++args_progress, ++args_names_tuple_progress) { + PyObject *const key = PyTuple_GetItem(args_names_tuple, args_names_tuple_progress); + PyObject *const value = args[args_progress]; + if (PyUnicode_CompareWithASCIIString(key, "dtype") == 0 && !dtype_obj) { dtype_obj = value; } + else if (PyUnicode_CompareWithASCIIString(key, "out") == 0 && !out_obj) { out_obj = value; } + else if (PyUnicode_CompareWithASCIIString(key, "alpha") == 0 && !alpha_obj) { alpha_obj = value; } + else if (PyUnicode_CompareWithASCIIString(key, "beta") == 0 && !beta_obj) { beta_obj = value; } + else { + PyErr_Format(PyExc_TypeError, "Got unexpected keyword argument: %S", key); + return NULL; + } + } + + // Convert `dtype_obj` to `dtype_str` and to `dtype` + if (dtype_obj) { + dtype_str = PyUnicode_AsUTF8(dtype_obj); + if (!dtype_str && PyErr_Occurred()) { + PyErr_SetString(PyExc_TypeError, "Expected 'dtype' to be a string"); + return NULL; + } + dtype = python_string_to_datatype(dtype_str); + if (dtype == simsimd_datatype_unknown_k) { + PyErr_SetString(PyExc_ValueError, "Unsupported 'dtype'"); + return NULL; + } + } + + // Convert `alpha_obj` to `alpha` and `beta_obj` to `beta` + if (alpha_obj) alpha = PyFloat_AsDouble(alpha_obj); + if (beta_obj) beta = PyFloat_AsDouble(beta_obj); + if (PyErr_Occurred()) { + PyErr_SetString(PyExc_TypeError, "Expected 'alpha' and 'beta' to be a float"); + return NULL; + } + + // Convert `a_obj` to `a_buffer` and to `a_parsed`. Same for `b_obj` and `out_obj`. + if (!parse_tensor(a_obj, &a_buffer, &a_parsed) || !parse_tensor(b_obj, &b_buffer, &b_parsed)) return NULL; + if (out_obj && !parse_tensor(out_obj, &out_buffer, &out_parsed)) return NULL; + + // Check dimensions + if (a_parsed.rank != 1 || b_parsed.rank != 1 || (out_obj && out_parsed.rank != 1)) { + PyErr_SetString(PyExc_ValueError, "All tensors must be vectors"); + goto cleanup; + } + if (a_parsed.dimensions != b_parsed.dimensions || (out_obj && a_parsed.dimensions != out_parsed.dimensions)) { + PyErr_SetString(PyExc_ValueError, "Vector dimensions don't match"); + goto cleanup; + } + + // Check data types + if (a_parsed.datatype != b_parsed.datatype || a_parsed.datatype == simsimd_datatype_unknown_k || + b_parsed.datatype == simsimd_datatype_unknown_k || + (out_obj && out_parsed.datatype == simsimd_datatype_unknown_k)) { + PyErr_SetString(PyExc_TypeError, + "Input tensors must have matching datatypes, check with `X.__array_interface__`"); + goto cleanup; + } + if (dtype == simsimd_datatype_unknown_k) dtype = a_parsed.datatype; + + // Look up the metric and the capability + simsimd_kernel_wsum_punned_t metric = NULL; + simsimd_capability_t capability = simsimd_cap_serial_k; + simsimd_metric_kind_t const metric_kind = simsimd_metric_wsum_k; + simsimd_find_metric_punned(metric_kind, dtype, static_capabilities, simsimd_cap_any_k, + (simsimd_metric_punned_t *)&metric, &capability); + if (!metric) { + PyErr_Format( // + PyExc_LookupError, + "Unsupported metric '%c' and datatype combination across vectors ('%s'/'%s') and " + "`dtype` override ('%s'/'%s')", + metric_kind, // + a_buffer.format ? a_buffer.format : "nil", datatype_to_python_string(a_parsed.datatype), // + dtype_str ? dtype_str : "nil", datatype_to_python_string(dtype)); + goto cleanup; + } + + char *distances_start = NULL; + size_t distances_stride_bytes = 0; + + // Allocate the output matrix if it wasn't provided + if (!out_obj) { + DistancesTensor *distances_obj = + PyObject_NewVar(DistancesTensor, &DistancesTensorType, a_parsed.dimensions * bytes_per_datatype(dtype)); + if (!distances_obj) { + PyErr_NoMemory(); + goto cleanup; + } + + // Initialize the object + distances_obj->datatype = dtype; + distances_obj->dimensions = 1; + distances_obj->shape[0] = a_parsed.dimensions; + distances_obj->shape[1] = 1; + distances_obj->strides[0] = bytes_per_datatype(dtype); + distances_obj->strides[1] = 0; + return_obj = (PyObject *)distances_obj; + distances_start = (char *)&distances_obj->start[0]; + distances_stride_bytes = distances_obj->strides[0]; + } + else { + distances_start = (char *)&out_parsed.start[0]; + distances_stride_bytes = out_buffer.strides[0]; + //? Logic suggests to return `None` in in-place mode... + //? SciPy decided differently. + return_obj = Py_None; + } + + metric(a_parsed.start, b_parsed.start, a_parsed.dimensions, alpha, beta, distances_start); +cleanup: + PyBuffer_Release(&a_buffer); + PyBuffer_Release(&b_buffer); + PyBuffer_Release(&out_buffer); + return return_obj; +} + +// There are several flags we can use to define the functions: +// - `METH_O`: Single object argument +// - `METH_VARARGS`: Variable number of arguments +// - `METH_FASTCALL`: Fast calling convention +// - `METH_KEYWORDS`: Accepts keyword arguments, can be combined with `METH_FASTCALL` +// +// https://llllllllll.github.io/c-extension-tutorial/appendix.html#c.PyMethodDef.ml_flags static PyMethodDef simsimd_methods[] = { // Introspecting library and hardware capabilities - { - "get_capabilities", - (PyCFunction)api_get_capabilities, - METH_NOARGS, - "Get the current hardware SIMD capabilities as a dictionary of feature flags.\n" - "On x86 includes: 'serial', 'haswell', 'skylake', 'ice', 'genoa', 'sapphire', 'turin'.\n" - "On Arm includes: 'serial', 'neon', 'sve', 'sve2', and their extensions.\n", - }, - { - "enable_capability", - (PyCFunction)api_enable_capability, - METH_VARARGS, - "Enable a specific SIMD kernel family.\n\n" - "Args:\n" - " capability (str): The name of the SIMD feature to enable (e.g., 'haswell').", - }, - { - "disable_capability", - (PyCFunction)api_disable_capability, - METH_VARARGS, - "Disable a specific SIMD kernel family.\n\n" - "Args:\n" - " capability (str): The name of the SIMD feature to disable (e.g., 'haswell').", - }, + {"get_capabilities", (PyCFunction)api_get_capabilities, METH_NOARGS, doc_get_capabilities}, + {"enable_capability", (PyCFunction)api_enable_capability, METH_O, doc_enable_capability}, + {"disable_capability", (PyCFunction)api_disable_capability, METH_O, doc_disable_capability}, // NumPy and SciPy compatible interfaces for dense vector representations // Each function can compute distances between: // - A pair of vectors // - A batch of vector pairs (two matrices of identical shape) // - A matrix of vectors and a single vector - { - "euclidean", - (PyCFunction)api_l2, - METH_FASTCALL, - "Compute Euclidean (L2) distances between two matrices.\n\n" - "Args:\n" - " a (NDArray): First matrix or vector.\n" - " b (NDArray): Second matrix or vector.\n" - " dtype (Union[IntegralType, FloatType], optional): Override the presumed input type.\n" - " out_dtype (Union[FloatType, ComplexType], optional): Result type, default is 'float64'.\n\n" - "Returns:\n" - " DistancesTensor: The squared Euclidean distances.\n\n" - "Equivalent to: `scipy.spatial.distance.euclidean`.\n" - "Notes:\n" - " * `a` and `b` are positional-only arguments, while `dtype` and `out_dtype` are keyword-only arguments.", - }, - { - "sqeuclidean", - (PyCFunction)api_l2sq, - METH_FASTCALL, - "Compute squared Euclidean (L2) distances between two matrices.\n\n" - "Args:\n" - " a (NDArray): First matrix or vector.\n" - " b (NDArray): Second matrix or vector.\n" - " dtype (Union[IntegralType, FloatType], optional): Override the presumed input type.\n" - " out_dtype (Union[FloatType, ComplexType], optional): Result type, default is 'float64'.\n\n" - "Returns:\n" - " DistancesTensor: The squared Euclidean distances.\n\n" - "Equivalent to: `scipy.spatial.distance.sqeuclidean`.\n" - "Notes:\n" - " * `a` and `b` are positional-only arguments, while `dtype` and `out_dtype` are keyword-only arguments.", - }, - { - "cosine", - (PyCFunction)api_cos, - METH_FASTCALL, - "Compute cosine (angular) distances between two matrices.\n\n" - "Args:\n" - " a (NDArray): First matrix or vector.\n" - " b (NDArray): Second matrix or vector.\n" - " dtype (Union[IntegralType, FloatType], optional): Override the presumed input type.\n" - " out_dtype (Union[FloatType, ComplexType], optional): Result type, default is 'float64'.\n\n" - "Returns:\n" - " DistancesTensor: The cosine distances.\n\n" - "Equivalent to: `scipy.spatial.distance.cosine`.\n" - "Notes:\n" - " * `a` and `b` are positional-only arguments, while `dtype` and `out_dtype` are keyword-only arguments.", - }, - { - "inner", - (PyCFunction)api_dot, - METH_FASTCALL, - "Compute the inner (dot) product between two matrices (real or complex).\n\n" - "Args:\n" - " a (NDArray): First matrix or vector.\n" - " b (NDArray): Second matrix or vector.\n" - " dtype (Union[FloatType, ComplexType], optional): Override the presumed input type.\n\n" - "Returns:\n" - " DistancesTensor: The inner product.\n\n" - "Equivalent to: `numpy.inner`.\n" - "Notes:\n" - " * `a` and `b` are positional-only arguments, while `dtype` is a keyword-only argument.", - }, - { - "dot", - (PyCFunction)api_dot, - METH_FASTCALL, - "Compute the dot product between two matrices (real or complex).\n\n" - "Args:\n" - " a (NDArray): First matrix or vector.\n" - " b (NDArray): Second matrix or vector.\n" - " dtype (Union[FloatType, ComplexType], optional): Override the presumed input type.\n\n" - "Returns:\n" - " DistancesTensor: The dot product.\n\n" - "Equivalent to: `numpy.dot`.\n" - "Notes:\n" - " * `a` and `b` are positional-only arguments, while `dtype` is a keyword-only argument.", - }, - { - "vdot", - (PyCFunction)api_vdot, - METH_FASTCALL, - "Compute the conjugate dot product between two complex matrices.\n\n" - "Args:\n" - " a (NDArray): First complex matrix or vector.\n" - " b (NDArray): Second complex matrix or vector.\n" - " dtype (Union[ComplexType], optional): Override the presumed input type.\n\n" - "Returns:\n" - " DistancesTensor: The conjugate dot product.\n\n" - "Equivalent to: `numpy.vdot`.\n" - "Notes:\n" - " * `a` and `b` are positional-only arguments, while `dtype` is a keyword-only argument.", - }, - { - "hamming", - (PyCFunction)api_hamming, - METH_FASTCALL, - "Compute Hamming distances between two matrices.\n\n" - "Args:\n" - " a (NDArray): First binary matrix or vector.\n" - " b (NDArray): Second binary matrix or vector.\n" - " dtype (IntegralType, optional): Override the presumed input type.\n\n" - "Returns:\n" - " DistancesTensor: The Hamming distances.\n\n" - "Equivalent to: `scipy.spatial.distance.hamming`.\n" - "Notes:\n" - " * `a` and `b` are positional-only arguments, while `dtype` is a keyword-only argument.", - }, - { - "jaccard", - (PyCFunction)api_jaccard, - METH_FASTCALL, - "Compute Jaccard distances (bitwise Tanimoto) between two matrices.\n\n" - "Args:\n" - " a (NDArray): First binary matrix or vector.\n" - " b (NDArray): Second binary matrix or vector.\n" - " dtype (IntegralType, optional): Override the presumed input type.\n\n" - "Returns:\n" - " DistancesTensor: The Jaccard distances.\n\n" - "Equivalent to: `scipy.spatial.distance.jaccard`.\n" - "Notes:\n" - " * `a` and `b` are positional-only arguments, while `dtype` is a keyword-only argument.", - }, - { - "kullbackleibler", - (PyCFunction)api_kl, - METH_FASTCALL, - "Compute Kullback-Leibler divergences between two matrices.\n\n" - "Args:\n" - " a (NDArray): First floating-point matrix or vector.\n" - " b (NDArray): Second floating-point matrix or vector.\n" - " dtype (IntegralType, optional): Override the presumed input type.\n\n" - "Returns:\n" - " DistancesTensor: The Kullback-Leibler divergences distances.\n\n" - "Equivalent to: `scipy.special.kl_div`.\n" - "Notes:\n" - " * `a` and `b` are positional-only arguments, while `dtype` is a keyword-only argument.", - }, - { - "jensenshannon", - (PyCFunction)api_js, - METH_FASTCALL, - "Compute Jensen-Shannon divergences between two matrices.\n\n" - "Args:\n" - " a (NDArray): First floating-point matrix or vector.\n" - " b (NDArray): Second floating-point matrix or vector.\n" - " dtype (IntegralType, optional): Override the presumed input type.\n\n" - "Returns:\n" - " DistancesTensor: The Jensen-Shannon divergences distances.\n\n" - "Equivalent to: `scipy.spatial.distance.jensenshannon`.\n" - "Notes:\n" - " * `a` and `b` are positional-only arguments, while `dtype` is a keyword-only argument.", - }, + {"l2", (PyCFunction)api_l2, METH_FASTCALL | METH_KEYWORDS, doc_l2}, + {"l2sq", (PyCFunction)api_l2sq, METH_FASTCALL | METH_KEYWORDS, doc_l2sq}, + {"kl", (PyCFunction)api_kl, METH_FASTCALL | METH_KEYWORDS, doc_kl}, + {"js", (PyCFunction)api_js, METH_FASTCALL | METH_KEYWORDS, doc_js}, + {"cos", (PyCFunction)api_cos, METH_FASTCALL | METH_KEYWORDS, doc_cos}, + {"dot", (PyCFunction)api_dot, METH_FASTCALL | METH_KEYWORDS, doc_dot}, + {"vdot", (PyCFunction)api_vdot, METH_FASTCALL | METH_KEYWORDS, doc_vdot}, + {"hamming", (PyCFunction)api_hamming, METH_FASTCALL | METH_KEYWORDS, doc_hamming}, + {"jaccard", (PyCFunction)api_jaccard, METH_FASTCALL | METH_KEYWORDS, doc_jaccard}, + + // Aliases + {"euclidean", (PyCFunction)api_l2, METH_FASTCALL | METH_KEYWORDS, doc_l2}, + {"sqeuclidean", (PyCFunction)api_l2sq, METH_FASTCALL | METH_KEYWORDS, doc_l2sq}, + {"cosine", (PyCFunction)api_cos, METH_FASTCALL | METH_KEYWORDS, doc_cos}, + {"inner", (PyCFunction)api_dot, METH_FASTCALL | METH_KEYWORDS, doc_dot}, + {"kullbackleibler", (PyCFunction)api_kl, METH_FASTCALL | METH_KEYWORDS, doc_kl}, + {"jensenshannon", (PyCFunction)api_js, METH_FASTCALL | METH_KEYWORDS, doc_js}, // Conventional `cdist` interface for pairwise distances - { - "cdist", - (PyCFunction)api_cdist, - METH_FASTCALL | METH_KEYWORDS, - "Compute pairwise distances between two sets of input matrices.\n\n" - "Args:\n" - " a (NDArray): First matrix.\n" - " b (NDArray): Second matrix.\n" - " metric (str, optional): Distance metric to use (e.g., 'sqeuclidean', 'cosine').\n" - " threads (int, optional): Number of threads to use (default is 1).\n" - " dtype (Union[IntegralType, FloatType, ComplexType], optional): Override the presumed input type.\n" - " out_dtype (Union[FloatType, ComplexType], optional): Result type, default is 'float64'.\n\n" - "Returns:\n" - " DistancesTensor: Pairwise distances between all inputs.\n\n" - "Equivalent to: `scipy.spatial.distance.cdist`.\n" - "Notes:\n" - " * `a` and `b` are positional-only arguments.\n" - " * `metric` can be positional or keyword.\n" - " * `threads`, `dtype`, and `out_dtype` are keyword-only arguments.", - }, - - // Exposing underlying API for USearch - { - "pointer_to_euclidean", - (PyCFunction)api_l2_pointer, - METH_VARARGS, - "Retrieve the function pointer for the Euclidean distance function as an integer.", - }, - { - "pointer_to_sqeuclidean", - (PyCFunction)api_l2sq_pointer, - METH_VARARGS, - "Retrieve the function pointer for the squared Euclidean distance function as an integer.", - }, - { - "pointer_to_cosine", - (PyCFunction)api_cos_pointer, - METH_VARARGS, - "Retrieve the function pointer for the cosine distance function as an integer.", - }, - { - "pointer_to_inner", - (PyCFunction)api_dot_pointer, - METH_VARARGS, - "Retrieve the function pointer for the inner (dot) product function as an integer.", - }, - { - "pointer_to_dot", - (PyCFunction)api_dot_pointer, - METH_VARARGS, - "Retrieve the function pointer for the dot product function as an integer.", - }, - { - "pointer_to_vdot", - (PyCFunction)api_vdot_pointer, - METH_VARARGS, - "Retrieve the function pointer for the conjugate dot product function as an integer.", - }, - { - "pointer_to_kullbackleibler", - (PyCFunction)api_kl_pointer, - METH_VARARGS, - "Retrieve the function pointer for the Kullback-Leibler divergence function as an integer.", - }, - { - "pointer_to_jensenshannon", - (PyCFunction)api_js_pointer, - METH_VARARGS, - "Retrieve the function pointer for the Jensen-Shannon divergence function as an integer.", - }, + {"cdist", (PyCFunction)api_cdist, METH_FASTCALL | METH_KEYWORDS, doc_cdist}, + + // Exposing underlying API for USearch `CompiledMetric` + {"pointer_to_euclidean", (PyCFunction)api_l2_pointer, METH_O, doc_l2_pointer}, + {"pointer_to_sqeuclidean", (PyCFunction)api_l2sq_pointer, METH_O, doc_l2sq_pointer}, + {"pointer_to_cosine", (PyCFunction)api_cos_pointer, METH_O, doc_cos_pointer}, + {"pointer_to_inner", (PyCFunction)api_dot_pointer, METH_O, doc_dot_pointer}, + {"pointer_to_dot", (PyCFunction)api_dot_pointer, METH_O, doc_dot_pointer}, + {"pointer_to_vdot", (PyCFunction)api_vdot_pointer, METH_O, doc_vdot_pointer}, + {"pointer_to_kullbackleibler", (PyCFunction)api_kl_pointer, METH_O, doc_kl_pointer}, + {"pointer_to_jensenshannon", (PyCFunction)api_js_pointer, METH_O, doc_js_pointer}, // Set operations - { - "intersect", - (PyCFunction)api_intersect, - METH_FASTCALL, - "Compute the intersection of two sorted integer arrays.\n\n" - "Args:\n" - " a (NDArray): First sorted integer array.\n" - " b (NDArray): Second sorted integer array.\n\n" - "Returns:\n" - " float: The number of intersecting elements.\n\n" - "Similar to: `numpy.intersect1d`.", - }, + {"intersect", (PyCFunction)api_intersect, METH_FASTCALL, doc_intersect}, // Curved spaces - { - "bilinear", - (PyCFunction)api_bilinear, - METH_FASTCALL, - "Compute the bilinear form between two vectors given a metric tensor.\n\n" - "Args:\n" - " a (NDArray): First vector.\n" - " b (NDArray): Second vector.\n" - " metric_tensor (NDArray): The metric tensor defining the bilinear form.\n" - " dtype (FloatType, optional): Override the presumed input type.\n\n" - "Returns:\n" - " float: The bilinear form.\n\n" - "Equivalent to: `numpy.dot` with a metric tensor.\n" - "Notes:\n" - " * `a`, `b`, and `metric_tensor` are positional-only arguments, while `dtype` is keyword-only.", - }, + {"bilinear", (PyCFunction)api_bilinear, METH_FASTCALL | METH_KEYWORDS, doc_bilinear}, + {"mahalanobis", (PyCFunction)api_mahalanobis, METH_FASTCALL | METH_KEYWORDS, doc_mahalanobis}, - { - "mahalanobis", - (PyCFunction)api_mahalanobis, - METH_FASTCALL, - "Compute the Mahalanobis distance between two vectors given an inverse covariance matrix.\n\n" - "Args:\n" - " a (NDArray): First vector.\n" - " b (NDArray): Second vector.\n" - " inverse_covariance (NDArray): The inverse of the covariance matrix.\n" - " dtype (FloatType, optional): Override the presumed input type.\n\n" - "Returns:\n" - " float: The Mahalanobis distance.\n\n" - "Equivalent to: `scipy.spatial.distance.mahalanobis`.\n" - "Notes:\n" - " * `a`, `b`, and `inverse_covariance` are positional-only arguments, while `dtype` is keyword-only.", - }, + // Vectorized operations + {"fma", (PyCFunction)api_fma, METH_FASTCALL | METH_KEYWORDS, doc_fma}, + {"wsum", (PyCFunction)api_wsum, METH_FASTCALL | METH_KEYWORDS, doc_wsum}, // Sentinel {NULL, NULL, 0, NULL}}; +static char const doc_module[] = // + "Portable mixed-precision BLAS-like vector math library for x86 and ARM.\n" + "\n" + "Performance Recommendations:\n" + " - Avoid converting to NumPy arrays. SimSIMD works with any Tensor implementation\n" + " compatible with Python's Buffer Protocol, which can be coming from PyTorch, TensorFlow, etc.\n" + " - In low-latency environments - provide the output array with the `out=` parameter\n" + " to avoid expensive memory allocations on the hot path.\n" + " - On modern CPUs, if the application allows, prefer low-precision numeric types.\n" + " Whenever possible, use 'bf16' and 'f16' over 'f32'. Consider quantizing to 'i8'\n" + " and 'u8' for highest hardware compatibility and performance.\n" + " - If you are only interested in relative proximity instead of the absolute distance\n" + " prefer simpler kernels, like the Squared Euclidean distance over the Euclidean distance.\n" + " - Use row-major continuous matrix representations. Strides between rows won't have significant\n" + " impact on performance, but most modern HPC packages explicitly ban non-contiguous rows,\n" + " where the nearby matrix cells within a row have multi-byte gaps.\n" + " - The CPython runtime has a noticeable overhead for function calls, so consider batching\n" + " kernel invokations. Many kernels can compute not only 1-to-1 distance between vectors,\n" + " but also 1-to-N and N-to-N distances between two batches of vectors packed into matrices.\n" + "\n" + "Example:\n" + " >>> import simsimd\n" + " >>> simsimd.l2(a, b)\n" + "\n" + "Mixed-precision 1-to-N example with numeric types missing in NumPy, but present in PyTorch:\n" + " >>> import simsimd\n" + " >>> import torch\n" + " >>> a = torch.randn(1536, dtype=torch.bfloat16)\n" + " >>> b = torch.randn((100, 1536), dtype=torch.bfloat16)\n" + " >>> c = torch.zeros(100, dtype=torch.float32)\n" + " >>> simsimd.l2(a, b, dtype='bfloat16', out=c)\n"; + static PyModuleDef simsimd_module = { - PyModuleDef_HEAD_INIT, - .m_name = "SimSIMD", - .m_doc = "Fastest SIMD-Accelerated Vector Similarity Functions for x86 and Arm", - .m_size = -1, - .m_methods = simsimd_methods, + PyModuleDef_HEAD_INIT, .m_name = "SimSIMD", .m_doc = doc_module, .m_size = -1, .m_methods = simsimd_methods, }; PyMODINIT_FUNC PyInit_simsimd(void) { - PyObject* m; + PyObject *m; - if (PyType_Ready(&DistancesTensorType) < 0) - return NULL; + if (PyType_Ready(&DistancesTensorType) < 0) return NULL; m = PyModule_Create(&simsimd_module); - if (m == NULL) - return NULL; + if (m == NULL) return NULL; // Add version metadata { @@ -1489,7 +2056,7 @@ PyMODINIT_FUNC PyInit_simsimd(void) { } Py_INCREF(&DistancesTensorType); - if (PyModule_AddObject(m, "DistancesTensor", (PyObject*)&DistancesTensorType) < 0) { + if (PyModule_AddObject(m, "DistancesTensor", (PyObject *)&DistancesTensorType) < 0) { Py_XDECREF(&DistancesTensorType); Py_XDECREF(m); return NULL; diff --git a/scripts/bench.cxx b/scripts/bench.cxx index b438e669..e948c70f 100644 --- a/scripts/bench.cxx +++ b/scripts/bench.cxx @@ -3,6 +3,8 @@ #include // `std::memcpy` #include // `std::uniform_int_distribution` #include // `std::thread` +#include // `std::tuple` for callable introspection +#include // `` #include // `std::unordered_set` #include // `std::vector` @@ -111,6 +113,7 @@ template struct vector_gt { std::size_t size_bytes() const noexcept { return divide_round_up(dimensions_ * sizeof(scalar_t)); } + scalar_t* data() noexcept { return buffer_; } scalar_t const* data() const noexcept { return buffer_; } /** @@ -480,6 +483,99 @@ void measure_sparse(bm::State& state, metric_at metric, metric_at baseline, std: std::accumulate(results_contender.begin(), results_contender.end(), 0.0) / results_contender.size(); } +template +constexpr std::size_t function_args_count(void (*function)(function_args_at...)) { + return sizeof...(function_args_at); +} + +/** + * @brief Measures the performance of a vector-vector @b FMA function against a baseline using Google Benchmark. + * @tparam pair_at The type representing the vector pair used in the measurement. + * @tparam kernel_at The type of the kernel function (default is void). + * @param state The benchmark state object provided by Google Benchmark. + * @param kernel The kernel function to benchmark. + * @param baseline The baseline function to compare against. + * @param dimensions The number of dimensions in the vectors. + */ +template +void measure_fma(bm::State& state, kernel_at kernel, kernel_at baseline, l2_metric_at l2_metric, + std::size_t dimensions) { + + using pair_t = pair_at; + using vector_t = typename pair_at::vector_t; + + constexpr simsimd_distance_t alpha = 0.2; + constexpr simsimd_distance_t beta = 0.3; + static_assert(function_args_count(kernel_at{}) >= 6 && function_args_count(kernel_at{}) <= 7, + "Kernel must take two or three vectors."); + + auto call_baseline = [&](vector_t const& a, vector_t const& b, vector_t const& c, vector_t& d) { + if constexpr (function_args_count(kernel_at{}) == 6) { + baseline(a.data(), c.data(), a.dimensions(), alpha, beta, d.data()); + } else { + baseline(a.data(), b.data(), c.data(), a.dimensions(), alpha, beta, d.data()); + } + }; + auto call_contender = [&](vector_t const& a, vector_t const& b, vector_t const& c, vector_t& d) { + if constexpr (function_args_count(kernel_at{}) == 6) { + kernel(a.data(), c.data(), a.dimensions(), alpha, beta, d.data()); + } else { + kernel(a.data(), b.data(), c.data(), a.dimensions(), alpha, beta, d.data()); + } + }; + + // Let's average the distance results over many quads. + struct quad_t { + vector_t a, b, c, d; + }; + constexpr std::size_t quads_count = 128; + std::vector quads(quads_count); + for (std::size_t i = 0; i != quads.size(); ++i) { + auto& quad = quads[i]; + quad.a = quad.b = quad.c = quad.d = vector_t(dimensions); + quad.a.randomize(static_cast(i)); + quad.b.set(2); // Having a small constant here will help avoid overflows + quad.c.randomize(static_cast(i) + 54321u); + } + + // Initialize the output buffers for distance calculations. + vector_t baseline_d(dimensions), contender_d(dimensions), zeros(dimensions); + std::vector l2_metric_from_baseline(quads.size()); + std::vector l2_baseline_result_norm(quads.size()); + std::vector l2_contender_result_norm(quads.size()); + zeros.set(0); + double mean_delta = 0, mean_relative_error = 0; + for (std::size_t i = 0; i != quads.size(); ++i) { + quad_t& quad = quads[i]; + call_baseline(quad.a, quad.b, quad.c, baseline_d); + call_contender(quad.a, quad.b, quad.c, contender_d); + l2_metric(baseline_d.data(), contender_d.data(), dimensions, &l2_metric_from_baseline[i]); + l2_metric(baseline_d.data(), zeros.data(), dimensions, &l2_baseline_result_norm[i]); + l2_metric(contender_d.data(), zeros.data(), dimensions, &l2_contender_result_norm[i]); + + mean_delta += std::abs(l2_metric_from_baseline[i]); + mean_relative_error += + std::abs(l2_metric_from_baseline[i]) / (std::max)(l2_baseline_result_norm[i], l2_contender_result_norm[i]); + } + mean_delta /= quads_count; + mean_relative_error /= quads_count; + + // The actual benchmarking loop. + std::size_t iterations = 0; + for (auto _ : state) { + quad_t& quad = quads[iterations & (quads_count - 1)]; + call_contender(quad.a, quad.b, quad.c, quad.d); + iterations++; + } + + // Measure the mean absolute delta and relative error. + state.counters["abs_delta"] = mean_delta; + state.counters["relative_error"] = mean_relative_error; + state.counters["bytes"] = bm::Counter( + iterations * quads[0].a.size_bytes() * (function_args_count(kernel_at{}) > 6 ? 3 : 2), bm::Counter::kIsRate); + state.counters["pairs"] = bm::Counter(iterations, bm::Counter::kIsRate); +} + template void dense_(std::string name, metric_at* distance_func, metric_at* baseline_func) { using pair_t = vectors_pair_gt; @@ -490,6 +586,16 @@ void dense_(std::string name, metric_at* distance_func, metric_at* baseline_func ->Threads(default_threads); } +template +void fma_(std::string name, kernel_at* kernel_func, kernel_at* baseline_func, l2_metric_at* l2_metric_func) { + using pair_t = vectors_pair_gt; + std::string bench_name = name + "<" + std::to_string(dense_dimensions) + "d>"; + bm::RegisterBenchmark(bench_name.c_str(), measure_fma, kernel_func, + baseline_func, l2_metric_func, dense_dimensions) + ->MinTime(default_seconds) + ->Threads(default_threads); +} + template void sparse_(std::string name, metric_at* distance_func, metric_at* baseline_func) { @@ -637,81 +743,127 @@ int main(int argc, char** argv) { #endif #if SIMSIMD_TARGET_NEON - dense_("dot_f16_neon", simsimd_dot_f16_neon, simsimd_dot_f16_accurate); - dense_("cos_f16_neon", simsimd_cos_f16_neon, simsimd_cos_f16_accurate); - dense_("l2sq_f16_neon", simsimd_l2sq_f16_neon, simsimd_l2sq_f16_accurate); - dense_("kl_f16_neon", simsimd_kl_f16_neon, simsimd_kl_f16_accurate); - dense_("js_f16_neon", simsimd_js_f16_neon, simsimd_js_f16_accurate); - - dense_("dot_bf16_neon", simsimd_dot_bf16_neon, simsimd_dot_bf16_accurate); - dense_("cos_bf16_neon", simsimd_cos_bf16_neon, simsimd_cos_bf16_accurate); - dense_("l2sq_bf16_neon", simsimd_l2sq_bf16_neon, simsimd_l2sq_bf16_accurate); - dense_("dot_f32_neon", simsimd_dot_f32_neon, simsimd_dot_f32_accurate); dense_("cos_f32_neon", simsimd_cos_f32_neon, simsimd_cos_f32_accurate); dense_("l2sq_f32_neon", simsimd_l2sq_f32_neon, simsimd_l2sq_f32_accurate); + dense_("l2_f32_neon", simsimd_l2_f32_neon, simsimd_l2_f32_accurate); dense_("kl_f32_neon", simsimd_kl_f32_neon, simsimd_kl_f32_accurate); dense_("js_f32_neon", simsimd_js_f32_neon, simsimd_js_f32_accurate); dense_("cos_f64_neon", simsimd_cos_f64_neon, simsimd_cos_f64_serial); dense_("l2sq_f64_neon", simsimd_l2sq_f64_neon, simsimd_l2sq_f64_serial); + dense_("l2_f64_neon", simsimd_l2_f64_neon, simsimd_l2_f64_serial); dense_("cos_i8_neon", simsimd_cos_i8_neon, simsimd_cos_i8_serial); dense_("l2sq_i8_neon", simsimd_l2sq_i8_neon, simsimd_l2sq_i8_serial); + dense_("l2_i8_neon", simsimd_l2_i8_neon, simsimd_l2_i8_serial); dense_("dot_i8_neon", simsimd_dot_i8_neon, simsimd_dot_i8_serial); dense_("cos_u8_neon", simsimd_cos_u8_neon, simsimd_cos_u8_serial); dense_("l2sq_u8_neon", simsimd_l2sq_u8_neon, simsimd_l2sq_u8_serial); + dense_("l2_u8_neon", simsimd_l2_u8_neon, simsimd_l2_u8_serial); dense_("dot_u8_neon", simsimd_dot_u8_neon, simsimd_dot_u8_serial); dense_("hamming_b8_neon", simsimd_hamming_b8_neon, simsimd_hamming_b8_serial); dense_("jaccard_b8_neon", simsimd_jaccard_b8_neon, simsimd_jaccard_b8_serial); - dense_("dot_bf16c_neon", simsimd_dot_bf16c_neon, simsimd_dot_bf16c_accurate); - dense_("vdot_bf16c_neon", simsimd_vdot_bf16c_neon, simsimd_vdot_bf16c_accurate); - dense_("dot_f16c_neon", simsimd_dot_f16c_neon, simsimd_dot_f16c_accurate); - dense_("vdot_f16c_neon", simsimd_vdot_f16c_neon, simsimd_vdot_f16c_accurate); dense_("dot_f32c_neon", simsimd_dot_f32c_neon, simsimd_dot_f32c_accurate); dense_("vdot_f32c_neon", simsimd_vdot_f32c_neon, simsimd_vdot_f32c_accurate); curved_("bilinear_f32_neon", simsimd_bilinear_f32_neon, simsimd_bilinear_f32_accurate); curved_("mahalanobis_f32_neon", simsimd_mahalanobis_f32_neon, simsimd_mahalanobis_f32_accurate); + + sparse_("intersect_u16_neon", simsimd_intersect_u16_neon, simsimd_intersect_u16_accurate); + sparse_("intersect_u32_neon", simsimd_intersect_u32_neon, simsimd_intersect_u32_accurate); + + fma_("fma_f32_neon", simsimd_fma_f32_neon, simsimd_fma_f32_accurate, simsimd_l2_f32_accurate); + fma_("wsum_f32_neon", simsimd_wsum_f32_neon, simsimd_wsum_f32_accurate, simsimd_l2_f32_accurate); + fma_("fma_f32_serial", simsimd_fma_f32_serial, simsimd_fma_f32_accurate, simsimd_l2_f32_accurate); + fma_("wsum_f32_serial", simsimd_wsum_f32_serial, simsimd_wsum_f32_accurate, simsimd_l2_f32_accurate); + +#endif + +#if SIMSIMD_TARGET_NEON_F16 + dense_("dot_f16c_neon", simsimd_dot_f16c_neon, simsimd_dot_f16c_accurate); + dense_("vdot_f16c_neon", simsimd_vdot_f16c_neon, simsimd_vdot_f16c_accurate); + + dense_("dot_f16_neon", simsimd_dot_f16_neon, simsimd_dot_f16_accurate); + dense_("cos_f16_neon", simsimd_cos_f16_neon, simsimd_cos_f16_accurate); + dense_("l2sq_f16_neon", simsimd_l2sq_f16_neon, simsimd_l2sq_f16_accurate); + dense_("l2_f16_neon", simsimd_l2_f16_neon, simsimd_l2sq_f16_accurate); + dense_("kl_f16_neon", simsimd_kl_f16_neon, simsimd_kl_f16_accurate); + dense_("js_f16_neon", simsimd_js_f16_neon, simsimd_js_f16_accurate); + curved_("bilinear_f16_neon", simsimd_bilinear_f16_neon, simsimd_bilinear_f16_accurate); curved_("mahalanobis_f16_neon", simsimd_mahalanobis_f16_neon, simsimd_mahalanobis_f16_accurate); + + fma_("fma_f16_neon", simsimd_fma_f16_neon, simsimd_fma_f16_accurate, simsimd_l2_f16_accurate); + fma_("wsum_f16_neon", simsimd_wsum_f16_neon, simsimd_wsum_f16_accurate, simsimd_l2_f16_accurate); + + // FMA kernels for `u8` on NEON use `f16` arithmetic + fma_("fma_u8_neon", simsimd_fma_u8_neon, simsimd_fma_u8_accurate, simsimd_l2_u8_serial); + fma_("wsum_u8_neon", simsimd_wsum_u8_neon, simsimd_wsum_u8_accurate, simsimd_l2_u8_serial); + fma_("fma_i8_neon", simsimd_fma_i8_neon, simsimd_fma_i8_accurate, simsimd_l2_i8_serial); + fma_("wsum_i8_neon", simsimd_wsum_i8_neon, simsimd_wsum_i8_accurate, simsimd_l2_i8_serial); +#endif + +#if SIMSIMD_TARGET_NEON_BF16 + dense_("dot_bf16c_neon", simsimd_dot_bf16c_neon, simsimd_dot_bf16c_accurate); + dense_("vdot_bf16c_neon", simsimd_vdot_bf16c_neon, simsimd_vdot_bf16c_accurate); + + dense_("dot_bf16_neon", simsimd_dot_bf16_neon, simsimd_dot_bf16_accurate); + dense_("cos_bf16_neon", simsimd_cos_bf16_neon, simsimd_cos_bf16_accurate); + dense_("l2sq_bf16_neon", simsimd_l2sq_bf16_neon, simsimd_l2sq_bf16_accurate); + dense_("l2_bf16_neon", simsimd_l2_bf16_neon, simsimd_l2_bf16_accurate); + curved_("bilinear_bf16_neon", simsimd_bilinear_bf16_neon, simsimd_bilinear_bf16_accurate); curved_("mahalanobis_bf16_neon", simsimd_mahalanobis_bf16_neon, simsimd_mahalanobis_bf16_accurate); - - sparse_("intersect_u16_neon", simsimd_intersect_u16_neon, simsimd_intersect_u16_accurate); - sparse_("intersect_u32_neon", simsimd_intersect_u32_neon, simsimd_intersect_u32_accurate); #endif #if SIMSIMD_TARGET_SVE dense_("dot_f16_sve", simsimd_dot_f16_sve, simsimd_dot_f16_accurate); dense_("cos_f16_sve", simsimd_cos_f16_sve, simsimd_cos_f16_accurate); dense_("l2sq_f16_sve", simsimd_l2sq_f16_sve, simsimd_l2sq_f16_accurate); + dense_("l2_f16_sve", simsimd_l2_f16_sve, simsimd_l2_f16_accurate); dense_("cos_bf16_sve", simsimd_cos_bf16_sve, simsimd_cos_bf16_accurate); dense_("l2sq_bf16_sve", simsimd_l2sq_bf16_sve, simsimd_l2sq_bf16_accurate); + dense_("l2_bf16_sve", simsimd_l2_bf16_sve, simsimd_l2_bf16_accurate); dense_("dot_f32_sve", simsimd_dot_f32_sve, simsimd_dot_f32_accurate); dense_("cos_f32_sve", simsimd_cos_f32_sve, simsimd_cos_f32_accurate); dense_("l2sq_f32_sve", simsimd_l2sq_f32_sve, simsimd_l2sq_f32_accurate); + dense_("l2_f32_sve", simsimd_l2_f32_sve, simsimd_l2_f32_accurate); dense_("dot_f64_sve", simsimd_dot_f64_sve, simsimd_dot_f64_serial); dense_("cos_f64_sve", simsimd_cos_f64_sve, simsimd_cos_f64_serial); dense_("l2sq_f64_sve", simsimd_l2sq_f64_sve, simsimd_l2sq_f64_serial); + dense_("l2_f64_sve", simsimd_l2_f64_sve, simsimd_l2_f64_serial); dense_("hamming_b8_sve", simsimd_hamming_b8_sve, simsimd_hamming_b8_serial); dense_("jaccard_b8_sve", simsimd_jaccard_b8_sve, simsimd_jaccard_b8_serial); - dense_("dot_f16c_sve", simsimd_dot_f16c_sve, simsimd_dot_f16c_accurate); - dense_("vdot_f16c_sve", simsimd_vdot_f16c_sve, simsimd_vdot_f16c_accurate); dense_("dot_f32c_sve", simsimd_dot_f32c_sve, simsimd_dot_f32c_accurate); dense_("vdot_f32c_sve", simsimd_vdot_f32c_sve, simsimd_vdot_f32c_accurate); dense_("dot_f64c_sve", simsimd_dot_f64c_sve, simsimd_dot_f64c_serial); dense_("vdot_f64c_sve", simsimd_vdot_f64c_sve, simsimd_vdot_f64c_serial); #endif +#if SIMSIMD_TARGET_SVE_F16 + dense_("dot_f16_sve", simsimd_dot_f16_sve, simsimd_dot_f16_accurate); + dense_("cos_f16_sve", simsimd_cos_f16_sve, simsimd_cos_f16_accurate); + dense_("l2sq_f16_sve", simsimd_l2sq_f16_sve, simsimd_l2sq_f16_accurate); + dense_("l2_f16_sve", simsimd_l2_f16_sve, simsimd_l2sq_f16_accurate); + dense_("dot_f16c_sve", simsimd_dot_f16c_sve, simsimd_dot_f16c_accurate); + dense_("vdot_f16c_sve", simsimd_vdot_f16c_sve, simsimd_vdot_f16c_accurate); +#endif + +#if SIMSIMD_TARGET_SVE_BF16 + dense_("cos_bf16_sve", simsimd_cos_bf16_sve, simsimd_cos_bf16_accurate); + dense_("l2sq_bf16_sve", simsimd_l2sq_bf16_sve, simsimd_l2sq_bf16_accurate); + dense_("l2_bf16_sve", simsimd_l2_bf16_sve, simsimd_l2_bf16_accurate); +#endif + #if SIMSIMD_TARGET_SVE2 sparse_("intersect_u16_sve2", simsimd_intersect_u16_sve2, simsimd_intersect_u16_accurate); sparse_("intersect_u32_sve2", simsimd_intersect_u32_sve2, simsimd_intersect_u32_accurate); @@ -721,19 +873,23 @@ int main(int argc, char** argv) { dense_("dot_f16_haswell", simsimd_dot_f16_haswell, simsimd_dot_f16_accurate); dense_("cos_f16_haswell", simsimd_cos_f16_haswell, simsimd_cos_f16_accurate); dense_("l2sq_f16_haswell", simsimd_l2sq_f16_haswell, simsimd_l2sq_f16_accurate); + dense_("l2_f16_haswell", simsimd_l2_f16_haswell, simsimd_l2_f16_accurate); dense_("kl_f16_haswell", simsimd_kl_f16_haswell, simsimd_kl_f16_accurate); dense_("js_f16_haswell", simsimd_js_f16_haswell, simsimd_js_f16_accurate); dense_("dot_bf16_haswell", simsimd_dot_bf16_haswell, simsimd_dot_bf16_accurate); dense_("cos_bf16_haswell", simsimd_cos_bf16_haswell, simsimd_cos_bf16_accurate); dense_("l2sq_bf16_haswell", simsimd_l2sq_bf16_haswell, simsimd_l2sq_bf16_accurate); + dense_("l2_bf16_haswell", simsimd_l2_bf16_haswell, simsimd_l2_bf16_accurate); dense_("cos_i8_haswell", simsimd_cos_i8_haswell, simsimd_cos_i8_serial); dense_("l2sq_i8_haswell", simsimd_l2sq_i8_haswell, simsimd_l2sq_i8_serial); + dense_("l2_i8_haswell", simsimd_l2_i8_haswell, simsimd_l2_i8_serial); dense_("dot_i8_haswell", simsimd_dot_i8_haswell, simsimd_dot_i8_serial); dense_("cos_u8_haswell", simsimd_cos_u8_haswell, simsimd_cos_u8_serial); dense_("l2sq_u8_haswell", simsimd_l2sq_u8_haswell, simsimd_l2sq_u8_serial); + dense_("l2_u8_haswell", simsimd_l2_u8_haswell, simsimd_l2_u8_serial); dense_("dot_u8_haswell", simsimd_dot_u8_haswell, simsimd_dot_u8_serial); dense_("hamming_b8_haswell", simsimd_hamming_b8_haswell, simsimd_hamming_b8_serial); @@ -749,12 +905,22 @@ int main(int argc, char** argv) { curved_("bilinear_bf16_haswell", simsimd_bilinear_bf16_haswell, simsimd_bilinear_bf16_accurate); curved_("mahalanobis_bf16_haswell", simsimd_mahalanobis_bf16_haswell, simsimd_mahalanobis_bf16_accurate); + fma_("fma_f64_haswell", simsimd_fma_f64_haswell, simsimd_fma_f64_serial, simsimd_l2_f64_serial); + fma_("wsum_f64_haswell", simsimd_wsum_f64_haswell, simsimd_wsum_f64_serial, simsimd_l2_f64_serial); + fma_("fma_f32_haswell", simsimd_fma_f32_haswell, simsimd_fma_f32_accurate, simsimd_l2_f32_accurate); + fma_("wsum_f32_haswell", simsimd_wsum_f32_haswell, simsimd_wsum_f32_accurate, simsimd_l2_f32_accurate); + fma_("fma_f16_haswell", simsimd_fma_f16_haswell, simsimd_fma_f16_accurate, simsimd_l2_f16_accurate); + fma_("wsum_f16_haswell", simsimd_wsum_f16_haswell, simsimd_wsum_f16_accurate, simsimd_l2_f16_accurate); + fma_("fma_bf16_haswell", simsimd_fma_bf16_haswell, simsimd_fma_bf16_accurate, simsimd_l2_bf16_accurate); + fma_("wsum_bf16_haswell", simsimd_wsum_bf16_haswell, simsimd_wsum_bf16_accurate, simsimd_l2_bf16_accurate); + #endif #if SIMSIMD_TARGET_GENOA dense_("dot_bf16_genoa", simsimd_dot_bf16_genoa, simsimd_dot_bf16_accurate); dense_("cos_bf16_genoa", simsimd_cos_bf16_genoa, simsimd_cos_bf16_accurate); dense_("l2sq_bf16_genoa", simsimd_l2sq_bf16_genoa, simsimd_l2sq_bf16_accurate); + dense_("l2_bf16_genoa", simsimd_l2_bf16_genoa, simsimd_l2_bf16_accurate); dense_("dot_bf16c_genoa", simsimd_dot_bf16c_genoa, simsimd_dot_bf16c_accurate); dense_("vdot_bf16c_genoa", simsimd_vdot_bf16c_genoa, simsimd_vdot_bf16c_accurate); @@ -767,25 +933,34 @@ int main(int argc, char** argv) { dense_("dot_f16_sapphire", simsimd_dot_f16_sapphire, simsimd_dot_f16_accurate); dense_("cos_f16_sapphire", simsimd_cos_f16_sapphire, simsimd_cos_f16_accurate); dense_("l2sq_f16_sapphire", simsimd_l2sq_f16_sapphire, simsimd_l2sq_f16_accurate); + dense_("l2_f16_sapphire", simsimd_l2_f16_sapphire, simsimd_l2_f16_accurate); dense_("kl_f16_sapphire", simsimd_kl_f16_sapphire, simsimd_kl_f16_accurate); dense_("js_f16_sapphire", simsimd_js_f16_sapphire, simsimd_js_f16_accurate); dense_("dot_f16c_sapphire", simsimd_dot_f16c_sapphire, simsimd_dot_f16c_accurate); dense_("vdot_f16c_sapphire", simsimd_vdot_f16c_sapphire, simsimd_vdot_f16c_accurate); + + fma_("fma_u8_sapphire", simsimd_fma_u8_sapphire, simsimd_fma_u8_accurate, simsimd_l2_u8_serial); + fma_("wsum_u8_sapphire", simsimd_wsum_u8_sapphire, simsimd_wsum_u8_accurate, simsimd_l2_u8_serial); + fma_("fma_i8_sapphire", simsimd_fma_i8_sapphire, simsimd_fma_i8_accurate, simsimd_l2_i8_serial); + fma_("wsum_i8_sapphire", simsimd_wsum_i8_sapphire, simsimd_wsum_i8_accurate, simsimd_l2_i8_serial); #endif #if SIMSIMD_TARGET_ICE dense_("cos_i8_ice", simsimd_cos_i8_ice, simsimd_cos_i8_serial); dense_("l2sq_i8_ice", simsimd_l2sq_i8_ice, simsimd_l2sq_i8_serial); + dense_("l2_i8_ice", simsimd_l2_i8_ice, simsimd_l2_i8_serial); dense_("dot_i8_ice", simsimd_dot_i8_ice, simsimd_dot_i8_serial); dense_("cos_u8_ice", simsimd_cos_u8_ice, simsimd_cos_u8_serial); dense_("l2sq_u8_ice", simsimd_l2sq_u8_ice, simsimd_l2sq_u8_serial); + dense_("l2_u8_ice", simsimd_l2_u8_ice, simsimd_l2_u8_serial); dense_("dot_u8_ice", simsimd_dot_u8_ice, simsimd_dot_u8_serial); dense_("dot_f64_skylake", simsimd_dot_f64_skylake, simsimd_dot_f64_serial); dense_("cos_f64_skylake", simsimd_cos_f64_skylake, simsimd_cos_f64_serial); dense_("l2sq_f64_skylake", simsimd_l2sq_f64_skylake, simsimd_l2sq_f64_serial); + dense_("l2_f64_skylake", simsimd_l2_f64_skylake, simsimd_l2_f64_serial); dense_("hamming_b8_ice", simsimd_hamming_b8_ice, simsimd_hamming_b8_serial); dense_("jaccard_b8_ice", simsimd_jaccard_b8_ice, simsimd_jaccard_b8_serial); @@ -798,6 +973,7 @@ int main(int argc, char** argv) { dense_("dot_f32_skylake", simsimd_dot_f32_skylake, simsimd_dot_f32_accurate); dense_("cos_f32_skylake", simsimd_cos_f32_skylake, simsimd_cos_f32_accurate); dense_("l2sq_f32_skylake", simsimd_l2sq_f32_skylake, simsimd_l2sq_f32_accurate); + dense_("l2_f32_skylake", simsimd_l2_f32_skylake, simsimd_l2_f32_accurate); dense_("kl_f32_skylake", simsimd_kl_f32_skylake, simsimd_kl_f32_accurate); dense_("js_f32_skylake", simsimd_js_f32_skylake, simsimd_js_f32_accurate); @@ -805,6 +981,14 @@ int main(int argc, char** argv) { dense_("vdot_f32c_skylake", simsimd_vdot_f32c_skylake, simsimd_vdot_f32c_accurate); dense_("dot_f64c_skylake", simsimd_dot_f64c_skylake, simsimd_dot_f64c_serial); dense_("vdot_f64c_skylake", simsimd_vdot_f64c_skylake, simsimd_vdot_f64c_serial); + + fma_("fma_f64_skylake", simsimd_fma_f64_skylake, simsimd_fma_f64_serial, simsimd_l2_f64_serial); + fma_("wsum_f64_skylake", simsimd_wsum_f64_skylake, simsimd_wsum_f64_serial, simsimd_l2_f64_serial); + fma_("fma_f32_skylake", simsimd_fma_f32_skylake, simsimd_fma_f32_accurate, simsimd_l2_f32_accurate); + fma_("wsum_f32_skylake", simsimd_wsum_f32_skylake, simsimd_wsum_f32_accurate, simsimd_l2_f32_accurate); + fma_("fma_bf16_skylake", simsimd_fma_bf16_skylake, simsimd_fma_bf16_accurate, simsimd_l2_bf16_accurate); + fma_("wsum_bf16_skylake", simsimd_wsum_bf16_skylake, simsimd_wsum_bf16_accurate, simsimd_l2_bf16_accurate); + #endif sparse_("intersect_u16_serial", simsimd_intersect_u16_serial, simsimd_intersect_u16_accurate); @@ -824,31 +1008,37 @@ int main(int argc, char** argv) { dense_("dot_bf16_serial", simsimd_dot_bf16_serial, simsimd_dot_bf16_accurate); dense_("cos_bf16_serial", simsimd_cos_bf16_serial, simsimd_cos_bf16_accurate); dense_("l2sq_bf16_serial", simsimd_l2sq_bf16_serial, simsimd_l2sq_bf16_accurate); + dense_("l2_bf16_serial", simsimd_l2_bf16_serial, simsimd_l2_bf16_accurate); dense_("kl_bf16_serial", simsimd_kl_bf16_serial, simsimd_kl_bf16_accurate); dense_("js_bf16_serial", simsimd_js_bf16_serial, simsimd_js_bf16_accurate); dense_("dot_f16_serial", simsimd_dot_f16_serial, simsimd_dot_f16_accurate); dense_("cos_f16_serial", simsimd_cos_f16_serial, simsimd_cos_f16_accurate); dense_("l2sq_f16_serial", simsimd_l2sq_f16_serial, simsimd_l2sq_f16_accurate); + dense_("l2_f16_serial", simsimd_l2_f16_serial, simsimd_l2_f16_accurate); dense_("kl_f16_serial", simsimd_kl_f16_serial, simsimd_kl_f16_accurate); dense_("js_f16_serial", simsimd_js_f16_serial, simsimd_js_f16_accurate); dense_("dot_f32_serial", simsimd_dot_f32_serial, simsimd_dot_f32_accurate); dense_("cos_f32_serial", simsimd_cos_f32_serial, simsimd_cos_f32_accurate); dense_("l2sq_f32_serial", simsimd_l2sq_f32_serial, simsimd_l2sq_f32_accurate); + dense_("l2_f32_serial", simsimd_l2_f32_serial, simsimd_l2_f32_accurate); dense_("kl_f32_serial", simsimd_kl_f32_serial, simsimd_kl_f32_accurate); dense_("js_f32_serial", simsimd_js_f32_serial, simsimd_js_f32_accurate); dense_("dot_f64_serial", simsimd_dot_f64_serial, simsimd_dot_f64_serial); dense_("cos_f64_serial", simsimd_cos_f64_serial, simsimd_cos_f64_serial); dense_("l2sq_f64_serial", simsimd_l2sq_f64_serial, simsimd_l2sq_f64_serial); + dense_("l2_f64_serial", simsimd_l2_f64_serial, simsimd_l2_f64_serial); dense_("cos_i8_serial", simsimd_cos_i8_serial, simsimd_cos_i8_serial); dense_("l2sq_i8_serial", simsimd_l2sq_i8_serial, simsimd_l2sq_i8_serial); + dense_("l2_i8_serial", simsimd_l2_i8_serial, simsimd_l2_i8_serial); dense_("dot_i8_serial", simsimd_dot_i8_serial, simsimd_dot_i8_serial); dense_("cos_u8_serial", simsimd_cos_u8_serial, simsimd_cos_u8_serial); dense_("l2sq_u8_serial", simsimd_l2sq_u8_serial, simsimd_l2sq_u8_serial); + dense_("l2_u8_serial", simsimd_l2_u8_serial, simsimd_l2_u8_serial); dense_("dot_u8_serial", simsimd_dot_u8_serial, simsimd_dot_u8_serial); dense_("dot_f64c_serial", simsimd_dot_f64c_serial, simsimd_dot_f64c_serial); @@ -866,6 +1056,13 @@ int main(int argc, char** argv) { dense_("hamming_b8_serial", simsimd_hamming_b8_serial, simsimd_hamming_b8_serial); dense_("jaccard_b8_serial", simsimd_jaccard_b8_serial, simsimd_jaccard_b8_serial); + fma_("fma_f16_serial", simsimd_fma_f16_serial, simsimd_fma_f16_accurate, simsimd_l2_f16_accurate); + fma_("wsum_f16_serial", simsimd_wsum_f16_serial, simsimd_wsum_f16_accurate, simsimd_l2_f16_accurate); + fma_("fma_u8_serial", simsimd_fma_u8_serial, simsimd_fma_u8_accurate, simsimd_l2_u8_serial); + fma_("wsum_u8_serial", simsimd_wsum_u8_serial, simsimd_wsum_u8_accurate, simsimd_l2_u8_serial); + fma_("fma_i8_serial", simsimd_fma_i8_serial, simsimd_fma_i8_accurate, simsimd_l2_i8_serial); + fma_("wsum_i8_serial", simsimd_wsum_i8_serial, simsimd_wsum_i8_accurate, simsimd_l2_i8_serial); + bm::RunSpecifiedBenchmarks(); bm::Shutdown(); return 0; diff --git a/scripts/test.py b/scripts/test.py index 2cff738c..85b23e67 100644 --- a/scripts/test.py +++ b/scripts/test.py @@ -59,10 +59,60 @@ import numpy as np numpy_available = True + + baseline_inner = np.inner + baseline_intersect = lambda x, y: len(np.intersect1d(x, y)) + baseline_bilinear = lambda x, y, z: x @ z @ y + + def baseline_fma(x, y, z, alpha, beta): + xy_scaled = np.multiply((alpha * x), y) + z_scaled = beta * z + r = xy_scaled + z_scaled + if np.issubdtype(x.dtype, np.integer): + r = np.round(r) + #! We need non-overflowing saturating addition for small integers, that NumPy lacks: + #! https://stackoverflow.com/questions/29611185/avoid-overflow-when-adding-numpy-arrays + if x.dtype == np.uint8: + r = np.clip(r, 0, 255, out=r) + elif x.dtype == np.int8: + r = np.clip(r, -128, 127, out=r) + return r.astype(x.dtype) + + def baseline_wsum(x, y, alpha, beta): + x_scaled = alpha * x + y_scaled = beta * y + r = x_scaled + y_scaled + if np.issubdtype(x.dtype, np.integer): + r = np.round(r) + #! We need non-overflowing saturating addition for small integers, that NumPy lacks: + #! https://stackoverflow.com/questions/29611185/avoid-overflow-when-adding-numpy-arrays + if x.dtype == np.uint8: + r = np.clip(r, 0, 255, out=r) + elif x.dtype == np.int8: + r = np.clip(r, -128, 127, out=r) + return r.astype(x.dtype) + except: # NumPy is not installed, most tests will be skipped numpy_available = False + baseline_inner = lambda x, y: sum(x[i] * y[i] for i in range(len(x))) + baseline_intersect = lambda x, y: len(set(x).intersection(y)) + + def baseline_bilinear(x, y, z): + result = 0 + for i in range(len(x)): + for j in range(len(y)): + result += x[i] * z[i][j] * y[j] + return result + + def baseline_fma(x, y, z, alpha, beta): + return [(alpha * xi) * yi + beta * zi for xi, yi, zi in zip(x, y, z)] + + def baseline_wsum(x, y, alpha, beta): + return [(alpha * xi) + beta * yi for xi, yi in zip(x, y)] + + # At the time of Python 3.12, SciPy doesn't support 32-bit Windows on any CPU, # or 64-bit Windows on Arm. It also doesn't support `musllinux` distributions, # like CentOS, RedHat OS, and many others. @@ -71,15 +121,12 @@ scipy_available = True - baseline_inner = np.inner baseline_euclidean = lambda x, y: np.array(spd.euclidean(x, y)) #! SciPy returns a scalar baseline_sqeuclidean = spd.sqeuclidean baseline_cosine = spd.cosine baseline_jensenshannon = lambda x, y: spd.jensenshannon(x, y) ** 2 baseline_hamming = lambda x, y: spd.hamming(x, y) * len(x) baseline_jaccard = spd.jaccard - baseline_intersect = lambda x, y: len(np.intersect1d(x, y)) - baseline_bilinear = lambda x, y, z: x @ z @ y def baseline_mahalanobis(x, y, z): # If there was an error, or the value is NaN, we skip the test. @@ -95,13 +142,11 @@ def baseline_mahalanobis(x, y, z): # SciPy is not installed, some tests will be skipped scipy_available = False - baseline_inner = lambda x, y: np.inner(x, y) baseline_cosine = lambda x, y: 1.0 - np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y)) baseline_euclidean = lambda x, y: np.array([np.sqrt(np.sum((x - y) ** 2))]) baseline_sqeuclidean = lambda x, y: np.sum((x - y) ** 2) baseline_jensenshannon = lambda p, q: (np.sum((np.sqrt(p) - np.sqrt(q)) ** 2)) / 2 baseline_hamming = lambda x, y: np.logical_xor(x, y).sum() - baseline_bilinear = lambda x, y, z: x @ z @ y def baseline_mahalanobis(x, y, z): diff = x - y @@ -312,10 +357,11 @@ def collect_errors( - TODO: How much faster is SimSIMD than the baseline kernel? - TODO: How much faster is SimSIMD than the accurate kernel? """ - absolute_baseline_error = np.abs(baseline_result - accurate_result) - relative_baseline_error = absolute_baseline_error / np.abs(accurate_result) - absolute_simsimd_error = np.abs(simsimd_result - accurate_result) - relative_simsimd_error = absolute_simsimd_error / np.abs(accurate_result) + eps = np.finfo(accurate_result.dtype).resolution + absolute_baseline_error = np.max(np.abs(baseline_result - accurate_result)) + relative_baseline_error = np.max(np.abs(baseline_result - accurate_result) / (np.abs(accurate_result) + eps)) + absolute_simsimd_error = np.max(np.abs(simsimd_result - accurate_result)) + relative_simsimd_error = np.max(np.abs(simsimd_result - accurate_result) / (np.abs(accurate_result) + eps)) stats["metric"].append(metric) stats["ndim"].append(ndim) @@ -403,6 +449,10 @@ def name_to_kernels(name: str): return baseline_hamming, simd.hamming elif name == "intersect": return baseline_intersect, simd.intersect + elif name == "fma": + return baseline_fma, simd.fma + elif name == "wsum": + return baseline_wsum, simd.wsum else: raise ValueError(f"Unknown kernel name: {name}") @@ -489,9 +539,12 @@ def test_capabilities_list(): simd.disable_capability("neon") -def to_array(x): +def to_array(x, dtype=None): if numpy_available: - return np.array(x) + y = np.array(x) + if dtype is not None: + y = y.astype(dtype) + return y @pytest.mark.skipif(not numpy_available, reason="NumPy is not installed") @@ -516,12 +569,15 @@ def to_array(x): (simd.bilinear, TypeError, (to_array([1.0]),), {}), # Missing second vector and metric tensor # Test passing too many arguments to a method (simd.cosine, TypeError, (to_array([1.0]), to_array([1.0]), to_array([1.0])), {}), # Too many arguments - # Too many arguments (simd.cdist, TypeError, (to_array([[1.0]]), to_array([[1.0]]), "cos", "dos"), {}), # Too many arguments # Same argument as both positional and keyword (simd.cdist, TypeError, (to_array([[1.0]]), to_array([[1.0]]), "cos"), {"metric": "cos"}), # Applying real metric to complex numbers - missing kernel (simd.cosine, LookupError, (to_array([1 + 2j]), to_array([1 + 2j])), {}), + # Test incompatible vectors for cosine + (simd.cosine, ValueError, (to_array([1.0]), to_array([1.0, 2.0])), {}), # Different number of dimensions + (simd.cosine, TypeError, (to_array([1.0]), to_array([1], "uint32")), {}), # Floats and integers + (simd.cosine, TypeError, (to_array([1]), to_array([1], "float16")), {}), # Different floats ], ) def test_invalid_argument_handling(function, expected_error, args, kwargs): @@ -825,7 +881,7 @@ def test_cosine_zero_vector(ndim, dtype, capability): assert np.all(result >= 0), f"Negative result for cosine distance" -@pytest.mark.skip() # TODO: https://github.com/ashvardanian/SimSIMD/issues/206 +@pytest.mark.skip(reason="Lacks overflow protection: https://github.com/ashvardanian/SimSIMD/issues/206") # TODO @pytest.mark.skipif(not numpy_available, reason="NumPy is not installed") @pytest.mark.repeat(50) @pytest.mark.parametrize("ndim", [11, 97, 1536]) @@ -859,7 +915,7 @@ def test_overflow(ndim, dtype, metric, capability): collect_warnings(f"Arbitrary error raised in SciPy: {e}", stats_fixture) -@pytest.mark.skip() # TODO: https://github.com/ashvardanian/SimSIMD/issues/206 +@pytest.mark.skip(reason="Lacks overflow protection: https://github.com/ashvardanian/SimSIMD/issues/206") # TODO @pytest.mark.skipif(not numpy_available, reason="NumPy is not installed") @pytest.mark.repeat(50) @pytest.mark.parametrize("ndim", [131072, 262144]) @@ -975,6 +1031,123 @@ def test_intersect(dtype, first_length_bound, second_length_bound, capability): assert round(float(expected)) == round(float(result)), f"Missing {np.intersect1d(a, b)} from {a} and {b}" +@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed") +@pytest.mark.repeat(50) +@pytest.mark.parametrize("ndim", [11, 97, 1536]) +@pytest.mark.parametrize("dtype", ["float64", "float32", "float16", "int8", "uint8"]) +@pytest.mark.parametrize("kernel", ["fma"]) +@pytest.mark.parametrize("capability", possible_capabilities) +def test_fma(ndim, dtype, kernel, capability, stats_fixture): + """""" + + if dtype == "float16" and is_running_under_qemu(): + pytest.skip("Testing low-precision math isn't reliable in QEMU") + + np.random.seed() + if np.issubdtype(np.dtype(dtype), np.integer): + dtype_info = np.iinfo(np.dtype(dtype)) + a = np.random.randint(dtype_info.min, dtype_info.max, size=ndim, dtype=dtype) + b = np.random.randint(dtype_info.min, dtype_info.max, size=ndim, dtype=dtype) + c = np.random.randint(dtype_info.min, dtype_info.max, size=ndim, dtype=dtype) + alpha = abs(np.random.randn(1).astype(np.float64).item()) / 512 + beta = abs(np.random.randn(1).astype(np.float64).item()) / 3 + atol = 1 # ? Allow at most one rounding error per vector + rtol = 0 + else: + a = np.random.randn(ndim).astype(dtype) + b = np.random.randn(ndim).astype(dtype) + c = np.random.randn(ndim).astype(dtype) + alpha = np.random.randn(1).astype(np.float64).item() + beta = np.random.randn(1).astype(np.float64).item() + atol = SIMSIMD_ATOL + rtol = SIMSIMD_RTOL + + keep_one_capability(capability) + baseline_kernel, simd_kernel = name_to_kernels(kernel) + + accurate_dt, accurate = profile( + baseline_kernel, + a.astype(np.float64), + b.astype(np.float64), + c.astype(np.float64), + alpha=alpha, + beta=beta, + ) + expected_dt, expected = profile(baseline_kernel, a, b, c, alpha=alpha, beta=beta) + result_dt, result = profile(simd_kernel, a, b, c, alpha=alpha, beta=beta) + + np.testing.assert_allclose(result, expected.astype(np.float64), atol=atol, rtol=rtol) + collect_errors( + kernel, + ndim, + dtype, + accurate, + accurate_dt, + expected, + expected_dt, + result, + result_dt, + stats_fixture, + ) + + +@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed") +@pytest.mark.repeat(50) +@pytest.mark.parametrize("ndim", [11, 97, 1536]) +@pytest.mark.parametrize("dtype", ["float64", "float32", "float16", "int8", "uint8"]) +@pytest.mark.parametrize("kernel", ["wsum"]) +@pytest.mark.parametrize("capability", possible_capabilities) +def test_wsum(ndim, dtype, kernel, capability, stats_fixture): + """""" + + if dtype == "float16" and is_running_under_qemu(): + pytest.skip("Testing low-precision math isn't reliable in QEMU") + + np.random.seed() + if np.issubdtype(np.dtype(dtype), np.integer): + dtype_info = np.iinfo(np.dtype(dtype)) + a = np.random.randint(dtype_info.min, dtype_info.max, size=ndim, dtype=dtype) + b = np.random.randint(dtype_info.min, dtype_info.max, size=ndim, dtype=dtype) + alpha = abs(np.random.randn(1).astype(np.float64).item()) / 2 + beta = abs(np.random.randn(1).astype(np.float64).item()) / 2 + atol = 1 # ? Allow at most one rounding error per vector + rtol = 0 + else: + a = np.random.randn(ndim).astype(dtype) + b = np.random.randn(ndim).astype(dtype) + alpha = np.random.randn(1).astype(np.float64).item() + beta = np.random.randn(1).astype(np.float64).item() + atol = SIMSIMD_ATOL + rtol = SIMSIMD_RTOL + + keep_one_capability(capability) + baseline_kernel, simd_kernel = name_to_kernels(kernel) + + accurate_dt, accurate = profile( + baseline_kernel, + a.astype(np.float64), + b.astype(np.float64), + alpha=alpha, + beta=beta, + ) + expected_dt, expected = profile(baseline_kernel, a, b, alpha=alpha, beta=beta) + result_dt, result = profile(simd_kernel, a, b, alpha=alpha, beta=beta) + + np.testing.assert_allclose(result, expected.astype(np.float64), atol=atol, rtol=rtol) + collect_errors( + kernel, + ndim, + dtype, + accurate, + accurate_dt, + expected, + expected_dt, + result, + result_dt, + stats_fixture, + ) + + @pytest.mark.skipif(not numpy_available, reason="NumPy is not installed") @pytest.mark.skipif(not scipy_available, reason="SciPy is not installed") @pytest.mark.parametrize("ndim", [11, 97, 1536]) @@ -1024,11 +1197,11 @@ def test_batch(ndim, dtype, capability): assert np.allclose(result_simd, result_np, atol=SIMSIMD_ATOL, rtol=SIMSIMD_RTOL) # Distance between matrixes A (N x D scalars) and B (N x D scalars) in slices of bigger matrices. - A_exteded = np.random.randn(10, ndim + 11).astype(dtype) - B_extended = np.random.randn(10, ndim + 11).astype(dtype) - A = A_exteded[:, 1 : 1 + ndim] + A_extended = np.random.randn(10, ndim + 11).astype(dtype) + B_extended = np.random.randn(10, ndim + 13).astype(dtype) + A = A_extended[:, 1 : 1 + ndim] B = B_extended[:, 3 : 3 + ndim] - assert A.base is A_exteded and B.base is B_extended + assert A.base is A_extended and B.base is B_extended assert A.__array_interface__["strides"] is not None and B.__array_interface__["strides"] is not None result_np = [spd.sqeuclidean(A[i], B[i]) for i in range(10)] result_simd = np.array(simd.sqeuclidean(A, B)).astype(np.float64) @@ -1042,6 +1215,23 @@ def test_batch(ndim, dtype, capability): result_simd = np.array(simd.sqeuclidean(A, B)).astype(np.float64) assert np.allclose(result_simd, result_np, atol=SIMSIMD_ATOL, rtol=SIMSIMD_RTOL) + # Distance between matrixes A (N x D scalars) and B (N x D scalars) with a differnt output type. + A = np.random.randn(10, ndim).astype(dtype) + B = np.random.randn(10, ndim).astype(dtype) + result_np = np.array([spd.sqeuclidean(A[i], B[i]) for i in range(10)]).astype(np.float32) + result_simd = np.array(simd.sqeuclidean(A, B, out_dtype="float32")) + assert np.allclose(result_simd, result_np, atol=SIMSIMD_ATOL, rtol=SIMSIMD_RTOL) + assert result_simd.dtype == result_np.dtype + + # Distance between matrixes A (N x D scalars) and B (N x D scalars) with a supplied output buffer. + A = np.random.randn(10, ndim).astype(dtype) + B = np.random.randn(10, ndim).astype(dtype) + result_np = np.array([spd.sqeuclidean(A[i], B[i]) for i in range(10)]).astype(np.float32) + result_simd = np.zeros(10, dtype=np.float32) + assert simd.sqeuclidean(A, B, out=result_simd) is None + assert np.allclose(result_simd, result_np, atol=SIMSIMD_ATOL, rtol=SIMSIMD_RTOL) + assert result_simd.dtype == result_np.dtype + @pytest.mark.skipif(not numpy_available, reason="NumPy is not installed") @pytest.mark.skipif(not scipy_available, reason="SciPy is not installed") @@ -1063,16 +1253,58 @@ def test_cdist(ndim, input_dtype, out_dtype, metric, capability): # To test their ability to handle strided inputs, we are going to add one extra dimension. M, N = 10, 15 A_extended = np.random.randn(M, ndim + 1).astype(input_dtype) - B_extended = np.random.randn(N, ndim + 1).astype(input_dtype) + B_extended = np.random.randn(N, ndim + 3).astype(input_dtype) A = A_extended[:, :ndim] B = B_extended[:, :ndim] if out_dtype is None: expected = spd.cdist(A, B, metric) - result = simd.cdist(A, B, metric=metric) + result = simd.cdist(A, B, metric) + #! Same functions can be used in-place, but SciPy doesn't support misaligned outputs + expected_out = np.zeros((M, N)) + result_out_extended = np.zeros((M, N + 7)) + result_out = result_out_extended[:, :N] + assert spd.cdist(A, B, metric, out=expected_out) is not None + assert simd.cdist(A, B, metric, out=result_out) is None else: expected = spd.cdist(A, B, metric).astype(out_dtype) - result = simd.cdist(A, B, metric=metric, out_dtype=out_dtype) + result = simd.cdist(A, B, metric, out_dtype=out_dtype) + + #! Same functions can be used in-place, but SciPy doesn't support misaligned outputs + expected_out = np.zeros((M, N), dtype=np.float64) + result_out_extended = np.zeros((M, N + 7), dtype=out_dtype) + result_out = result_out_extended[:, :N] + assert spd.cdist(A, B, metric, out=expected_out) is not None + assert simd.cdist(A, B, metric, out=result_out) is None + #! Moreover, SciPy supports only double-precision outputs, so we need to downcast afterwards. + expected_out = expected_out.astype(out_dtype) + + # Assert they're close. + np.testing.assert_allclose(result, expected, atol=SIMSIMD_ATOL, rtol=SIMSIMD_RTOL) + np.testing.assert_allclose(result_out, expected_out, atol=SIMSIMD_ATOL, rtol=SIMSIMD_RTOL) + + +@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed") +@pytest.mark.skipif(not scipy_available, reason="SciPy is not installed") +@pytest.mark.parametrize("ndim", [11, 97, 1536]) +@pytest.mark.parametrize("input_dtype", ["float32", "float16"]) +@pytest.mark.parametrize("out_dtype", [None, "float32", "int32"]) +@pytest.mark.parametrize("metric", ["cosine", "sqeuclidean"]) +def test_cdist_itself(ndim, input_dtype, out_dtype, metric): + """Compares the simd.cdist(A, A) function with scipy.spatial.distance.cdist(A, A), measuring the accuracy error for f16, and f32 types using sqeuclidean and cosine metrics.""" + + if input_dtype == "float16" and is_running_under_qemu(): + pytest.skip("Testing low-precision math isn't reliable in QEMU") + + np.random.seed() + + A = np.random.randn(10, ndim + 1).astype(input_dtype) + if out_dtype is None: + expected = spd.cdist(A, A, metric) + result = simd.cdist(A, A, metric=metric) + else: + expected = spd.cdist(A, A, metric).astype(out_dtype) + result = simd.cdist(A, A, metric=metric, out_dtype=out_dtype) # Assert they're close. np.testing.assert_allclose(result, expected, atol=SIMSIMD_ATOL, rtol=SIMSIMD_RTOL) @@ -1085,28 +1317,46 @@ def test_cdist(ndim, input_dtype, out_dtype, metric, capability): @pytest.mark.parametrize("metric", ["dot", "vdot"]) @pytest.mark.parametrize("capability", possible_capabilities) def test_cdist_complex(ndim, input_dtype, out_dtype, metric, capability): - """Compares the simd.cdist() for complex numbers to pure NumPy complex dot-products, as SciPy has no such functionality.""" + """Compares the simd.cdist() for complex numbers to pure NumPy complex dot-products, as SciPy has no such functionality. + The goal is to make sure that addressing multi-component numbers is done properly in both real and imaginary parts. + """ np.random.seed() keep_one_capability(capability) # We will work with random matrices A (M x D) and B (N x D). # To test their ability to handle strided inputs, we are going to add one extra dimension. - A = np.random.randn(ndim).astype(input_dtype) - B = np.random.randn(ndim).astype(input_dtype) - - expected = np.dot(A, B) if metric == "dot" else np.vdot(A, B) + M, N = 10, 15 + A_extended = np.random.randn(M, ndim + 1).astype(input_dtype) + B_extended = np.random.randn(N, ndim + 3).astype(input_dtype) + A = A_extended[:, :ndim] + B = B_extended[:, :ndim] + C_extended = np.random.randn(M, N + 7).astype(out_dtype if out_dtype else np.complex128) + C = C_extended[:, :N] + + #! Unlike the `np.dot`, the `np.vdot` flattens multi-dimensional inputs into 1D arrays. + #! So to compreare the results we need to manually compute all the dot-products. + expected = np.zeros((M, N), dtype=out_dtype if out_dtype else np.complex128) + baseline_kernel = np.dot if metric == "dot" else np.vdot + for i in range(M): + for j in range(N): + expected[i, j] = baseline_kernel(A[i], B[j]) + + # Compute with SimSIMD: if out_dtype is None: - result1d = simd.cdist(A, B, metric=metric) - result2d = simd.cdist(A.reshape(1, ndim), B.reshape(1, ndim), metric=metric) + result1d = simd.cdist(A[0], B[0], metric=metric) + result2d = simd.cdist(A, B, metric=metric) + assert simd.cdist(A, B, metric=metric, out=C) is None else: expected = expected.astype(out_dtype) - result1d = simd.cdist(A, B, metric=metric, out_dtype=out_dtype) - result2d = simd.cdist(A.reshape(1, ndim), B.reshape(1, ndim), metric=metric, out_dtype=out_dtype) + result1d = simd.cdist(A[0], B[0], metric=metric, out_dtype=out_dtype) + result2d = simd.cdist(A, B, metric=metric, out_dtype=out_dtype) + assert simd.cdist(A, B, metric=metric, out_dtype=out_dtype, out=C) is None # Assert they're close. - np.testing.assert_allclose(result1d, expected, atol=SIMSIMD_ATOL, rtol=SIMSIMD_RTOL) + np.testing.assert_allclose(result1d, expected[0, 0], atol=SIMSIMD_ATOL, rtol=SIMSIMD_RTOL) np.testing.assert_allclose(result2d, expected, atol=SIMSIMD_ATOL, rtol=SIMSIMD_RTOL) + np.testing.assert_allclose(C, expected, atol=SIMSIMD_ATOL, rtol=SIMSIMD_RTOL) @pytest.mark.skipif(not numpy_available, reason="NumPy is not installed") diff --git a/setup.py b/setup.py index a9a8c36c..7974c414 100644 --- a/setup.py +++ b/setup.py @@ -77,7 +77,11 @@ def get_bool_env_w_name(name: str, preference: bool) -> tuple: macros_args.extend( [ get_bool_env_w_name("SIMSIMD_TARGET_NEON", True), + get_bool_env_w_name("SIMSIMD_TARGET_NEON_F16", True), + get_bool_env_w_name("SIMSIMD_TARGET_NEON_BF16", True), get_bool_env_w_name("SIMSIMD_TARGET_SVE", True), + get_bool_env_w_name("SIMSIMD_TARGET_SVE_F16", True), + get_bool_env_w_name("SIMSIMD_TARGET_SVE_BF16", True), get_bool_env_w_name("SIMSIMD_TARGET_SVE2", True), get_bool_env_w_name("SIMSIMD_TARGET_HASWELL", True), get_bool_env_w_name("SIMSIMD_TARGET_SKYLAKE", True), @@ -101,6 +105,8 @@ def get_bool_env_w_name(name: str, preference: bool) -> tuple: macros_args.extend( [ get_bool_env_w_name("SIMSIMD_TARGET_NEON", True), + get_bool_env_w_name("SIMSIMD_TARGET_NEON_F16", True), # Supported on Apple M1 and newer + get_bool_env_w_name("SIMSIMD_TARGET_NEON_BF16", True), # Supported on Apple M2 and newer get_bool_env_w_name("SIMSIMD_TARGET_SVE", False), get_bool_env_w_name("SIMSIMD_TARGET_SVE2", False), get_bool_env_w_name("SIMSIMD_TARGET_HASWELL", True), @@ -124,9 +130,11 @@ def get_bool_env_w_name(name: str, preference: bool) -> tuple: compile_args.append("/d2FH4-") # We can't SIMD all the way on Windows :( + # Even NEON `f16` fails: https://github.com/ashvardanian/SimSIMD/actions/runs/11419164624/job/31773473319?pr=214 macros_args.extend( [ get_bool_env_w_name("SIMSIMD_TARGET_NEON", True), + get_bool_env_w_name("SIMSIMD_TARGET_NEON_F16", False), get_bool_env_w_name("SIMSIMD_TARGET_NEON_BF16", False), get_bool_env_w_name("SIMSIMD_TARGET_SVE", False), get_bool_env_w_name("SIMSIMD_TARGET_SVE2", False), @@ -161,7 +169,7 @@ def get_bool_env_w_name(name: str, preference: bool) -> tuple: author="Ash Vardanian", author_email="1983160+ashvardanian@users.noreply.github.com", url="https://github.com/ashvardanian/simsimd", - description="Fastest SIMD-Accelerated Vector Similarity Functions for x86 and Arm", + description="Portable mixed-precision BLAS-like vector math library for x86 and ARM", long_description=long_description, long_description_content_type="text/markdown", license="Apache-2.0",