Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complex Bilinear Forms for Computational Physics #240

Merged
merged 26 commits into from
Nov 26, 2024
Merged

Conversation

ashvardanian
Copy link
Owner

@ashvardanian ashvardanian commented Nov 24, 2024

  • Technical highlight of the release: AVX-512 masking is extremely handy in implementing unrolled BLAS Level 2 operations for small inputs, resulting in up to 5x faster kernels than OpenBLAS.
  • Semantic highlight of the release: Bilinear forms now support complex numbers as inputs, extending the kernels' applicability to Computational Physics.

Bilinear Forms are essential in Scientific Computing. Some of the most computationally intensive cases arise in Quantum systems and their simulations, as discussed on r/Quantum. This PR adds support for complex inputs to make it more broadly applicable.

$$\text{BilinearForm}(a, b, M) = a^T M b$$

In Python, you can execute this by consecutively calling 2 NumPy functions. Ideally, reusing a buffer for the intermediate results:

ndim = 128

import numpy as np
dtype = np.float32
temporary_vector = np.empty((ndim, ), dtype=dtype)

first_quantum_state = np.random.randn(ndim).astype(dtype)
second_quantum_state = np.random.randn(ndim).astype(dtype)
interaction_matrix = np.random.randn(ndim, ndim).astype(dtype)

np.matmul(first_quantum_state, interaction_matrix, out=temporary_vector)
result: float = np.inner(temporary_vector, second_quantum_state)

With SimSIMD, the last 2 lines are fused:

import simsimd as simd
simd.bilinear(first_quantum_state, second_quantum_state, interaction_matrix)

For 128-dimensional np.float32, the latency of 2.11 μs with NumPy went down to 1.31 μs. For smaller 16-dimensional np.float32, the latency of 1.31 μs with NumPy went down to 202 ns. As always, the gap is wider for low-precision np.float16 representations: 2.68 μs with NumPy vs 313 ns with NumPy.

Small Matrices and AVX-512

In the past, developers were used to providing separate precompiled kernels for every reasonable matrix size when dealing with small matrices. That negatively affects the binary size and makes CPU L1i instruction caches ineffective. With AVX-512, however, for different matrix sizes, we can reuse the same single-instruction vectorized loops with just a single additional BZHI instruction precomputing the load masks.

Avoiding Data Dependency

A common approach in dot products is to use a single register to accumulate dot products. That VFMADD132PS instruction:

  • AMD Zen 4 has a latency of 4 cycles and can execute on ports 0 and 1.
  • Intel Skylake-X has a latency of 4 cycles and can execute on ports 0 and 5.

Assuming it can run on 2 ports simultaneously, even on modern hardware, introducing data dependency between consecutive statements is inefficient. In future generations, we may be able to compute this on more ports, so to "futureproof" the solution, I use 4 intermediaries.

Avoiding Horizontal Reductions

When computing $a \dot X \dot b$, we may prefer to evaluate $X \dot b$ first due to the associativity of matrix multiplication. On tiny inputs, the operation may be bottlenecked by computing horizontal reductions for every one of the rows in $X$. Instead, we use more serial loads and broadcasts but only perform one horizontal accumulation in the end, assuming all of the needed intermediaries fit into a single register (or a few if we minimize the data dependency).

Intel Sapphire Rapids Benchmarks

Running on recent Intel Sapphire Rapids CPUs, one can expect the following performance metrics for 128-dimensional Bilinear Forms for SimSIMD and OpenBLAS:

