Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v6: Future-Proofing Dense & Sparse Operations #217

Merged
merged 5 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@
"format": "c",
"execution": "cpp",
"math.h": "c",
"float.h": "c"
"float.h": "c",
"text_encoding": "cpp",
"stdio.h": "c"
},
"cSpell.words": [
"allclose",
Expand Down
33 changes: 31 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ You can also benchmark against other libraries, filter the numeric types, and di
$ python scripts/bench_vectors.py --help
> usage: bench.py [-h] [--ndim NDIM] [-n COUNT]
> [--metric {all,dot,spatial,binary,probability,sparse}]
> [--dtype {all,bits,int8,uint16,uint32,float16,float32,float64,bfloat16,complex32,complex64,complex128}]
> [--dtype {all,bin8,int8,uint16,uint32,float16,float32,float64,bfloat16,complex32,complex64,complex128}]
> [--scipy] [--scikit] [--torch] [--tf] [--jax]
>
> Benchmark SimSIMD vs. other libraries
Expand All @@ -119,7 +119,7 @@ $ python scripts/bench_vectors.py --help
> `cdist`.
> --metric {all,dot,spatial,binary,probability,sparse}
> Distance metric to use, profiles everything by default
> --dtype {all,bits,int8,uint16,uint32,float16,float32,float64,bfloat16,complex32,complex64,complex128}
> --dtype {all,bin8,int8,uint16,uint32,float16,float32,float64,bfloat16,complex32,complex64,complex128}
> Defines numeric types to benchmark, profiles everything by default
> --scipy Profile SciPy, must be installed
> --scikit Profile scikit-learn, must be installed
Expand Down Expand Up @@ -203,6 +203,35 @@ bun test
swift build && swift test -v
```

Running Swift on Linux requires a couple of extra steps, as the Swift compiler is not available in the default repositories.
Please get the most recent Swift tarball from the [official website](https://www.swift.org/install/).
At the time of writing, for 64-bit Arm CPU running Ubuntu 22.04, the following commands would work:

```bash
wget https://download.swift.org/swift-5.9.2-release/ubuntu2204-aarch64/swift-5.9.2-RELEASE/swift-5.9.2-RELEASE-ubuntu22.04-aarch64.tar.gz
tar xzf swift-5.9.2-RELEASE-ubuntu22.04-aarch64.tar.gz
sudo mv swift-5.9.2-RELEASE-ubuntu22.04-aarch64 /usr/share/swift
echo "export PATH=/usr/share/swift/usr/bin:$PATH" >> ~/.bashrc
source ~/.bashrc
```

You can check the available images on [`swift.org/download` page](https://www.swift.org/download/#releases).
For x86 CPUs, the following commands would work:

```bash
wget https://download.swift.org/swift-5.9.2-release/ubuntu2204/swift-5.9.2-RELEASE/swift-5.9.2-RELEASE-ubuntu22.04.tar.gz
tar xzf swift-5.9.2-RELEASE-ubuntu22.04.tar.gz
sudo mv swift-5.9.2-RELEASE-ubuntu22.04 /usr/share/swift
echo "export PATH=/usr/share/swift/usr/bin:$PATH" >> ~/.bashrc
source ~/.bashrc
```

Alternatively, on Linux, the official Swift Docker image can be used for builds and tests:

```bash
sudo docker run --rm -v "$PWD:/workspace" -w /workspace swift:5.9 /bin/bash -cl "swift build -c release --static-swift-stdlib && swift test -c release --enable-test-discovery"
```

## GoLang

```sh
Expand Down
126 changes: 90 additions & 36 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ Implemented distance functions include:

Moreover, SimSIMD...

- handles `f64`, `f32`, `f16`, and `bf16` real & complex vectors.
- handles `i8` integral, `i4` sub-byte, and `b8` binary vectors.
- handles sparse `u32` and `u16` sets, and weighted sparse vectors.
- handles `float64`, `float32`, `float16`, and `bfloat16` real & complex vectors.
- handles `int8` integral, `int4` sub-byte, and `b8` binary vectors.
- handles sparse `uint32` and `uint16` sets, and weighted sparse vectors.
- is a zero-dependency [header-only C 99](#using-simsimd-in-c) library.
- has [Python](#using-simsimd-in-python), [Rust](#using-simsimd-in-rust), [JS](#using-simsimd-in-javascript), and [Swift](#using-simsimd-in-swift) bindings.
- has Arm backends for NEON, Scalable Vector Extensions (SVE), and SVE2.
Expand All @@ -95,14 +95,14 @@ You can learn more about the technical implementation details in the following b
For reference, we use 1536-dimensional vectors, like the embeddings produced by the OpenAI Ada API.
Comparing the serial code throughput produced by GCC 12 to hand-optimized kernels in SimSIMD, we see the following single-core improvements for the two most common vector-vector similarity metrics - the Cosine similarity and the Euclidean distance:

| Type | Apple M2 Pro | Intel Sapphire Rapids | AWS Graviton 4 |
| :----- | ----------------------------: | -------------------------------: | ------------------------------: |
| `f64` | 18.5 → 28.8 GB/s <br/> + 56 % | 21.9 → 41.4 GB/s <br/> + 89 % | 20.7 → 41.3 GB/s <br/> + 99 % |
| `f32` | 9.2 → 29.6 GB/s <br/> + 221 % | 10.9 → 95.8 GB/s <br/> + 779 % | 4.9 → 41.9 GB/s <br/> + 755 % |
| `f16` | 4.6 → 14.6 GB/s <br/> + 217 % | 3.1 → 108.4 GB/s <br/> + 3,397 % | 5.4 → 39.3 GB/s <br/> + 627 % |
| `bf16` | 4.6 → 26.3 GB/s <br/> + 472 % | 0.8 → 59.5 GB/s <br/> +7,437 % | 2.5 → 29.9 GB/s <br/> + 1,096 % |
| `i8` | 25.8 → 47.1 GB/s <br/> + 83 % | 33.1 → 65.3 GB/s <br/> + 97 % | 35.2 → 43.5 GB/s <br/> + 24 % |
| `u8` | | 32.5 → 66.5 GB/s <br/> + 105 % | |
| Type | Apple M2 Pro | Intel Sapphire Rapids | AWS Graviton 4 |
| :--------- | ----------------------------: | -------------------------------: | ------------------------------: |
| `float64` | 18.5 → 28.8 GB/s <br/> + 56 % | 21.9 → 41.4 GB/s <br/> + 89 % | 20.7 → 41.3 GB/s <br/> + 99 % |
| `float32` | 9.2 → 29.6 GB/s <br/> + 221 % | 10.9 → 95.8 GB/s <br/> + 779 % | 4.9 → 41.9 GB/s <br/> + 755 % |
| `float16` | 4.6 → 14.6 GB/s <br/> + 217 % | 3.1 → 108.4 GB/s <br/> + 3,397 % | 5.4 → 39.3 GB/s <br/> + 627 % |
| `bfloat16` | 4.6 → 26.3 GB/s <br/> + 472 % | 0.8 → 59.5 GB/s <br/> +7,437 % | 2.5 → 29.9 GB/s <br/> + 1,096 % |
| `int8` | 25.8 → 47.1 GB/s <br/> + 83 % | 33.1 → 65.3 GB/s <br/> + 97 % | 35.2 → 43.5 GB/s <br/> + 24 % |
| `uint8` | | 32.5 → 66.5 GB/s <br/> + 105 % | |

Similar speedups are often observed even when compared to BLAS and LAPACK libraries underlying most numerical computing libraries, including NumPy and SciPy in Python.
Broader benchmarking results:
Expand All @@ -115,8 +115,8 @@ Broader benchmarking results:

The package is intended to replace the usage of `numpy.inner`, `numpy.dot`, and `scipy.spatial.distance`.
Aside from drastic performance improvements, SimSIMD significantly improves accuracy in mixed precision setups.
NumPy and SciPy, processing `i8`, `u8` or `f16` vectors, will use the same types for accumulators, while SimSIMD can combine `i8` enumeration, `i16` multiplication, and `i32` accumulation to avoid overflows entirely.
The same applies to processing `f16` and `bf16` values with `f32` precision.
NumPy and SciPy, processing `int8`, `uint8` or `float16` vectors, will use the same types for accumulators, while SimSIMD can combine `int8` enumeration, `int16` multiplication, and `int32` accumulation to avoid overflows entirely.
The same applies to processing `float16` and `bfloat16` values with `float32` precision.

### Installation

Expand Down Expand Up @@ -155,14 +155,33 @@ dist = simsimd.vdot(vec1.astype(np.complex64), vec2.astype(np.complex64)) # conj
```

Unlike SciPy, SimSIMD allows explicitly stating the precision of the input vectors, which is especially useful for mixed-precision setups.
The `dtype` argument can be passed both by name and as a positional argument:

```py
dist = simsimd.cosine(vec1, vec2, "i8")
dist = simsimd.cosine(vec1, vec2, "f16")
dist = simsimd.cosine(vec1, vec2, "f32")
dist = simsimd.cosine(vec1, vec2, "f64")
dist = simsimd.hamming(vec1, vec2, "bits")
dist = simsimd.jaccard(vec1, vec2, "bits")
dist = simsimd.cosine(vec1, vec2, "int8")
dist = simsimd.cosine(vec1, vec2, "float16")
dist = simsimd.cosine(vec1, vec2, "float32")
dist = simsimd.cosine(vec1, vec2, "float64")
dist = simsimd.hamming(vec1, vec2, "bit8")
```

With other frameworks, like PyTorch, one can get a richer type-system than NumPy, but the lack of good CPython interoperability makes it hard to pass data without copies.

```py
import numpy as np
buf1 = np.empty(8, dtype=np.uint16)
buf2 = np.empty(8, dtype=np.uint16)

# View the same memory region with PyTorch and randomize it
import torch
vec1 = torch.asarray(memoryview(buf1), copy=False).view(torch.bfloat16)
vec2 = torch.asarray(memoryview(buf2), copy=False).view(torch.bfloat16)
torch.randn(8, out=vec1)
torch.randn(8, out=vec2)

# Both libs will look into the same memory buffers and report the same results
dist_slow = 1 - torch.nn.functional.cosine_similarity(vec1, vec2, dim=0)
dist_fast = simsimd.cosine(buf1, buf2, "bf16")
```

It also allows using SimSIMD for half-precision complex numbers, which NumPy does not support.
Expand Down Expand Up @@ -235,6 +254,48 @@ distances: DistancesTensor = simsimd.cdist(matrix1, matrix2, metric="cosine")
distances_array: np.ndarray = np.array(distances, copy=True) # now managed by NumPy
```

### Elementwise Kernels

SimSIMD also provides mixed-precision elementwise kernels, where the input vectors and the output have the same numeric type, but the intermediate accumulators are of a higher precision.

```py
import numpy as np
from simsimd import fma, wsum

# Let's take two FullHD video frames
first_frame = np.random.randn(1920 * 1024).astype(np.uint8)
second_frame = np.random.randn(1920 * 1024).astype(np.uint8)
average_frame = np.empty_like(first_frame)
wsum(first_frame, second_frame, alpha=0.5, beta=0.5, out=average_frame)

# Slow analog with NumPy:
slow_average_frame = (0.5 * first_frame + 0.5 * second_frame).astype(np.uint8)
```

Similarly, the `fma` takes three arguments and computes the fused multiply-add operation.
In applications like Machine Learning you may also benefit from using the "brain-float" format not natively supported by NumPy.
In 3D Graphics, for example, we can use FMA to compute the [Phong shading model](https://en.wikipedia.org/wiki/Phong_shading):

```py
# Assume a FullHD frame with random values for simplicity
light_intensity = np.random.rand(1920 * 1080).astype(np.float16) # Intensity of light on each pixel
diffuse_component = np.random.rand(1920 * 1080).astype(np.float16) # Diffuse reflectance on the surface
specular_component = np.random.rand(1920 * 1080).astype(np.float16) # Specular reflectance for highlights
output_color = np.empty_like(light_intensity) # Array to store the resulting color intensity

# Define the scaling factors for diffuse and specular contributions
alpha = 0.7 # Weight for the diffuse component
beta = 0.3 # Weight for the specular component

# Formula: color = alpha * light_intensity * diffuse_component + beta * specular_component
fma(light_intensity, diffuse_component, specular_component,
dtype="float16", # Optional, unless it can't be inferred from the input
alpha=alpha, beta=beta, out=output_color)

# Slow analog with NumPy for comparison
slow_output_color = (alpha * light_intensity * diffuse_component + beta * specular_component).astype(np.float16)
```

### Multithreading and Memory Usage

By default, computations use a single CPU core.
Expand All @@ -248,15 +309,15 @@ matrix1 = np.packbits(np.random.randint(2, size=(10_000, ndim)).astype(np.uint8)
matrix2 = np.packbits(np.random.randint(2, size=(1_000, ndim)).astype(np.uint8))

distances = simsimd.cdist(matrix1, matrix2,
metric="hamming", # Unlike SciPy, SimSIMD doesn't divide by the number of dimensions
out_dtype="u8", # so we can use `u8` instead of `f64` to save memory.
threads=0, # Use all CPU cores with OpenMP.
dtype="b8", # Override input argument type to `b8` eight-bit words.
metric="hamming", # Unlike SciPy, SimSIMD doesn't divide by the number of dimensions
out_dtype="uint8", # so we can use `uint8` instead of `float64` to save memory.
threads=0, # Use all CPU cores with OpenMP.
dtype="bin8", # Override input argument type to `bin8` eight-bit words.
)
```

By default, the output distances will be stored in double-precision `f64` floating-point numbers.
That behavior may not be space-efficient, especially if you are computing the hamming distance between short binary vectors, that will generally fit into 8x smaller `u8` or `u16` types.
By default, the output distances will be stored in double-precision `float64` floating-point numbers.
That behavior may not be space-efficient, especially if you are computing the hamming distance between short binary vectors, that will generally fit into 8x smaller `uint8` or `uint16` types.
To override this behavior, use the `dtype` argument.

### Helper Functions
Expand Down Expand Up @@ -575,7 +636,7 @@ Simplest of all, you can include the headers, and the compiler will automaticall
int main() {
simsimd_f32_t vector_a[1536];
simsimd_f32_t vector_b[1536];
simsimd_metric_punned_t distance_function = simsimd_metric_punned(
simsimd_kernel_punned_t distance_function = simsimd_metric_punned(
simsimd_metric_cos_k, // Metric kind, like the angular cosine distance
simsimd_datatype_f32_k, // Data type, like: f16, f32, f64, i8, b8, and complex variants
simsimd_cap_any_k); // Which CPU capabilities are we allowed to use
Expand Down Expand Up @@ -663,7 +724,6 @@ int main() {
simsimd_vdot_f16c(f16s, f16s, 1536, &distance);
simsimd_vdot_f32c(f32s, f32s, 1536, &distance);
simsimd_vdot_f64c(f64s, f64s, 1536, &distance);

return 0;
}
```
Expand All @@ -676,13 +736,8 @@ int main() {
int main() {
simsimd_b8_t b8s[1536 / 8]; // 8 bits per word
simsimd_distance_t distance;

// Hamming distance between two vectors
simsimd_hamming_b8(b8s, b8s, 1536 / 8, &distance);

// Jaccard distance between two vectors
simsimd_jaccard_b8(b8s, b8s, 1536 / 8, &distance);

return 0;
}
```
Expand All @@ -707,7 +762,6 @@ int main() {
simsimd_kl_f16(f16s, f16s, 1536, &distance);
simsimd_kl_f32(f32s, f32s, 1536, &distance);
simsimd_kl_f64(f64s, f64s, 1536, &distance);

return 0;
}
```
Expand Down Expand Up @@ -949,10 +1003,10 @@ In NumPy terms, the implementation may look like:

```py
import numpy as np
def wsum(A: np.ndarray, B: np.ndarray, Alpha: float, Beta: float) -> np.ndarray:
def wsum(A: np.ndarray, B: np.ndarray, /, Alpha: float, Beta: float) -> np.ndarray:
assert A.dtype == B.dtype, "Input types must match and affect the output style"
return (Alpha * A + Beta * B).astype(A.dtype)
def fma(A: np.ndarray, B: np.ndarray, C: np.ndarray, Alpha: float, Beta: float) -> np.ndarray:
def fma(A: np.ndarray, B: np.ndarray, C: np.ndarray, /, Alpha: float, Beta: float) -> np.ndarray:
assert A.dtype == B.dtype and A.dtype == C.dtype, "Input types must match and affect the output style"
return (Alpha * A * B + Beta * C).astype(A.dtype)
```
Expand Down Expand Up @@ -1095,7 +1149,7 @@ All of the function names follow the same pattern: `simsimd_{function}_{type}_{b
- The type can be `f64`, `f32`, `f16`, `bf16`, `f64c`, `f32c`, `f16c`, `bf16c`, `i8`, or `b8`.
- The function can be `dot`, `vdot`, `cos`, `l2sq`, `hamming`, `jaccard`, `kl`, `js`, or `intersect`.

To avoid hard-coding the backend, you can use the `simsimd_metric_punned_t` to pun the function pointer and the `simsimd_capabilities` function to get the available backends at runtime.
To avoid hard-coding the backend, you can use the `simsimd_kernel_punned_t` to pun the function pointer and the `simsimd_capabilities` function to get the available backends at runtime.
To match all the function names, consider a RegEx:

```regex
Expand Down
Loading
Loading