Skip to content

Commit

Permalink
Merge pull request #16 from ashvardanian/main-dev
Browse files Browse the repository at this point in the history
x86 Optimizations
  • Loading branch information
ashvardanian authored Sep 5, 2023
2 parents 05ee5ee + ff2d333 commit a3239cc
Showing 1 changed file with 106 additions and 40 deletions.
146 changes: 106 additions & 40 deletions include/simsimd/simsimd.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
#endif
#endif

#undef SIMSIMD_TARGET_X86
#undef SIMSIMD_TARGET_X86_AVX2
#undef SIMSIMD_TARGET_X86_AVX512
#define SIMSIMD_TARGET_X86 0
#define SIMSIMD_TARGET_X86_AVX2 0
#define SIMSIMD_TARGET_X86_AVX512 0
Expand All @@ -71,6 +74,9 @@
#endif
#endif

#undef SIMSIMD_TARGET_ARM
#undef SIMSIMD_TARGET_ARM_NEON
#undef SIMSIMD_TARGET_ARM_SVE
#define SIMSIMD_TARGET_ARM 0
#define SIMSIMD_TARGET_ARM_NEON 0
#define SIMSIMD_TARGET_ARM_SVE 0
Expand Down Expand Up @@ -143,7 +149,7 @@ simsimd_cos_f32_sve(simsimd_f32_t const* a, simsimd_f32_t const* b, size_t d) {
simsimd_f32_t ab = svaddv_f32(svptrue_b32(), ab_vec);
simsimd_f32_t a2 = svaddv_f32(svptrue_b32(), a2_vec);
simsimd_f32_t b2 = svaddv_f32(svptrue_b32(), b2_vec);
return 1 - ab / (sqrt(a2) * sqrt(b2));
return 1 - ab / (sqrtf(a2) * sqrtf(b2));
}

__attribute__((target("+sve"))) inline static simsimd_f32_t //
Expand Down Expand Up @@ -184,7 +190,7 @@ simsimd_cos_f16_sve(simsimd_f16_t const* a_enum, simsimd_f16_t const* b_enum, si
simsimd_f16_t ab = svaddv_f16(svptrue_b16(), ab_vec);
simsimd_f16_t a2 = svaddv_f16(svptrue_b16(), a2_vec);
simsimd_f16_t b2 = svaddv_f16(svptrue_b16(), b2_vec);
return 1 - ab / (sqrt(a2) * sqrt(b2));
return 1 - ab / (sqrtf(a2) * sqrtf(b2));
}

__attribute__((target("+sve+fp16"))) inline static simsimd_f32_t //
Expand Down Expand Up @@ -235,6 +241,18 @@ simsimd_dot_f32x4_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, size_t d)
return 1 - vaddvq_f32(ab_vec);
}

__attribute__((target("+simd+fp16"))) inline static simsimd_f32_t //
simsimd_dot_f16x4_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, size_t d) {

float32x4_t ab_vec = vdupq_n_f32(0);
for (size_t i = 0; i != d; i += 4) {
float32x4_t a_vec = vcvt_f32_f16(vld1_f16((float16_t const*)a + i));
float32x4_t b_vec = vcvt_f32_f16(vld1_f16((float16_t const*)b + i));
ab_vec = vfmaq_f32(ab_vec, a_vec, b_vec);
}
return 1 - vaddvq_f32(ab_vec);
}

__attribute__((target("+simd+fp16"))) inline static simsimd_f32_t //
simsimd_cos_f16x4_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, size_t d) {

Expand All @@ -252,7 +270,7 @@ simsimd_cos_f16x4_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, size_t d)
simsimd_f32_t ab = vaddvq_f32(ab_vec);
simsimd_f32_t a2 = vaddvq_f32(a2_vec);
simsimd_f32_t b2 = vaddvq_f32(b2_vec);
return 1 - ab / (sqrt(a2) * sqrt(b2));
return 1 - ab / (sqrtf(a2) * sqrtf(b2));
}

