diff --git a/include/openrand/base_state.h b/include/openrand/base_state.h index edd646d..d886813 100644 --- a/include/openrand/base_state.h +++ b/include/openrand/base_state.h @@ -100,7 +100,7 @@ class BaseRNG { /** * @brief Generates a number from a uniform distribution between a and b. * - * @Note For integer type, this method is slightly biased towards lower numbers. + * @Note For integer types, consider using @ref range for greater control. * * @tparam T Data type to be returned. Can be float or double. * double. @@ -114,13 +114,12 @@ class BaseRNG { static_assert(!(std::is_integral_v && sizeof(T) > sizeof(int32_t)), "64 bit int not yet supported"); - T range = high - low; + T r = high - low; if constexpr (std::is_floating_point_v) { - return low + range * rand(); + return low + r * rand(); } else if constexpr (std::is_integral_v) { - // Thanks to (https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/) - return low + T(((uint64_t)rand() * (uint64_t)range) >> 32); + return low + range(r); } } @@ -190,6 +189,53 @@ class BaseRNG { return mean + randn() * std_dev; } + /** + * @brief Generates a random integer of certain range + * + * This uses the method described in [1] to generate a random integer + * of range [0..N) + * + * + * @attention if using non-biased version, please make sure that N is not + * too large [2] + * + * @tparam biased if true, the faster, but slightly biased variant is used. + * @tparam T integer type (<=32 bit) to be returned + * + * @param N A random integral of range [0..N) will be returned + * @return T random number from a uniform distribution between 0 and N + * + * + * @note [1] https://lemire.me/blog/2016/06/30/fast-random-shuffling/ + * @note [2] If N=2^b (b<32), pr(taking the branch) = p = 1/2^(32-b). For N=2^24, + * this value is 1/2^8 = .4%, quite negligible. But GPU complicates this simple math. + * Assuming a warp size of 32, the probability of a thread taking the branch becomes + * 1 - (1-p)^32. For N=2^24, that value is 11.8%. For N=2^20, it's 0.8%. + */ + template + DEVICE T range(const T N) { + // static_assert(std::is_integral_v && sizeof(T) <= sizeof(int32_t), + // "64 bit int not yet supported"); + + uint32_t x = gen().template draw(); + uint64_t res = static_cast(x) * static_cast(N); + + if constexpr (biased) { + return static_cast(res >> 32); + } else { + uint32_t leftover = static_cast(res); + if (leftover < N) { + uint32_t threshold = -N % N; + while (leftover < threshold) { + x = gen().template draw(); + res = static_cast(x) * static_cast(N); + leftover = static_cast(res); + } + } + return static_cast(res); + } + } + /** * @brief Generates a random number from a gamma distribution with shape alpha * and scale b. diff --git a/include/openrand/philox.h b/include/openrand/philox.h index ee1c7cb..7f0398c 100644 --- a/include/openrand/philox.h +++ b/include/openrand/philox.h @@ -62,8 +62,8 @@ class Philox : public BaseRNG { * @param ctr1 (Optional) Another 32-bit counter exposed for advanced use. */ DEVICE Philox(uint64_t seed, uint32_t ctr, - uint32_t global_seed = openrand::DEFAULT_GLOBAL_SEED, - uint32_t ctr1 = 0x12345) + uint32_t global_seed = openrand::DEFAULT_GLOBAL_SEED, + uint32_t ctr1 = 0x12345) : seed_hi((uint32_t)(seed >> 32)), seed_lo((uint32_t)(seed & 0xFFFFFFFF)), ctr0(ctr), @@ -76,8 +76,7 @@ class Philox : public BaseRNG { generate(); static_assert(std::is_same_v || std::is_same_v); - if constexpr (std::is_same_v) - return _out[0]; + if constexpr (std::is_same_v) return _out[0]; // Not wrapping this block in else{} would lead to compiler warning else { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 7b7c89f..541fff5 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -21,9 +21,19 @@ target_link_libraries( GTest::gtest_main ) +add_executable( + base + test_base.cpp +) +target_link_libraries( + base + GTest::gtest_main +) + include(GoogleTest) gtest_discover_tests(uniform) gtest_discover_tests(normal) +gtest_discover_tests(base) # Statistical tests, not run through gtest framework diff --git a/tests/test_base.cpp b/tests/test_base.cpp new file mode 100644 index 0000000..e97c878 --- /dev/null +++ b/tests/test_base.cpp @@ -0,0 +1,62 @@ +// @HEADER +// ******************************************************************************* +// OpenRAND * +// A Performance Portable, Reproducible Random Number Generation Library * +// * +// Copyright (c) 2023, Michigan State University * +// * +// Permission is hereby granted, free of charge, to any person obtaining a copy * +// of this software and associated documentation files (the "Software"), to deal * +// in the Software without restriction, including without limitation the rights * +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * +// copies of the Software, and to permit persons to whom the Software is * +// furnished to do so, subject to the following conditions: * +// * +// The above copyright notice and this permission notice shall be included in * +// all copies or substantial portions of the Software. * +// * +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * +// SOFTWARE. * +//******************************************************************************** +// @HEADER + +#include +#include +#include +#include +#include + +#include + +template +void test_rangev2(int seed) { + RNG rng(seed, 0); + for (int i = 0; i < 10; i++) { + ASSERT_LT(rng.range(10), 10); + } + const int v = (1 << 27); + for (int i = 0; i < 10; i++) { + // had to create tmp variable. Couldn't directly pass to ASSERT_LT for + // some reason + auto x = rng.template range(10); + ASSERT_LT(x, 10); + auto y = rng.template range(v); + ASSERT_LT(y, v); + } + for (int i = 0; i < 10; i++) { + auto z = rng.template range(1000); + ASSERT_LT(z, 1000); + } +} + +TEST(BASE, rangev2) { + test_rangev2(42); + test_rangev2(37); + test_rangev2(12345); + test_rangev2(1234); +} \ No newline at end of file