-----------------------------------------------------------------------------------------------------------------
Benchmark                                                       Time             CPU   Iterations UserCounters...
-----------------------------------------------------------------------------------------------------------------
bilinear_f64_blas<128d>/min_time:10.000/threads:1            3584 ns         3584 ns      3906234 abs_delta=3.8576a bytes=571.503M/s pairs=279.054k/s relative_error=1.45341f
bilinear_f64c_blas<128d>/min_time:10.000/threads:1           7605 ns         7604 ns      1856665 abs_delta=3.90906a bytes=538.656M/s pairs=131.508k/s relative_error=3.10503f
bilinear_f32_blas<128d>/min_time:10.000/threads:1            1818 ns         1818 ns      7621072 abs_delta=743.294p bytes=563.325M/s pairs=550.122k/s relative_error=301.396n
bilinear_f32c_blas<128d>/min_time:10.000/threads:1           3607 ns         3606 ns      3886483 abs_delta=958.531p bytes=567.864M/s pairs=277.278k/s relative_error=1.4445u
bilinear_f16_haswell<128d>/min_time:10.000/threads:1         1324 ns         1324 ns     10597225 abs_delta=1.31674n bytes=386.742M/s pairs=755.355k/s relative_error=851.968n
bilinear_bf16_haswell<128d>/min_time:10.000/threads:1        1305 ns         1305 ns     10752131 abs_delta=1.33001n bytes=392.464M/s pairs=766.532k/s relative_error=561.046n
bilinear_bf16_genoa<128d>/min_time:10.000/threads:1           862 ns          862 ns     16241596 abs_delta=1.40284n bytes=593.885M/s pairs=1.15993M/s relative_error=849.533n
bilinear_bf16c_genoa<128d>/min_time:10.000/threads:1         2610 ns         2610 ns      5355435 abs_delta=351.596p bytes=392.313M/s pairs=383.118k/s relative_error=243.698n
bilinear_f16_sapphire<128d>/min_time:10.000/threads:1         875 ns          875 ns     16038203 abs_delta=10.5652u bytes=584.951M/s pairs=1.14248M/s relative_error=9.42998m
bilinear_f16c_sapphire<128d>/min_time:10.000/threads:1       2159 ns         2159 ns      6449575 abs_delta=4.43296u bytes=474.398M/s pairs=463.28k/s relative_error=3.98057m
bilinear_f64_skylake<128d>/min_time:10.000/threads:1         3483 ns         3483 ns      4019657 abs_delta=4.3853a bytes=587.96M/s pairs=287.09k/s relative_error=3.02046f
bilinear_f64c_skylake<128d>/min_time:10.000/threads:1        7178 ns         7178 ns      1949803 abs_delta=3.45547a bytes=570.624M/s pairs=139.313k/s relative_error=4.07708f
bilinear_f32_skylake<128d>/min_time:10.000/threads:1        1783 ns         1783 ns      7848896 abs_delta=2.45041n bytes=574.255M/s pairs=560.796k/s relative_error=811.561n
bilinear_f32c_skylake<128d>/min_time:10.000/threads:1       3504 ns         3504 ns      3976879 abs_delta=1.94251n bytes=584.494M/s pairs=285.397k/s relative_error=2.99757u
bilinear_f64_serial<128d>/min_time:10.000/threads:1          5528 ns         5528 ns      2529904 abs_delta=0 bytes=370.459M/s pairs=180.888k/s relative_error=0
bilinear_f64c_serial<128d>/min_time:10.000/threads:1        12324 ns        12324 ns      1140788 abs_delta=0 bytes=332.371M/s pairs=81.1453k/s relative_error=0
bilinear_f32_serial<128d>/min_time:10.000/threads:1          5299 ns         5298 ns      2649614 abs_delta=1.69242n bytes=193.264M/s pairs=188.734k/s relative_error=776.834n
bilinear_f32c_serial<128d>/min_time:10.000/threads:1        10217 ns        10216 ns      1370535 abs_delta=1.89398n bytes=200.461M/s pairs=97.8816k/s relative_error=3.25219u
bilinear_f16_serial<128d>/min_time:10.000/threads:1         42372 ns        42371 ns       330369 abs_delta=1.93284n bytes=12.0838M/s pairs=23.6011k/s relative_error=1.51289u
bilinear_f16c_serial<128d>/min_time:10.000/threads:1        46101 ns        46100 ns       303997 abs_delta=1.77214n bytes=22.2124M/s pairs=21.6918k/s relative_error=1.5494u
bilinear_bf16_serial<128d>/min_time:10.000/threads:1        85325 ns        85324 ns       163256 abs_delta=1.34067n bytes=6.00066M/s pairs=11.72k/s relative_error=527.801n
bilinear_bf16c_serial<128d>/min_time:10.000/threads:1      178970 ns       178967 ns        78235 abs_delta=1.46323n bytes=5.72174M/s pairs=5.58764k/s relative_error=1004.88n

Highlights:

  • Single- and double-precision kernels are only about 5% faster than BLAS due to removed temporary buffer stores.
  • Both bf16 and f16 kernels provide linear speedups proportional to the number of bits in the data type.

On low-dimensional inputs, the performance gap is larger:

