Skip to content

Commit

Permalink
#1934: Add additional test cases for model switching
Browse files Browse the repository at this point in the history
  • Loading branch information
thearusable committed Dec 6, 2022
1 parent 93fcbff commit eee6783
Showing 1 changed file with 94 additions and 0 deletions.
94 changes: 94 additions & 0 deletions tests/unit/collection/test_lb_data_retention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -266,4 +266,98 @@ TEST_F(TestLBDataRetention, test_lbdata_config_retention_higher) {
}
}

TEST_F(TestLBDataRetention, test_lbdata_retention_model_switch_1) {
static constexpr int const first_stage_num_phases = 11;

// We must have more or equal number of elements than nodes for this test to
// work properly
EXPECT_GE(num_elms, vt::theContext()->getNumNodes());
auto range = vt::Index1D(num_elms);
vt::vrt::collection::CollectionProxy<TestCol> proxy;

// Construct two collections
runInEpochCollective([&]{
proxy = vt::theCollection()->constructCollective<TestCol>(
range, "test_lbdata_config_retention_higher"
);
});

// Get the base model, assert it's valid
auto base = theLBManager()->getBaseLoadModel();
EXPECT_NE(base, nullptr);

// Create a new models
auto model_10_phases = std::make_shared<PersistenceMedianLastN>(base, 10U);
auto model_1_phase = std::make_shared<PersistenceMedianLastN>(base, 1U);

// Set model which needs 10 phases of data
theLBManager()->setLoadModel(model_10_phases);

for (uint32_t i=0; i<first_stage_num_phases; ++i) {
runInEpochCollective([&]{
// Do some work.
proxy.broadcastCollective<MyMsg<TestCol>, TestCol::colHandler>();
});
// Go to the next phase.
vt::thePhase()->nextPhaseCollective();
}

// Set model which needs only 1 phase of data
theLBManager()->setLoadModel(model_1_phase);

runInEpochCollective([&]{
// Do some work.
proxy.broadcastCollective<MyMsg<TestCol>, TestCol::colHandler>();
});
// Go to the next phase.
vt::thePhase()->nextPhaseCollective();
}

TEST_F(TestLBDataRetention, test_lbdata_retention_model_switch_2) {
static constexpr int const first_stage_num_phases = 6;

// We must have more or equal number of elements than nodes for this test to
// work properly
EXPECT_GE(num_elms, vt::theContext()->getNumNodes());
auto range = vt::Index1D(num_elms);
vt::vrt::collection::CollectionProxy<TestCol> proxy;

// Construct two collections
runInEpochCollective([&]{
proxy = vt::theCollection()->constructCollective<TestCol>(
range, "test_lbdata_config_retention_higher"
);
});

// Get the base model, assert it's valid
auto base = theLBManager()->getBaseLoadModel();
EXPECT_NE(base, nullptr);

// Create a new models
auto model_10_phases = std::make_shared<PersistenceMedianLastN>(base, 10U);
auto model_1_phase = std::make_shared<PersistenceMedianLastN>(base, 1U);

// Set model which needs 10 phases of data
theLBManager()->setLoadModel(model_10_phases);

for (uint32_t i=0; i<first_stage_num_phases; ++i) {
runInEpochCollective([&]{
// Do some work.
proxy.broadcastCollective<MyMsg<TestCol>, TestCol::colHandler>();
});
// Go to the next phase.
vt::thePhase()->nextPhaseCollective();
}

// Set model which needs only 1 phase of data
theLBManager()->setLoadModel(model_1_phase);

runInEpochCollective([&]{
// Do some work.
proxy.broadcastCollective<MyMsg<TestCol>, TestCol::colHandler>();
});
// Go to the next phase.
vt::thePhase()->nextPhaseCollective();
}

}}} // end namespace vt::tests::unit

0 comments on commit eee6783

Please sign in to comment.