Skip to content

Commit

Permalink
Merge pull request #662 from streeve/view_neighbors
Browse files Browse the repository at this point in the history
Allow Kokkos::Views for neighbor lists
  • Loading branch information
streeve authored Sep 17, 2024
2 parents 0c37ccb + 507112c commit 142068e
Show file tree
Hide file tree
Showing 7 changed files with 462 additions and 248 deletions.
196 changes: 103 additions & 93 deletions core/src/Cabana_Experimental_NeighborList.hpp

Large diffs are not rendered by default.

165 changes: 73 additions & 92 deletions core/src/Cabana_LinkedCellList.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,39 +123,41 @@ class LinkedCellList
LinkedCellList() {}

/*!
\brief Slice constructor
\brief Simple constructor
\tparam SliceType Slice type for positions.
\tparam PositionType Type for positions.
\param positions Slice of positions.
\param positions Particle positions.
\param grid_delta Grid sizes in each cardinal direction.
\param grid_min Grid minimum value in each direction.
\param grid_max Grid maximum value in each direction.
*/
template <class SliceType>
LinkedCellList( SliceType positions, const Scalar grid_delta[3],
const Scalar grid_min[3], const Scalar grid_max[3],
typename std::enable_if<( is_slice<SliceType>::value ),
int>::type* = 0 )
template <class PositionType>
LinkedCellList(
PositionType positions, const Scalar grid_delta[3],
const Scalar grid_min[3], const Scalar grid_max[3],
typename std::enable_if<( is_slice<PositionType>::value ||
Kokkos::is_view<PositionType>::value ),
int>::type* = 0 )
: _begin( 0 )
, _end( positions.size() )
, _end( size( positions ) )
, _grid( grid_min[0], grid_min[1], grid_min[2], grid_max[0],
grid_max[1], grid_max[2], grid_delta[0], grid_delta[1],
grid_delta[2] )
, _cell_stencil( grid_delta[0], 1.0, grid_min, grid_max )
, _sorted( false )
{
std::size_t np = positions.size();
std::size_t np = size( positions );
allocate( totalBins(), np );
build( positions, 0, np );
}

/*!
\brief Slice constructor
\brief Partial range constructor
\tparam SliceType Slice type for positions.
\tparam PositionType Type for positions.
\param positions Slice of positions.
\param positions Particle positions.
\param begin The beginning index of particles to bin or find neighbors
for. Particles outside this range will NOT be considered as candidate
neighbors.
Expand All @@ -166,12 +168,14 @@ class LinkedCellList
\param grid_min Grid minimum value in each direction.
\param grid_max Grid maximum value in each direction.
*/
template <class SliceType>
LinkedCellList( SliceType positions, const std::size_t begin,
const std::size_t end, const Scalar grid_delta[3],
const Scalar grid_min[3], const Scalar grid_max[3],
typename std::enable_if<( is_slice<SliceType>::value ),
int>::type* = 0 )
template <class PositionType>
LinkedCellList(
PositionType positions, const std::size_t begin, const std::size_t end,
const Scalar grid_delta[3], const Scalar grid_min[3],
const Scalar grid_max[3],
typename std::enable_if<( is_slice<PositionType>::value ||
Kokkos::is_view<PositionType>::value ),
int>::type* = 0 )
: _begin( begin )
, _end( end )
, _grid( grid_min[0], grid_min[1], grid_min[2], grid_max[0],
Expand All @@ -185,44 +189,45 @@ class LinkedCellList
}

/*!
\brief Slice constructor
\brief Explicit stencil constructor
\tparam SliceType Slice type for positions.
\tparam PositionType Type for positions.
\param positions Slice of positions.
\param positions Particle positions.
\param grid_delta Grid sizes in each cardinal direction.
\param grid_min Grid minimum value in each direction.
\param grid_max Grid maximum value in each direction.
\param neighborhood_radius Radius for neighbors.
\param cell_size_ratio Ratio of the cell size to the neighborhood size.
*/
template <class SliceType>
template <class PositionType>
LinkedCellList(
SliceType positions, const Scalar grid_delta[3],
PositionType positions, const Scalar grid_delta[3],
const Scalar grid_min[3], const Scalar grid_max[3],
const Scalar neighborhood_radius, const Scalar cell_size_ratio = 1,
typename std::enable_if<( is_slice<SliceType>::value ), int>::type* =
0 )
typename std::enable_if<( is_slice<PositionType>::value ||
Kokkos::is_view<PositionType>::value ),
int>::type* = 0 )
: _begin( 0 )
, _end( positions.size() )
, _end( size( positions ) )
, _grid( grid_min[0], grid_min[1], grid_min[2], grid_max[0],
grid_max[1], grid_max[2], grid_delta[0], grid_delta[1],
grid_delta[2] )
, _cell_stencil( neighborhood_radius, cell_size_ratio, grid_min,
grid_max )
, _sorted( false )
{
std::size_t np = positions.size();
std::size_t np = size( positions );
allocate( totalBins(), np );
build( positions, 0, np );
}