__attribute__((target("+simd"))) inline static simsimd_f32_t //
Expand Down Expand Up @@ -290,7 +308,7 @@ simsimd_cos_i8x16_neon(int8_t const* a, int8_t const* b, size_t d) {
int32_t a2 = vget_lane_s32(vpadd_s32(a2_part, a2_part), 0);
int32x2_t b2_part = vadd_s32(vget_high_s32(b2_vec), vget_low_s32(b2_vec));
int32_t b2 = vget_lane_s32(vpadd_s32(b2_part, b2_part), 0);
return 1 - ab / (sqrt(a2) * sqrt(b2));
return 1 - ab / (sqrtf(a2) * sqrtf(b2));
}

__attribute__((target("+simd"))) inline static simsimd_f32_t //
Expand Down Expand Up @@ -327,7 +345,7 @@ simsimd_cos_f32x4_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, size_t d)
simsimd_f32_t ab = vaddvq_f32(ab_vec);
simsimd_f32_t a2 = vaddvq_f32(a2_vec);
simsimd_f32_t b2 = vaddvq_f32(b2_vec);
return 1 - ab / (sqrt(a2) * sqrt(b2));
return 1 - ab / (sqrtf(a2) * sqrtf(b2));
}

__attribute__((target("+sve"))) inline static simsimd_f32_t //
Expand Down Expand Up @@ -385,7 +403,7 @@ simsimd_cos_f32x4_avx2(simsimd_f32_t const* a, simsimd_f32_t const* b, size_t d)
union simsimd_f32i32_t ab_union = {_mm_cvtsi128_si32(_mm_castps_si128(ab_vec))};
union simsimd_f32i32_t a2_union = {_mm_cvtsi128_si32(_mm_castps_si128(a2_vec))};
union simsimd_f32i32_t b2_union = {_mm_cvtsi128_si32(_mm_castps_si128(b2_vec))};
return 1 - ab_union.f / (sqrt(a2_union.f) * sqrt(b2_union.f));
return 1 - ab_union.f / (sqrtf(a2_union.f) * sqrtf(b2_union.f));
}

__attribute__((target("avx2"))) //
Expand All @@ -407,64 +425,112 @@ __attribute__((target("avx2"))) //
__attribute__((target("f16c"))) //
__attribute__((target("fma"))) inline static simsimd_f32_t //
simsimd_dot_f16x8_avx2(simsimd_f16_t const* a, simsimd_f16_t const* b, size_t d) {
__m128 ab_vec = _mm_set1_ps(0);
__m256 ab_vec = _mm256_set1_ps(0);
for (size_t i = 0; i != d; i += 8) {
__m128 a_vec = _mm_cvtph_ps(_mm_loadu_si128((__m128i const*)(a + i)));
__m128 b_vec = _mm_cvtph_ps(_mm_loadu_si128((__m128i const*)(b + i)));
ab_vec = _mm_fmadd_ps(a_vec, b_vec, ab_vec);
__m256 a_vec = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const*)(a + i)));
__m256 b_vec = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const*)(b + i)));
ab_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_vec);
}
ab_vec = _mm_hadd_ps(ab_vec, ab_vec);
ab_vec = _mm_hadd_ps(ab_vec, ab_vec);
ab_vec = _mm256_add_ps(_mm256_permute2f128_ps(ab_vec, ab_vec, 0x81), ab_vec);
ab_vec = _mm256_hadd_ps(ab_vec, ab_vec);
ab_vec = _mm256_hadd_ps(ab_vec, ab_vec);
float result[1];
_mm_store_ss(result, ab_vec);
_mm_store_ss(result, _mm256_castps256_ps128(ab_vec));
return 1 - result[0];
}

