Skip to content

Commit

Permalink
feat: SIMD version of quantizing (#574)
Browse files Browse the repository at this point in the history
Signed-off-by: usamoi <[email protected]>
  • Loading branch information
usamoi authored Aug 27, 2024
1 parent 165b7de commit b7e1a7a
Show file tree
Hide file tree
Showing 15 changed files with 366 additions and 51 deletions.
4 changes: 4 additions & 0 deletions crates/base/src/aligned.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
#[derive(Debug, Clone, Copy)]
#[repr(C, align(16))]
pub struct Aligned16<T>(pub T);

#[derive(Debug, Clone, Copy)]
#[repr(C, align(32))]
pub struct Aligned32<T>(pub T);
34 changes: 34 additions & 0 deletions crates/base/src/scalar/emulate.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#![allow(clippy::missing_safety_doc)]

// VP2INTERSECT emulation.
// Díez-Cañas, G. (2021). Faster-Than-Native Alternatives for x86 VP2INTERSECT
// Instructions. arXiv preprint arXiv:2112.06342.
Expand Down Expand Up @@ -70,3 +72,35 @@ pub unsafe fn emulate_mm256_reduce_add_ps(mut x: std::arch::x86_64::__m256) -> f
_mm256_cvtss_f32(x)
}
}

#[inline]
#[cfg(target_arch = "x86_64")]
#[detect::target_cpu(enable = "v3")]
pub unsafe fn emulate_mm256_reduce_min_ps(x: std::arch::x86_64::__m256) -> f32 {
use crate::aligned::Aligned16;
unsafe {
use std::arch::x86_64::*;
let lo = _mm256_castps256_ps128(x);
let hi = _mm256_extractf128_ps(x, 1);
let min = _mm_min_ps(lo, hi);
let mut x = Aligned16([0.0f32; 4]);
_mm_store_ps(x.0.as_mut_ptr(), min);
f32::min(f32::min(x.0[0], x.0[1]), f32::min(x.0[2], x.0[3]))
}
}

#[inline]
#[cfg(target_arch = "x86_64")]
#[detect::target_cpu(enable = "v3")]
pub unsafe fn emulate_mm256_reduce_max_ps(x: std::arch::x86_64::__m256) -> f32 {
use crate::aligned::Aligned16;
unsafe {
use std::arch::x86_64::*;
let lo = _mm256_castps256_ps128(x);
let hi = _mm256_extractf128_ps(x, 1);
let max = _mm_max_ps(lo, hi);
let mut x = Aligned16([0.0f32; 4]);
_mm_store_ps(x.0.as_mut_ptr(), max);
f32::max(f32::max(x.0[0], x.0[1]), f32::max(x.0[2], x.0[3]))
}
}
14 changes: 7 additions & 7 deletions crates/base/src/scalar/f16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ impl ScalarLike for f16 {
}

