Skip to content

Commit

Permalink
Improve: Type-casting logic
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Nov 5, 2024
1 parent a334e99 commit 14fd5d3
Show file tree
Hide file tree
Showing 9 changed files with 589 additions and 519 deletions.
93 changes: 57 additions & 36 deletions include/simsimd/curved.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,13 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_sapphire(simsimd_f16_t const* a, sim
SIMSIMD_PUBLIC void simsimd_bilinear_##input_type##_##name( \
simsimd_##input_type##_t const *a, simsimd_##input_type##_t const *b, simsimd_##input_type##_t const *c, \
simsimd_size_t n, simsimd_distance_t *result) { \
simsimd_##accumulator_type##_t sum = 0; \
simsimd_##accumulator_type##_t sum = 0, a_i, b_j, c_ij; \
for (simsimd_size_t i = 0; i != n; ++i) { \
simsimd_##accumulator_type##_t partial = 0; \
simsimd_##accumulator_type##_t a_i = load_and_convert(a + i); \
load_and_convert(a + i, &a_i); \
for (simsimd_size_t j = 0; j != n; ++j) { \
simsimd_##accumulator_type##_t b_j = load_and_convert(b + j); \
simsimd_##accumulator_type##_t c_ij = load_and_convert(c + i * n + j); \
load_and_convert(b + j, &b_j); \
load_and_convert(c + i * n + j, &c_ij); \
partial += c_ij * b_j; \
} \
sum += a_i * partial; \
Expand All @@ -114,40 +114,44 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_sapphire(simsimd_f16_t const* a, sim
SIMSIMD_PUBLIC void simsimd_mahalanobis_##input_type##_##name( \
simsimd_##input_type##_t const *a, simsimd_##input_type##_t const *b, simsimd_##input_type##_t const *c, \
simsimd_size_t n, simsimd_distance_t *result) { \
simsimd_##accumulator_type##_t sum = 0; \
simsimd_##accumulator_type##_t sum = 0, a_i, a_j, b_i, b_j, c_ij; \
for (simsimd_size_t i = 0; i != n; ++i) { \
simsimd_##accumulator_type##_t partial = 0; \
simsimd_##accumulator_type##_t diff_i = load_and_convert(a + i) - load_and_convert(b + i); \
load_and_convert(a + i, &a_i); \
load_and_convert(b + i, &b_i); \
simsimd_##accumulator_type##_t diff_i = a_i - b_i; \
for (simsimd_size_t j = 0; j != n; ++j) { \
simsimd_##accumulator_type##_t diff_j = load_and_convert(a + j) - load_and_convert(b + j); \
simsimd_##accumulator_type##_t c_ij = load_and_convert(c + i * n + j); \
load_and_convert(a + j, &a_j); \
load_and_convert(b + j, &b_j); \
simsimd_##accumulator_type##_t diff_j = a_j - b_j; \
load_and_convert(c + i * n + j, &c_ij); \
partial += c_ij * diff_j; \
} \
sum += diff_i * partial; \
} \
*result = (simsimd_distance_t)SIMSIMD_SQRT(sum); \
}

SIMSIMD_MAKE_BILINEAR(serial, f64, f64, SIMSIMD_DEREFERENCE) // simsimd_bilinear_f64_serial
SIMSIMD_MAKE_MAHALANOBIS(serial, f64, f64, SIMSIMD_DEREFERENCE) // simsimd_mahalanobis_f64_serial
SIMSIMD_MAKE_BILINEAR(serial, f64, f64, _SIMSIMD_ASSIGN_1_TO_2) // simsimd_bilinear_f64_serial
SIMSIMD_MAKE_MAHALANOBIS(serial, f64, f64, _SIMSIMD_ASSIGN_1_TO_2) // simsimd_mahalanobis_f64_serial

SIMSIMD_MAKE_BILINEAR(serial, f32, f32, SIMSIMD_DEREFERENCE) // simsimd_bilinear_f32_serial
SIMSIMD_MAKE_MAHALANOBIS(serial, f32, f32, SIMSIMD_DEREFERENCE) // simsimd_mahalanobis_f32_serial
SIMSIMD_MAKE_BILINEAR(serial, f32, f32, _SIMSIMD_ASSIGN_1_TO_2) // simsimd_bilinear_f32_serial
SIMSIMD_MAKE_MAHALANOBIS(serial, f32, f32, _SIMSIMD_ASSIGN_1_TO_2) // simsimd_mahalanobis_f32_serial

SIMSIMD_MAKE_BILINEAR(serial, f16, f32, SIMSIMD_F16_TO_F32) // simsimd_bilinear_f16_serial
SIMSIMD_MAKE_MAHALANOBIS(serial, f16, f32, SIMSIMD_F16_TO_F32) // simsimd_mahalanobis_f16_serial
SIMSIMD_MAKE_BILINEAR(serial, f16, f32, simsimd_f16_to_f32) // simsimd_bilinear_f16_serial
SIMSIMD_MAKE_MAHALANOBIS(serial, f16, f32, simsimd_f16_to_f32) // simsimd_mahalanobis_f16_serial

SIMSIMD_MAKE_BILINEAR(serial, bf16, f32, SIMSIMD_BF16_TO_F32) // simsimd_bilinear_bf16_serial
SIMSIMD_MAKE_MAHALANOBIS(serial, bf16, f32, SIMSIMD_BF16_TO_F32) // simsimd_mahalanobis_bf16_serial
SIMSIMD_MAKE_BILINEAR(serial, bf16, f32, simsimd_bf16_to_f32) // simsimd_bilinear_bf16_serial
SIMSIMD_MAKE_MAHALANOBIS(serial, bf16, f32, simsimd_bf16_to_f32) // simsimd_mahalanobis_bf16_serial

SIMSIMD_MAKE_BILINEAR(accurate, f32, f64, SIMSIMD_DEREFERENCE) // simsimd_bilinear_f32_accurate
SIMSIMD_MAKE_MAHALANOBIS(accurate, f32, f64, SIMSIMD_DEREFERENCE) // simsimd_mahalanobis_f32_accurate
SIMSIMD_MAKE_BILINEAR(accurate, f32, f64, _SIMSIMD_ASSIGN_1_TO_2) // simsimd_bilinear_f32_accurate
SIMSIMD_MAKE_MAHALANOBIS(accurate, f32, f64, _SIMSIMD_ASSIGN_1_TO_2) // simsimd_mahalanobis_f32_accurate

SIMSIMD_MAKE_BILINEAR(accurate, f16, f64, SIMSIMD_F16_TO_F32) // simsimd_bilinear_f16_accurate
SIMSIMD_MAKE_MAHALANOBIS(accurate, f16, f64, SIMSIMD_F16_TO_F32) // simsimd_mahalanobis_f16_accurate
SIMSIMD_MAKE_BILINEAR(accurate, f16, f64, _simsimd_f16_to_f64) // simsimd_bilinear_f16_accurate
SIMSIMD_MAKE_MAHALANOBIS(accurate, f16, f64, _simsimd_f16_to_f64) // simsimd_mahalanobis_f16_accurate

SIMSIMD_MAKE_BILINEAR(accurate, bf16, f64, SIMSIMD_BF16_TO_F32) // simsimd_bilinear_bf16_accurate
SIMSIMD_MAKE_MAHALANOBIS(accurate, bf16, f64, SIMSIMD_BF16_TO_F32) // simsimd_mahalanobis_bf16_accurate
SIMSIMD_MAKE_BILINEAR(accurate, bf16, f64, _simsimd_bf16_to_f64) // simsimd_bilinear_bf16_accurate
SIMSIMD_MAKE_MAHALANOBIS(accurate, bf16, f64, _simsimd_bf16_to_f64) // simsimd_mahalanobis_bf16_accurate

#if _SIMSIMD_TARGET_ARM
#if SIMSIMD_TARGET_NEON
Expand Down Expand Up @@ -313,7 +317,9 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16_neon(simsimd_bf16_t const *a, simsimd_
simsimd_bf16_t const *c, simsimd_size_t n, simsimd_distance_t *result) {
float32x4_t sum_vec = vdupq_n_f32(0);
for (simsimd_size_t i = 0; i != n; ++i) {
float32x4_t a_vec = vdupq_n_f32(simsimd_bf16_to_f32(a + i));
simsimd_f32_t a_i;
simsimd_bf16_to_f32(a + i, &a_i);
float32x4_t a_vec = vdupq_n_f32(a_i);
float32x4_t partial_sum_vec = vdupq_n_f32(0);
for (simsimd_size_t j = 0; j + 8 <= n; j += 8) {
bfloat16x8_t b_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const *)(b + j));
Expand All @@ -329,7 +335,8 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16_neon(simsimd_bf16_t const *a, simsimd_
simsimd_size_t tail_start = n - tail_length;
if (tail_length) {
for (simsimd_size_t i = 0; i != n; ++i) {
simsimd_f32_t a_i = simsimd_bf16_to_f32(a + i);
simsimd_f32_t a_i;
simsimd_bf16_to_f32(a + i, &a_i);
bfloat16x8_t b_vec = _simsimd_partial_load_bf16x8_neon(b + tail_start, tail_length);
bfloat16x8_t c_vec = _simsimd_partial_load_bf16x8_neon(c + i * n + tail_start, tail_length);
simsimd_f32_t partial_sum = vaddvq_f32(vbfdotq_f32(vdupq_n_f32(0), b_vec, c_vec));
Expand All @@ -345,8 +352,9 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_neon(simsimd_bf16_t const *a, simsi
simsimd_distance_t *result) {
float32x4_t sum_vec = vdupq_n_f32(0);
for (simsimd_size_t i = 0; i != n; ++i) {
simsimd_f32_t a_i = simsimd_bf16_to_f32(a + i);
simsimd_f32_t b_i = simsimd_bf16_to_f32(b + i);
simsimd_f32_t a_i, b_i;
simsimd_bf16_to_f32(a + i, &a_i);
simsimd_bf16_to_f32(b + i, &b_i);
float32x4_t diff_i_vec = vdupq_n_f32(a_i - b_i);
float32x4_t partial_sum_vec = vdupq_n_f32(0);
for (simsimd_size_t j = 0; j + 8 <= n; j += 8) {
Expand Down Expand Up @@ -376,8 +384,9 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_neon(simsimd_bf16_t const *a, simsi
simsimd_size_t tail_start = n - tail_length;
if (tail_length) {
for (simsimd_size_t i = 0; i != n; ++i) {
simsimd_f32_t a_i = simsimd_bf16_to_f32(a + i);
simsimd_f32_t b_i = simsimd_bf16_to_f32(b + i);
simsimd_f32_t a_i, b_i;
simsimd_bf16_to_f32(a + i, &a_i);
simsimd_bf16_to_f32(b + i, &b_i);
simsimd_f32_t diff_i = a_i - b_i;
bfloat16x8_t a_j_vec = _simsimd_partial_load_bf16x8_neon(a + tail_start, tail_length);
bfloat16x8_t b_j_vec = _simsimd_partial_load_bf16x8_neon(b + tail_start, tail_length);
Expand Down Expand Up @@ -489,7 +498,9 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16_haswell(simsimd_bf16_t const *a, simsi
__m256 sum_vec = _mm256_setzero_ps();
for (simsimd_size_t i = 0; i != n; ++i) {
// The `simsimd_bf16_to_f32` is cheaper than `_simsimd_bf16x8_to_f32x8_haswell`
__m256 a_vec = _mm256_set1_ps(simsimd_bf16_to_f32(a + i));
simsimd_f32_t a_i;
simsimd_bf16_to_f32(a + i, &a_i);
__m256 a_vec = _mm256_set1_ps(a_i);
__m256 partial_sum_vec = _mm256_setzero_ps();
for (simsimd_size_t j = 0; j + 8 <= n; j += 8) {
__m256 b_vec = _simsimd_bf16x8_to_f32x8_haswell(_mm_lddqu_si128((__m128i const *)(b + j)));
Expand All @@ -505,7 +516,8 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16_haswell(simsimd_bf16_t const *a, simsi
simsimd_size_t tail_start = n - tail_length;
if (tail_length) {
for (simsimd_size_t i = 0; i != n; ++i) {
simsimd_f32_t a_i = simsimd_bf16_to_f32(a + i);
simsimd_f32_t a_i;
simsimd_bf16_to_f32(a + i, &a_i);
__m256 b_vec = _simsimd_bf16x8_to_f32x8_haswell( //
_simsimd_partial_load_bf16x8_haswell(b + tail_start, tail_length));
__m256 c_vec = _simsimd_bf16x8_to_f32x8_haswell( //
Expand All @@ -523,9 +535,10 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_haswell(simsimd_bf16_t const *a, si
simsimd_distance_t *result) {
__m256 sum_vec = _mm256_setzero_ps();
for (simsimd_size_t i = 0; i != n; ++i) {
__m256 diff_i_vec = _mm256_sub_ps( //
_mm256_set1_ps(simsimd_bf16_to_f32(a + i)), //
_mm256_set1_ps(simsimd_bf16_to_f32(b + i)));
simsimd_f32_t a_i, b_i;
simsimd_bf16_to_f32(a + i, &a_i);
simsimd_bf16_to_f32(b + i, &b_i);
__m256 diff_i_vec = _mm256_set1_ps(a_i - b_i);
__m256 partial_sum_vec = _mm256_setzero_ps();
for (simsimd_size_t j = 0; j + 8 <= n; j += 8) {
__m256 diff_j_vec = _mm256_sub_ps( //
Expand All @@ -543,7 +556,10 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_haswell(simsimd_bf16_t const *a, si
simsimd_size_t tail_start = n - tail_length;
if (tail_length) {
for (simsimd_size_t i = 0; i != n; ++i) {
simsimd_f32_t diff_i = simsimd_bf16_to_f32(a + i) - simsimd_bf16_to_f32(b + i);
simsimd_f32_t a_i, b_i;
simsimd_bf16_to_f32(a + i, &a_i);
simsimd_bf16_to_f32(b + i, &b_i);
simsimd_f32_t diff_i = a_i - b_i;
__m256 diff_j_vec = _mm256_sub_ps( //
_simsimd_bf16x8_to_f32x8_haswell(_simsimd_partial_load_bf16x8_haswell(a + tail_start, tail_length)),
_simsimd_bf16x8_to_f32x8_haswell(_simsimd_partial_load_bf16x8_haswell(b + tail_start, tail_length)));
Expand Down Expand Up @@ -651,7 +667,9 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16_genoa(simsimd_bf16_t const *a, simsimd
__mmask32 tail_mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, tail_length);

for (simsimd_size_t i = 0; i != n; ++i) {
__m512 a_vec = _mm512_set1_ps(simsimd_bf16_to_f32(a + i));
simsimd_f32_t a_i;
simsimd_bf16_to_f32(a + i, &a_i);
__m512 a_vec = _mm512_set1_ps(a_i);
__m512 partial_sum_vec = _mm512_setzero_ps();
__m512i b_vec, c_vec;
simsimd_size_t j = 0;
Expand Down Expand Up @@ -683,7 +701,10 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_genoa(simsimd_bf16_t const *a, sims
__mmask32 tail_mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, tail_length);

for (simsimd_size_t i = 0; i != n; ++i) {
__m512 diff_i_vec = _mm512_set1_ps(simsimd_bf16_to_f32(a + i) - simsimd_bf16_to_f32(b + i));
simsimd_f32_t a_i, b_i;
simsimd_bf16_to_f32(a + i, &a_i);
simsimd_bf16_to_f32(b + i, &b_i);
__m512 diff_i_vec = _mm512_set1_ps(a_i - b_i);
__m512 partial_sum_vec = _mm512_setzero_ps();
__m512i a_j_vec, b_j_vec, diff_j_vec, c_vec;
simsimd_size_t j = 0;
Expand Down
Loading

0 comments on commit 14fd5d3

Please sign in to comment.