__attribute__((target("avx2"))) //
__attribute__((target("f16c"))) //
__attribute__((target("fma"))) inline static simsimd_f32_t //
simsimd_cos_f16x8_avx2(simsimd_f16_t const* a, simsimd_f16_t const* b, size_t d) {
__m128 ab_vec = _mm_set1_ps(0);
__m128 a2_vec = _mm_set1_ps(0);
__m128 b2_vec = _mm_set1_ps(0);
__m256 ab_vec = _mm256_set1_ps(0);
__m256 a2_vec = _mm256_set1_ps(0);
__m256 b2_vec = _mm256_set1_ps(0);
for (size_t i = 0; i != d; i += 8) {
__m128 a_vec = _mm_cvtph_ps(_mm_loadu_si128((__m128i const*)(a + i)));
__m128 b_vec = _mm_cvtph_ps(_mm_loadu_si128((__m128i const*)(b + i)));
ab_vec = _mm_fmadd_ps(a_vec, b_vec, ab_vec);
a2_vec = _mm_fmadd_ps(a_vec, a_vec, a2_vec);
b2_vec = _mm_fmadd_ps(b_vec, b_vec, b2_vec);
__m256 a_vec = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const*)(a + i)));
__m256 b_vec = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const*)(b + i)));
ab_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_vec);
a2_vec = _mm256_fmadd_ps(a_vec, a_vec, a2_vec);
b2_vec = _mm256_fmadd_ps(b_vec, b_vec, b2_vec);
}
ab_vec = _mm_hadd_ps(ab_vec, ab_vec);
ab_vec = _mm_hadd_ps(ab_vec, ab_vec);
a2_vec = _mm_hadd_ps(a2_vec, a2_vec);
a2_vec = _mm_hadd_ps(a2_vec, a2_vec);
b2_vec = _mm_hadd_ps(b2_vec, b2_vec);
b2_vec = _mm_hadd_ps(b2_vec, b2_vec);
ab_vec = _mm256_add_ps(_mm256_permute2f128_ps(ab_vec, ab_vec, 0x81), ab_vec);
ab_vec = _mm256_hadd_ps(ab_vec, ab_vec);
ab_vec = _mm256_hadd_ps(ab_vec, ab_vec);
a2_vec = _mm256_add_ps(_mm256_permute2f128_ps(a2_vec, a2_vec, 0x81), a2_vec);
a2_vec = _mm256_hadd_ps(a2_vec, a2_vec);
a2_vec = _mm256_hadd_ps(a2_vec, a2_vec);
b2_vec = _mm256_add_ps(_mm256_permute2f128_ps(b2_vec, b2_vec, 0x81), b2_vec);
b2_vec = _mm256_hadd_ps(b2_vec, b2_vec);
b2_vec = _mm256_hadd_ps(b2_vec, b2_vec);
float ab_result[1], a2_result[1], b2_result[1];
_mm_store_ss(ab_result, ab_vec);
_mm_store_ss(a2_result, a2_vec);
_mm_store_ss(b2_result, b2_vec);
return 1 - ab_result[0] / (sqrt(a2_result[0]) * sqrt(b2_result[0]));
_mm_store_ss(ab_result, _mm256_castps256_ps128(ab_vec));
_mm_store_ss(a2_result, _mm256_castps256_ps128(a2_vec));
_mm_store_ss(b2_result, _mm256_castps256_ps128(b2_vec));
return 1 - ab_result[0] / (sqrtf(a2_result[0]) * sqrtf(b2_result[0]));
}

__attribute__((target("avx2"))) //
__attribute__((target("f16c"))) //
__attribute__((target("fma"))) inline static simsimd_f32_t //
simsimd_l2sq_f16x8_avx2(simsimd_f16_t const* a, simsimd_f16_t const* b, size_t d) {
__m128 sum_vec = _mm_set1_ps(0);
__m256 sum_vec = _mm256_set1_ps(0);
for (size_t i = 0; i != d; i += 8) {
__m128 a_vec = _mm_cvtph_ps(_mm_loadu_si128((__m128i const*)(a + i)));
__m128 b_vec = _mm_cvtph_ps(_mm_loadu_si128((__m128i const*)(b + i)));
__m128 diff_vec = _mm_sub_ps(a_vec, b_vec);
sum_vec = _mm_fmadd_ps(diff_vec, diff_vec, sum_vec);
__m256 a_vec = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const*)(a + i)));
__m256 b_vec = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const*)(b + i)));
__m256 diff_vec = _mm256_sub_ps(a_vec, b_vec);
sum_vec = _mm256_fmadd_ps(diff_vec, diff_vec, sum_vec);
}
sum_vec = _mm_hadd_ps(sum_vec, sum_vec);
sum_vec = _mm_hadd_ps(sum_vec, sum_vec);
sum_vec = _mm256_add_ps(_mm256_permute2f128_ps(sum_vec, sum_vec, 0x81), sum_vec);
sum_vec = _mm256_hadd_ps(sum_vec, sum_vec);
sum_vec = _mm256_hadd_ps(sum_vec, sum_vec);
float result[1];
_mm_store_ss(result, sum_vec);
_mm_store_ss(result, _mm256_castps256_ps128(sum_vec));
return result[0];
}

