Skip to content

Commit

Permalink
added single to many vector distance calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
Jack Dermody committed Jul 27, 2024
1 parent cd71ccb commit 3ff5c4e
Show file tree
Hide file tree
Showing 30 changed files with 6,308 additions and 3,318 deletions.
5 changes: 4 additions & 1 deletion BrightData.Cuda/BrightData.Cuda.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

74 changes: 71 additions & 3 deletions BrightData.Cuda/CudaLinearAlgebraProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using BrightData.LinearAlgebra;
using BrightData.LinearAlgebra.Segments;
using CommunityToolkit.HighPerformance.Buffers;
using static System.Runtime.InteropServices.JavaScript.JSType;

namespace BrightData.Cuda
{
Expand Down Expand Up @@ -75,6 +76,23 @@ public override INumericSegment<float> CreateSegment(params float[] data)
return new CudaTensorSegment(deviceMemory, Provider);
}

/// <inheritdoc />
public override INumericSegment<float> CreateSegment(IReadOnlyNumericSegment<float> segment)
{
var deviceMemory = Provider.Allocate(segment.Size);
var temp = SpanOwner<float>.Empty;
var wasTempUsed = false;
try {
var span = segment.GetSpan(ref temp, out wasTempUsed);
deviceMemory.CopyToDevice(span, 0);
}
finally {
if (wasTempUsed)
temp.Dispose();
}
return new CudaTensorSegment(deviceMemory, Provider);
}

internal CudaTensorSegment CreateCudaTensorSegment(IDeviceMemoryPtr ptr) => new(ptr, Provider);

/// <inheritdoc />
Expand Down Expand Up @@ -713,14 +731,14 @@ public override IMatrix<float> CreateMatrix(uint rows, uint columns, bool initia
}

