Skip to content

Commit

Permalink
try to resolve & locally reproduce the fail
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed Dec 4, 2024
1 parent 240b527 commit 2e4a031
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions cpp/test/sparse/convert_csr.cu
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ class BitmapToCSRTest : public ::testing::TestWithParam<BitmapToCSRInputs<index_
}

std::random_device rd;
std::mt19937 gen(rd());
std::mt19937 gen(random_number = rd());
std::uniform_int_distribution<index_t> dis(0, total - 1);

while (num_ones > 0) {
Expand Down Expand Up @@ -318,8 +318,8 @@ class BitmapToCSRTest : public ::testing::TestWithParam<BitmapToCSRInputs<index_
size_t start_idx = row_ptrs1[i];
size_t end_idx = row_ptrs1[i + 1];

std::vector<int> cols1(col_indices1.begin() + start_idx, col_indices1.begin() + end_idx);
std::vector<int> cols2(col_indices2.begin() + start_idx, col_indices2.begin() + end_idx);
std::vector<index_t> cols1(col_indices1.begin() + start_idx, col_indices1.begin() + end_idx);
std::vector<index_t> cols2(col_indices2.begin() + start_idx, col_indices2.begin() + end_idx);

std::sort(cols1.begin(), cols1.end());
std::sort(cols2.begin(), cols2.end());
Expand Down Expand Up @@ -396,9 +396,13 @@ class BitmapToCSRTest : public ::testing::TestWithParam<BitmapToCSRInputs<index_

resource::sync_stream(handle);

ASSERT_TRUE(csr_compare(indptr_h, indices_h, indptr_expected_h, indices_expected_h));
ASSERT_TRUE(raft::devArrMatch<value_t>(
values_expected_d.data(), values_d.data(), nnz, raft::Compare<value_t>(), stream));
EXPECT_TRUE(csr_compare(indptr_h, indices_h, indptr_expected_h, indices_expected_h))
<< " n_row: " << params.n_rows << ", n_cols: " << params.n_cols << ", nnz: " << nnz
<< ", random_number: " << random_number;
EXPECT_TRUE(raft::devArrMatch<value_t>(
values_expected_d.data(), values_d.data(), nnz, raft::Compare<value_t>(), stream))
<< " n_row: " << params.n_rows << ", n_cols: " << params.n_cols << ", nnz: " << nnz
<< ", random_number: " << random_number;
}

protected:
Expand All @@ -418,6 +422,8 @@ class BitmapToCSRTest : public ::testing::TestWithParam<BitmapToCSRInputs<index_
rmm::device_uvector<index_t> indptr_expected_d;
rmm::device_uvector<index_t> indices_expected_d;
rmm::device_uvector<float> values_expected_d;

unsigned int random_number;
};

using BitmapToCSRTestI = BitmapToCSRTest<uint32_t, int, float>;
Expand Down

0 comments on commit 2e4a031

Please sign in to comment.