diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 9cbacee8e8d..890b01e99a8 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -463,6 +463,7 @@ add_library( src/hash/sha256_hash.cu src/hash/sha384_hash.cu src/hash/sha512_hash.cu + src/hash/xxhash_32.cu src/hash/xxhash_64.cu src/interop/dlpack.cpp src/interop/arrow_utilities.cpp diff --git a/cpp/include/cudf/hashing.hpp b/cpp/include/cudf/hashing.hpp index 307a52cd242..279b841c23a 100644 --- a/cpp/include/cudf/hashing.hpp +++ b/cpp/include/cudf/hashing.hpp @@ -166,6 +166,26 @@ std::unique_ptr sha512( rmm::cuda_stream_view stream = cudf::get_default_stream(), rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref()); +/** + * @brief Computes the XXHash_32 hash value of each row in the given table + * + * This function computes the hash of each column using the `seed` for the first column + * and the resulting hash as a seed for the next column and so on. + * The result is a uint32 value for each row. + * + * @param input The table of columns to hash + * @param seed Optional seed value to use for the hash function + * @param stream CUDA stream used for device memory operations and kernel launches + * @param mr Device memory resource used to allocate the returned column's device memory + * + * @returns A column where each row is the hash of a row from the input + */ +std::unique_ptr xxhash_32( + table_view const& input, + uint32_t seed = DEFAULT_HASH_SEED, + rmm::cuda_stream_view stream = cudf::get_default_stream(), + rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref()); + /** * @brief Computes the XXHash_64 hash value of each row in the given table * diff --git a/cpp/include/cudf/hashing/detail/hashing.hpp b/cpp/include/cudf/hashing/detail/hashing.hpp index 7cb80081a95..39bfeb633a3 100644 --- a/cpp/include/cudf/hashing/detail/hashing.hpp +++ b/cpp/include/cudf/hashing/detail/hashing.hpp @@ -61,6 +61,11 @@ std::unique_ptr sha512(table_view const& input, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr); +std::unique_ptr xxhash_32(table_view const& input, + uint64_t seed, + rmm::cuda_stream_view, + rmm::device_async_resource_ref mr); + std::unique_ptr xxhash_64(table_view const& input, uint64_t seed, rmm::cuda_stream_view, diff --git a/cpp/include/cudf/hashing/detail/xxhash_32.cuh b/cpp/include/cudf/hashing/detail/xxhash_32.cuh new file mode 100644 index 00000000000..4cc607fc186 --- /dev/null +++ b/cpp/include/cudf/hashing/detail/xxhash_32.cuh @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace cudf::hashing::detail { + +template +struct XXHash_32 { + using result_type = std::uint32_t; + + CUDF_HOST_DEVICE constexpr XXHash_32(uint32_t seed = cudf::DEFAULT_HASH_SEED) : _impl{seed} {} + + __device__ constexpr result_type operator()(Key const& key) const { return this->_impl(key); } + + __device__ constexpr result_type compute_bytes(cuda::std::byte const* bytes, + std::uint64_t size) const + { + return this->_impl.compute_hash(bytes, size); + } + + private: + template + __device__ constexpr result_type compute(T const& key) const + { + return this->compute_bytes(reinterpret_cast(&key), sizeof(T)); + } + + cuco::xxhash_32 _impl; +}; + +template <> +XXHash_32::result_type __device__ inline XXHash_32::operator()(bool const& key) const +{ + return this->compute(static_cast(key)); +} + +template <> +XXHash_32::result_type __device__ inline XXHash_32::operator()(float const& key) const +{ + return this->compute(normalize_nans_and_zeros(key)); +} + +template <> +XXHash_32::result_type __device__ inline XXHash_32::operator()( + double const& key) const +{ + return this->compute(normalize_nans_and_zeros(key)); +} + +template <> +XXHash_32::result_type + __device__ inline XXHash_32::operator()(cudf::string_view const& key) const +{ + return this->compute_bytes(reinterpret_cast(key.data()), + key.size_bytes()); +} + +template <> +XXHash_32::result_type + __device__ inline XXHash_32::operator()(numeric::decimal32 const& key) const +{ + return this->compute(key.value()); +} + +template <> +XXHash_32::result_type + __device__ inline XXHash_32::operator()(numeric::decimal64 const& key) const +{ + return this->compute(key.value()); +} + +template <> +XXHash_32::result_type + __device__ inline XXHash_32::operator()(numeric::decimal128 const& key) const +{ + return this->compute(key.value()); +} + +template <> +hash_value_type __device__ inline XXHash_32::operator()( + cudf::list_view const& key) const +{ + CUDF_UNREACHABLE("List column hashing is not supported"); +} + +template <> +hash_value_type __device__ inline XXHash_32::operator()( + cudf::struct_view const& key) const +{ + CUDF_UNREACHABLE("Direct hashing of struct_view is not supported"); +} + +} // namespace cudf::hashing::detail diff --git a/cpp/src/hash/xxhash_32.cu b/cpp/src/hash/xxhash_32.cu new file mode 100644 index 00000000000..7a864fbc98e --- /dev/null +++ b/cpp/src/hash/xxhash_32.cu @@ -0,0 +1,138 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +namespace cudf { +namespace hashing { +namespace detail { + +namespace { + +using hash_value_type = uint32_t; + +/** + * @brief Computes the hash value of a row in the given table. + * + * @tparam Nullate A cudf::nullate type describing whether to check for nulls. + */ +template +class device_row_hasher { + public: + device_row_hasher(Nullate nulls, table_device_view const& t, hash_value_type seed) + : _check_nulls(nulls), _table(t), _seed(seed) + { + } + + __device__ auto operator()(size_type row_index) const noexcept + { + return cudf::detail::accumulate( + _table.begin(), + _table.end(), + _seed, + [row_index, nulls = _check_nulls] __device__(auto hash, auto column) { + return cudf::type_dispatcher( + column.type(), element_hasher_adapter{}, column, row_index, nulls, hash); + }); + } + + /** + * @brief Computes the hash value of an element in the given column. + */ + class element_hasher_adapter { + public: + template ())> + __device__ hash_value_type operator()(column_device_view const& col, + size_type const row_index, + Nullate const _check_nulls, + hash_value_type const _seed) const noexcept + { + if (_check_nulls && col.is_null(row_index)) { + return cuda::std::numeric_limits::max(); + } + auto const hasher = XXHash_32{_seed}; + return hasher(col.element(row_index)); + } + + template ())> + __device__ hash_value_type operator()(column_device_view const&, + size_type const, + Nullate const, + hash_value_type const) const noexcept + { + CUDF_UNREACHABLE("Unsupported type for XXHash_32"); + } + }; + + Nullate const _check_nulls; + table_device_view const _table; + hash_value_type const _seed; +}; + +} // namespace + +std::unique_ptr xxhash_32(table_view const& input, + uint32_t seed, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto output = make_numeric_column(data_type(type_to_id()), + input.num_rows(), + mask_state::UNALLOCATED, + stream, + mr); + + // Return early if there's nothing to hash + if (input.num_columns() == 0 || input.num_rows() == 0) { return output; } + + bool const nullable = has_nulls(input); + auto const input_view = table_device_view::create(input, stream); + auto output_view = output->mutable_view(); + + // Compute the hash value for each row + thrust::tabulate(rmm::exec_policy(stream), + output_view.begin(), + output_view.end(), + device_row_hasher(nullable, *input_view, seed)); + + return output; +} + +} // namespace detail + +std::unique_ptr xxhash_32(table_view const& input, + uint32_t seed, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + CUDF_FUNC_RANGE(); + return detail::xxhash_32(input, seed, stream, mr); +} + +} // namespace hashing +} // namespace cudf diff --git a/cpp/src/io/orc/dict_enc.cu b/cpp/src/io/orc/dict_enc.cu index 0cb5c382631..6dc3a9d793c 100644 --- a/cpp/src/io/orc/dict_enc.cu +++ b/cpp/src/io/orc/dict_enc.cu @@ -18,6 +18,7 @@ #include #include +#include #include #include diff --git a/cpp/src/io/parquet/chunk_dict.cu b/cpp/src/io/parquet/chunk_dict.cu index b85ebf2fa1a..5ca94a0bec7 100644 --- a/cpp/src/io/parquet/chunk_dict.cu +++ b/cpp/src/io/parquet/chunk_dict.cu @@ -18,6 +18,7 @@ #include #include +#include #include #include diff --git a/cpp/src/join/join_common_utils.cuh b/cpp/src/join/join_common_utils.cuh index 4f75908fe72..2d67f3391d9 100644 --- a/cpp/src/join/join_common_utils.cuh +++ b/cpp/src/join/join_common_utils.cuh @@ -20,6 +20,7 @@ #include #include #include +#include #include #include diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index adf512811cc..6984b10eefd 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -190,6 +190,7 @@ ConfigureTest( hashing/sha256_test.cpp hashing/sha384_test.cpp hashing/sha512_test.cpp + hashing/xxhash_32_test.cpp hashing/xxhash_64_test.cpp ) diff --git a/cpp/tests/hashing/xxhash_32_test.cpp b/cpp/tests/hashing/xxhash_32_test.cpp new file mode 100644 index 00000000000..18ba7b1d0be --- /dev/null +++ b/cpp/tests/hashing/xxhash_32_test.cpp @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include + +class XXHash_32_Test : public cudf::test::BaseFixture {}; + +TEST_F(XXHash_32_Test, TestInteger) +{ + auto col1 = cudf::test::fixed_width_column_wrapper{{0, 42, 825}}; + auto constexpr seed = 0u; + auto const output = cudf::hashing::xxhash_32(cudf::table_view({col1}), seed); + + // Expected results were generated with the reference implementation: + // https://github.com/Cyan4973/xxHash/blob/dev/xxhash.h + auto expected = + cudf::test::fixed_width_column_wrapper({148298089u, 1161967057u, 1066694813u}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), expected); +} + +TEST_F(XXHash_32_Test, TestDouble) +{ + auto col1 = cudf::test::fixed_width_column_wrapper{{-8., 25., 90.}}; + auto constexpr seed = 42u; + + auto const output = cudf::hashing::xxhash_32(cudf::table_view({col1}), seed); + + // Expected results were generated with the reference implementation: + // https://github.com/Cyan4973/xxHash/blob/dev/xxhash.h + auto expected = + cudf::test::fixed_width_column_wrapper({2276435783u, 3120212431u, 3454197470u}); + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), expected); +} + +TEST_F(XXHash_32_Test, StringType) +{ + auto col1 = cudf::test::strings_column_wrapper({"I", "am", "AI"}); + auto constexpr seed = 825u; + + auto output = cudf::hashing::xxhash_32(cudf::table_view({col1}), seed); + + // Expected results were generated with the reference implementation: + // https://github.com/Cyan4973/xxHash/blob/dev/xxhash.h + auto expected = + cudf::test::fixed_width_column_wrapper({320624298u, 1612654309u, 1409499009u}); + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), expected); +} diff --git a/python/cudf/cudf/core/indexed_frame.py b/python/cudf/cudf/core/indexed_frame.py index 72bb85821fa..a288eb245e0 100644 --- a/python/cudf/cudf/core/indexed_frame.py +++ b/python/cudf/cudf/core/indexed_frame.py @@ -2838,16 +2838,22 @@ def hash_values( Parameters ---------- - method : {'murmur3', 'md5', 'xxhash64'}, default 'murmur3' + method : {'murmur3', 'xxhash32', 'xxhash64', 'md5', 'sha1', 'sha224', 'sha256', 'sha384', 'sha512'}, default 'murmur3' Hash function to use: * murmur3: MurmurHash3 hash function - * md5: MD5 hash function + * xxhash32: xxHash32 hash function * xxhash64: xxHash64 hash function + * md5: MD5 hash function + * sha1: SHA-1 hash function + * sha224: SHA-224 hash function + * sha256: SHA-256 hash function + * sha384: SHA-384 hash function + * sha512: SHA-512 hash function seed : int, optional Seed value to use for the hash function. This parameter is only - supported for 'murmur3' and 'xxhash64'. + supported for 'murmur3', 'xxhash32', and 'xxhash64'. Returns @@ -2902,7 +2908,7 @@ def hash_values( 2 fe061786ea286a515b772d91b0dfcd70 dtype: object """ - seed_hash_methods = {"murmur3", "xxhash64"} + seed_hash_methods = {"murmur3", "xxhash32", "xxhash64"} if seed is None: seed = 0 elif method not in seed_hash_methods: @@ -2916,6 +2922,8 @@ def hash_values( ) if method == "murmur3": plc_column = plc.hashing.murmurhash3_x86_32(plc_table, seed) + elif method == "xxhash32": + plc_column = plc.hashing.xxhash_32(plc_table, seed) elif method == "xxhash64": plc_column = plc.hashing.xxhash_64(plc_table, seed) elif method == "md5": diff --git a/python/cudf/cudf/tests/test_dataframe.py b/python/cudf/cudf/tests/test_dataframe.py index d04fd97dcbd..51de33576c0 100644 --- a/python/cudf/cudf/tests/test_dataframe.py +++ b/python/cudf/cudf/tests/test_dataframe.py @@ -1440,6 +1440,7 @@ def test_assign_callable(mapping): "sha256", "sha384", "sha512", + "xxhash32", "xxhash64", ], ) @@ -1447,6 +1448,7 @@ def test_assign_callable(mapping): def test_dataframe_hash_values(nrows, method, seed): warning_expected = seed is not None and method not in { "murmur3", + "xxhash32", "xxhash64", } potential_warning = ( @@ -1472,6 +1474,7 @@ def test_dataframe_hash_values(nrows, method, seed): "sha256": object, "sha384": object, "sha512": object, + "xxhash32": np.uint32, "xxhash64": np.uint64, } assert out.dtype == expected_dtypes[method] @@ -1486,7 +1489,7 @@ def test_dataframe_hash_values(nrows, method, seed): assert_eq(gdf["a"].hash_values(method=method, seed=seed), out_one) -@pytest.mark.parametrize("method", ["murmur3", "xxhash64"]) +@pytest.mark.parametrize("method", ["murmur3", "xxhash32", "xxhash64"]) def test_dataframe_hash_values_seed(method): gdf = cudf.DataFrame() data = np.arange(10) @@ -1500,6 +1503,34 @@ def test_dataframe_hash_values_seed(method): assert_neq(out_one, out_two) +def test_dataframe_hash_values_xxhash32(): + # xxhash32 has no built-in implementation in Python and we don't want to + # add a testing dependency, so we use regression tests against known good + # values. + gdf = cudf.DataFrame({"a": [0.0, 1.0, 2.0, np.inf, np.nan]}) + gdf["b"] = -gdf["a"] + out_a = gdf["a"].hash_values(method="xxhash32", seed=0) + expected_a = cudf.Series( + [3736311059, 2307980487, 2906647130, 746578903, 4294967295], + dtype=np.uint32, + ) + assert_eq(out_a, expected_a) + + out_b = gdf["b"].hash_values(method="xxhash32", seed=42) + expected_b = cudf.Series( + [1076387279, 2261349915, 531498073, 650869264, 4294967295], + dtype=np.uint32, + ) + assert_eq(out_b, expected_b) + + out_df = gdf.hash_values(method="xxhash32", seed=0) + expected_df = cudf.Series( + [1223721700, 2885793241, 1920811472, 1146715602, 4294967295], + dtype=np.uint32, + ) + assert_eq(out_df, expected_df) + + def test_dataframe_hash_values_xxhash64(): # xxhash64 has no built-in implementation in Python and we don't want to # add a testing dependency, so we use regression tests against known good diff --git a/python/pylibcudf/pylibcudf/hashing.pxd b/python/pylibcudf/pylibcudf/hashing.pxd index 2d070ddda69..4421d540af6 100644 --- a/python/pylibcudf/pylibcudf/hashing.pxd +++ b/python/pylibcudf/pylibcudf/hashing.pxd @@ -16,6 +16,10 @@ cpdef Table murmurhash3_x64_128( uint64_t seed=* ) +cpdef Column xxhash_32( + Table input, + uint32_t seed=* +) cpdef Column xxhash_64( Table input, diff --git a/python/pylibcudf/pylibcudf/hashing.pyi b/python/pylibcudf/pylibcudf/hashing.pyi index a849f5d0729..d535d842a18 100644 --- a/python/pylibcudf/pylibcudf/hashing.pyi +++ b/python/pylibcudf/pylibcudf/hashing.pyi @@ -9,6 +9,7 @@ LIBCUDF_DEFAULT_HASH_SEED: Final[int] def murmurhash3_x86_32(input: Table, seed: int = ...) -> Column: ... def murmurhash3_x64_128(input: Table, seed: int = ...) -> Table: ... +def xxhash_32(input: Table, seed: int = ...) -> Column: ... def xxhash_64(input: Table, seed: int = ...) -> Column: ... def md5(input: Table) -> Column: ... def sha1(input: Table) -> Column: ... diff --git a/python/pylibcudf/pylibcudf/hashing.pyx b/python/pylibcudf/pylibcudf/hashing.pyx index 548cffc0ce8..4cb977b506a 100644 --- a/python/pylibcudf/pylibcudf/hashing.pyx +++ b/python/pylibcudf/pylibcudf/hashing.pyx @@ -13,6 +13,7 @@ from pylibcudf.libcudf.hash cimport ( sha256 as cpp_sha256, sha384 as cpp_sha384, sha512 as cpp_sha512, + xxhash_32 as cpp_xxhash_32, xxhash_64 as cpp_xxhash_64, ) from pylibcudf.libcudf.table.table cimport table @@ -30,6 +31,7 @@ __all__ = [ "sha256", "sha384", "sha512", + "xxhash_32", "xxhash_64", ] @@ -95,6 +97,37 @@ cpdef Table murmurhash3_x64_128( return Table.from_libcudf(move(c_result)) +cpdef Column xxhash_32( + Table input, + uint32_t seed=DEFAULT_HASH_SEED +): + """Computes the xxHash 32-bit hash value of each row in the given table. + + For details, see :cpp:func:`xxhash_32`. + + Parameters + ---------- + input : Table + The table of columns to hash + seed : uint32_t + Optional seed value to use for the hash function + + Returns + ------- + pylibcudf.Column + A column where each row is the hash of a row from the input + """ + + cdef unique_ptr[column] c_result + with nogil: + c_result = cpp_xxhash_32( + input.view(), + seed + ) + + return Column.from_libcudf(move(c_result)) + + cpdef Column xxhash_64( Table input, uint64_t seed=DEFAULT_HASH_SEED diff --git a/python/pylibcudf/pylibcudf/libcudf/hash.pxd b/python/pylibcudf/pylibcudf/libcudf/hash.pxd index 4e8a01b41a5..b1640c4ad67 100644 --- a/python/pylibcudf/pylibcudf/libcudf/hash.pxd +++ b/python/pylibcudf/pylibcudf/libcudf/hash.pxd @@ -44,6 +44,11 @@ cdef extern from "cudf/hashing.hpp" namespace "cudf::hashing" nogil: const table_view& input ) except +libcudf_exception_handler + cdef unique_ptr[column] xxhash_32( + const table_view& input, + const uint32_t seed + ) except +libcudf_exception_handler + cdef unique_ptr[column] xxhash_64( const table_view& input, const uint64_t seed diff --git a/python/pylibcudf/pylibcudf/tests/test_hashing.py b/python/pylibcudf/pylibcudf/tests/test_hashing.py index 83fb50fa4ef..87d21618a75 100644 --- a/python/pylibcudf/pylibcudf/tests/test_hashing.py +++ b/python/pylibcudf/pylibcudf/tests/test_hashing.py @@ -34,7 +34,9 @@ def hash_single_uint32(val, seed=0): def hash_combine_32(lhs, rhs): - return np.uint32(lhs ^ (rhs + 0x9E3779B9 + (lhs << 6) + (lhs >> 2))) + return np.uint32( + int((lhs ^ (rhs + 0x9E3779B9 + (lhs << 6) + (lhs >> 2)))) % 2**32 + ) def uint_hash_combine_32(lhs, rhs): @@ -80,22 +82,6 @@ def list_struct_table(): return data -def python_hash_value(x, method): - if method == "murmurhash3_x86_32": - return libcudf_mmh3_x86_32(x) - elif method == "murmurhash3_x64_128": - hasher = mmh3.mmh3_x64_128(seed=plc.hashing.LIBCUDF_DEFAULT_HASH_SEED) - hasher.update(x) - # libcudf returns a tuple of two 64-bit integers - return hasher.utupledigest() - elif method == "xxhash_64": - return xxhash.xxh64( - x, seed=plc.hashing.LIBCUDF_DEFAULT_HASH_SEED - ).intdigest() - else: - return getattr(hashlib, method)(x).hexdigest() - - @pytest.mark.parametrize( "method", ["sha1", "sha224", "sha256", "sha384", "sha512", "md5"] ) @@ -115,6 +101,23 @@ def py_hasher(val): assert_column_eq(got, expect) +def test_hash_column_xxhash32(pa_scalar_input_column, plc_scalar_input_tbl): + def py_hasher(val): + return xxhash.xxh32( + scalar_to_binary(val), seed=plc.hashing.LIBCUDF_DEFAULT_HASH_SEED + ).intdigest() + + expect = pa.array( + [py_hasher(val) for val in pa_scalar_input_column.to_pylist()], + type=pa.uint32(), + ) + got = plc.hashing.xxhash_32( + plc_scalar_input_tbl, plc.hashing.LIBCUDF_DEFAULT_HASH_SEED + ) + + assert_column_eq(got, expect) + + def test_hash_column_xxhash64(pa_scalar_input_column, plc_scalar_input_tbl): def py_hasher(val): return xxhash.xxh64( @@ -125,7 +128,9 @@ def py_hasher(val): [py_hasher(val) for val in pa_scalar_input_column.to_pylist()], type=pa.uint64(), ) - got = plc.hashing.xxhash_64(plc_scalar_input_tbl, 0) + got = plc.hashing.xxhash_64( + plc_scalar_input_tbl, plc.hashing.LIBCUDF_DEFAULT_HASH_SEED + ) assert_column_eq(got, expect)