// FIXME: add manually-implemented SIMD version
#[inline(always)]
#[detect::multiversion(v4, v3, v2, neon, fallback)]
fn reduce_min_max_of_x(this: &[f16]) -> (f32, f32) {
let mut min = 0.0f32;
let mut max = 0.0f32;
let mut min = f32::INFINITY;
let mut max = f32::NEG_INFINITY;
let n = this.len();
for i in 0..n {
min = min.min(this[i].to_f32());
Expand Down Expand Up @@ -160,13 +160,13 @@ impl ScalarLike for f16 {
}

#[detect::multiversion(v4, v3, v2, neon, fallback)]
fn vector_div_scalar(lhs: &[f16], rhs: f32) -> Vec<f16> {
fn vector_mul_scalar(lhs: &[f16], rhs: f32) -> Vec<f16> {
let rhs = f16::from_f32(rhs);
let n = lhs.len();
let mut r = Vec::<f16>::with_capacity(n);
for i in 0..n {
unsafe {
r.as_mut_ptr().add(i).write(lhs[i] / rhs);
r.as_mut_ptr().add(i).write(lhs[i] * rhs);
}
}
unsafe {
Expand All @@ -176,11 +176,11 @@ impl ScalarLike for f16 {
}

#[detect::multiversion(v4, v3, v2, neon, fallback)]
fn vector_div_scalar_inplace(lhs: &mut [f16], rhs: f32) {
fn vector_mul_scalar_inplace(lhs: &mut [f16], rhs: f32) {
let rhs = f16::from_f32(rhs);
let n = lhs.len();
for i in 0..n {
lhs[i] /= rhs;
lhs[i] *= rhs;
}
}

Expand Down
175 changes: 154 additions & 21 deletions crates/base/src/scalar/f32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,9 @@ impl ScalarLike for f32 {
x2
}

// FIXME: add manually-implemented SIMD version
#[inline(always)]
fn reduce_min_max_of_x(this: &[f32]) -> (f32, f32) {
let mut min = 0.0f32;
let mut max = 0.0f32;
let n = this.len();
for i in 0..n {
min = min.min(this[i]);
max = max.max(this[i]);
}
(min, max)
reduce_min_max_of_x::reduce_min_max_of_x(this)
}

#[inline(always)]
Expand Down Expand Up @@ -159,12 +151,12 @@ impl ScalarLike for f32 {
}

#[detect::multiversion(v4, v3, v2, neon, fallback)]
fn vector_div_scalar(lhs: &[f32], rhs: f32) -> Vec<f32> {
fn vector_mul_scalar(lhs: &[f32], rhs: f32) -> Vec<f32> {
let n = lhs.len();
let mut r = Vec::<f32>::with_capacity(n);
for i in 0..n {
unsafe {
r.as_mut_ptr().add(i).write(lhs[i] / rhs);
r.as_mut_ptr().add(i).write(lhs[i] * rhs);
}
}
unsafe {
Expand All @@ -174,10 +166,10 @@ impl ScalarLike for f32 {
}

#[detect::multiversion(v4, v3, v2, neon, fallback)]
fn vector_div_scalar_inplace(lhs: &mut [f32], rhs: f32) {
fn vector_mul_scalar_inplace(lhs: &mut [f32], rhs: f32) {
let n = lhs.len();
for i in 0..n {
lhs[i] /= rhs;
lhs[i] *= rhs;
}
}

Expand All @@ -204,6 +196,147 @@ impl ScalarLike for f32 {
}
}

mod reduce_min_max_of_x {
// Semanctics of `f32::min` is different from `_mm256_min_ps`,
// which may lead to issues...

#[cfg(target_arch = "x86_64")]
#[detect::target_cpu(enable = "v4")]
unsafe fn reduce_min_max_of_x_v4(this: &[f32]) -> (f32, f32) {
unsafe {
use std::arch::x86_64::*;
let mut n = this.len();
let mut a = this.as_ptr();
let mut min = _mm512_set1_ps(f32::INFINITY);
let mut max = _mm512_set1_ps(f32::NEG_INFINITY);
while n >= 16 {
let x = _mm512_loadu_ps(a);
a = a.add(16);
n -= 16;
min = _mm512_min_ps(x, min);
max = _mm512_max_ps(x, max);
}
if n > 0 {
let mask = _bzhi_u32(0xffff, n as u32) as u16;
let x = _mm512_maskz_loadu_ps(mask, a);
min = _mm512_mask_min_ps(min, mask, x, min);
max = _mm512_mask_max_ps(max, mask, x, max);
}
let min = _mm512_reduce_min_ps(min);
let max = _mm512_reduce_max_ps(max);
(min, max)
}
}

#[cfg(all(target_arch = "x86_64", test))]
#[test]
fn reduce_min_max_of_x_v4_test() {
const EPSILON: f32 = 0.0001;
detect::init();
if !detect::v4::detect() {
println!("test {} ... skipped (v4)", module_path!());
return;
}
for _ in 0..300 {
let n = 200;
let x = (0..n).map(|_| rand::random::<_>()).collect::<Vec<_>>();
for z in 50..200 {
let x = &x[..z];
let specialized = unsafe { reduce_min_max_of_x_v4(x) };
let fallback = unsafe { reduce_min_max_of_x_fallback(x) };
assert!(
(specialized.0 - fallback.0).abs() < EPSILON,
"min: specialized = {}, fallback = {}.",
specialized.0,
fallback.0,
);
assert!(
(specialized.1 - fallback.1).abs() < EPSILON,
"max: specialized = {}, fallback = {}.",
specialized.1,
fallback.1,
);
}
}
}

#[cfg(target_arch = "x86_64")]
#[detect::target_cpu(enable = "v3")]
unsafe fn reduce_min_max_of_x_v3(this: &[f32]) -> (f32, f32) {
use crate::scalar::emulate::emulate_mm256_reduce_max_ps;
use crate::scalar::emulate::emulate_mm256_reduce_min_ps;
unsafe {
use std::arch::x86_64::*;
let mut n = this.len();
let mut a = this.as_ptr();
let mut min = _mm256_set1_ps(f32::INFINITY);
let mut max = _mm256_set1_ps(f32::NEG_INFINITY);
while n >= 8 {
let x = _mm256_loadu_ps(a);
a = a.add(8);
n -= 8;
min = _mm256_min_ps(x, min);
max = _mm256_max_ps(x, max);
}
let mut min = emulate_mm256_reduce_min_ps(min);
let mut max = emulate_mm256_reduce_max_ps(max);
// this hint is used to disable loop unrolling
while std::hint::black_box(n) > 0 {
let x = a.read();
a = a.add(1);
n -= 1;
min = x.min(min);
max = x.max(max);
}
(min, max)
}
}

#[cfg(all(target_arch = "x86_64", test))]
#[test]
fn reduce_min_max_of_x_v3_test() {
const EPSILON: f32 = 0.0001;
detect::init();
if !detect::v3::detect() {
println!("test {} ... skipped (v3)", module_path!());
return;
}
for _ in 0..300 {
let n = 200;
let x = (0..n).map(|_| rand::random::<_>()).collect::<Vec<_>>();
for z in 50..200 {
let x = &x[..z];
let specialized = unsafe { reduce_min_max_of_x_v3(x) };
let fallback = unsafe { reduce_min_max_of_x_fallback(x) };
assert!(
(specialized.0 - fallback.0).abs() < EPSILON,
"specialized = {}, fallback = {}.",
specialized.0,
fallback.0,
);
assert!(
(specialized.1 - fallback.1).abs() < EPSILON,
"specialized = {}, fallback = {}.",
specialized.1,
fallback.1,
);
}
}
}

#[detect::multiversion(v4 = import, v3 = import, v2, neon, fallback = export)]
pub fn reduce_min_max_of_x(this: &[f32]) -> (f32, f32) {
let mut min = f32::INFINITY;
let mut max = f32::NEG_INFINITY;
let n = this.len();
for i in 0..n {
min = min.min(this[i]);
max = max.max(this[i]);
}
(min, max)
}
}

mod reduce_sum_of_xy {
#[cfg(target_arch = "x86_64")]
#[detect::target_cpu(enable = "v4")]
Expand All @@ -216,17 +349,17 @@ mod reduce_sum_of_xy {
let mut b = rhs.as_ptr();
let mut xy = _mm512_setzero_ps();
while n >= 16 {
let x = _mm512_loadu_ps(a.cast());
let y = _mm512_loadu_ps(b.cast());
let x = _mm512_loadu_ps(a);
let y = _mm512_loadu_ps(b);
a = a.add(16);
b = b.add(16);
n -= 16;
xy = _mm512_fmadd_ps(x, y, xy);
}
if n > 0 {
let mask = _bzhi_u32(0xffff, n as u32) as u16;
let x = _mm512_maskz_loadu_ps(mask, a.cast());
let y = _mm512_maskz_loadu_ps(mask, b.cast());
let x = _mm512_maskz_loadu_ps(mask, a);
let y = _mm512_maskz_loadu_ps(mask, b);
xy = _mm512_fmadd_ps(x, y, xy);
}
_mm512_reduce_add_ps(xy)
Expand Down Expand Up @@ -271,16 +404,16 @@ mod reduce_sum_of_xy {
let mut b = rhs.as_ptr();
let mut xy = _mm256_setzero_ps();
while n >= 8 {
let x = _mm256_loadu_ps(a.cast());
let y = _mm256_loadu_ps(b.cast());
let x = _mm256_loadu_ps(a);
let y = _mm256_loadu_ps(b);
a = a.add(8);
b = b.add(8);
n -= 8;
xy = _mm256_fmadd_ps(x, y, xy);
}
if n >= 4 {
let x = _mm256_zextps128_ps256(_mm_loadu_ps(a.cast()));
let y = _mm256_zextps128_ps256(_mm_loadu_ps(b.cast()));
let x = _mm256_zextps128_ps256(_mm_loadu_ps(a));
let y = _mm256_zextps128_ps256(_mm_loadu_ps(b));
a = a.add(4);
b = b.add(4);
n -= 4;
Expand Down
4 changes: 2 additions & 2 deletions crates/base/src/scalar/impossible.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,11 @@ impl ScalarLike for Impossible {
unimplemented!()
}

fn vector_div_scalar(_lhs: &[Self], _rhs: f32) -> Vec<Self> {
fn vector_mul_scalar(_lhs: &[Self], _rhs: f32) -> Vec<Self> {
unimplemented!()
}

fn vector_div_scalar_inplace(_lhs: &mut [Self], _rhs: f32) {
fn vector_mul_scalar_inplace(_lhs: &mut [Self], _rhs: f32) {
unimplemented!()
}

Expand Down
6 changes: 3 additions & 3 deletions crates/base/src/scalar/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pub mod bit;
mod emulate;
pub mod emulate;
mod f16;
mod f32;
pub mod impossible;
Expand Down Expand Up @@ -44,8 +44,8 @@ pub trait ScalarLike:
fn vector_add_inplace(lhs: &mut [Self], rhs: &[Self]);
fn vector_sub(lhs: &[Self], rhs: &[Self]) -> Vec<Self>;
fn vector_mul(lhs: &[Self], rhs: &[Self]) -> Vec<Self>;
fn vector_div_scalar(lhs: &[Self], rhs: f32) -> Vec<Self>;
fn vector_div_scalar_inplace(lhs: &mut [Self], rhs: f32);
fn vector_mul_scalar(lhs: &[Self], rhs: f32) -> Vec<Self>;
fn vector_mul_scalar_inplace(lhs: &mut [Self], rhs: f32);

fn kmeans_helper(this: &mut [Self], x: f32, y: f32);
}
2 changes: 1 addition & 1 deletion crates/base/src/vector/svect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ impl<'a, S: ScalarLike> VectorBorrowed for SVectBorrowed<'a, S> {
let l = S::reduce_sum_of_x2(self.values).sqrt();
let indexes = self.indexes.to_vec();
let mut values = self.values.to_vec();
S::vector_div_scalar_inplace(&mut values, l);
S::vector_mul_scalar_inplace(&mut values, 1.0 / l);
// FIXME: it may panic because of zeros
SVectOwned::new(self.dims, indexes, values)
}
Expand Down
2 changes: 1 addition & 1 deletion crates/base/src/vector/vect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ impl<'a, S: ScalarLike> VectorBorrowed for VectBorrowed<'a, S> {
fn function_normalize(&self) -> VectOwned<S> {
let mut data = self.0.to_vec();
let l = S::reduce_sum_of_x2(&data).sqrt();
S::vector_div_scalar_inplace(&mut data, l);
S::vector_mul_scalar_inplace(&mut data, 1.0 / l);
VectOwned(data)
}

Expand Down
2 changes: 1 addition & 1 deletion crates/k_means/src/elkan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ impl<S: ScalarLike, F: FnMut(&mut [S])> ElkanKMeans<S, F> {
if count[i] == 0.0f32 {
continue;
}
S::vector_div_scalar_inplace(&mut centroids[(i,)], count[i]);
S::vector_mul_scalar_inplace(&mut centroids[(i,)], 1.0 / count[i]);
}
for i in 0..c {
if count[i] != 0.0f32 {
Expand Down
2 changes: 1 addition & 1 deletion crates/k_means/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ pub fn k_means_lookup_many<S: ScalarLike>(vector: &[S], centroids: &Vec2<S>) ->

fn spherical<S: ScalarLike>(vector: &mut [S]) {
let l = S::reduce_sum_of_x2(vector).sqrt();
S::vector_div_scalar_inplace(vector, l);
S::vector_mul_scalar_inplace(vector, 1.0 / l);
}

fn dummy<S: ScalarLike>(_: &mut [S]) {}
Loading

0 comments on commit b7e1a7a

Please sign in to comment.