diff --git a/README.md b/README.md index 63744309..5d43f65d 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,7 @@ Implemented distance functions include: - Set Intersections for Sparse Vectors and Text Analysis. _[docs][docs-sparse]_ - Mahalanobis distance and Quadratic forms for Scientific Computing. _[docs][docs-curved]_ - Kullback-Leibler and Jensen–Shannon divergences for probability distributions. _[docs][docs-probability]_ +- Fused-Multiply-Add (FMA) and Weighted Sums to replace BLAS level 1 functions. _[docs][docs-fma]_ - For Levenshtein, Needleman–Wunsch, and Smith-Waterman, check [StringZilla][stringzilla]. - 🔜 Haversine and Vincenty's formulae for Geospatial Analysis. @@ -61,6 +62,7 @@ Implemented distance functions include: [docs-binary]: https://github.com/ashvardanian/SimSIMD/pull/138 [docs-dot]: #complex-dot-products-conjugate-dot-products-and-complex-numbers [docs-probability]: #logarithms-in-kullback-leibler--jensenshannon-divergences +[docs-fma]: #mixed-precision-in-fused-multiply-add-and-weighted-sums [scipy]: https://docs.scipy.org/doc/scipy/reference/spatial.distance.html#module-scipy.spatial.distance [numpy]: https://numpy.org/doc/stable/reference/generated/numpy.inner.html [stringzilla]: https://github.com/ashvardanian/stringzilla @@ -122,7 +124,8 @@ Use the following snippet to install SimSIMD and list available hardware acceler ```sh pip install simsimd -python -c "import simsimd; print(simsimd.get_capabilities())" +python -c "import simsimd; print(simsimd.get_capabilities())" # for hardware introspection +python -c "import simsimd; help(simsimd)" # for documentation ``` With precompiled binaries, SimSIMD ships `.pyi` interface files for type hinting and static analysis. @@ -929,6 +932,36 @@ Jensen-Shannon divergence is a symmetrized and smoothed version of the Kullback- Both functions are defined for non-negative numbers, and the logarithm is a key part of their computation. +### Mixed Precision in Fused-Multiply-Add and Weighted Sums + +The Fused-Multiply-Add (FMA) operation is a single operation that combines element-wise multiplication and addition with different scaling factors. +The Weighted Sum is it's simplified variant without element-wise multiplication. + +```math +\text{FMA}_i(A, B, C, \alpha, \beta) = \alpha \cdot A_i \cdot B_i + \beta \cdot C_i +``` + +```math +\text{WSum}_i(A, B, \alpha, \beta) = \alpha \cdot A_i + \beta \cdot B_i +``` + +In NumPy terms, the implementation may look like: + +```py +import numpy as np +def wsum(A: np.ndarray, B: np.ndarray, Alpha: float, Beta: float) -> np.ndarray: + assert A.dtype == B.dtype, "Input types must match and affect the output style" + return (Alpha * A + Beta * B).astype(A.dtype) +def fma(A: np.ndarray, B: np.ndarray, C: np.ndarray, Alpha: float, Beta: float) -> np.ndarray: + assert A.dtype == B.dtype and A.dtype == C.dtype, "Input types must match and affect the output style" + return (Alpha * A * B + Beta * C).astype(A.dtype) +``` + +The tricky part is implementing those operations in mixed precision, where the scaling factors are of different precision than the input and output vectors. +SimSIMD uses double-precision floating-point scaling factors for any input and output precision, including `i8` and `u8` integers and `f16` and `bf16` floats. +Depending on the generation of the CPU, given native support for `f16` addition and multiplication, the `f16` temporaries are used for `i8` and `u8` multiplication, scaling, and addition. +For `bf16`, native support is generally limited to dot-products with subsequent partial accumulation, which is not enough for the FMA and WSum operations, so `f32` is used as a temporary. + ### Auto-Vectorization & Loop Unrolling On the Intel Sapphire Rapids platform, SimSIMD was benchmarked against auto-vectorized code using GCC 12. diff --git a/c/lib.c b/c/lib.c index 951f50ed..d59724a1 100644 --- a/c/lib.c +++ b/c/lib.c @@ -107,6 +107,34 @@ extern "C" { metric(a, b, c, n, result); \ } +#define SIMSIMD_DECLARATION_FMA(name, extension, type) \ + SIMSIMD_DYNAMIC void simsimd_##name##_##extension( \ + simsimd_##type##_t const *a, simsimd_##type##_t const *b, simsimd_##type##_t const *c, simsimd_size_t n, \ + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_##type##_t *result) { \ + static simsimd_kernel_fma_punned_t metric = 0; \ + if (metric == 0) { \ + simsimd_capability_t used_capability; \ + simsimd_find_metric_punned(simsimd_metric_##name##_k, simsimd_datatype_##extension##_k, \ + simsimd_capabilities(), simsimd_cap_any_k, \ + (simsimd_metric_punned_t *)(&metric), &used_capability); \ + } \ + metric(a, b, c, n, alpha, beta, result); \ + } + +#define SIMSIMD_DECLARATION_WSUM(name, extension, type) \ + SIMSIMD_DYNAMIC void simsimd_##name##_##extension(simsimd_##type##_t const *a, simsimd_##type##_t const *b, \ + simsimd_size_t n, simsimd_distance_t alpha, \ + simsimd_distance_t beta, simsimd_##type##_t *result) { \ + static simsimd_kernel_wsum_punned_t metric = 0; \ + if (metric == 0) { \ + simsimd_capability_t used_capability; \ + simsimd_find_metric_punned(simsimd_metric_##name##_k, simsimd_datatype_##extension##_k, \ + simsimd_capabilities(), simsimd_cap_any_k, \ + (simsimd_metric_punned_t *)(&metric), &used_capability); \ + } \ + metric(a, b, n, alpha, beta, result); \ + } + // Dot products SIMSIMD_DECLARATION_DENSE(dot, i8, i8) SIMSIMD_DECLARATION_DENSE(dot, u8, u8) @@ -171,6 +199,20 @@ SIMSIMD_DECLARATION_CURVED(mahalanobis, f16, f16) SIMSIMD_DECLARATION_CURVED(bilinear, bf16, bf16) SIMSIMD_DECLARATION_CURVED(mahalanobis, bf16, bf16) +// Element-wise operations +SIMSIMD_DECLARATION_FMA(fma, f64, f64) +SIMSIMD_DECLARATION_FMA(fma, f32, f32) +SIMSIMD_DECLARATION_FMA(fma, f16, f16) +SIMSIMD_DECLARATION_FMA(fma, bf16, bf16) +SIMSIMD_DECLARATION_FMA(fma, i8, i8) +SIMSIMD_DECLARATION_FMA(fma, u8, u8) +SIMSIMD_DECLARATION_WSUM(wsum, f64, f64) +SIMSIMD_DECLARATION_WSUM(wsum, f32, f32) +SIMSIMD_DECLARATION_WSUM(wsum, f16, f16) +SIMSIMD_DECLARATION_WSUM(wsum, bf16, bf16) +SIMSIMD_DECLARATION_WSUM(wsum, i8, i8) +SIMSIMD_DECLARATION_WSUM(wsum, u8, u8) + SIMSIMD_DYNAMIC int simsimd_uses_neon(void) { return (simsimd_capabilities() & simsimd_cap_neon_k) != 0; } SIMSIMD_DYNAMIC int simsimd_uses_neon_f16(void) { return (simsimd_capabilities() & simsimd_cap_neon_f16_k) != 0; } SIMSIMD_DYNAMIC int simsimd_uses_neon_bf16(void) { return (simsimd_capabilities() & simsimd_cap_neon_bf16_k) != 0; } @@ -200,73 +242,86 @@ SIMSIMD_DYNAMIC simsimd_capability_t simsimd_capabilities(void) { // with dummy inputs: simsimd_distance_t dummy_results_buffer[2]; simsimd_distance_t *dummy_results = &dummy_results_buffer[0]; - void *dummy = 0; + void *x = 0; // Dense: - simsimd_dot_i8((simsimd_i8_t *)dummy, (simsimd_i8_t *)dummy, 0, dummy_results); - simsimd_dot_u8((simsimd_u8_t *)dummy, (simsimd_u8_t *)dummy, 0, dummy_results); - simsimd_dot_f16((simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, 0, dummy_results); - simsimd_dot_bf16((simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, 0, dummy_results); - simsimd_dot_f32((simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, 0, dummy_results); - simsimd_dot_f64((simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, 0, dummy_results); - - simsimd_dot_f16c((simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, 0, dummy_results); - simsimd_dot_bf16c((simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, 0, dummy_results); - simsimd_dot_f32c((simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, 0, dummy_results); - simsimd_dot_f64c((simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, 0, dummy_results); - simsimd_vdot_f16c((simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, 0, dummy_results); - simsimd_vdot_bf16c((simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, 0, dummy_results); - simsimd_vdot_f32c((simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, 0, dummy_results); - simsimd_vdot_f64c((simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, 0, dummy_results); - - simsimd_cos_i8((simsimd_i8_t *)dummy, (simsimd_i8_t *)dummy, 0, dummy_results); - simsimd_cos_u8((simsimd_u8_t *)dummy, (simsimd_u8_t *)dummy, 0, dummy_results); - simsimd_cos_f16((simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, 0, dummy_results); - simsimd_cos_bf16((simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, 0, dummy_results); - simsimd_cos_f32((simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, 0, dummy_results); - simsimd_cos_f64((simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, 0, dummy_results); - - simsimd_l2sq_i8((simsimd_i8_t *)dummy, (simsimd_i8_t *)dummy, 0, dummy_results); - simsimd_l2sq_u8((simsimd_u8_t *)dummy, (simsimd_u8_t *)dummy, 0, dummy_results); - simsimd_l2sq_f16((simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, 0, dummy_results); - simsimd_l2sq_bf16((simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, 0, dummy_results); - simsimd_l2sq_f32((simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, 0, dummy_results); - simsimd_l2sq_f64((simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, 0, dummy_results); - - simsimd_l2_i8((simsimd_i8_t *)dummy, (simsimd_i8_t *)dummy, 0, dummy_results); - simsimd_l2_i8((simsimd_i8_t *)dummy, (simsimd_i8_t *)dummy, 0, dummy_results); - simsimd_l2_u8((simsimd_u8_t *)dummy, (simsimd_u8_t *)dummy, 0, dummy_results); - simsimd_l2_f16((simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, 0, dummy_results); - simsimd_l2_bf16((simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, 0, dummy_results); - simsimd_l2_f32((simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, 0, dummy_results); - simsimd_l2_f64((simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, 0, dummy_results); - - simsimd_hamming_b8((simsimd_b8_t *)dummy, (simsimd_b8_t *)dummy, 0, dummy_results); - simsimd_jaccard_b8((simsimd_b8_t *)dummy, (simsimd_b8_t *)dummy, 0, dummy_results); - - simsimd_kl_f16((simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, 0, dummy_results); - simsimd_kl_bf16((simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, 0, dummy_results); - simsimd_kl_f32((simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, 0, dummy_results); - simsimd_kl_f64((simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, 0, dummy_results); - simsimd_js_f16((simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, 0, dummy_results); - simsimd_js_bf16((simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, 0, dummy_results); - simsimd_js_f32((simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, 0, dummy_results); - simsimd_js_f64((simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, 0, dummy_results); + simsimd_dot_i8((simsimd_i8_t *)x, (simsimd_i8_t *)x, 0, dummy_results); + simsimd_dot_u8((simsimd_u8_t *)x, (simsimd_u8_t *)x, 0, dummy_results); + simsimd_dot_f16((simsimd_f16_t *)x, (simsimd_f16_t *)x, 0, dummy_results); + simsimd_dot_bf16((simsimd_bf16_t *)x, (simsimd_bf16_t *)x, 0, dummy_results); + simsimd_dot_f32((simsimd_f32_t *)x, (simsimd_f32_t *)x, 0, dummy_results); + simsimd_dot_f64((simsimd_f64_t *)x, (simsimd_f64_t *)x, 0, dummy_results); + + simsimd_dot_f16c((simsimd_f16_t *)x, (simsimd_f16_t *)x, 0, dummy_results); + simsimd_dot_bf16c((simsimd_bf16_t *)x, (simsimd_bf16_t *)x, 0, dummy_results); + simsimd_dot_f32c((simsimd_f32_t *)x, (simsimd_f32_t *)x, 0, dummy_results); + simsimd_dot_f64c((simsimd_f64_t *)x, (simsimd_f64_t *)x, 0, dummy_results); + simsimd_vdot_f16c((simsimd_f16_t *)x, (simsimd_f16_t *)x, 0, dummy_results); + simsimd_vdot_bf16c((simsimd_bf16_t *)x, (simsimd_bf16_t *)x, 0, dummy_results); + simsimd_vdot_f32c((simsimd_f32_t *)x, (simsimd_f32_t *)x, 0, dummy_results); + simsimd_vdot_f64c((simsimd_f64_t *)x, (simsimd_f64_t *)x, 0, dummy_results); + + simsimd_cos_i8((simsimd_i8_t *)x, (simsimd_i8_t *)x, 0, dummy_results); + simsimd_cos_u8((simsimd_u8_t *)x, (simsimd_u8_t *)x, 0, dummy_results); + simsimd_cos_f16((simsimd_f16_t *)x, (simsimd_f16_t *)x, 0, dummy_results); + simsimd_cos_bf16((simsimd_bf16_t *)x, (simsimd_bf16_t *)x, 0, dummy_results); + simsimd_cos_f32((simsimd_f32_t *)x, (simsimd_f32_t *)x, 0, dummy_results); + simsimd_cos_f64((simsimd_f64_t *)x, (simsimd_f64_t *)x, 0, dummy_results); + + simsimd_l2sq_i8((simsimd_i8_t *)x, (simsimd_i8_t *)x, 0, dummy_results); + simsimd_l2sq_u8((simsimd_u8_t *)x, (simsimd_u8_t *)x, 0, dummy_results); + simsimd_l2sq_f16((simsimd_f16_t *)x, (simsimd_f16_t *)x, 0, dummy_results); + simsimd_l2sq_bf16((simsimd_bf16_t *)x, (simsimd_bf16_t *)x, 0, dummy_results); + simsimd_l2sq_f32((simsimd_f32_t *)x, (simsimd_f32_t *)x, 0, dummy_results); + simsimd_l2sq_f64((simsimd_f64_t *)x, (simsimd_f64_t *)x, 0, dummy_results); + + simsimd_l2_i8((simsimd_i8_t *)x, (simsimd_i8_t *)x, 0, dummy_results); + simsimd_l2_i8((simsimd_i8_t *)x, (simsimd_i8_t *)x, 0, dummy_results); + simsimd_l2_u8((simsimd_u8_t *)x, (simsimd_u8_t *)x, 0, dummy_results); + simsimd_l2_f16((simsimd_f16_t *)x, (simsimd_f16_t *)x, 0, dummy_results); + simsimd_l2_bf16((simsimd_bf16_t *)x, (simsimd_bf16_t *)x, 0, dummy_results); + simsimd_l2_f32((simsimd_f32_t *)x, (simsimd_f32_t *)x, 0, dummy_results); + simsimd_l2_f64((simsimd_f64_t *)x, (simsimd_f64_t *)x, 0, dummy_results); + + simsimd_hamming_b8((simsimd_b8_t *)x, (simsimd_b8_t *)x, 0, dummy_results); + simsimd_jaccard_b8((simsimd_b8_t *)x, (simsimd_b8_t *)x, 0, dummy_results); + + simsimd_kl_f16((simsimd_f16_t *)x, (simsimd_f16_t *)x, 0, dummy_results); + simsimd_kl_bf16((simsimd_bf16_t *)x, (simsimd_bf16_t *)x, 0, dummy_results); + simsimd_kl_f32((simsimd_f32_t *)x, (simsimd_f32_t *)x, 0, dummy_results); + simsimd_kl_f64((simsimd_f64_t *)x, (simsimd_f64_t *)x, 0, dummy_results); + simsimd_js_f16((simsimd_f16_t *)x, (simsimd_f16_t *)x, 0, dummy_results); + simsimd_js_bf16((simsimd_bf16_t *)x, (simsimd_bf16_t *)x, 0, dummy_results); + simsimd_js_f32((simsimd_f32_t *)x, (simsimd_f32_t *)x, 0, dummy_results); + simsimd_js_f64((simsimd_f64_t *)x, (simsimd_f64_t *)x, 0, dummy_results); // Sparse - simsimd_intersect_u16((simsimd_u16_t *)dummy, (simsimd_u16_t *)dummy, 0, 0, dummy_results); - simsimd_intersect_u32((simsimd_u32_t *)dummy, (simsimd_u32_t *)dummy, 0, 0, dummy_results); + simsimd_intersect_u16((simsimd_u16_t *)x, (simsimd_u16_t *)x, 0, 0, dummy_results); + simsimd_intersect_u32((simsimd_u32_t *)x, (simsimd_u32_t *)x, 0, 0, dummy_results); // Curved: - simsimd_bilinear_f64((simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, 0, dummy_results); - simsimd_mahalanobis_f64((simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, 0, dummy_results); - simsimd_bilinear_f32((simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, 0, dummy_results); - simsimd_mahalanobis_f32((simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, 0, dummy_results); - simsimd_bilinear_f16((simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, 0, dummy_results); - simsimd_mahalanobis_f16((simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, 0, dummy_results); - simsimd_bilinear_bf16((simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, 0, dummy_results); - simsimd_mahalanobis_bf16((simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, 0, - dummy_results); + simsimd_bilinear_f64((simsimd_f64_t *)x, (simsimd_f64_t *)x, (simsimd_f64_t *)x, 0, dummy_results); + simsimd_mahalanobis_f64((simsimd_f64_t *)x, (simsimd_f64_t *)x, (simsimd_f64_t *)x, 0, dummy_results); + simsimd_bilinear_f32((simsimd_f32_t *)x, (simsimd_f32_t *)x, (simsimd_f32_t *)x, 0, dummy_results); + simsimd_mahalanobis_f32((simsimd_f32_t *)x, (simsimd_f32_t *)x, (simsimd_f32_t *)x, 0, dummy_results); + simsimd_bilinear_f16((simsimd_f16_t *)x, (simsimd_f16_t *)x, (simsimd_f16_t *)x, 0, dummy_results); + simsimd_mahalanobis_f16((simsimd_f16_t *)x, (simsimd_f16_t *)x, (simsimd_f16_t *)x, 0, dummy_results); + simsimd_bilinear_bf16((simsimd_bf16_t *)x, (simsimd_bf16_t *)x, (simsimd_bf16_t *)x, 0, dummy_results); + simsimd_mahalanobis_bf16((simsimd_bf16_t *)x, (simsimd_bf16_t *)x, (simsimd_bf16_t *)x, 0, dummy_results); + + // Elementwise + simsimd_wsum_f64((simsimd_f64_t *)x, (simsimd_f64_t *)x, 0, 0, 0, (simsimd_f64_t *)x); + simsimd_wsum_f32((simsimd_f32_t *)x, (simsimd_f32_t *)x, 0, 0, 0, (simsimd_f32_t *)x); + simsimd_wsum_f16((simsimd_f16_t *)x, (simsimd_f16_t *)x, 0, 0, 0, (simsimd_f16_t *)x); + simsimd_wsum_bf16((simsimd_bf16_t *)x, (simsimd_bf16_t *)x, 0, 0, 0, (simsimd_bf16_t *)x); + simsimd_wsum_i8((simsimd_i8_t *)x, (simsimd_i8_t *)x, 0, 0, 0, (simsimd_i8_t *)x); + simsimd_wsum_u8((simsimd_u8_t *)x, (simsimd_u8_t *)x, 0, 0, 0, (simsimd_u8_t *)x); + simsimd_fma_f64((simsimd_f64_t *)x, (simsimd_f64_t *)x, (simsimd_f64_t *)x, 0, 0, 0, (simsimd_f64_t *)x); + simsimd_fma_f32((simsimd_f32_t *)x, (simsimd_f32_t *)x, (simsimd_f32_t *)x, 0, 0, 0, (simsimd_f32_t *)x); + simsimd_fma_f16((simsimd_f16_t *)x, (simsimd_f16_t *)x, (simsimd_f16_t *)x, 0, 0, 0, (simsimd_f16_t *)x); + simsimd_fma_bf16((simsimd_bf16_t *)x, (simsimd_bf16_t *)x, (simsimd_bf16_t *)x, 0, 0, 0, (simsimd_bf16_t *)x); + simsimd_fma_i8((simsimd_i8_t *)x, (simsimd_i8_t *)x, (simsimd_i8_t *)x, 0, 0, 0, (simsimd_i8_t *)x); + simsimd_fma_u8((simsimd_u8_t *)x, (simsimd_u8_t *)x, (simsimd_u8_t *)x, 0, 0, 0, (simsimd_u8_t *)x); return static_capabilities; } diff --git a/include/simsimd/fma.h b/include/simsimd/elementwise.h similarity index 89% rename from include/simsimd/fma.h rename to include/simsimd/elementwise.h index ab99027c..5bf86d2e 100644 --- a/include/simsimd/fma.h +++ b/include/simsimd/elementwise.h @@ -1,6 +1,6 @@ /** - * @file fma.h - * @brief SIMD-accelerated mixed-precision Fused-Multiply-Add operations. + * @file elementwise.h + * @brief SIMD-accelerated mixed-precision element-wise operations. * @author Ash Vardanian * @date October 16, 2024 * @@ -43,6 +43,10 @@ extern "C" { #endif +/* Serial backends for all numeric types. + * By default they use 32-bit arithmetic, unless the arguments themselves contain 64-bit floats. + * For double-precision computation check out the "*_accurate" variants of those "*_serial" functions. + */ SIMSIMD_PUBLIC void simsimd_wsum_f64_serial( // simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, // simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f64_t *result); @@ -136,6 +140,48 @@ SIMSIMD_MAKE_FMA(accurate, bf16, f64, SIMSIMD_BF16_TO_F32, SIMSIMD_F32_TO_BF16) SIMSIMD_MAKE_FMA(accurate, i8, f64, SIMSIMD_DEREFERENCE, SIMSIMD_F64_TO_I8) // simsimd_fma_i8_accurate SIMSIMD_MAKE_FMA(accurate, u8, f64, SIMSIMD_DEREFERENCE, SIMSIMD_F64_TO_U8) // simsimd_fma_u8_accurate +/* SIMD-powered backends for Arm NEON, mostly using 32-bit arithmetic over 128-bit words. + * By far the most portable backend, covering most Arm v8 devices, over a billion phones, and almost all + * server CPUs produced before 2023. + */ +SIMSIMD_PUBLIC void simsimd_wsum_f32_neon( // + simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f32_t *result); +SIMSIMD_PUBLIC void simsimd_wsum_f16_neon( // + simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f16_t *result); +SIMSIMD_PUBLIC void simsimd_wsum_bf16_neon( // + simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_bf16_t *result); +SIMSIMD_PUBLIC void simsimd_wsum_u8_neon( // + simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_u8_t *result); +SIMSIMD_PUBLIC void simsimd_wsum_i8_neon( // + simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_i8_t *result); + +SIMSIMD_PUBLIC void simsimd_fma_f32_neon( // + simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_f32_t const *c, // + simsimd_size_t n, simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f32_t *result); +SIMSIMD_PUBLIC void simsimd_fma_f16_neon( // + simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_f16_t const *c, // + simsimd_size_t n, simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f16_t *result); +SIMSIMD_PUBLIC void simsimd_fma_bf16_neon( // + simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_bf16_t const *c, // + simsimd_size_t n, simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_bf16_t *result); +SIMSIMD_PUBLIC void simsimd_fma_u8_neon( // + simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_u8_t const *c, // + simsimd_size_t n, simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_u8_t *result); +SIMSIMD_PUBLIC void simsimd_fma_i8_neon( // + simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_i8_t const *c, // + simsimd_size_t n, simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_i8_t *result); + +/* SIMD-powered backends for AVX2 CPUs of Haswell generation and newer, using 32-bit arithmetic over 256-bit words. + * First demonstrated in 2011, at least one Haswell-based processor was still being sold in 2022 — the Pentium G3420. + * Practically all modern x86 CPUs support AVX2, FMA, and F16C, making it a perfect baseline for SIMD algorithms. + * On other hand, there is no need to implement AVX2 versions of `f32` and `f64` functions, as those are + * properly vectorized by recent compilers. + */ SIMSIMD_PUBLIC void simsimd_wsum_f64_haswell( // simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, // simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f64_t *result); @@ -981,6 +1027,62 @@ SIMSIMD_PUBLIC void simsimd_fma_f32_neon( // #pragma GCC pop_options #endif // SIMSIMD_TARGET_NEON +#if SIMSIMD_TARGET_NEON_BF16 +#pragma GCC push_options +#pragma GCC target("arch=armv8.6-a+simd+bf16") +#pragma clang attribute push(__attribute__((target("arch=armv8.6-a+simd+bf16"))), apply_to = function) + +SIMSIMD_PUBLIC void simsimd_wsum_bf16_neon( // + simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_bf16_t *result) { + simsimd_f32_t alpha_f32 = (simsimd_f32_t)alpha; + simsimd_f32_t beta_f32 = (simsimd_f32_t)beta; + + // The main loop: + simsimd_size_t i = 0; + for (; i + 4 <= n; i += 4) { + float32x4_t a_vec = vcvt_f32_bf16(vld1_bf16((bfloat16_t const *)a + i)); + float32x4_t b_vec = vcvt_f32_bf16(vld1_bf16((bfloat16_t const *)b + i)); + float32x4_t a_scaled_vec = vmulq_n_f32(a_vec, alpha_f32); + float32x4_t b_scaled_vec = vmulq_n_f32(b_vec, beta_f32); + float32x4_t sum_vec = vaddq_f32(a_scaled_vec, b_scaled_vec); + vst1_bf16((bfloat16_t *)result + i, vcvt_bf16_f32(sum_vec)); + } + + // The tail: + for (; i < n; ++i) + simsimd_f32_to_bf16(alpha_f32 * simsimd_bf16_to_f32(a + i) + beta_f32 * simsimd_bf16_to_f32(b + i), result + i); +} + +SIMSIMD_PUBLIC void simsimd_fma_bf16_neon( // + simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_bf16_t const *c, // + simsimd_size_t n, simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_bf16_t *result) { + simsimd_f32_t alpha_f32 = (simsimd_f32_t)alpha; + simsimd_f32_t beta_f32 = (simsimd_f32_t)beta; + + // The main loop: + simsimd_size_t i = 0; + for (; i + 4 <= n; i += 4) { + float32x4_t a_vec = vcvt_f32_bf16(vld1_bf16((bfloat16_t const *)a + i)); + float32x4_t b_vec = vcvt_f32_bf16(vld1_bf16((bfloat16_t const *)b + i)); + float32x4_t c_vec = vcvt_f32_bf16(vld1_bf16((bfloat16_t const *)c + i)); + float32x4_t ab_vec = vmulq_f32(a_vec, b_vec); + float32x4_t ab_scaled_vec = vmulq_n_f32(ab_vec, alpha_f32); + float32x4_t sum_vec = vfmaq_n_f32(ab_scaled_vec, c_vec, beta_f32); + vst1_bf16((bfloat16_t *)result + i, vcvt_bf16_f32(sum_vec)); + } + + // The tail: + for (; i < n; ++i) + simsimd_f32_to_bf16( + alpha_f32 * simsimd_bf16_to_f32(a + i) * simsimd_bf16_to_f32(b + i) + beta_f32 * simsimd_bf16_to_f32(c + i), + result + i); +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_NEON_BF16 + #if SIMSIMD_TARGET_NEON_F16 #pragma GCC push_options #pragma GCC target("arch=armv8.2-a+simd+fp16") @@ -995,12 +1097,12 @@ SIMSIMD_PUBLIC void simsimd_wsum_f16_neon( // // The main loop: simsimd_size_t i = 0; for (; i + 8 <= n; i += 8) { - float16x8_t a_vec = vld1q_f16(a + i); - float16x8_t b_vec = vld1q_f16(b + i); + float16x8_t a_vec = vld1q_f16((float16_t const *)a + i); + float16x8_t b_vec = vld1q_f16((float16_t const *)b + i); float16x8_t a_scaled_vec = vmulq_n_f16(a_vec, alpha_f16); float16x8_t b_scaled_vec = vmulq_n_f16(b_vec, beta_f16); float16x8_t sum_vec = vaddq_f16(a_scaled_vec, b_scaled_vec); - vst1q_f16(result + i, sum_vec); + vst1q_f16((float16_t *)result + i, sum_vec); } // The tail: @@ -1017,13 +1119,13 @@ SIMSIMD_PUBLIC void simsimd_fma_f16_neon( // // The main loop: simsimd_size_t i = 0; for (; i + 8 <= n; i += 8) { - float16x8_t a_vec = vld1q_f16(a + i); - float16x8_t b_vec = vld1q_f16(b + i); - float16x8_t c_vec = vld1q_f16(c + i); + float16x8_t a_vec = vld1q_f16((float16_t const *)a + i); + float16x8_t b_vec = vld1q_f16((float16_t const *)b + i); + float16x8_t c_vec = vld1q_f16((float16_t const *)c + i); float16x8_t ab_vec = vmulq_f16(a_vec, b_vec); float16x8_t ab_scaled_vec = vmulq_n_f16(ab_vec, alpha_f16); float16x8_t sum_vec = vfmaq_n_f16(ab_scaled_vec, c_vec, beta_f16); - vst1q_f16(result + i, sum_vec); + vst1q_f16((float16_t *)result + i, sum_vec); } // The tail: @@ -1048,7 +1150,7 @@ SIMSIMD_PUBLIC void simsimd_wsum_u8_neon( // float16x8_t a_scaled_vec = vmulq_n_f16(a_vec, alpha_f16); float16x8_t b_scaled_vec = vmulq_n_f16(b_vec, beta_f16); float16x8_t sum_vec = vaddq_f16(a_scaled_vec, b_scaled_vec); - uint8x8_t sum_u8_vec = vmovn_u16(vcvtaq_u16_f16(sum_vec)); + uint8x8_t sum_u8_vec = vqmovn_u16(vcvtaq_u16_f16(sum_vec)); vst1_u8(result + i, sum_u8_vec); } @@ -1074,7 +1176,7 @@ SIMSIMD_PUBLIC void simsimd_fma_u8_neon( // float16x8_t ab_vec = vmulq_f16(a_vec, b_vec); float16x8_t ab_scaled_vec = vmulq_n_f16(ab_vec, alpha_f16); float16x8_t sum_vec = vfmaq_n_f16(ab_scaled_vec, c_vec, beta_f16); - uint8x8_t sum_u8_vec = vmovn_u16(vcvtaq_u16_f16(sum_vec)); + uint8x8_t sum_u8_vec = vqmovn_u16(vcvtaq_u16_f16(sum_vec)); vst1_u8(result + i, sum_u8_vec); } @@ -1098,7 +1200,7 @@ SIMSIMD_PUBLIC void simsimd_wsum_i8_neon( // float16x8_t a_scaled_vec = vmulq_n_f16(a_vec, alpha_f16); float16x8_t b_scaled_vec = vmulq_n_f16(b_vec, beta_f16); float16x8_t sum_vec = vaddq_f16(a_scaled_vec, b_scaled_vec); - int8x8_t sum_i8_vec = vmovn_s16(vcvtaq_s16_f16(sum_vec)); + int8x8_t sum_i8_vec = vqmovn_s16(vcvtaq_s16_f16(sum_vec)); vst1_s8(result + i, sum_i8_vec); } @@ -1124,7 +1226,7 @@ SIMSIMD_PUBLIC void simsimd_fma_i8_neon( // float16x8_t ab_vec = vmulq_f16(a_vec, b_vec); float16x8_t ab_scaled_vec = vmulq_n_f16(ab_vec, alpha_f16); float16x8_t sum_vec = vfmaq_n_f16(ab_scaled_vec, c_vec, beta_f16); - int8x8_t sum_i8_vec = vmovn_s16(vcvtaq_s16_f16(sum_vec)); + int8x8_t sum_i8_vec = vqmovn_s16(vcvtaq_s16_f16(sum_vec)); vst1_s8(result + i, sum_i8_vec); } diff --git a/include/simsimd/simsimd.h b/include/simsimd/simsimd.h index 7e62cdb1..ff3b5c98 100644 --- a/include/simsimd/simsimd.h +++ b/include/simsimd/simsimd.h @@ -105,7 +105,7 @@ #include "binary.h" // Hamming, Jaccard #include "curved.h" // Mahalanobis, Bilinear Forms #include "dot.h" // Inner (dot) product, and its conjugate -#include "fma.h" // Weighted Sum, Fused Multiply-Add +#include "elementwise.h" // Weighted Sum, Fused-Multiply-Add #include "geospatial.h" // Haversine and Vincenty #include "probability.h" // Kullback-Leibler, Jensen–Shannon #include "sparse.h" // Intersect @@ -613,6 +613,8 @@ SIMSIMD_INTERNAL void _simsimd_find_metric_punned_f32(simsimd_capability_t v, si case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_f32_neon, *c = simsimd_cap_neon_k; return; case simsimd_metric_js_k: *m = (m_t)&simsimd_js_f32_neon, *c = simsimd_cap_neon_k; return; case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_f32_neon, *c = simsimd_cap_neon_k; return; + case simsimd_metric_fma_k: *m = (m_t)&simsimd_fma_f32_neon, *c = simsimd_cap_neon_k; return; + case simsimd_metric_wsum_k: *m = (m_t)&simsimd_wsum_f32_neon, *c = simsimd_cap_neon_k; return; default: break; } #endif @@ -681,6 +683,8 @@ SIMSIMD_INTERNAL void _simsimd_find_metric_punned_f16(simsimd_capability_t v, si case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_f16_neon, *c = simsimd_cap_neon_f16_k; return; case simsimd_metric_bilinear_k: *m = (m_t)&simsimd_bilinear_f16_neon, *c = simsimd_cap_neon_f16_k; return; case simsimd_metric_mahalanobis_k: *m = (m_t)&simsimd_mahalanobis_f16_neon, *c = simsimd_cap_neon_f16_k; return; + case simsimd_metric_fma_k: *m = (m_t)&simsimd_fma_f16_neon, *c = simsimd_cap_neon_f16_k; return; + case simsimd_metric_wsum_k: *m = (m_t)&simsimd_wsum_f16_neon, *c = simsimd_cap_neon_f16_k; return; default: break; } #endif @@ -750,6 +754,8 @@ SIMSIMD_INTERNAL void _simsimd_find_metric_punned_bf16(simsimd_capability_t v, s case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_bf16_neon, *c = simsimd_cap_neon_bf16_k; return; case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_bf16_neon, *c = simsimd_cap_neon_bf16_k; return; case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_bf16_neon, *c = simsimd_cap_neon_bf16_k; return; + case simsimd_metric_fma_k: *m = (m_t)&simsimd_fma_bf16_neon, *c = simsimd_cap_neon_bf16_k; return; + case simsimd_metric_wsum_k: *m = (m_t)&simsimd_wsum_bf16_neon, *c = simsimd_cap_neon_bf16_k; return; default: break; } #endif @@ -815,7 +821,14 @@ SIMSIMD_INTERNAL void _simsimd_find_metric_punned_i8(simsimd_capability_t v, sim default: break; } #endif -#if SIMSIMD_TARGET_SAPPHIRE +#if SIMSIMD_TARGET_NEON_F16 //! Scaling of 8-bit integers is performed using 16-bit floats. + if (v & simsimd_cap_neon_f16_k) switch (k) { + case simsimd_metric_fma_k: *m = (m_t)&simsimd_fma_i8_neon, *c = simsimd_cap_neon_f16_k; return; + case simsimd_metric_wsum_k: *m = (m_t)&simsimd_wsum_i8_neon, *c = simsimd_cap_neon_f16_k; return; + default: break; + } +#endif +#if SIMSIMD_TARGET_SAPPHIRE //! Scaling of 8-bit integers is performed using 16-bit floats. if (v & simsimd_cap_sapphire_k) switch (k) { case simsimd_metric_fma_k: *m = (m_t)&simsimd_fma_i8_sapphire, *c = simsimd_cap_sapphire_k; return; case simsimd_metric_wsum_k: *m = (m_t)&simsimd_wsum_i8_sapphire, *c = simsimd_cap_sapphire_k; return; @@ -864,7 +877,14 @@ SIMSIMD_INTERNAL void _simsimd_find_metric_punned_u8(simsimd_capability_t v, sim default: break; } #endif -#if SIMSIMD_TARGET_SAPPHIRE +#if SIMSIMD_TARGET_NEON_F16 //! Scaling of 8-bit integers is performed using 16-bit floats. + if (v & simsimd_cap_neon_f16_k) switch (k) { + case simsimd_metric_fma_k: *m = (m_t)&simsimd_fma_u8_neon, *c = simsimd_cap_neon_f16_k; return; + case simsimd_metric_wsum_k: *m = (m_t)&simsimd_wsum_u8_neon, *c = simsimd_cap_neon_f16_k; return; + default: break; + } +#endif +#if SIMSIMD_TARGET_SAPPHIRE //! Scaling of 8-bit integers is performed using 16-bit floats. if (v & simsimd_cap_sapphire_k) switch (k) { case simsimd_metric_fma_k: *m = (m_t)&simsimd_fma_u8_sapphire, *c = simsimd_cap_sapphire_k; return; case simsimd_metric_wsum_k: *m = (m_t)&simsimd_wsum_u8_sapphire, *c = simsimd_cap_sapphire_k; return; @@ -1164,9 +1184,9 @@ SIMSIMD_INTERNAL void _simsimd_find_metric_punned_implementation( // __asm__ __volatile__("" ::: "memory"); #endif - volatile simsimd_metric_punned_t *m = metric_output; - volatile simsimd_capability_t *c = capability_output; - volatile simsimd_capability_t viable = (simsimd_capability_t)(supported & allowed); + simsimd_metric_punned_t *m = metric_output; + simsimd_capability_t *c = capability_output; + simsimd_capability_t viable = (simsimd_capability_t)(supported & allowed); switch (datatype) { @@ -2147,6 +2167,174 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16(simsimd_bf16_t const *a, simsimd_bf #endif } +/* Elementwise operations + * + * @param a The first vector of integral or floating point values. + * @param b The second vector of integral or floating point values. + * @param c The third vector of integral or floating point values. + * @param n The number of dimensions in the vectors. + * @param alpha The first scaling factor. + * @param beta The first scaling factor. + * @param r The output vector or integral or floating point values. + */ +SIMSIMD_PUBLIC void simsimd_wsum_f64(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f64_t *r) { +#if SIMSIMD_TARGET_SKYLAKE + simsimd_wsum_f64_skylake(a, b, n, alpha, beta, r); +#elif SIMSIMD_TARGET_HASWELL + simsimd_wsum_f64_haswell(a, b, n, alpha, beta, r); +#else + simsimd_wsum_f64_serial(a, b, n, alpha, beta, r); +#endif +} + +SIMSIMD_PUBLIC void simsimd_wsum_f32(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f32_t *r) { +#if SIMSIMD_TARGET_SKYLAKE + simsimd_wsum_f32_skylake(a, b, n, alpha, beta, r); +#elif SIMSIMD_TARGET_HASWELL + simsimd_wsum_f32_haswell(a, b, n, alpha, beta, r); +#elif SIMSIMD_TARGET_NEON + simsimd_wsum_f32_neon(a, b, n, alpha, beta, r); +#else + simsimd_wsum_f32_serial(a, b, n, alpha, beta, r); +#endif +} + +SIMSIMD_PUBLIC void simsimd_wsum_bf16(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_size_t n, + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_bf16_t *r) { +#if SIMSIMD_TARGET_SKYLAKE + simsimd_wsum_bf16_skylake(a, b, n, alpha, beta, r); +#elif SIMSIMD_TARGET_HASWELL + simsimd_wsum_bf16_haswell(a, b, n, alpha, beta, r); +#elif SIMSIMD_TARGET_NEON + simsimd_wsum_bf16_neon(a, b, n, alpha, beta, r); +#else + simsimd_wsum_bf16_serial(a, b, n, alpha, beta, r); +#endif +} + +SIMSIMD_PUBLIC void simsimd_wsum_f16(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_size_t n, + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_f16_t *r) { +#if SIMSIMD_TARGET_SAPPHIRE + simsimd_wsum_f16_sapphire(a, b, n, alpha, beta, r); +#elif SIMSIMD_TARGET_HASWELL + simsimd_wsum_f16_haswell(a, b, n, alpha, beta, r); +#elif SIMSIMD_TARGET_NEON_F16 + simsimd_wsum_f16_neon(a, b, n, alpha, beta, r); +#else + simsimd_wsum_f16_serial(a, b, n, alpha, beta, r); +#endif +} + +SIMSIMD_PUBLIC void simsimd_wsum_i8(simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_i8_t *r) { +#if SIMSIMD_TARGET_SAPPHIRE + simsimd_wsum_i8_sapphire(a, b, n, alpha, beta, r); +#elif SIMSIMD_TARGET_HASWELL + simsimd_wsum_i8_haswell(a, b, n, alpha, beta, r); +#elif SIMSIMD_TARGET_NEON_F16 + simsimd_wsum_i8_neon(a, b, n, alpha, beta, r); +#else + simsimd_wsum_i8_serial(a, b, n, alpha, beta, r); +#endif +} + +SIMSIMD_PUBLIC void simsimd_wsum_u8(simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_u8_t *r) { +#if SIMSIMD_TARGET_SAPPHIRE + simsimd_wsum_u8_sapphire(a, b, n, alpha, beta, r); +#elif SIMSIMD_TARGET_HASWELL + simsimd_wsum_u8_haswell(a, b, n, alpha, beta, r); +#elif SIMSIMD_TARGET_NEON_F16 + simsimd_wsum_u8_neon(a, b, n, alpha, beta, r); +#else + simsimd_wsum_u8_serial(a, b, n, alpha, beta, r); +#endif +} + +SIMSIMD_PUBLIC void simsimd_fma_f64(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_f64_t const *c, + simsimd_size_t n, simsimd_distance_t alpha, simsimd_distance_t beta, + simsimd_f64_t *r) { +#if SIMSIMD_TARGET_SKYLAKE + simsimd_fma_f64_skylake(a, b, c, n, alpha, beta, r); +#elif SIMSIMD_TARGET_HASWELL + simsimd_fma_f64_haswell(a, b, c, n, alpha, beta, r); +#else + simsimd_fma_f64_serial(a, b, c, n, alpha, beta, r); +#endif +} + +SIMSIMD_PUBLIC void simsimd_fma_f32(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_f32_t const *c, + simsimd_size_t n, simsimd_distance_t alpha, simsimd_distance_t beta, + simsimd_f32_t *r) { +#if SIMSIMD_TARGET_SKYLAKE + simsimd_fma_f32_skylake(a, b, c, n, alpha, beta, r); +#elif SIMSIMD_TARGET_HASWELL + simsimd_fma_f32_haswell(a, b, c, n, alpha, beta, r); +#elif SIMSIMD_TARGET_NEON + simsimd_fma_f32_neon(a, b, c, n, alpha, beta, r); +#else + simsimd_fma_f32_serial(a, b, c, n, alpha, beta, r); +#endif +} + +SIMSIMD_PUBLIC void simsimd_fma_bf16(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_bf16_t const *c, + simsimd_size_t n, simsimd_distance_t alpha, simsimd_distance_t beta, + simsimd_bf16_t *r) { +#if SIMSIMD_TARGET_SKYLAKE + simsimd_fma_bf16_skylake(a, b, c, n, alpha, beta, r); +#elif SIMSIMD_TARGET_HASWELL + simsimd_fma_bf16_haswell(a, b, c, n, alpha, beta, r); +#elif SIMSIMD_TARGET_NEON + simsimd_fma_bf16_neon(a, b, c, n, alpha, beta, r); +#else + simsimd_fma_bf16_serial(a, b, c, n, alpha, beta, r); +#endif +} + +SIMSIMD_PUBLIC void simsimd_fma_f16(simsimd_f16_t const *a, simsimd_f16_t const *b, simsimd_f16_t const *c, + simsimd_size_t n, simsimd_distance_t alpha, simsimd_distance_t beta, + simsimd_f16_t *r) { +#if SIMSIMD_TARGET_SAPPHIRE + simsimd_fma_f16_sapphire(a, b, c, n, alpha, beta, r); +#elif SIMSIMD_TARGET_HASWELL + simsimd_fma_f16_haswell(a, b, c, n, alpha, beta, r); +#elif SIMSIMD_TARGET_NEON_F16 + simsimd_fma_f16_neon(a, b, c, n, alpha, beta, r); +#else + simsimd_fma_f16_serial(a, b, c, n, alpha, beta, r); +#endif +} + +SIMSIMD_PUBLIC void simsimd_fma_i8(simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_i8_t const *c, + simsimd_size_t n, simsimd_distance_t alpha, simsimd_distance_t beta, + simsimd_i8_t *r) { +#if SIMSIMD_TARGET_SAPPHIRE + simsimd_fma_i8_sapphire(a, b, c, n, alpha, beta, r); +#elif SIMSIMD_TARGET_HASWELL + simsimd_fma_i8_haswell(a, b, c, n, alpha, beta, r); +#elif SIMSIMD_TARGET_NEON_F16 + simsimd_fma_i8_neon(a, b, c, n, alpha, beta, r); +#else + simsimd_fma_i8_serial(a, b, c, n, alpha, beta, r); +#endif +} + +SIMSIMD_PUBLIC void simsimd_fma_u8(simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_u8_t const *c, + simsimd_size_t n, simsimd_distance_t alpha, simsimd_distance_t beta, + simsimd_u8_t *r) { +#if SIMSIMD_TARGET_SAPPHIRE + simsimd_fma_u8_sapphire(a, b, c, n, alpha, beta, r); +#elif SIMSIMD_TARGET_HASWELL + simsimd_fma_u8_haswell(a, b, c, n, alpha, beta, r); +#elif SIMSIMD_TARGET_NEON_F16 + simsimd_fma_u8_neon(a, b, c, n, alpha, beta, r); +#else + simsimd_fma_u8_serial(a, b, c, n, alpha, beta, r); +#endif +} + #endif #ifdef __cplusplus diff --git a/python/annotations/__init__.pyi b/python/annotations/__init__.pyi index 987ea9f0..484387fc 100644 --- a/python/annotations/__init__.pyi +++ b/python/annotations/__init__.pyi @@ -27,6 +27,8 @@ _MetricType = Literal[ "intersection", "bilinear", "mahalanobis", + "fma", + "wsum", ] _IntegralType = Literal[ # Booleans @@ -115,8 +117,9 @@ def cdist( *, threads: int = 1, dtype: Optional[Union[_IntegralType, _FloatType, _ComplexType]] = None, + out: Optional[_BufferType] = None, out_dtype: Union[_FloatType, _ComplexType] = "d", -) -> Union[float, complex, DistancesTensor]: ... +) -> Optional[Union[float, complex, DistancesTensor]]: ... # --------------------------------------------------------------------- # Vector-vector dot products for real and complex numbers @@ -129,7 +132,10 @@ def inner( b: _BufferType, /, dtype: Optional[Union[_FloatType, _ComplexType]] = None, -) -> Union[float, complex, DistancesTensor]: ... + *, + out: Optional[_BufferType] = None, + out_dtype: Union[_FloatType, _ComplexType] = "d", +) -> Optional[Union[float, complex, DistancesTensor]]: ... # Dot product, similar to: `numpy.dot`. # https://numpy.org/doc/stable/reference/generated/numpy.dot.html @@ -138,7 +144,10 @@ def dot( b: _BufferType, /, dtype: Optional[Union[_FloatType, _ComplexType]] = None, -) -> Union[float, complex, DistancesTensor]: ... + *, + out: Optional[_BufferType] = None, + out_dtype: Union[_FloatType, _ComplexType] = None, +) -> Optional[Union[float, complex, DistancesTensor]]: ... # Vector-vector dot product for complex conjugates, similar to: `numpy.vdot`. # https://numpy.org/doc/stable/reference/generated/numpy.vdot.html @@ -147,7 +156,10 @@ def vdot( b: _BufferType, /, dtype: Optional[_ComplexType] = None, -) -> Union[complex, DistancesTensor]: ... + *, + out: Optional[Union[float, complex, DistancesTensor]] = None, + out_dtype: Optional[_ComplexType] = None, +) -> Optional[Union[complex, DistancesTensor]]: ... # --------------------------------------------------------------------- # Vector-vector spatial distance metrics for real and integer numbers @@ -161,7 +173,10 @@ def sqeuclidean( b: _BufferType, /, dtype: Optional[Union[_IntegralType, _FloatType]] = None, -) -> Union[float, DistancesTensor]: ... + *, + out: Optional[_BufferType] = None, + out_dtype: Union[_FloatType] = None, +) -> Optional[Union[float, DistancesTensor]]: ... # Vector-vector cosine distance, similar to: `scipy.spatial.distance.cosine`. # https://docs.scipy.org/doc/scipy-1.11.4/reference/generated/scipy.spatial.distance.cosine.html @@ -170,7 +185,10 @@ def cosine( b: _BufferType, /, dtype: Optional[Union[_IntegralType, _FloatType]] = None, -) -> Union[float, DistancesTensor]: ... + *, + out: Optional[_BufferType] = None, + out_dtype: Union[_FloatType] = None, +) -> Optional[Union[float, DistancesTensor]]: ... # --------------------------------------------------------------------- # Vector-vector similarity functions for binary vectors @@ -183,7 +201,10 @@ def hamming( b: _BufferType, /, dtype: Optional[_IntegralType] = None, -) -> Union[float, DistancesTensor]: ... + *, + out: Optional[_BufferType] = None, + out_dtype: Union[_FloatType] = None, +) -> Optional[Union[float, DistancesTensor]]: ... # Vector-vector Jaccard distance, similar to: `scipy.spatial.distance.jaccard`. # https://docs.scipy.org/doc/scipy-1.11.4/reference/generated/scipy.spatial.distance.jaccard.html @@ -192,7 +213,10 @@ def jaccard( b: _BufferType, /, dtype: Optional[_IntegralType] = None, -) -> Union[float, DistancesTensor]: ... + *, + out: Optional[_BufferType] = None, + out_dtype: Union[_FloatType] = None, +) -> Optional[Union[float, DistancesTensor]]: ... # --------------------------------------------------------------------- # Vector-vector similarity between probability distributions @@ -205,7 +229,10 @@ def jensenshannon( b: _BufferType, /, dtype: Optional[_FloatType] = None, -) -> Union[float, DistancesTensor]: ... + *, + out: Optional[_BufferType] = None, + out_dtype: Union[_FloatType] = None, +) -> Optional[Union[float, DistancesTensor]]: ... # Vector-vector Kullback-Leibler divergence, similar to: `scipy.spatial.distance.kullback_leibler`. # https://docs.scipy.org/doc/scipy-1.11.4/reference/generated/scipy.spatial.distance.kullback_leibler.html @@ -214,7 +241,10 @@ def kullbackleibler( b: _BufferType, /, dtype: Optional[_FloatType] = None, -) -> Union[float, DistancesTensor]: ... + *, + out: Optional[_BufferType] = None, + out_dtype: Union[_FloatType] = None, +) -> Optional[Union[float, DistancesTensor]]: ... # --------------------------------------------------------------------- # Vector-vector similarity between vectors in curved spaces @@ -247,3 +277,32 @@ def mahalanobis( # Vector-vector intersection similarity, similar to: `numpy.intersect1d`. # https://numpy.org/doc/stable/reference/generated/numpy.intersect1d.html def intersection(array1: _BufferType, array2: _BufferType, /) -> float: ... + +# --------------------------------------------------------------------- +# Vector-vector math: FMA, WSum +# --------------------------------------------------------------------- + +# Vector-vector element-wise fused-multiply add. +def fma( + a: _BufferType, + b: _BufferType, + c: _BufferType, + /, + dtype: Optional[Union[_FloatType, _IntegralType]] = None, + *, + alpha: float = 1, + beta: float = 1, + out: Optional[_BufferType] = None, +) -> Optional[DistancesTensor]: ... + +# Vector-vector element-wise weighted sum. +def wum( + a: _BufferType, + b: _BufferType, + /, + dtype: Optional[Union[_FloatType, _IntegralType]] = None, + *, + alpha: float = 1, + beta: float = 1, + out: Optional[_BufferType] = None, +) -> Optional[DistancesTensor]: ... diff --git a/scripts/test.py b/scripts/test.py index 85b23e67..f8b9bbdd 100644 --- a/scripts/test.py +++ b/scripts/test.py @@ -609,6 +609,7 @@ def test_dense(ndim, dtype, metric, capability, stats_fixture): accurate_dt, accurate = profile(baseline_kernel, a.astype(np.float64), b.astype(np.float64)) expected_dt, expected = profile(baseline_kernel, a, b) result_dt, result = profile(simd_kernel, a, b) + result = np.array(result) np.testing.assert_allclose(result, expected.astype(np.float64), atol=SIMSIMD_ATOL, rtol=SIMSIMD_RTOL) collect_errors(metric, ndim, dtype, accurate, accurate_dt, expected, expected_dt, result, result_dt, stats_fixture) @@ -664,6 +665,7 @@ def test_curved(ndim, dtypes, metric, capability, stats_fixture): c.astype(compute_dtype), ) result_dt, result = profile(simd_kernel, a, b, c) + result = np.array(result) np.testing.assert_allclose(result, expected, atol=SIMSIMD_ATOL, rtol=SIMSIMD_RTOL) collect_errors(metric, ndim, dtype, accurate, accurate_dt, expected, expected_dt, result, result_dt, stats_fixture) @@ -690,6 +692,7 @@ def test_dense_bf16(ndim, metric, capability, stats_fixture): accurate_dt, accurate = profile(baseline_kernel, a_f32_rounded.astype(np.float64), b_f32_rounded.astype(np.float64)) expected_dt, expected = profile(baseline_kernel, a_f32_rounded, b_f32_rounded) result_dt, result = profile(simd_kernel, a_bf16, b_bf16, "bf16") + result = np.array(result) np.testing.assert_allclose( result, @@ -746,6 +749,7 @@ def test_curved_bf16(ndim, metric, capability, stats_fixture): ) expected_dt, expected = profile(baseline_kernel, a_f32_rounded, b_f32_rounded, c_f32_rounded) result_dt, result = profile(simd_kernel, a_bf16, b_bf16, c_bf16, "bf16") + result = np.array(result) np.testing.assert_allclose( result, @@ -791,6 +795,7 @@ def test_dense_i8(ndim, dtype, metric, capability, stats_fixture): accurate_dt, accurate = profile(baseline_kernel, a.astype(np.float64), b.astype(np.float64)) expected_dt, expected = profile(baseline_kernel, a.astype(np.int64), b.astype(np.int64)) result_dt, result = profile(simd_kernel, a, b) + result = np.array(result) if metric == "inner": assert round(float(result)) == round(float(expected)), f"Expected {expected}, but got {result}" @@ -828,6 +833,7 @@ def test_dense_bits(ndim, metric, capability, stats_fixture): accurate_dt, accurate = profile(baseline_kernel, a.astype(np.uint64), b.astype(np.uint64)) expected_dt, expected = profile(baseline_kernel, a, b) result_dt, result = profile(simd_kernel, np.packbits(a), np.packbits(b), "b8") + result = np.array(result) np.testing.assert_allclose(result, expected, atol=SIMSIMD_ATOL, rtol=SIMSIMD_RTOL) collect_errors(metric, ndim, "bits", accurate, accurate_dt, expected, expected_dt, result, result_dt, stats_fixture) @@ -851,6 +857,7 @@ def test_jensen_shannon(ndim, dtype, capability): accurate_dt, accurate = profile(baseline_kernel, a.astype(np.float64), b.astype(np.float64)) expected_dt, expected = profile(baseline_kernel, a, b) result_dt, result = profile(simd_kernel, a, b) + result = np.array(result) np.testing.assert_allclose(result, expected, atol=SIMSIMD_ATOL, rtol=SIMSIMD_RTOL) collect_errors( @@ -962,6 +969,7 @@ def test_dot_complex(ndim, dtype, capability, stats_fixture): accurate_dt, accurate = profile(np.dot, a.astype(np.complex128), b.astype(np.complex128)) expected_dt, expected = profile(np.dot, a, b) result_dt, result = profile(simd.dot, a, b) + result = np.array(result) np.testing.assert_allclose(result, expected, atol=SIMSIMD_ATOL, rtol=SIMSIMD_RTOL) collect_errors( @@ -971,6 +979,7 @@ def test_dot_complex(ndim, dtype, capability, stats_fixture): accurate_dt, accurate = profile(np.vdot, a.astype(np.complex128), b.astype(np.complex128)) expected_dt, expected = profile(np.vdot, a, b) result_dt, result = profile(simd.vdot, a, b) + result = np.array(result) np.testing.assert_allclose(result, expected, atol=SIMSIMD_ATOL, rtol=SIMSIMD_RTOL) collect_errors( @@ -1075,6 +1084,7 @@ def test_fma(ndim, dtype, kernel, capability, stats_fixture): ) expected_dt, expected = profile(baseline_kernel, a, b, c, alpha=alpha, beta=beta) result_dt, result = profile(simd_kernel, a, b, c, alpha=alpha, beta=beta) + result = np.array(result) np.testing.assert_allclose(result, expected.astype(np.float64), atol=atol, rtol=rtol) collect_errors( @@ -1132,6 +1142,7 @@ def test_wsum(ndim, dtype, kernel, capability, stats_fixture): ) expected_dt, expected = profile(baseline_kernel, a, b, alpha=alpha, beta=beta) result_dt, result = profile(simd_kernel, a, b, alpha=alpha, beta=beta) + result = np.array(result) np.testing.assert_allclose(result, expected.astype(np.float64), atol=atol, rtol=rtol) collect_errors(