diff --git a/cajita/src/Cajita_Halo.hpp b/cajita/src/Cajita_Halo.hpp index 1460f6f8f..d351be74a 100644 --- a/cajita/src/Cajita_Halo.hpp +++ b/cajita/src/Cajita_Halo.hpp @@ -943,7 +943,7 @@ struct ArrayPackMemorySpace //---------------------------------------------------------------------------// /*! - \brief Array creation function. + \brief Halo creation function. \param pattern The pattern to build the halo from. \param width Must be less than or equal to the width of the array halo. \param arrays The arrays over which to build the halo. diff --git a/cajita/src/Cajita_SparseHalo.hpp b/cajita/src/Cajita_SparseHalo.hpp index c904af2fc..b927e6802 100644 --- a/cajita/src/Cajita_SparseHalo.hpp +++ b/cajita/src/Cajita_SparseHalo.hpp @@ -43,7 +43,7 @@ namespace Experimental communicator and halo size. The arrays must also reside in the same memory space. These requirements are checked at construction. */ -template class SparseHalo @@ -71,7 +71,7 @@ class SparseHalo }; //! data members in AoSoA structure - using aosoa_member_types = DataMembers; + using aosoa_member_types = DataTypes; //! AoSoA tuple type using tuple_type = Cabana::Tuple; @@ -114,13 +114,11 @@ class SparseHalo \brief constructor \tparam LocalGridType local grid type \param pattern The halo pattern to use for halo communication - \param local_grid_ptr pointer to sparse local grid - \param comm MPI communicator + \param sparse_array Sparse array to communicate */ - template + template SparseHalo( halo_pattern_type pattern, - const std::shared_ptr& local_grid_ptr, - MPI_Comm comm ) + const std::shared_ptr& sparse_array ) : _pattern( pattern ) { // Function to get the local id of the neighbor. @@ -151,9 +149,12 @@ class SparseHalo std::accumulate( soa_byte_array.begin(), soa_byte_array.end(), 0 ), static_cast( sizeof( tuple_type ) ) ); + // Get the local grid the array uses. + auto local_grid = sparse_array->layout().localGrid(); + // linear MPI rank ID of the current working rank _self_rank = - local_grid_ptr->neighborRank( std::array( { 0, 0, 0 } ) ); + local_grid->neighborRank( std::array( { 0, 0, 0 } ) ); // set the linear neighbor rank ID // set up correspondence between sending and receiving buffers @@ -161,7 +162,7 @@ class SparseHalo for ( const auto& n : neighbors ) { // neighbor rank linear ID - int rank = local_grid_ptr->neighborRank( n ); + int rank = local_grid->neighborRank( n ); // if neighbor is valid if ( rank >= 0 ) @@ -176,10 +177,10 @@ class SparseHalo _receive_tags.push_back( neighbor_id( flip_id( n ) ) ); // build communication data for owned entries - buildCommData( Own(), local_grid_ptr, n, _owned_buffers, + buildCommData( Own(), local_grid, n, _owned_buffers, _owned_tile_steering, _owned_tile_spaces ); // build communication data for ghosted entries - buildCommData( Ghost(), local_grid_ptr, n, _ghosted_buffers, + buildCommData( Ghost(), local_grid, n, _ghosted_buffers, _ghosted_tile_steering, _ghosted_tile_spaces ); auto& own_index_space = _owned_tile_spaces.back(); @@ -215,7 +216,7 @@ class SparseHalo \tparam DecompositionTag decomposition tag type \tparam LocalGridType sparse local grid type \param decomposition_tag tag to indicate if it's owned or ghosted halo - \param local_grid_ptr sparse local grid shared pointer + \param local_grid sparse local grid shared pointer \param nid neighbor local id (ijk in pattern) \param buffers buffer to be used to store communicated data \param steering steering to be used to guide communications @@ -223,7 +224,7 @@ class SparseHalo */ template void buildCommData( DecompositionTag decomposition_tag, - const std::shared_ptr& local_grid_ptr, + const std::shared_ptr& local_grid, const std::array& nid, std::vector& buffers, std::vector& steering, @@ -231,9 +232,8 @@ class SparseHalo { // get the halo sparse tile index space sharsed with the neighbor spaces.push_back( - local_grid_ptr - ->template sharedTileIndexSpace( - decomposition_tag, entity_type(), nid ) ); + local_grid->template sharedTileIndexSpace( + decomposition_tag, entity_type(), nid ) ); auto& index_space = spaces.back(); // allocate the buffer to store shared data with given neighbor @@ -251,10 +251,10 @@ class SparseHalo /*! \brief update tile index space according to current partition \tparam LocalGridType sparse local grid type - \param local_grid_ptr sparse local grid pointer + \param local_grid sparse local grid pointer */ template - void updateTileSpace( const std::shared_ptr& local_grid_ptr ) + void updateTileSpace( const std::shared_ptr& local_grid ) { // clear index space array first _owned_tile_spaces.clear(); @@ -267,7 +267,7 @@ class SparseHalo // get neighbor relative id auto& n = _valid_neighbor_ids[i]; // get neighbor linear MPI rank ID - int rank = local_grid_ptr->neighborRank( n ); + int rank = local_grid->neighborRank( n ); // check if neighbor rank is valid // the neighbor id should always be valid (as all should be // well-prepared during construction/initialization) @@ -275,11 +275,11 @@ class SparseHalo { // get shared tile index spcae from local grid _owned_tile_spaces.push_back( - local_grid_ptr + local_grid ->template sharedTileIndexSpace( Own(), entity_type(), n ) ); _ghosted_tile_spaces.push_back( - local_grid_ptr + local_grid ->template sharedTileIndexSpace( Ghost(), entity_type(), n ) ); @@ -1283,6 +1283,34 @@ class SparseHalo // SoA total bytes count std::size_t _soa_total_bytes; }; + +//---------------------------------------------------------------------------// +// Sparse halo creation. +//---------------------------------------------------------------------------// +/*! + \brief SparseHalo creation function. + \param pattern The pattern to build the sparse halo from. + \param array The sparse array over which to build the halo. +*/ +template +auto createSparseHalo( + const Pattern& pattern, + const std::shared_ptr< + SparseArray> + array ) +{ + using array_type = + SparseArray; + using memory_space = typename array_type::memory_space; + static constexpr std::size_t num_space_dim = array_type::num_space_dim; + return std::make_shared< + SparseHalo>( pattern, array ); +} + }; // namespace Experimental }; // end namespace Cajita diff --git a/cajita/unit_test/tstSparseHalo.hpp b/cajita/unit_test/tstSparseHalo.hpp index 5ed0fef45..a85e61d71 100644 --- a/cajita/unit_test/tstSparseHalo.hpp +++ b/cajita/unit_test/tstSparseHalo.hpp @@ -465,8 +465,8 @@ void haloScatterAndGatherTest( ReduceOp reduce_op, EntityType entity ) auto sparse_array = createSparseArray( std::string( "test_sparse_grid" ), *sparse_layout ); - SparseHalo - halo( NodeHaloPattern<3>(), local_grid, MPI_COMM_WORLD ); + auto halo = createSparseHalo( + NodeHaloPattern<3>(), sparse_array ); // sample valid halos on rank 0 and broadcast to other ranks // Kokkos::View tile_view; @@ -517,7 +517,7 @@ void haloScatterAndGatherTest( ReduceOp reduce_op, EntityType entity ) } ); sparse_array->resize( sparse_map.sizeCell() ); - halo.template register_halo( sparse_map ); + halo->template register_halo( sparse_map ); MPI_Barrier( MPI_COMM_WORLD ); } @@ -605,10 +605,10 @@ void haloScatterAndGatherTest( ReduceOp reduce_op, EntityType entity ) // halo scatter and gather /// false means the heighbors' halo counting information is not /// collected - halo.scatter( TEST_EXECSPACE(), reduce_op, *sparse_array, false ); + halo->scatter( TEST_EXECSPACE(), reduce_op, *sparse_array, false ); /// halo counting info already collected in the previous scatter, thus true /// and no need to recount again - halo.gather( TEST_EXECSPACE(), *sparse_array, true ); + halo->gather( TEST_EXECSPACE(), *sparse_array, true ); MPI_Barrier( MPI_COMM_WORLD ); // check results