---------------------------------------------------------------------------------------------------------------
Benchmark                                                     Time             CPU   Iterations UserCounters...
---------------------------------------------------------------------------------------------------------------
bilinear_f64_blas<8d>/min_time:10.000/threads:1            42.7 ns         42.7 ns    328247670 abs_delta=15.9107a bytes=3.00004G/s pairs=23.4378M/s relative_error=550.946a
bilinear_f64c_blas<8d>/min_time:10.000/threads:1           57.4 ns         57.4 ns    243896993 abs_delta=21.3452a bytes=4.46378G/s pairs=17.4366M/s relative_error=514.643a
bilinear_f32_blas<8d>/min_time:10.000/threads:1            32.2 ns         32.2 ns    434784869 abs_delta=6.73645n bytes=3.97757G/s pairs=31.0747M/s relative_error=235.395n
bilinear_f32c_blas<8d>/min_time:10.000/threads:1           50.6 ns         50.6 ns    276504577 abs_delta=7.97379n bytes=2.52823G/s pairs=19.7518M/s relative_error=251.204n
bilinear_f16_haswell<8d>/min_time:10.000/threads:1         13.7 ns         13.7 ns   1000000000 abs_delta=6.06053n bytes=9.35133G/s pairs=73.0573M/s relative_error=139.096n
bilinear_bf16_haswell<8d>/min_time:10.000/threads:1        13.0 ns         13.0 ns   1000000000 abs_delta=5.03892n bytes=9.84787G/s pairs=76.9365M/s relative_error=114.101n
bilinear_bf16_genoa<8d>/min_time:10.000/threads:1          12.6 ns         12.6 ns   1000000000 abs_delta=5.63947n bytes=10.1297G/s pairs=79.1384M/s relative_error=166.305n
bilinear_bf16c_genoa<8d>/min_time:10.000/threads:1         69.0 ns         69.0 ns    203022389 abs_delta=1.61581n bytes=1.85573G/s pairs=14.4979M/s relative_error=60.9203n
bilinear_f16_sapphire<8d>/min_time:10.000/threads:1        8.52 ns         8.52 ns   1000000000 abs_delta=51.4863u bytes=15.0256G/s pairs=117.387M/s relative_error=1.92771m
bilinear_f16c_sapphire<8d>/min_time:10.000/threads:1       64.6 ns         64.6 ns    216692584 abs_delta=43.8492u bytes=1.98133G/s pairs=15.4791M/s relative_error=1.48218m
bilinear_f32_skylake<8d>/min_time:10.000/threads:1         7.28 ns         7.28 ns   1000000000 abs_delta=8.92396n bytes=17.5799G/s pairs=137.343M/s relative_error=266.557n
bilinear_f32c_skylake<8d>/min_time:10.000/threads:1        42.8 ns         42.8 ns    326789735 abs_delta=10.4774n bytes=2.98821G/s pairs=23.3454M/s relative_error=267.67n
bilinear_f64_skylake<8d>/min_time:10.000/threads:1       7.16 ns         7.16 ns     1000000000 abs_delta=16.8322a bytes=17.8732G/s pairs=139.634M/s relative_error=776.898a
bilinear_f64c_skylake<8d>/min_time:10.000/threads:1        31.2 ns         31.2 ns    449958679 abs_delta=17.4692a bytes=8.20188G/s pairs=32.0386M/s relative_error=477.326a
bilinear_f64_serial<8d>/min_time:10.000/threads:1          19.3 ns         19.3 ns    724453573 abs_delta=0 bytes=6.63046G/s pairs=51.8005M/s relative_error=0
bilinear_f64c_serial<8d>/min_time:10.000/threads:1         47.7 ns         47.7 ns    293638808 abs_delta=0 bytes=5.36703G/s pairs=20.965M/s relative_error=0
bilinear_f32_serial<8d>/min_time:10.000/threads:1          18.4 ns         18.4 ns    759547931 abs_delta=7.93122n bytes=6.94336G/s pairs=54.245M/s relative_error=213.04n
bilinear_f32c_serial<8d>/min_time:10.000/threads:1         45.6 ns         45.6 ns    307012654 abs_delta=9.52236n bytes=2.80829G/s pairs=21.9398M/s relative_error=282.08n
bilinear_f16_serial<8d>/min_time:10.000/threads:1           171 ns          171 ns     81713243 abs_delta=7.46151n bytes=747.117M/s pairs=5.83685M/s relative_error=187.409n
bilinear_f16c_serial<8d>/min_time:10.000/threads:1          208 ns          208 ns     67195854 abs_delta=8.79194n bytes=614.281M/s pairs=4.79907M/s relative_error=265.818n
bilinear_bf16_serial<8d>/min_time:10.000/threads:1          359 ns          359 ns     38947709 abs_delta=5.77119n bytes=356.094M/s pairs=2.78198M/s relative_error=122.725n
bilinear_bf16c_serial<8d>/min_time:10.000/threads:1         744 ns          744 ns     18821435 abs_delta=6.72388n bytes=172.071M/s pairs=1.34431M/s relative_error=145.277n

Highlights:

  • For f32, the performance grew from 31.07 to 137.34 Million operations per second.
  • For f64, the performance grew from 23.44 to 139.63 Million operations per second.

@ashvardanian ashvardanian merged commit 73a8520 into main Nov 26, 2024
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant