Skip to content

Commit

Permalink
Merge pull request #215 from ashvardanian/main-dev
Browse files Browse the repository at this point in the history
Missing NEON Dispatch
  • Loading branch information
ashvardanian authored Oct 27, 2024
2 parents 2a0a5bd + a32b187 commit efd2a52
Show file tree
Hide file tree
Showing 6 changed files with 539 additions and 91 deletions.
35 changes: 34 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Implemented distance functions include:
- Set Intersections for Sparse Vectors and Text Analysis. _[docs][docs-sparse]_
- Mahalanobis distance and Quadratic forms for Scientific Computing. _[docs][docs-curved]_
- Kullback-Leibler and Jensen–Shannon divergences for probability distributions. _[docs][docs-probability]_
- Fused-Multiply-Add (FMA) and Weighted Sums to replace BLAS level 1 functions. _[docs][docs-fma]_
- For Levenshtein, Needleman–Wunsch, and Smith-Waterman, check [StringZilla][stringzilla].
- 🔜 Haversine and Vincenty's formulae for Geospatial Analysis.

Expand All @@ -61,6 +62,7 @@ Implemented distance functions include:
[docs-binary]: https://github.com/ashvardanian/SimSIMD/pull/138
[docs-dot]: #complex-dot-products-conjugate-dot-products-and-complex-numbers
[docs-probability]: #logarithms-in-kullback-leibler--jensenshannon-divergences
[docs-fma]: #mixed-precision-in-fused-multiply-add-and-weighted-sums
[scipy]: https://docs.scipy.org/doc/scipy/reference/spatial.distance.html#module-scipy.spatial.distance
[numpy]: https://numpy.org/doc/stable/reference/generated/numpy.inner.html
[stringzilla]: https://github.com/ashvardanian/stringzilla
Expand Down Expand Up @@ -122,7 +124,8 @@ Use the following snippet to install SimSIMD and list available hardware acceler

```sh
pip install simsimd
python -c "import simsimd; print(simsimd.get_capabilities())"
python -c "import simsimd; print(simsimd.get_capabilities())" # for hardware introspection
python -c "import simsimd; help(simsimd)" # for documentation
```

With precompiled binaries, SimSIMD ships `.pyi` interface files for type hinting and static analysis.
Expand Down Expand Up @@ -929,6 +932,36 @@ Jensen-Shannon divergence is a symmetrized and smoothed version of the Kullback-

Both functions are defined for non-negative numbers, and the logarithm is a key part of their computation.

### Mixed Precision in Fused-Multiply-Add and Weighted Sums

The Fused-Multiply-Add (FMA) operation is a single operation that combines element-wise multiplication and addition with different scaling factors.
The Weighted Sum is it's simplified variant without element-wise multiplication.

```math
\text{FMA}_i(A, B, C, \alpha, \beta) = \alpha \cdot A_i \cdot B_i + \beta \cdot C_i
```

```math
\text{WSum}_i(A, B, \alpha, \beta) = \alpha \cdot A_i + \beta \cdot B_i
```

In NumPy terms, the implementation may look like:

```py
import numpy as np
def wsum(A: np.ndarray, B: np.ndarray, Alpha: float, Beta: float) -> np.ndarray:
assert A.dtype == B.dtype, "Input types must match and affect the output style"
return (Alpha * A + Beta * B).astype(A.dtype)
def fma(A: np.ndarray, B: np.ndarray, C: np.ndarray, Alpha: float, Beta: float) -> np.ndarray:
assert A.dtype == B.dtype and A.dtype == C.dtype, "Input types must match and affect the output style"
return (Alpha * A * B + Beta * C).astype(A.dtype)
```

The tricky part is implementing those operations in mixed precision, where the scaling factors are of different precision than the input and output vectors.
SimSIMD uses double-precision floating-point scaling factors for any input and output precision, including `i8` and `u8` integers and `f16` and `bf16` floats.
Depending on the generation of the CPU, given native support for `f16` addition and multiplication, the `f16` temporaries are used for `i8` and `u8` multiplication, scaling, and addition.
For `bf16`, native support is generally limited to dot-products with subsequent partial accumulation, which is not enough for the FMA and WSum operations, so `f32` is used as a temporary.

### Auto-Vectorization & Loop Unrolling

On the Intel Sapphire Rapids platform, SimSIMD was benchmarked against auto-vectorized code using GCC 12.
Expand Down
177 changes: 116 additions & 61 deletions c/lib.c
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,34 @@ extern "C" {
metric(a, b, c, n, result); \
}

