Skip to content

Commit

Permalink
Twidmer/radix sort for float keys
Browse files Browse the repository at this point in the history
  • Loading branch information
nvtw authored and mmacklin committed Dec 3, 2024
1 parent fa0add3 commit 29d6783
Show file tree
Hide file tree
Showing 7 changed files with 166 additions and 23 deletions.
3 changes: 3 additions & 0 deletions warp/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3062,6 +3062,9 @@ def __init__(self):
self.core.radix_sort_pairs_int_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
self.core.radix_sort_pairs_int_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]

self.core.radix_sort_pairs_float_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
self.core.radix_sort_pairs_float_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]

self.core.runlength_encode_int_host.argtypes = [
ctypes.c_uint64,
ctypes.c_uint64,
Expand Down
85 changes: 85 additions & 0 deletions warp/native/sort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,90 @@ void radix_sort_pairs_host(int* keys, int* values, int n)
}
}

//http://stereopsis.com/radix.html
inline unsigned int radix_float_to_int(float f)
{
unsigned int i = reinterpret_cast<unsigned int&>(f);
unsigned int mask = (unsigned int)(-(int)(i >> 31)) | 0x80000000;
return i ^ mask;
}

void radix_sort_pairs_host(float* keys, int* values, int n)
{
static unsigned int tables[2][1 << 16];
memset(tables, 0, sizeof(tables));

float* auxKeys = keys + n;
int* auxValues = values + n;

// build histograms
for (int i=0; i < n; ++i)
{
const unsigned int k = radix_float_to_int(keys[i]);
const unsigned short low = k & 0xffff;
const unsigned short high = k >> 16;

++tables[0][low];
++tables[1][high];
}

// convert histograms to offset tables in-place
unsigned int offlow = 0;
unsigned int offhigh = 0;

for (int i=0; i < 65536; ++i)
{
const unsigned int newofflow = offlow + tables[0][i];
const unsigned int newoffhigh = offhigh + tables[1][i];

tables[0][i] = offlow;
tables[1][i] = offhigh;

offlow = newofflow;
offhigh = newoffhigh;
}

// pass 1 - sort by low 16 bits
for (int i=0; i < n; ++i)
{
// lookup offset of input
const float f = keys[i];
const unsigned int k = radix_float_to_int(f);
const int v = values[i];
const unsigned int b = k & 0xffff;

// find offset and increment
const unsigned int offset = tables[0][b]++;

auxKeys[offset] = f;
auxValues[offset] = v;
}

// pass 2 - sort by high 16 bits
for (int i=0; i < n; ++i)
{
// lookup offset of input
const float f = auxKeys[i];
const unsigned int k = radix_float_to_int(f);
const int v = auxValues[i];

const unsigned int b = k >> 16;

const unsigned int offset = tables[1][b]++;

keys[offset] = f;
values[offset] = v;
}
}

#if !WP_ENABLE_CUDA

void radix_sort_reserve(void* context, int n, void** mem_out, size_t* size_out) {}

void radix_sort_pairs_int_device(uint64_t keys, uint64_t values, int n) {}

void radix_sort_pairs_float_device(uint64_t keys, uint64_t values, int n) {}

#endif // !WP_ENABLE_CUDA


Expand All @@ -92,3 +170,10 @@ void radix_sort_pairs_int_host(uint64_t keys, uint64_t values, int n)
reinterpret_cast<int *>(keys),
reinterpret_cast<int *>(values), n);
}

void radix_sort_pairs_float_host(uint64_t keys, uint64_t values, int n)
{
radix_sort_pairs_host(
reinterpret_cast<float *>(keys),
reinterpret_cast<int *>(values), n);
}
34 changes: 34 additions & 0 deletions warp/native/sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,37 @@ void radix_sort_pairs_int_device(uint64_t keys, uint64_t values, int n)
reinterpret_cast<int *>(keys),
reinterpret_cast<int *>(values), n);
}

void radix_sort_pairs_device(void* context, float* keys, int* values, int n)
{
ContextGuard guard(context);

cub::DoubleBuffer<float> d_keys(keys, keys + n);
cub::DoubleBuffer<int> d_values(values, values + n);

RadixSortTemp temp;
radix_sort_reserve(WP_CURRENT_CONTEXT, n, &temp.mem, &temp.size);

// sort
check_cuda(cub::DeviceRadixSort::SortPairs(
temp.mem,
temp.size,
d_keys,
d_values,
n, 0, 32,
(cudaStream_t)cuda_stream_get_current()));

if (d_keys.Current() != keys)
memcpy_d2d(WP_CURRENT_CONTEXT, keys, d_keys.Current(), sizeof(float)*n);

if (d_values.Current() != values)
memcpy_d2d(WP_CURRENT_CONTEXT, values, d_values.Current(), sizeof(int)*n);
}

void radix_sort_pairs_float_device(uint64_t keys, uint64_t values, int n)
{
radix_sort_pairs_device(
WP_CURRENT_CONTEXT,
reinterpret_cast<float *>(keys),
reinterpret_cast<int *>(values), n);
}
4 changes: 3 additions & 1 deletion warp/native/sort.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@

