Skip to content

Commit

Permalink
vector indexing refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Jack Dermody committed Jul 7, 2024
1 parent db702b2 commit 48eec47
Show file tree
Hide file tree
Showing 12 changed files with 363 additions and 222 deletions.
2 changes: 1 addition & 1 deletion BrightData.MKL/MklLinearAlgebraProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ static RT Traverse<RT>(IReadOnlyNumericSegment<float> x, IReadOnlyNumericSegment
{
if (x is INumericSegment<float> x2 && y is INumericSegment<float> y2)
return Traverse(x2, y2, callback);
return x.ApplyReadOnlySpans(y, (a, b) => {
return x.ReduceReadOnlySpans(y, (a, b) => {
fixed(float* p1 = a)
fixed (float* p2 = b) {
return callback(a.Length, p1, 1, p2, 1);
Expand Down
18 changes: 17 additions & 1 deletion BrightData.UnitTests/VectorSetTests.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using BrightData.UnitTests.Helper;
using System.Linq;
using BrightData.LinearAlgebra.VectorIndexing;
using BrightData.Types;
using FluentAssertions;
using Xunit;
Expand All @@ -14,7 +15,7 @@ public void Average()
var set = new VectorSet<float>(4);
set.Add(_context.CreateReadOnlyVector(0, 0, 0, 0));
set.Add(_context.CreateReadOnlyVector(1, 1, 1, 1));
var average = set.GetAverage(new[] { 0U, 1U });
var average = set.GetAverage([0U, 1U]);
average.Should().AllBeEquivalentTo(0.5f);
}

Expand All @@ -37,5 +38,20 @@ public void Rank2()
var rank = set.Rank(_context.CreateReadOnlyVector(0.45f, 0.45f, 0.45f, 0.45f));
rank.First().Should().Be(0);
}

[Fact]
public void Closest()
{
var set = new VectorSet<float>(4);
set.Add(_context.CreateReadOnlyVector(0, 0, 0, 0));
set.Add(_context.CreateReadOnlyVector(1, 1, 1, 1));
var score = set.Closest([
_context.CreateReadOnlyVector(0.5f, 0.5f, 0.5f, 0.5f), // 0
_context.CreateReadOnlyVector(0.9f, 0.9f, 0.9f, 0.9f), // 1
_context.CreateReadOnlyVector(0.1f, 0.1f, 0.1f, 0.1f), // 2
], DistanceMetric.Euclidean);
score[0].Should().Be(2);
score[1].Should().Be(1);
}
}
}
22 changes: 11 additions & 11 deletions BrightData/ExtensionMethods.TensorSegment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,68 +43,68 @@ public static WeightedIndexList ToSparse(this IReadOnlyNumericSegment<float> seg
/// <param name="segment"></param>
/// <returns></returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static (float Min, float Max, uint MinIndex, uint MaxIndex) GetMinAndMaxValues(this IReadOnlyNumericSegment<float> segment) => segment.ApplyReadOnlySpan(x => x.GetMinAndMaxValues());
public static (T Min, T Max, uint MinIndex, uint MaxIndex) GetMinAndMaxValues<T>(this IReadOnlyNumericSegment<T> segment) where T : unmanaged, INumber<T>, IMinMaxValue<T> => segment.ApplyReadOnlySpan(x => x.GetMinAndMaxValues());

/// <summary>
/// Returns the index with the minimum value from this tensor segment
/// </summary>
/// <param name="segment"></param>
/// <returns></returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)] public static uint GetMinimumIndex(this IReadOnlyNumericSegment<float> segment) => GetMinAndMaxValues(segment).MinIndex;
[MethodImpl(MethodImplOptions.AggressiveInlining)] public static uint GetMinimumIndex<T>(this IReadOnlyNumericSegment<T> segment) where T : unmanaged, INumber<T>, IMinMaxValue<T> => GetMinAndMaxValues(segment).MinIndex;

/// <summary>
/// Returns the index with the maximum value from this tensor segment
/// </summary>
/// <param name="segment"></param>
/// <returns></returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)] public static uint GetMaximumIndex(this IReadOnlyNumericSegment<float> segment) => GetMinAndMaxValues(segment).MaxIndex;
[MethodImpl(MethodImplOptions.AggressiveInlining)] public static uint GetMaximumIndex<T>(this IReadOnlyNumericSegment<T> segment) where T : unmanaged, INumber<T>, IMinMaxValue<T> => GetMinAndMaxValues(segment).MaxIndex;

/// <summary>
/// Sums all values
/// </summary>
/// <param name="segment"></param>
/// <returns></returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)] public static float Sum(this IReadOnlyNumericSegment<float> segment) => segment.ApplyReadOnlySpan(x => x.Sum());
[MethodImpl(MethodImplOptions.AggressiveInlining)] public static T Sum<T>(this IReadOnlyNumericSegment<T> segment) where T : unmanaged, IBinaryFloatingPointIeee754<T> => segment.ApplyReadOnlySpan(x => x.Sum());

/// <summary>
/// Finds cosine distance (0 for perpendicular, 1 for orthogonal, 2 for opposite) between this and another vector
/// </summary>
/// <param name="vector"></param>
/// <param name="other"></param>
/// <returns></returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)] public static float CosineDistance(this IReadOnlyNumericSegment<float> vector, IReadOnlyNumericSegment<float> other) => vector.ApplyReadOnlySpans(other, (x,y) => x.CosineDistance(y));
[MethodImpl(MethodImplOptions.AggressiveInlining)] public static T CosineDistance<T>(this IReadOnlyNumericSegment<T> vector, IReadOnlyNumericSegment<T> other) where T : unmanaged, IBinaryFloatingPointIeee754<T> => vector.ReduceReadOnlySpans(other, (x,y) => x.CosineDistance(y));

/// <summary>
/// Finds the euclidean distance between this and another vector
/// </summary>
/// <param name="vector"></param>
/// <param name="other"></param>
/// <returns></returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)] public static float EuclideanDistance(this IReadOnlyNumericSegment<float> vector, IReadOnlyNumericSegment<float> other) => vector.ApplyReadOnlySpans(other, (x,y) => x.EuclideanDistance(y));
[MethodImpl(MethodImplOptions.AggressiveInlining)] public static T EuclideanDistance<T>(this IReadOnlyNumericSegment<T> vector, IReadOnlyNumericSegment<T> other) where T : unmanaged, IBinaryFloatingPointIeee754<T> => vector.ReduceReadOnlySpans(other, (x,y) => x.EuclideanDistance(y));

/// <summary>
/// Finds the manhattan distance between this and another vector
/// </summary>
/// <param name="vector"></param>
/// <param name="other"></param>
/// <returns></returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)] public static float ManhattanDistance(this IReadOnlyNumericSegment<float> vector, IReadOnlyNumericSegment<float> other) => vector.ApplyReadOnlySpans(other, (x,y) => x.ManhattanDistance(y));
[MethodImpl(MethodImplOptions.AggressiveInlining)] public static T ManhattanDistance<T>(this IReadOnlyNumericSegment<T> vector, IReadOnlyNumericSegment<T> other) where T : unmanaged, IBinaryFloatingPointIeee754<T> => vector.ReduceReadOnlySpans(other, (x,y) => x.ManhattanDistance(y));

/// <summary>
/// Finds the mean squared distance between this and another vector
/// </summary>
/// <param name="vector"></param>
/// <param name="other"></param>
/// <returns></returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)] public static float MeanSquaredDistance(this IReadOnlyNumericSegment<float> vector, IReadOnlyNumericSegment<float> other) => vector.ApplyReadOnlySpans(other, (x,y) => x.MeanSquaredDistance(y));
[MethodImpl(MethodImplOptions.AggressiveInlining)] public static T MeanSquaredDistance<T>(this IReadOnlyNumericSegment<T> vector, IReadOnlyNumericSegment<T> other) where T : unmanaged, IBinaryFloatingPointIeee754<T> => vector.ReduceReadOnlySpans(other, (x,y) => x.MeanSquaredDistance(y));

/// <summary>
/// Finds the squared euclidean distance between this and another vector
/// </summary>
/// <param name="vector"></param>
/// <param name="other"></param>
/// <returns></returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)] public static float SquaredEuclideanDistance(this IReadOnlyNumericSegment<float> vector, IReadOnlyNumericSegment<float> other) => vector.ApplyReadOnlySpans(other, (x,y) => x.SquaredEuclideanDistance(y));
[MethodImpl(MethodImplOptions.AggressiveInlining)] public static T SquaredEuclideanDistance<T>(this IReadOnlyNumericSegment<T> vector, IReadOnlyNumericSegment<T> other) where T : unmanaged, IBinaryFloatingPointIeee754<T> => vector.ReduceReadOnlySpans(other, (x,y) => x.SquaredEuclideanDistance(y));

/// <summary>
/// Finds the distance between this and another vector
Expand All @@ -113,7 +113,7 @@ public static WeightedIndexList ToSparse(this IReadOnlyNumericSegment<float> seg
/// <param name="other"></param>
/// <param name="distance"></param>
/// <returns></returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)] public static float FindDistance(this IReadOnlyNumericSegment<float> vector, IReadOnlyNumericSegment<float> other, DistanceMetric distance) => vector.ApplyReadOnlySpans(other, (x,y) => x.FindDistance(y, distance));
[MethodImpl(MethodImplOptions.AggressiveInlining)] public static T FindDistance<T>(this IReadOnlyNumericSegment<T> vector, IReadOnlyNumericSegment<T> other, DistanceMetric distance) where T : unmanaged, IBinaryFloatingPointIeee754<T> => vector.ReduceReadOnlySpans(other, (x,y) => x.FindDistance(y, distance));

/// <summary>
/// Splits this tensor segment into multiple contiguous tensor segments
Expand Down Expand Up @@ -417,7 +417,7 @@ public static void ApplyReadOnlySpans<T>(this IReadOnlyNumericSegment<T> segment
/// <param name="segment2"></param>
/// <param name="callback"></param>
/// <returns></returns>
public static RT ApplyReadOnlySpans<T, RT>(this IReadOnlyNumericSegment<T> segment1, IReadOnlyNumericSegment<T> segment2, TransformReadOnlySpans<T, RT> callback) where T: unmanaged, INumber<T>
public static RT ReduceReadOnlySpans<T, RT>(this IReadOnlyNumericSegment<T> segment1, IReadOnlyNumericSegment<T> segment2, TransformReadOnlySpans<T, RT> callback) where T: unmanaged, INumber<T>
{
SpanOwner<T> temp1 = SpanOwner<T>.Empty, temp2 = SpanOwner<T>.Empty;
bool wasTemp1Used = false, wasTemp2Used = false;
Expand Down
2 changes: 2 additions & 0 deletions BrightData/Interfaces.LinearAlgebra.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Numerics;
using BrightData.Helper;
using BrightData.LinearAlgebra;
using CommunityToolkit.HighPerformance.Buffers;

Expand Down Expand Up @@ -1711,4 +1712,5 @@ public interface ICostFunction<T> where T: unmanaged, INumber<T>
/// <returns></returns>
IReadOnlyNumericSegment<T> Gradient(IReadOnlyNumericSegment<T> predicted, IReadOnlyNumericSegment<T> expected);
}

}
120 changes: 120 additions & 0 deletions BrightData/Interfaces.VectorIndexing.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Numerics;
using System.Text;
using System.Threading.Tasks;

namespace BrightData
{
/// <summary>
/// Represents how vectors should be indexed
/// </summary>
public enum VectorIndexStrategy
{
/// <summary>
/// Flat indexing
/// </summary>
Flat
}

/// <summary>
/// Determines how vectors are stored
/// </summary>
public enum VectorStorageType
{
/// <summary>
/// Vectors are stored in memory
/// </summary>
InMemory
}

/// <summary>
/// Responsible for storing vectors
/// </summary>
public interface IStoreVectors : IHaveSize
{
/// <summary>
/// Storage type
/// </summary>
VectorStorageType StorageType { get; }

/// <summary>
/// Size of each vector (fixed)
/// </summary>
uint VectorSize { get; }

/// <summary>
/// Removes a vector at the specified index
/// </summary>
/// <param name="index"></param>
void Remove(uint index);
}

/// <summary>
/// Stores typed vectors
/// </summary>
/// <typeparam name="T"></typeparam>
public interface IStoreVectors<T> : IStoreVectors where T : unmanaged, IBinaryFloatingPointIeee754<T>, IMinMaxValue<T>
{
/// <summary>
/// Adds a vector
/// </summary>
/// <param name="vector"></param>
/// <returns></returns>
uint Add(IReadOnlyVector<T> vector);

/// <summary>
/// Returns a segment at the specified index
/// </summary>
/// <param name="index"></param>
/// <returns></returns>
IReadOnlyNumericSegment<T> this[uint index] { get; }

/// <summary>
/// Passes each vector to the callback, possible in parallel
/// </summary>
/// <param name="callback"></param>
void ForEach(Action<IReadOnlyVector<T>, uint> callback);
}

/// <summary>
/// A vector set index
/// </summary>
public interface IVectorIndex<T> where T: unmanaged, IBinaryFloatingPointIeee754<T>, IMinMaxValue<T>
{
/// <summary>
/// The vector storage for the index
/// </summary>
IStoreVectors<T> Storage { get; }

/// <summary>
/// Adds a vector to the index
/// </summary>
/// <param name="vector"></param>
/// <returns></returns>
uint Add(IReadOnlyVector<T> vector);

/// <summary>
/// Removes a vector at the specified index
/// </summary>
/// <param name="index"></param>
void Remove(uint index);

/// <summary>
/// Returns a list of vector indices ranked by the distance between that vector and a comparison vector
/// </summary>
/// <param name="vector">Vector to compare</param>
/// <param name="distanceMetric"></param>
/// <returns></returns>
IEnumerable<uint> Rank(IReadOnlyVector<T> vector, DistanceMetric distanceMetric);

/// <summary>
/// Returns the index of the closest vector in the set to each of the supplied vectors
/// </summary>
/// <param name="vector">Vectors to compare</param>
/// <param name="distanceMetric"></param>
/// <returns></returns>
uint[] Closest(IReadOnlyVector<T>[] vector, DistanceMetric distanceMetric);
}
}
2 changes: 1 addition & 1 deletion BrightData/LinearAlgebra/Clustering/KMeans.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using System.Collections.Generic;
using System.Linq;
using BrightData.Types;
using BrightData.LinearAlgebra.VectorIndexing;

namespace BrightData.LinearAlgebra.Clustering
{
Expand Down
Loading

0 comments on commit 48eec47

Please sign in to comment.