#define SIMSIMD_DECLARATION_FMA(name, extension, type) \
SIMSIMD_DYNAMIC void simsimd_##name##_##extension( \
simsimd_##type##_t const *a, simsimd_##type##_t const *b, simsimd_##type##_t const *c, simsimd_size_t n, \
simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_##type##_t *result) { \
static simsimd_kernel_fma_punned_t metric = 0; \
if (metric == 0) { \
simsimd_capability_t used_capability; \
simsimd_find_metric_punned(simsimd_metric_##name##_k, simsimd_datatype_##extension##_k, \
simsimd_capabilities(), simsimd_cap_any_k, \
(simsimd_metric_punned_t *)(&metric), &used_capability); \
} \
metric(a, b, c, n, alpha, beta, result); \
}

#define SIMSIMD_DECLARATION_WSUM(name, extension, type) \
SIMSIMD_DYNAMIC void simsimd_##name##_##extension(simsimd_##type##_t const *a, simsimd_##type##_t const *b, \
simsimd_size_t n, simsimd_distance_t alpha, \
simsimd_distance_t beta, simsimd_##type##_t *result) { \
static simsimd_kernel_wsum_punned_t metric = 0; \
if (metric == 0) { \
simsimd_capability_t used_capability; \
simsimd_find_metric_punned(simsimd_metric_##name##_k, simsimd_datatype_##extension##_k, \
simsimd_capabilities(), simsimd_cap_any_k, \
(simsimd_metric_punned_t *)(&metric), &used_capability); \
} \
metric(a, b, n, alpha, beta, result); \
}

// Dot products
SIMSIMD_DECLARATION_DENSE(dot, i8, i8)
SIMSIMD_DECLARATION_DENSE(dot, u8, u8)
Expand Down Expand Up @@ -171,6 +199,20 @@ SIMSIMD_DECLARATION_CURVED(mahalanobis, f16, f16)
SIMSIMD_DECLARATION_CURVED(bilinear, bf16, bf16)
SIMSIMD_DECLARATION_CURVED(mahalanobis, bf16, bf16)

// Element-wise operations
SIMSIMD_DECLARATION_FMA(fma, f64, f64)
SIMSIMD_DECLARATION_FMA(fma, f32, f32)
SIMSIMD_DECLARATION_FMA(fma, f16, f16)
SIMSIMD_DECLARATION_FMA(fma, bf16, bf16)
SIMSIMD_DECLARATION_FMA(fma, i8, i8)
SIMSIMD_DECLARATION_FMA(fma, u8, u8)
SIMSIMD_DECLARATION_WSUM(wsum, f64, f64)
SIMSIMD_DECLARATION_WSUM(wsum, f32, f32)
SIMSIMD_DECLARATION_WSUM(wsum, f16, f16)
SIMSIMD_DECLARATION_WSUM(wsum, bf16, bf16)
SIMSIMD_DECLARATION_WSUM(wsum, i8, i8)
SIMSIMD_DECLARATION_WSUM(wsum, u8, u8)

