Skip to content

Commit

Permalink
Refactor stream mode setup for gtests (#15337)
Browse files Browse the repository at this point in the history
Setting up the stream mode logic was duplicated in `testing_main.hpp` and `error_handing_test.cu`.
Refactoring the logic will help setup for a large strings test fixture in a follow-on PR.

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - https://github.com/nvdbaranec
  - Mark Harris (https://github.com/harrism)

URL: #15337
  • Loading branch information
davidwendt authored Apr 1, 2024
1 parent 0a8807e commit e5f9e2d
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 33 deletions.
57 changes: 36 additions & 21 deletions cpp/include/cudf_test/testing_main.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,32 @@ inline auto parse_cudf_test_opts(int argc, char** argv)
}
}

/**
* @brief Sets up stream mode memory resource adaptor
*
* The resource adaptor is only set as the current device resource if the
* stream mode is enabled.
*
* The caller must keep the return object alive for the life of the test runs.
*
* @param cmd_opts Command line options returned by parse_cudf_test_opts
* @return Memory resource adaptor
*/
inline auto make_stream_mode_adaptor(cxxopts::ParseResult const& cmd_opts)
{
auto resource = rmm::mr::get_current_device_resource();
auto const stream_mode = cmd_opts["stream_mode"].as<std::string>();
auto const stream_error_mode = cmd_opts["stream_error_mode"].as<std::string>();
auto const error_on_invalid_stream = (stream_error_mode == "error");
auto const check_default_stream = (stream_mode == "new_cudf_default");
auto adaptor =
make_stream_checking_resource_adaptor(resource, error_on_invalid_stream, check_default_stream);
if ((stream_mode == "new_cudf_default") || (stream_mode == "new_testing_default")) {
rmm::mr::set_current_device_resource(&adaptor);
}
return adaptor;
}

/**
* @brief Macro that defines main function for gtest programs that use rmm
*
Expand All @@ -155,25 +181,14 @@ inline auto parse_cudf_test_opts(int argc, char** argv)
* function parses the command line to customize test behavior, like the
* allocation mode used for creating the default memory resource.
*/
#define CUDF_TEST_PROGRAM_MAIN() \
int main(int argc, char** argv) \
{ \
::testing::InitGoogleTest(&argc, argv); \
auto const cmd_opts = parse_cudf_test_opts(argc, argv); \
auto const rmm_mode = cmd_opts["rmm_mode"].as<std::string>(); \
auto resource = cudf::test::create_memory_resource(rmm_mode); \
rmm::mr::set_current_device_resource(resource.get()); \
\
auto const stream_mode = cmd_opts["stream_mode"].as<std::string>(); \
if ((stream_mode == "new_cudf_default") || (stream_mode == "new_testing_default")) { \
auto const stream_error_mode = cmd_opts["stream_error_mode"].as<std::string>(); \
auto const error_on_invalid_stream = (stream_error_mode == "error"); \
auto const check_default_stream = (stream_mode == "new_cudf_default"); \
auto adaptor = make_stream_checking_resource_adaptor( \
resource.get(), error_on_invalid_stream, check_default_stream); \
rmm::mr::set_current_device_resource(&adaptor); \
return RUN_ALL_TESTS(); \
} \
\
return RUN_ALL_TESTS(); \
#define CUDF_TEST_PROGRAM_MAIN() \
int main(int argc, char** argv) \
{ \
::testing::InitGoogleTest(&argc, argv); \
auto const cmd_opts = parse_cudf_test_opts(argc, argv); \
auto const rmm_mode = cmd_opts["rmm_mode"].as<std::string>(); \
auto resource = cudf::test::create_memory_resource(rmm_mode); \
rmm::mr::set_current_device_resource(resource.get()); \
auto adaptor = make_stream_mode_adaptor(cmd_opts); \
return RUN_ALL_TESTS(); \
}
14 changes: 2 additions & 12 deletions cpp/tests/error/error_handling_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -128,17 +128,7 @@ TEST(DebugAssert, cudf_assert_true)
int main(int argc, char** argv)
{
::testing::InitGoogleTest(&argc, argv);
auto const cmd_opts = parse_cudf_test_opts(argc, argv);
auto const stream_mode = cmd_opts["stream_mode"].as<std::string>();
if ((stream_mode == "new_cudf_default") || (stream_mode == "new_testing_default")) {
auto resource = rmm::mr::get_current_device_resource();
auto const stream_error_mode = cmd_opts["stream_error_mode"].as<std::string>();
auto const error_on_invalid_stream = (stream_error_mode == "error");
auto const check_default_stream = (stream_mode == "new_cudf_default");
auto adaptor = make_stream_checking_resource_adaptor(
resource, error_on_invalid_stream, check_default_stream);
rmm::mr::set_current_device_resource(&adaptor);
return RUN_ALL_TESTS();
}
auto const cmd_opts = parse_cudf_test_opts(argc, argv);
auto adaptor = make_stream_mode_adaptor(cmd_opts);
return RUN_ALL_TESTS();
}

0 comments on commit e5f9e2d

Please sign in to comment.