diff --git a/include/simsimd/binary.h b/include/simsimd/binary.h index d83e3aca..2f6a8059 100644 --- a/include/simsimd/binary.h +++ b/include/simsimd/binary.h @@ -160,30 +160,72 @@ SIMSIMD_PUBLIC void simsimd_jaccard_b8_neon(simsimd_b8_t const* a, simsimd_b8_t SIMSIMD_PUBLIC void simsimd_hamming_b8_sve(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, simsimd_distance_t* result) { - simsimd_size_t i = 0; + + // On very small register sizes, NEON is at least as fast as SVE. + simsimd_size_t const words_per_register = svcntb(); + if (words_per_register <= 32) { + simsimd_hamming_b8_neon(a, b, n_words, result); + return; + } + + // On larger register sizes, SVE is faster. + simsimd_size_t i = 0, cycle = 0; simsimd_i32_t differences = 0; - do { - svbool_t pg_vec = svwhilelt_b8((unsigned int)i, (unsigned int)n_words); - svuint8_t a_vec = svld1_u8(pg_vec, a + i); - svuint8_t b_vec = svld1_u8(pg_vec, b + i); - differences += svaddv_u8(svptrue_b8(), svcnt_u8_x(svptrue_b8(), sveor_u8_m(svptrue_b8(), a_vec, b_vec))); - i += svcntb(); - } while (i < n_words); + svuint8_t differences_cycle_vec = svdup_n_u8(0); + svbool_t const all_vec = svptrue_b8(); + while (i < n_words) { + do { + svbool_t pg_vec = svwhilelt_b8((unsigned int)i, (unsigned int)n_words); + svuint8_t a_vec = svld1_u8(pg_vec, a + i); + svuint8_t b_vec = svld1_u8(pg_vec, b + i); + differences_cycle_vec = + svadd_u8_z(all_vec, differences_cycle_vec, svcnt_u8_x(all_vec, sveor_u8_m(all_vec, a_vec, b_vec))); + i += words_per_register; + ++cycle; + } while (i < n_words && cycle < 31); + differences += svaddv_u8(all_vec, differences_cycle_vec); + differences_cycle_vec = svdup_n_u8(0); + cycle = 0; // Reset the cycle counter. + } + *result = differences; } SIMSIMD_PUBLIC void simsimd_jaccard_b8_sve(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, simsimd_distance_t* result) { - simsimd_size_t i = 0; + + // On very small register sizes, NEON is at least as fast as SVE. + simsimd_size_t const words_per_register = svcntb(); + if (words_per_register <= 32) { + simsimd_jaccard_b8_neon(a, b, n_words, result); + return; + } + + // On larger register sizes, SVE is faster. + simsimd_size_t i = 0, cycle = 0; simsimd_i32_t intersection = 0, union_ = 0; - do { - svbool_t pg_vec = svwhilelt_b8((unsigned int)i, (unsigned int)n_words); - svuint8_t a_vec = svld1_u8(pg_vec, a + i); - svuint8_t b_vec = svld1_u8(pg_vec, b + i); - intersection += svaddv_u8(svptrue_b8(), svcnt_u8_x(svptrue_b8(), svand_u8_m(svptrue_b8(), a_vec, b_vec))); - union_ += svaddv_u8(svptrue_b8(), svcnt_u8_x(svptrue_b8(), svorr_u8_m(svptrue_b8(), a_vec, b_vec))); - i += svcntb(); - } while (i < n_words); + svuint8_t intersection_cycle_vec = svdup_n_u8(0); + svuint8_t union_cycle_vec = svdup_n_u8(0); + svbool_t const all_vec = svptrue_b8(); + while (i < n_words) { + do { + svbool_t pg_vec = svwhilelt_b8((unsigned int)i, (unsigned int)n_words); + svuint8_t a_vec = svld1_u8(pg_vec, a + i); + svuint8_t b_vec = svld1_u8(pg_vec, b + i); + intersection_cycle_vec = + svadd_u8_z(all_vec, intersection_cycle_vec, svcnt_u8_x(all_vec, svand_u8_m(all_vec, a_vec, b_vec))); + union_cycle_vec = + svadd_u8_z(all_vec, union_cycle_vec, svcnt_u8_x(all_vec, svorr_u8_m(all_vec, a_vec, b_vec))); + i += words_per_register; + ++cycle; + } while (i < n_words && cycle < 31); + intersection += svaddv_u8(all_vec, intersection_cycle_vec); + intersection_cycle_vec = svdup_n_u8(0); + union_ += svaddv_u8(all_vec, union_cycle_vec); + union_cycle_vec = svdup_n_u8(0); + cycle = 0; // Reset the cycle counter. + } + *result = (union_ != 0) ? 1 - (simsimd_f64_t)intersection / (simsimd_f64_t)union_ : 1; }