SIMSIMD_DYNAMIC int simsimd_uses_neon(void) { return (simsimd_capabilities() & simsimd_cap_neon_k) != 0; }
SIMSIMD_DYNAMIC int simsimd_uses_neon_f16(void) { return (simsimd_capabilities() & simsimd_cap_neon_f16_k) != 0; }
SIMSIMD_DYNAMIC int simsimd_uses_neon_bf16(void) { return (simsimd_capabilities() & simsimd_cap_neon_bf16_k) != 0; }
Expand Down Expand Up @@ -200,73 +242,86 @@ SIMSIMD_DYNAMIC simsimd_capability_t simsimd_capabilities(void) {
// with dummy inputs:
simsimd_distance_t dummy_results_buffer[2];
simsimd_distance_t *dummy_results = &dummy_results_buffer[0];
void *dummy = 0;
void *x = 0;

// Dense:
simsimd_dot_i8((simsimd_i8_t *)dummy, (simsimd_i8_t *)dummy, 0, dummy_results);
simsimd_dot_u8((simsimd_u8_t *)dummy, (simsimd_u8_t *)dummy, 0, dummy_results);
simsimd_dot_f16((simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, 0, dummy_results);
simsimd_dot_bf16((simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, 0, dummy_results);
simsimd_dot_f32((simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, 0, dummy_results);
simsimd_dot_f64((simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, 0, dummy_results);

simsimd_dot_f16c((simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, 0, dummy_results);
simsimd_dot_bf16c((simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, 0, dummy_results);
simsimd_dot_f32c((simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, 0, dummy_results);
simsimd_dot_f64c((simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, 0, dummy_results);
simsimd_vdot_f16c((simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, 0, dummy_results);
simsimd_vdot_bf16c((simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, 0, dummy_results);
simsimd_vdot_f32c((simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, 0, dummy_results);
simsimd_vdot_f64c((simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, 0, dummy_results);

simsimd_cos_i8((simsimd_i8_t *)dummy, (simsimd_i8_t *)dummy, 0, dummy_results);
simsimd_cos_u8((simsimd_u8_t *)dummy, (simsimd_u8_t *)dummy, 0, dummy_results);
simsimd_cos_f16((simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, 0, dummy_results);
simsimd_cos_bf16((simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, 0, dummy_results);
simsimd_cos_f32((simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, 0, dummy_results);
simsimd_cos_f64((simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, 0, dummy_results);

simsimd_l2sq_i8((simsimd_i8_t *)dummy, (simsimd_i8_t *)dummy, 0, dummy_results);
simsimd_l2sq_u8((simsimd_u8_t *)dummy, (simsimd_u8_t *)dummy, 0, dummy_results);
simsimd_l2sq_f16((simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, 0, dummy_results);
simsimd_l2sq_bf16((simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, 0, dummy_results);
simsimd_l2sq_f32((simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, 0, dummy_results);
simsimd_l2sq_f64((simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, 0, dummy_results);

simsimd_l2_i8((simsimd_i8_t *)dummy, (simsimd_i8_t *)dummy, 0, dummy_results);
simsimd_l2_i8((simsimd_i8_t *)dummy, (simsimd_i8_t *)dummy, 0, dummy_results);
simsimd_l2_u8((simsimd_u8_t *)dummy, (simsimd_u8_t *)dummy, 0, dummy_results);
simsimd_l2_f16((simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, 0, dummy_results);
simsimd_l2_bf16((simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, 0, dummy_results);
simsimd_l2_f32((simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, 0, dummy_results);
simsimd_l2_f64((simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, 0, dummy_results);

simsimd_hamming_b8((simsimd_b8_t *)dummy, (simsimd_b8_t *)dummy, 0, dummy_results);
simsimd_jaccard_b8((simsimd_b8_t *)dummy, (simsimd_b8_t *)dummy, 0, dummy_results);

simsimd_kl_f16((simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, 0, dummy_results);
simsimd_kl_bf16((simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, 0, dummy_results);
simsimd_kl_f32((simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, 0, dummy_results);
simsimd_kl_f64((simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, 0, dummy_results);
simsimd_js_f16((simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, 0, dummy_results);
simsimd_js_bf16((simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, 0, dummy_results);
simsimd_js_f32((simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, 0, dummy_results);
simsimd_js_f64((simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, 0, dummy_results);
simsimd_dot_i8((simsimd_i8_t *)x, (simsimd_i8_t *)x, 0, dummy_results);
simsimd_dot_u8((simsimd_u8_t *)x, (simsimd_u8_t *)x, 0, dummy_results);
simsimd_dot_f16((simsimd_f16_t *)x, (simsimd_f16_t *)x, 0, dummy_results);
simsimd_dot_bf16((simsimd_bf16_t *)x, (simsimd_bf16_t *)x, 0, dummy_results);
simsimd_dot_f32((simsimd_f32_t *)x, (simsimd_f32_t *)x, 0, dummy_results);
simsimd_dot_f64((simsimd_f64_t *)x, (simsimd_f64_t *)x, 0, dummy_results);

simsimd_dot_f16c((simsimd_f16_t *)x, (simsimd_f16_t *)x, 0, dummy_results);
simsimd_dot_bf16c((simsimd_bf16_t *)x, (simsimd_bf16_t *)x, 0, dummy_results);
simsimd_dot_f32c((simsimd_f32_t *)x, (simsimd_f32_t *)x, 0, dummy_results);
simsimd_dot_f64c((simsimd_f64_t *)x, (simsimd_f64_t *)x, 0, dummy_results);
simsimd_vdot_f16c((simsimd_f16_t *)x, (simsimd_f16_t *)x, 0, dummy_results);
simsimd_vdot_bf16c((simsimd_bf16_t *)x, (simsimd_bf16_t *)x, 0, dummy_results);
simsimd_vdot_f32c((simsimd_f32_t *)x, (simsimd_f32_t *)x, 0, dummy_results);
simsimd_vdot_f64c((simsimd_f64_t *)x, (simsimd_f64_t *)x, 0, dummy_results);

simsimd_cos_i8((simsimd_i8_t *)x, (simsimd_i8_t *)x, 0, dummy_results);
simsimd_cos_u8((simsimd_u8_t *)x, (simsimd_u8_t *)x, 0, dummy_results);
simsimd_cos_f16((simsimd_f16_t *)x, (simsimd_f16_t *)x, 0, dummy_results);
simsimd_cos_bf16((simsimd_bf16_t *)x, (simsimd_bf16_t *)x, 0, dummy_results);
simsimd_cos_f32((simsimd_f32_t *)x, (simsimd_f32_t *)x, 0, dummy_results);
simsimd_cos_f64((simsimd_f64_t *)x, (simsimd_f64_t *)x, 0, dummy_results);

simsimd_l2sq_i8((simsimd_i8_t *)x, (simsimd_i8_t *)x, 0, dummy_results);
simsimd_l2sq_u8((simsimd_u8_t *)x, (simsimd_u8_t *)x, 0, dummy_results);
simsimd_l2sq_f16((simsimd_f16_t *)x, (simsimd_f16_t *)x, 0, dummy_results);
simsimd_l2sq_bf16((simsimd_bf16_t *)x, (simsimd_bf16_t *)x, 0, dummy_results);
simsimd_l2sq_f32((simsimd_f32_t *)x, (simsimd_f32_t *)x, 0, dummy_results);
simsimd_l2sq_f64((simsimd_f64_t *)x, (simsimd_f64_t *)x, 0, dummy_results);

simsimd_l2_i8((simsimd_i8_t *)x, (simsimd_i8_t *)x, 0, dummy_results);
simsimd_l2_i8((simsimd_i8_t *)x, (simsimd_i8_t *)x, 0, dummy_results);
simsimd_l2_u8((simsimd_u8_t *)x, (simsimd_u8_t *)x, 0, dummy_results);
simsimd_l2_f16((simsimd_f16_t *)x, (simsimd_f16_t *)x, 0, dummy_results);
simsimd_l2_bf16((simsimd_bf16_t *)x, (simsimd_bf16_t *)x, 0, dummy_results);
simsimd_l2_f32((simsimd_f32_t *)x, (simsimd_f32_t *)x, 0, dummy_results);
simsimd_l2_f64((simsimd_f64_t *)x, (simsimd_f64_t *)x, 0, dummy_results);

simsimd_hamming_b8((simsimd_b8_t *)x, (simsimd_b8_t *)x, 0, dummy_results);
simsimd_jaccard_b8((simsimd_b8_t *)x, (simsimd_b8_t *)x, 0, dummy_results);

simsimd_kl_f16((simsimd_f16_t *)x, (simsimd_f16_t *)x, 0, dummy_results);
simsimd_kl_bf16((simsimd_bf16_t *)x, (simsimd_bf16_t *)x, 0, dummy_results);
simsimd_kl_f32((simsimd_f32_t *)x, (simsimd_f32_t *)x, 0, dummy_results);
simsimd_kl_f64((simsimd_f64_t *)x, (simsimd_f64_t *)x, 0, dummy_results);
simsimd_js_f16((simsimd_f16_t *)x, (simsimd_f16_t *)x, 0, dummy_results);
simsimd_js_bf16((simsimd_bf16_t *)x, (simsimd_bf16_t *)x, 0, dummy_results);
simsimd_js_f32((simsimd_f32_t *)x, (simsimd_f32_t *)x, 0, dummy_results);
simsimd_js_f64((simsimd_f64_t *)x, (simsimd_f64_t *)x, 0, dummy_results);

// Sparse
simsimd_intersect_u16((simsimd_u16_t *)dummy, (simsimd_u16_t *)dummy, 0, 0, dummy_results);
simsimd_intersect_u32((simsimd_u32_t *)dummy, (simsimd_u32_t *)dummy, 0, 0, dummy_results);
simsimd_intersect_u16((simsimd_u16_t *)x, (simsimd_u16_t *)x, 0, 0, dummy_results);
simsimd_intersect_u32((simsimd_u32_t *)x, (simsimd_u32_t *)x, 0, 0, dummy_results);

// Curved:
simsimd_bilinear_f64((simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, 0, dummy_results);
simsimd_mahalanobis_f64((simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, (simsimd_f64_t *)dummy, 0, dummy_results);
simsimd_bilinear_f32((simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, 0, dummy_results);
simsimd_mahalanobis_f32((simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, (simsimd_f32_t *)dummy, 0, dummy_results);
simsimd_bilinear_f16((simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, 0, dummy_results);
simsimd_mahalanobis_f16((simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, (simsimd_f16_t *)dummy, 0, dummy_results);
simsimd_bilinear_bf16((simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, 0, dummy_results);
simsimd_mahalanobis_bf16((simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, (simsimd_bf16_t *)dummy, 0,
dummy_results);
simsimd_bilinear_f64((simsimd_f64_t *)x, (simsimd_f64_t *)x, (simsimd_f64_t *)x, 0, dummy_results);
simsimd_mahalanobis_f64((simsimd_f64_t *)x, (simsimd_f64_t *)x, (simsimd_f64_t *)x, 0, dummy_results);
simsimd_bilinear_f32((simsimd_f32_t *)x, (simsimd_f32_t *)x, (simsimd_f32_t *)x, 0, dummy_results);
simsimd_mahalanobis_f32((simsimd_f32_t *)x, (simsimd_f32_t *)x, (simsimd_f32_t *)x, 0, dummy_results);
simsimd_bilinear_f16((simsimd_f16_t *)x, (simsimd_f16_t *)x, (simsimd_f16_t *)x, 0, dummy_results);
simsimd_mahalanobis_f16((simsimd_f16_t *)x, (simsimd_f16_t *)x, (simsimd_f16_t *)x, 0, dummy_results);
simsimd_bilinear_bf16((simsimd_bf16_t *)x, (simsimd_bf16_t *)x, (simsimd_bf16_t *)x, 0, dummy_results);
simsimd_mahalanobis_bf16((simsimd_bf16_t *)x, (simsimd_bf16_t *)x, (simsimd_bf16_t *)x, 0, dummy_results);

// Elementwise
simsimd_wsum_f64((simsimd_f64_t *)x, (simsimd_f64_t *)x, 0, 0, 0, (simsimd_f64_t *)x);
simsimd_wsum_f32((simsimd_f32_t *)x, (simsimd_f32_t *)x, 0, 0, 0, (simsimd_f32_t *)x);
simsimd_wsum_f16((simsimd_f16_t *)x, (simsimd_f16_t *)x, 0, 0, 0, (simsimd_f16_t *)x);
simsimd_wsum_bf16((simsimd_bf16_t *)x, (simsimd_bf16_t *)x, 0, 0, 0, (simsimd_bf16_t *)x);
simsimd_wsum_i8((simsimd_i8_t *)x, (simsimd_i8_t *)x, 0, 0, 0, (simsimd_i8_t *)x);
simsimd_wsum_u8((simsimd_u8_t *)x, (simsimd_u8_t *)x, 0, 0, 0, (simsimd_u8_t *)x);
simsimd_fma_f64((simsimd_f64_t *)x, (simsimd_f64_t *)x, (simsimd_f64_t *)x, 0, 0, 0, (simsimd_f64_t *)x);
simsimd_fma_f32((simsimd_f32_t *)x, (simsimd_f32_t *)x, (simsimd_f32_t *)x, 0, 0, 0, (simsimd_f32_t *)x);
simsimd_fma_f16((simsimd_f16_t *)x, (simsimd_f16_t *)x, (simsimd_f16_t *)x, 0, 0, 0, (simsimd_f16_t *)x);
simsimd_fma_bf16((simsimd_bf16_t *)x, (simsimd_bf16_t *)x, (simsimd_bf16_t *)x, 0, 0, 0, (simsimd_bf16_t *)x);
simsimd_fma_i8((simsimd_i8_t *)x, (simsimd_i8_t *)x, (simsimd_i8_t *)x, 0, 0, 0, (simsimd_i8_t *)x);
simsimd_fma_u8((simsimd_u8_t *)x, (simsimd_u8_t *)x, (simsimd_u8_t *)x, 0, 0, 0, (simsimd_u8_t *)x);

return static_capabilities;
}
Expand Down
Loading

0 comments on commit efd2a52

Please sign in to comment.