/// <inheritdoc />
public override IMatrix<float> FindDistances(IVector<float>[] vectors, IReadOnlyList<IVector<float>> compareTo, DistanceMetric distanceMetric)
public override IMatrix<float> FindDistances(IReadOnlyList<IReadOnlyNumericSegment<float>> vectors, IReadOnlyList<IReadOnlyNumericSegment<float>> compareTo, DistanceMetric distanceMetric)
{
if (distanceMetric is not (DistanceMetric.Euclidean or DistanceMetric.Manhattan or DistanceMetric.Cosine))
throw new NotImplementedException();

var size = vectors[0].Size;
var rows = (uint)compareTo.Count;
var columns = (uint)vectors.Length;
var columns = (uint)vectors.Count;
var ret = Provider.Allocate(rows * columns, null, true);

using (var vectorPtr = new PtrToDeviceMemoryList(vectors.Cast<IHaveDeviceMemory>().ToArray()))
Expand All @@ -746,7 +764,7 @@ public override IMatrix<float> FindDistances(IVector<float>[] vectors, IReadOnly
return ones.Subtract(distance);
}

Provider.CalculateDistances(size, columns, rows,
Provider.CalculateMultiDistances(size, columns, rows,
vectorPtr.DevicePointer,
compareToPtr.DevicePointer,
ret.DevicePointer,
Expand All @@ -764,6 +782,56 @@ public override IMatrix<float> FindDistances(IVector<float>[] vectors, IReadOnly
return matrix;
}

public override IVector<float> FindDistances(IReadOnlyNumericSegment<float> vector, IReadOnlyList<IReadOnlyNumericSegment<float>> compareTo, DistanceMetric distanceMetric)
{
if (distanceMetric is not (DistanceMetric.Euclidean or DistanceMetric.Manhattan or DistanceMetric.Cosine))
throw new NotImplementedException();

var size = vector.Size;
var numVectors = (uint)compareTo.Count;
var ret = Provider.Allocate(numVectors, null, true);

var vectorPtr = (IHaveDeviceMemory)vector;
using (var compareToPtr = new PtrToDeviceMemoryList(compareTo.Cast<IHaveDeviceMemory>().ToArray())) {
if (distanceMetric == DistanceMetric.Cosine) {
var aa = Provider.Allocate(numVectors, null, true);
var bb = Provider.Allocate(numVectors, null, true);
Provider.CosineDistances(size, numVectors,
vectorPtr.Memory.DevicePointer,
compareToPtr.DevicePointer,
aa.DevicePointer,
ret.DevicePointer,
bb.DevicePointer
);
using var ones = CreateVector(numVectors, _ => 1f);
using var vectorMagnitude = new CudaVector(CreateCudaTensorSegment(aa), this);
using var vectorSqrt = vectorMagnitude.Sqrt();
using var compareToMagnitude = new CudaVector(CreateCudaTensorSegment(bb), this);
using var compareToSqrt = compareToMagnitude.Sqrt();
using var norms = vectorSqrt.PointwiseMultiply(compareToSqrt);
using var result = new CudaVector(CreateCudaTensorSegment(ret), this);
using var distance = result.PointwiseDivide(norms);
return ones.Subtract(distance);
}

Provider.CalculateDistances(size, numVectors,
vectorPtr.Memory.DevicePointer,
compareToPtr.DevicePointer,
ret.DevicePointer,
distanceMetric
);
}

IVector<float> matrix = new CudaVector(CreateCudaTensorSegment(ret), this);
if (distanceMetric == DistanceMetric.Euclidean) {
var sqrt = matrix.Sqrt();
matrix.Dispose();
matrix = sqrt;
}

return matrix;
}

/// <inheritdoc />
public override void BindThread()
{
Expand Down
155 changes: 92 additions & 63 deletions BrightData.Cuda/CudaProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ readonly CuFunction
_tensorReverseMaxPool,
_tensorReverseIm2Col,
_isFinite,
_calculateDistance,
_calculateMultiDistances,
_calculateDistances,
_cosineDistances,
_roundInPlace,
_scale
;
Expand Down Expand Up @@ -197,66 +199,68 @@ public CudaProvider(BrightDataContext context, string? cudaKernelPath, string? c
});
_cuda.SetCurrent();

_pointwiseMultiply = _kernel.LoadFunction("PointwiseMultiply");
_addInPlace = _kernel.LoadFunction("AddInPlace");
_subtractInPlace = _kernel.LoadFunction("SubtractInPlace");
_addToEachRow = _kernel.LoadFunction("AddToEachRow");
_addToEachColumn = _kernel.LoadFunction("AddToEachColumn");
_multiplyByEachRow = _kernel.LoadFunction("MultiplyByEachRow");
_multiplyByEachColumn = _kernel.LoadFunction("MultiplyByEachColumn");
_tanh = _kernel.LoadFunction("TanH");
_tanhDerivative = _kernel.LoadFunction("TanHDerivative");
_sigmoid = _kernel.LoadFunction("Sigmoid");
_sigmoidDerivative = _kernel.LoadFunction("SigmoidDerivative");
_sumRows = _kernel.LoadFunction("SumRows");
_relu = _kernel.LoadFunction("RELU");
_reluDerivative = _kernel.LoadFunction("RELUDerivative");
_memSet = _kernel.LoadFunction("MemSet");
_memCpy = _kernel.LoadFunction("MemCpy");
_sumColumns = _kernel.LoadFunction("SumColumns");
_pointwiseDivide = _kernel.LoadFunction("PointwiseDivide");
_sqrt = _kernel.LoadFunction("Sqrt");
_findMinAndMax = _kernel.LoadFunction("FindMinAndMax");
_sumValues = _kernel.LoadFunction("SumValues");
_findStdDev = _kernel.LoadFunction("FindStdDev");
_constrain = _kernel.LoadFunction("Constrain");
_pow = _kernel.LoadFunction("Pow");
_diagonal = _kernel.LoadFunction("Diagonal");
_l1Regularisation = _kernel.LoadFunction("L1Regularisation");
_leakyRelu = _kernel.LoadFunction("LeakyRELU");
_leakyReluDerivative = _kernel.LoadFunction("LeakyRELUDerivative");
_pointwiseDivideRows = _kernel.LoadFunction("PointwiseDivideRows");
_pointwiseDivideColumns = _kernel.LoadFunction("PointwiseDivideColumns");
_splitRows = _kernel.LoadFunction("SplitRows");
_splitColumns = _kernel.LoadFunction("SplitColumns");
_concatRows = _kernel.LoadFunction("ConcatRows");
_concatColumns = _kernel.LoadFunction("ConcatColumns");
_euclideanDistance = _kernel.LoadFunction("EuclideanDistance");
_manhattanDistance = _kernel.LoadFunction("ManhattanDistance");
_cosineDistance = _kernel.LoadFunction("CosineDistance");
_abs = _kernel.LoadFunction("Abs");
_normalise = _kernel.LoadFunction("Normalise");
_softmaxVector = _kernel.LoadFunction("SoftmaxVector");
_multiCosine = _kernel.LoadFunction("MultiCosineDistance");
_log = _kernel.LoadFunction("Log");
_exp = _kernel.LoadFunction("Exp");
_vectorAddInPlace = _kernel.LoadFunction("VectorAddInPlace");
_vectorCopyRandom = _kernel.LoadFunction("VectorCopyRandom");
_copyToMatrixColumns = _kernel.LoadFunction("CopyToMatrixColumns");
_copyToMatrixRows = _kernel.LoadFunction("CopyToMatrixRows");
_tensorAddPadding = _kernel.LoadFunction("TensorAddPadding");
_tensorRemovePadding = _kernel.LoadFunction("TensorRemovePadding");
_tensorIm2Col = _kernel.LoadFunction("TensorIm2Col");
_softmaxDerivative = _kernel.LoadFunction("SoftmaxDerivative");
_reverse = _kernel.LoadFunction("Reverse");
_rotateInPlace = _kernel.LoadFunction("RotateInPlace");
_tensorMaxPool = _kernel.LoadFunction("TensorMaxPool");
_tensorReverseMaxPool = _kernel.LoadFunction("TensorReverseMaxPool");
_tensorReverseIm2Col = _kernel.LoadFunction("TensorReverseIm2Col");
_isFinite = _kernel.LoadFunction("IsFinite");
_calculateDistance = _kernel.LoadFunction("CalculateDistances");
_roundInPlace = _kernel.LoadFunction("RoundInPlace");
_scale = _kernel.LoadFunction("Scale");
_pointwiseMultiply = _kernel.LoadFunction("PointwiseMultiply");
_addInPlace = _kernel.LoadFunction("AddInPlace");
_subtractInPlace = _kernel.LoadFunction("SubtractInPlace");
_addToEachRow = _kernel.LoadFunction("AddToEachRow");
_addToEachColumn = _kernel.LoadFunction("AddToEachColumn");
_multiplyByEachRow = _kernel.LoadFunction("MultiplyByEachRow");
_multiplyByEachColumn = _kernel.LoadFunction("MultiplyByEachColumn");
_tanh = _kernel.LoadFunction("TanH");
_tanhDerivative = _kernel.LoadFunction("TanHDerivative");
_sigmoid = _kernel.LoadFunction("Sigmoid");
_sigmoidDerivative = _kernel.LoadFunction("SigmoidDerivative");
_sumRows = _kernel.LoadFunction("SumRows");
_relu = _kernel.LoadFunction("RELU");
_reluDerivative = _kernel.LoadFunction("RELUDerivative");
_memSet = _kernel.LoadFunction("MemSet");
_memCpy = _kernel.LoadFunction("MemCpy");
_sumColumns = _kernel.LoadFunction("SumColumns");
_pointwiseDivide = _kernel.LoadFunction("PointwiseDivide");
_sqrt = _kernel.LoadFunction("Sqrt");
_findMinAndMax = _kernel.LoadFunction("FindMinAndMax");
_sumValues = _kernel.LoadFunction("SumValues");
_findStdDev = _kernel.LoadFunction("FindStdDev");
_constrain = _kernel.LoadFunction("Constrain");
_pow = _kernel.LoadFunction("Pow");
_diagonal = _kernel.LoadFunction("Diagonal");
_l1Regularisation = _kernel.LoadFunction("L1Regularisation");
_leakyRelu = _kernel.LoadFunction("LeakyRELU");
_leakyReluDerivative = _kernel.LoadFunction("LeakyRELUDerivative");
_pointwiseDivideRows = _kernel.LoadFunction("PointwiseDivideRows");
_pointwiseDivideColumns = _kernel.LoadFunction("PointwiseDivideColumns");
_splitRows = _kernel.LoadFunction("SplitRows");
_splitColumns = _kernel.LoadFunction("SplitColumns");
_concatRows = _kernel.LoadFunction("ConcatRows");
_concatColumns = _kernel.LoadFunction("ConcatColumns");
_euclideanDistance = _kernel.LoadFunction("EuclideanDistance");
_manhattanDistance = _kernel.LoadFunction("ManhattanDistance");
_cosineDistance = _kernel.LoadFunction("CosineDistance");
_cosineDistances = _kernel.LoadFunction("CosineDistances");
_abs = _kernel.LoadFunction("Abs");
_normalise = _kernel.LoadFunction("Normalise");
_softmaxVector = _kernel.LoadFunction("SoftmaxVector");
_multiCosine = _kernel.LoadFunction("CosineMultiDistance");
_log = _kernel.LoadFunction("Log");
_exp = _kernel.LoadFunction("Exp");
_vectorAddInPlace = _kernel.LoadFunction("VectorAddInPlace");
_vectorCopyRandom = _kernel.LoadFunction("VectorCopyRandom");
_copyToMatrixColumns = _kernel.LoadFunction("CopyToMatrixColumns");
_copyToMatrixRows = _kernel.LoadFunction("CopyToMatrixRows");
_tensorAddPadding = _kernel.LoadFunction("TensorAddPadding");
_tensorRemovePadding = _kernel.LoadFunction("TensorRemovePadding");
_tensorIm2Col = _kernel.LoadFunction("TensorIm2Col");
_softmaxDerivative = _kernel.LoadFunction("SoftmaxDerivative");
_reverse = _kernel.LoadFunction("Reverse");
_rotateInPlace = _kernel.LoadFunction("RotateInPlace");
_tensorMaxPool = _kernel.LoadFunction("TensorMaxPool");
_tensorReverseMaxPool = _kernel.LoadFunction("TensorReverseMaxPool");
_tensorReverseIm2Col = _kernel.LoadFunction("TensorReverseIm2Col");
_isFinite = _kernel.LoadFunction("IsFinite");
_calculateMultiDistances = _kernel.LoadFunction("CalculateMultiDistances");
_calculateDistances = _kernel.LoadFunction("CalculateDistances");
_roundInPlace = _kernel.LoadFunction("RoundInPlace");
_scale = _kernel.LoadFunction("Scale");
}

/// <summary>
Expand Down Expand Up @@ -1058,9 +1062,22 @@ internal void MultiCosine(uint size, uint columns, uint rows, CuDevicePtr vector
);
}

internal void CalculateDistances(uint size, uint columns, uint rows, CuDevicePtr vectorPtr, CuDevicePtr compareToPtr, CuDevicePtr ret, DistanceMetric distanceMetric)
internal void CosineDistances(uint size, uint numVectors, CuDevicePtr vectorPtr, CuDevicePtr compareToPtr, CuDevicePtr aa, CuDevicePtr ret, CuDevicePtr bb)
{
InvokeTensor(_calculateDistance, null, size, columns, rows,
InvokeMatrix(_cosineDistances, null, size, numVectors,
vectorPtr,
compareToPtr,
aa,
ret,
bb,
numVectors,
size
);
}

internal void CalculateMultiDistances(uint size, uint columns, uint rows, CuDevicePtr vectorPtr, CuDevicePtr compareToPtr, CuDevicePtr ret, DistanceMetric distanceMetric)
{
InvokeTensor(_calculateMultiDistances, null, size, columns, rows,
vectorPtr,
compareToPtr,
ret,
Expand All @@ -1071,6 +1088,18 @@ internal void CalculateDistances(uint size, uint columns, uint rows, CuDevicePtr
);
}

internal void CalculateDistances(uint size, uint numVectors, CuDevicePtr vectorPtr, CuDevicePtr compareToPtr, CuDevicePtr ret, DistanceMetric distanceMetric)
{
InvokeMatrix(_calculateDistances, null, size, numVectors,
vectorPtr,
compareToPtr,
ret,
numVectors,
size,
(uint)distanceMetric
);
}

internal void CopyToMatrixRows(uint rows, uint columns, CudaDeviceVariable<CuDevicePtr> from, IDeviceMemoryPtr to, CuStream* stream = null)
{
InvokeMatrix(_copyToMatrixRows, stream, rows, columns, from.DevicePointer, to.DevicePointer, rows, columns);
Expand Down
3 changes: 2 additions & 1 deletion BrightData.Cuda/CudaTensorSegment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

namespace BrightData.Cuda
{
internal class CudaTensorSegment(IDeviceMemoryPtr data, CudaProvider provider) : INumericSegment<float>
internal class CudaTensorSegment(IDeviceMemoryPtr data, CudaProvider provider) : INumericSegment<float>, IHaveDeviceMemory
{
const string CudaSegmentType = "cuda";

Expand All @@ -29,6 +29,7 @@ public static bool IsCuda(IReadOnlyNumericSegment<float> segment, [NotNullWhen(t
public int Release() => DeviceMemory.Release();

public IDeviceMemoryPtr DeviceMemory { get; } = data;
IDeviceMemoryPtr IHaveDeviceMemory.Memory => DeviceMemory;
public bool IsValid => DeviceMemory.IsValid;
public uint Size => DeviceMemory.Size;
public string SegmentType => CudaSegmentType;
Expand Down
Loading

0 comments on commit 3ff5c4e

Please sign in to comment.