__attribute__((target("avx2"))) //
inline static simsimd_f32_t
simsimd_cos_i8x32_avx2(int8_t const* a, int8_t const* b, size_t d) {

__m256i ab_vec = _mm256_setzero_si256();
__m256i a2_vec = _mm256_setzero_si256();
__m256i b2_vec = _mm256_setzero_si256();

for (size_t i = 0; i != d; i += 32) {
__m256i a_vec = _mm256_loadu_si256((__m256i const*)(a + i));
__m256i b_vec = _mm256_loadu_si256((__m256i const*)(b + i));

// Multiply and add packed 8-bit integers
__m256i ab_part_vec = _mm256_maddubs_epi16(a_vec, b_vec);
__m256i a2_part_vec = _mm256_maddubs_epi16(a_vec, a_vec);
__m256i b2_part_vec = _mm256_maddubs_epi16(b_vec, b_vec);

ab_vec = _mm256_add_epi32(ab_vec, ab_part_vec);
a2_vec = _mm256_add_epi32(a2_vec, a2_part_vec);
b2_vec = _mm256_add_epi32(b2_vec, b2_part_vec);
}

// Horizontal sum across the 256-bit register
__m128i ab_low = _mm256_extracti128_si256(ab_vec, 0);
__m128i ab_high = _mm256_extracti128_si256(ab_vec, 1);
__m128i ab_sum = _mm_add_epi32(ab_low, ab_high);

__m128i a2_low = _mm256_extracti128_si256(a2_vec, 0);
__m128i a2_high = _mm256_extracti128_si256(a2_vec, 1);
__m128i a2_sum = _mm_add_epi32(a2_low, a2_high);

__m128i b2_low = _mm256_extracti128_si256(b2_vec, 0);
__m128i b2_high = _mm256_extracti128_si256(b2_vec, 1);
__m128i b2_sum = _mm_add_epi32(b2_low, b2_high);

// Further reduce to a single sum for each vector
int ab = _mm_extract_epi32(_mm_hadd_epi32(_mm_hadd_epi32(ab_sum, ab_sum), ab_sum), 0);
int a2 = _mm_extract_epi32(_mm_hadd_epi32(_mm_hadd_epi32(a2_sum, a2_sum), a2_sum), 0);
int b2 = _mm_extract_epi32(_mm_hadd_epi32(_mm_hadd_epi32(b2_sum, b2_sum), b2_sum), 0);

return 1 - ab / (sqrtf(a2) * sqrtf(b2));
}

#endif // SIMSIMD_TARGET_X86_AVX2

#if SIMSIMD_TARGET_X86_AVX512
Expand Down Expand Up @@ -498,7 +564,7 @@ simsimd_cos_f16x16_avx512(simsimd_f16_t const* a, simsimd_f16_t const* b, size_t
simsimd_f32_t ab = _mm512_reduce_add_ps(ab_vec);
simsimd_f32_t a2 = _mm512_reduce_add_ps(a2_vec);
simsimd_f32_t b2 = _mm512_reduce_add_ps(b2_vec);
return 1 - ab / (sqrt(a2) * sqrt(b2));
return 1 - ab / (sqrtf(a2) * sqrtf(b2));
}

__attribute__((target("avx512fp16"))) //
Expand Down

0 comments on commit a3239cc

Please sign in to comment.