/*!
\brief Slice range constructor
\brief Explicit stencil and partial range constructor
\tparam SliceType Slice type for positions.
\tparam PositionType Type for positions.
\param positions Slice of positions.
\param positions Particle positions.
\param begin The beginning index of particles to bin or find neighbors
for. Particles outside this range will NOT be considered as candidate
neighbors.
Expand All @@ -235,14 +240,15 @@ class LinkedCellList
\param neighborhood_radius Radius for neighbors.
\param cell_size_ratio Ratio of the cell size to the neighborhood size.
*/
template <class SliceType>
template <class PositionType>
LinkedCellList(
SliceType positions, const std::size_t begin, const std::size_t end,
PositionType positions, const std::size_t begin, const std::size_t end,
const Scalar grid_delta[3], const Scalar grid_min[3],
const Scalar grid_max[3], const Scalar neighborhood_radius,
const Scalar cell_size_ratio = 1,
typename std::enable_if<( is_slice<SliceType>::value ), int>::type* =
0 )
typename std::enable_if<( is_slice<PositionType>::value ||
Kokkos::is_view<PositionType>::value ),
int>::type* = 0 )
: _begin( begin )
, _end( end )
, _grid( grid_min[0], grid_min[1], grid_min[2], grid_max[0],
Expand Down Expand Up @@ -382,26 +388,26 @@ class LinkedCellList
\brief Build the linked cell list with a subset of particles.
\tparam ExecutionSpace Kokkos execution space.
\tparam SliceType Slice type for positions.
\tparam PositionType Type for positions.
\param positions Slice of positions.
\param positions Particle positions.
\param begin The beginning index of particles to bin or find neighbors
for. Particles outside this range will NOT be considered as candidate
neighbors.
\param end The end index of particles to bin or find neighbors
for. Particles outside this range will NOT be considered as candidate
neighbors.
*/
template <class ExecutionSpace, class SliceType>
void build( ExecutionSpace, SliceType positions, const std::size_t begin,
template <class ExecutionSpace, class PositionType>
void build( ExecutionSpace, PositionType positions, const std::size_t begin,
const std::size_t end )
{
Kokkos::Profiling::ScopedRegion region(
"Cabana::LinkedCellList::build" );

static_assert( is_accessible_from<memory_space, ExecutionSpace>{}, "" );
assert( end >= begin );
assert( end <= positions.size() );
assert( end <= size( positions ) );

// Resize the binning data. Note that the permutation vector spans
// only the length of begin-end;
Expand Down Expand Up @@ -481,18 +487,18 @@ class LinkedCellList
/*!
\brief Build the linked cell list with a subset of particles.
\tparam SliceType Slice type for positions.
\tparam PositionType Type for positions.
\param positions Slice of positions.
\param positions Particle positions.
\param begin The beginning index of particles to bin or find neighbors
for. Particles outside this range will NOT be considered as candidate
neighbors.
\param end The end index of particles to bin or find neighbors
for. Particles outside this range will NOT be considered as candidate
neighbors.
*/
template <class SliceType>
void build( SliceType positions, const std::size_t begin,
template <class PositionType>
void build( PositionType positions, const std::size_t begin,
const std::size_t end )
{
// Use the default execution space.
Expand All @@ -502,14 +508,14 @@ class LinkedCellList
/*!
\brief Build the linked cell list with all particles.
\tparam SliceType Slice type for positions.
\tparam PositionType Type for positions.
\param positions Slice of positions.
\param positions Particle positions.
*/
template <class SliceType>
void build( SliceType positions )
template <class PositionType>
void build( PositionType positions )
{
build( positions, 0, positions.size() );
build( positions, 0, size( positions ) );
}

/*!
Expand Down Expand Up @@ -640,8 +646,8 @@ class LinkedCellList
\brief Creation function for linked cell list.
\return LinkedCellList.
*/
template <class MemorySpace, class SliceType, class Scalar>
auto createLinkedCellList( SliceType positions, const Scalar grid_delta[3],
template <class MemorySpace, class PositionType, class Scalar>
auto createLinkedCellList( PositionType positions, const Scalar grid_delta[3],
const Scalar grid_min[3], const Scalar grid_max[3] )
{
return LinkedCellList<MemorySpace, Scalar>( positions, grid_delta, grid_min,
Expand All @@ -652,8 +658,8 @@ auto createLinkedCellList( SliceType positions, const Scalar grid_delta[3],
\brief Creation function for linked cell list with partial range.
\return LinkedCellList.
*/
template <class MemorySpace, class SliceType, class Scalar>
auto createLinkedCellList( SliceType positions, const std::size_t begin,
template <class MemorySpace, class PositionType, class Scalar>
auto createLinkedCellList( PositionType positions, const std::size_t begin,
const std::size_t end, const Scalar grid_delta[3],
const Scalar grid_min[3], const Scalar grid_max[3] )
{
Expand All @@ -666,8 +672,8 @@ auto createLinkedCellList( SliceType positions, const std::size_t begin,
cell ratio.
\return LinkedCellList.
*/
template <class MemorySpace, class SliceType, class Scalar>
auto createLinkedCellList( SliceType positions, const Scalar grid_delta[3],
template <class MemorySpace, class PositionType, class Scalar>
auto createLinkedCellList( PositionType positions, const Scalar grid_delta[3],
const Scalar grid_min[3], const Scalar grid_max[3],
const Scalar neighborhood_radius,
const Scalar cell_size_ratio = 1.0 )
Expand All @@ -682,8 +688,8 @@ auto createLinkedCellList( SliceType positions, const Scalar grid_delta[3],
cutoff radius and/or cell ratio.
\return LinkedCellList.
*/
template <class MemorySpace, class SliceType, class Scalar>
auto createLinkedCellList( SliceType positions, const std::size_t begin,
template <class MemorySpace, class PositionType, class Scalar>
auto createLinkedCellList( PositionType positions, const std::size_t begin,
const std::size_t end, const Scalar grid_delta[3],
const Scalar grid_min[3], const Scalar grid_max[3],
const Scalar neighborhood_radius,
Expand Down Expand Up @@ -717,51 +723,26 @@ struct is_linked_cell_list

//---------------------------------------------------------------------------//
/*!
\brief Given a linked cell list permute an AoSoA.
\tparam LinkedCellListType The linked cell list type.
\tparam AoSoA_t The AoSoA type.
\param linked_cell_list The linked cell list to permute the AoSoA with.
\param aosoa The AoSoA to permute.
*/
template <class LinkedCellListType, class AoSoA_t>
void permute(
LinkedCellListType& linked_cell_list, AoSoA_t& aosoa,
typename std::enable_if<( is_linked_cell_list<LinkedCellListType>::value &&
is_aosoa<AoSoA_t>::value ),
int>::type* = 0 )
{
permute( linked_cell_list.binningData(), aosoa );

// Update internal state.
linked_cell_list.update( true );

linked_cell_list.storeParticleBins();
}

//---------------------------------------------------------------------------//
/*!
\brief Given a linked cell list permute a slice.
\brief Given a linked cell list permute positions.
\tparam LinkedCellListType The linked cell list type.
\tparam SliceType The slice type.
\tparam PositionType Positions type (AoSoA or Slice or Kokkos::View).
\param linked_cell_list The linked cell list to permute the slice with.
\param linked_cell_list The linked cell list to permute the positions with.
\param slice The slice to permute.
\param positions Positions to permute.
*/
template <class LinkedCellListType, class SliceType>
template <class LinkedCellListType, class PositionType>
void permute(
LinkedCellListType& linked_cell_list, SliceType& slice,
LinkedCellListType& linked_cell_list, PositionType& positions,
typename std::enable_if<( is_linked_cell_list<LinkedCellListType>::value &&
is_slice<SliceType>::value ),
( is_aosoa<PositionType>::value ||
is_slice<PositionType>::value ||
Kokkos::is_view<PositionType>::value ) ),
int>::type* = 0 )
{
permute( linked_cell_list.binningData(), slice );
permute( linked_cell_list.binningData(), positions );

// Update internal state.
linked_cell_list.update( true );
Expand Down
Loading

0 comments on commit 142068e

Please sign in to comment.