void radix_sort_reserve(void* context, int n, void** mem_out=NULL, size_t* size_out=NULL);
void radix_sort_pairs_host(int* keys, int* values, int n);
void radix_sort_pairs_device(void* context, int* keys, int* values, int n);
void radix_sort_pairs_host(float* keys, int* values, int n);
void radix_sort_pairs_device(void* context, int* keys, int* values, int n);
void radix_sort_pairs_device(void* context, float* keys, int* values, int n);
3 changes: 3 additions & 0 deletions warp/native/warp.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ extern "C"
WP_API void radix_sort_pairs_int_host(uint64_t keys, uint64_t values, int n);
WP_API void radix_sort_pairs_int_device(uint64_t keys, uint64_t values, int n);

WP_API void radix_sort_pairs_float_host(uint64_t keys, uint64_t values, int n);
WP_API void radix_sort_pairs_float_device(uint64_t keys, uint64_t values, int n);

WP_API void runlength_encode_int_host(uint64_t values, uint64_t run_values, uint64_t run_lengths, uint64_t run_count, int n);
WP_API void runlength_encode_int_device(uint64_t values, uint64_t run_values, uint64_t run_lengths, uint64_t run_count, int n);

Expand Down
56 changes: 34 additions & 22 deletions warp/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,37 +79,49 @@ def test_array_scan_error_unsupported_dtype(test, device):


def test_radix_sort_pairs(test, device):
keys = wp.array((7, 2, 8, 4, 1, 6, 5, 3, 0, 0, 0, 0, 0, 0, 0, 0), dtype=int, device=device)
values = wp.array((1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0), dtype=int, device=device)
wp.utils.radix_sort_pairs(keys, values, 8)
assert_np_equal(keys.numpy()[:8], np.array((1, 2, 3, 4, 5, 6, 7, 8)))
assert_np_equal(values.numpy()[:8], np.array((5, 2, 8, 4, 7, 6, 1, 3)))
keyTypes = [int, wp.float32]

for keyType in keyTypes:
keys = wp.array((7, 2, 8, 4, 1, 6, 5, 3, 0, 0, 0, 0, 0, 0, 0, 0), dtype=keyType, device=device)
values = wp.array((1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0), dtype=int, device=device)
wp.utils.radix_sort_pairs(keys, values, 8)
assert_np_equal(keys.numpy()[:8], np.array((1, 2, 3, 4, 5, 6, 7, 8)))
assert_np_equal(values.numpy()[:8], np.array((5, 2, 8, 4, 7, 6, 1, 3)))


def test_radix_sort_pairs_empty(test, device):
keys = wp.array((), dtype=int, device=device)
values = wp.array((), dtype=int, device=device)
wp.utils.radix_sort_pairs(keys, values, 0)
keyTypes = [int, wp.float32]

for keyType in keyTypes:
keys = wp.array((), dtype=keyType, device=device)
values = wp.array((), dtype=int, device=device)
wp.utils.radix_sort_pairs(keys, values, 0)


def test_radix_sort_pairs_error_insufficient_storage(test, device):
keys = wp.array((1, 2, 3), dtype=int, device=device)
values = wp.array((1, 2, 3), dtype=int, device=device)
with test.assertRaisesRegex(
RuntimeError,
r"Array storage must be large enough to contain 2\*count elements$",
):
wp.utils.radix_sort_pairs(keys, values, 3)
keyTypes = [int, wp.float32]

for keyType in keyTypes:
keys = wp.array((1, 2, 3), dtype=keyType, device=device)
values = wp.array((1, 2, 3), dtype=int, device=device)
with test.assertRaisesRegex(
RuntimeError,
r"Array storage must be large enough to contain 2\*count elements$",
):
wp.utils.radix_sort_pairs(keys, values, 3)


def test_radix_sort_pairs_error_unsupported_dtype(test, device):
keys = wp.array((1.0, 2.0, 3.0), dtype=float, device=device)
values = wp.array((1.0, 2.0, 3.0), dtype=float, device=device)
with test.assertRaisesRegex(
RuntimeError,
r"Unsupported data type$",
):
wp.utils.radix_sort_pairs(keys, values, 1)
keyTypes = [int, wp.float32]

for keyType in keyTypes:
keys = wp.array((1.0, 2.0, 3.0), dtype=keyType, device=device)
values = wp.array((1.0, 2.0, 3.0), dtype=float, device=device)
with test.assertRaisesRegex(
RuntimeError,
r"Unsupported data type$",
):
wp.utils.radix_sort_pairs(keys, values, 1)


def test_array_sum(test, device):
Expand Down
4 changes: 4 additions & 0 deletions warp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,15 @@ def radix_sort_pairs(keys, values, count: int):
if keys.device.is_cpu:
if keys.dtype == wp.int32 and values.dtype == wp.int32:
runtime.core.radix_sort_pairs_int_host(keys.ptr, values.ptr, count)
elif keys.dtype == wp.float32 and values.dtype == wp.int32:
runtime.core.radix_sort_pairs_float_host(keys.ptr, values.ptr, count)
else:
raise RuntimeError("Unsupported data type")
elif keys.device.is_cuda:
if keys.dtype == wp.int32 and values.dtype == wp.int32:
runtime.core.radix_sort_pairs_int_device(keys.ptr, values.ptr, count)
elif keys.dtype == wp.float32 and values.dtype == wp.int32:
runtime.core.radix_sort_pairs_float_device(keys.ptr, values.ptr, count)
else:
raise RuntimeError("Unsupported data type")

Expand Down

0 comments on commit 29d6783

Please sign in to comment.