Skip to content

Commit

Permalink
Fix loadStateField evaluator to address #1092 (#1096)
Browse files Browse the repository at this point in the history
Apparently linear access of Kokkos dynamic rank views is no longer working
  • Loading branch information
mperego authored Dec 4, 2024
1 parent 26d3dda commit 238f75f
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 26 deletions.
8 changes: 2 additions & 6 deletions src/evaluators/state/PHAL_LoadStateField.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,11 @@ class LoadStateFieldBase : public PHX::EvaluatorWithBaseImpl<Traits>,

using ExecutionSpace = typename PHX::Device::execution_space;

PHX::MDField<ScalarType> data;
PHX::MDField<ScalarType> field;
std::string fieldName;
std::string stateName;

MDFieldMemoizer<Traits> memoizer;

MDFieldVectorRight<ScalarType> dataVec;
};

template<typename EvalT, typename Traits>
Expand All @@ -65,13 +63,11 @@ class LoadStateField : public PHX::EvaluatorWithBaseImpl<Traits>,

using ExecutionSpace = typename PHX::Device::execution_space;

PHX::MDField<ParamScalarT> data;
PHX::MDField<ParamScalarT> field;
std::string fieldName;
std::string stateName;

MDFieldMemoizer<Traits> memoizer;

MDFieldVectorRight<ParamScalarT> dataVec;
};

// Shortcut names
Expand Down
35 changes: 15 additions & 20 deletions src/evaluators/state/PHAL_LoadStateField_Def.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@ LoadStateFieldBase(const Teuchos::ParameterList& p)
fieldName = p.get<std::string>("Field Name");
stateName = p.get<std::string>("State Name");

PHX::MDField<ScalarType> f(fieldName, p.get<Teuchos::RCP<PHX::DataLayout> >("State Field Layout") );
data = f;
field = PHX::MDField<ScalarType>(fieldName, p.get<Teuchos::RCP<PHX::DataLayout> >("State Field Layout") );

this->addEvaluatedField(data);
this->addEvaluatedField(field);
this->setName("LoadStateField("+stateName+")"+PHX::print<EvalT>());
}

Expand All @@ -34,7 +33,7 @@ template<typename EvalT, typename Traits, typename ScalarType>
void LoadStateFieldBase<EvalT, Traits, ScalarType>::postRegistrationSetup(typename Traits::SetupData d,
PHX::FieldManager<Traits>& fm)
{
this->utils.setFieldData(data,fm);
this->utils.setFieldData(field,fm);

d.fill_field_dependencies(this->dependentFields(),this->evaluatedFields());
if (d.memoizer_active()) memoizer.enable_memoizer();
Expand All @@ -51,15 +50,13 @@ void LoadStateFieldBase<EvalT, Traits, ScalarType>::evaluateFields(typename Trai
// whomever changed the data.
const auto& stateToLoad = (*workset.stateArrayPtr)[stateName];
auto stateData = stateToLoad.dev();
const int stateToLoad_size = stateToLoad.size();

MDFieldVectorRight<ScalarType> g(data);
dataVec = g;
ALBANY_ASSERT (stateData.rank() <= 3, "Current implementation supports only views with rank up to 3. If larger rank is needed modify code below");

Kokkos::parallel_for(this->getName(),
Kokkos::RangePolicy<ExecutionSpace>(0,data.size()),
KOKKOS_CLASS_LAMBDA(const int i) {
dataVec[i] = (i < stateToLoad_size) ? stateData(i) : 0.0;
Kokkos::MDRangePolicy<ExecutionSpace, Kokkos::Rank<3>>({0,0,0},{stateData.extent(0),stateData.extent(1),stateData.extent(2)}),
KOKKOS_CLASS_LAMBDA(const int i, const int j, const int k) {
field.access(i,j,k) = stateData.access(i,j,k); //works also when rank is less than 3
});
}

Expand All @@ -70,10 +67,10 @@ LoadStateField(const Teuchos::ParameterList& p)
fieldName = p.get<std::string>("Field Name");
stateName = p.get<std::string>("State Name");

PHX::MDField<ParamScalarT> f(fieldName, p.get<Teuchos::RCP<PHX::DataLayout> >("State Field Layout") );
data = f;

this->addEvaluatedField(data);
field = PHX::MDField<ParamScalarT>(fieldName, p.get<Teuchos::RCP<PHX::DataLayout> >("State Field Layout") );

this->addEvaluatedField(field);
this->setName("Load State Field"+PHX::print<EvalT>());
}

Expand All @@ -82,7 +79,7 @@ template<typename EvalT, typename Traits>
void LoadStateField<EvalT, Traits>::postRegistrationSetup(typename Traits::SetupData d,
PHX::FieldManager<Traits>& fm)
{
this->utils.setFieldData(data,fm);
this->utils.setFieldData(field,fm);

d.fill_field_dependencies(this->dependentFields(),this->evaluatedFields());
if (d.memoizer_active()) memoizer.enable_memoizer();
Expand All @@ -99,15 +96,13 @@ void LoadStateField<EvalT, Traits>::evaluateFields(typename Traits::EvalData wor
// whomever changed the data.
const auto& stateToLoad = (*workset.stateArrayPtr)[stateName];
auto stateData = stateToLoad.dev();
const int stateToLoad_size = stateToLoad.size();

MDFieldVectorRight<ParamScalarT> g(data);
dataVec = g;
ALBANY_ASSERT (stateData.rank() <= 3, "Current implementation supports only views with rank up to 3. If larger rank is needed modify code below");

Kokkos::parallel_for(this->getName(),
Kokkos::RangePolicy<ExecutionSpace>(0,data.size()),
KOKKOS_CLASS_LAMBDA(const int i) {
dataVec[i] = (i < stateToLoad_size) ? stateData(i) : 0.0;
Kokkos::MDRangePolicy<ExecutionSpace, Kokkos::Rank<3>>({0,0,0},{stateData.extent(0),stateData.extent(1),stateData.extent(2)}),
KOKKOS_CLASS_LAMBDA(const int i, const int j, const int k) {
field.access(i,j,k) = stateData.access(i,j,k); //works also when rank is less than 3
});
}

Expand Down

0 comments on commit 238f75f

Please sign in to comment.