Skip to content

Commit

Permalink
Merge pull request #734 from streeve/hypre_inheritance
Browse files Browse the repository at this point in the history
Inherit Hypre matrices/vectors to avoid memory leak
  • Loading branch information
streeve authored Jan 22, 2024
2 parents 9ef42d2 + 736ec45 commit 9cc3824
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 102 deletions.
85 changes: 46 additions & 39 deletions grid/src/Cabana_Grid_HypreSemiStructuredSolver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ class HypreSemiStructuredSolver
auto error = HYPRE_SStructMatrixAssemble( _A );
checkHypreError( error );

this->setupImpl( _A, _b, _x );
this->setupImpl();
}

/*!
Expand Down Expand Up @@ -556,7 +556,7 @@ class HypreSemiStructuredSolver
checkHypreError( error );

// Solve the problem
this->solveImpl( _A, _b, _x );
this->solveImpl();

// Extract the solution from the LHS
for ( int var = 0; var < n_vars; ++var )
Expand Down Expand Up @@ -612,12 +612,10 @@ class HypreSemiStructuredSolver
virtual void setPrintLevelImpl( const int print_level ) = 0;

//! Setup implementation.
virtual void setupImpl( HYPRE_SStructMatrix A, HYPRE_SStructVector b,
HYPRE_SStructVector x ) = 0;
virtual void setupImpl() = 0;

//! Solver implementation.
virtual void solveImpl( HYPRE_SStructMatrix A, HYPRE_SStructVector b,
HYPRE_SStructVector x ) = 0;
virtual void solveImpl() = 0;

//! Get the number of iterations taken on the last solve.
virtual int getNumIterImpl() = 0;
Expand Down Expand Up @@ -645,6 +643,13 @@ class HypreSemiStructuredSolver
}
}

//! Matrix for the problem Ax = b.
HYPRE_SStructMatrix _A;
//! Forcing term for the problem Ax = b.
HYPRE_SStructVector _b;
//! Solution to the problem Ax = b.
HYPRE_SStructVector _x;

private:
MPI_Comm _comm;
bool _is_preconditioner;
Expand All @@ -655,9 +660,6 @@ class HypreSemiStructuredSolver
HYPRE_SStructGraph _graph;
std::vector<unsigned> _stencil_size;
std::vector<std::vector<unsigned>> _stencil_index;
HYPRE_SStructMatrix _A;
HYPRE_SStructVector _b;
HYPRE_SStructVector _x;
std::shared_ptr<HypreSemiStructuredSolver<Scalar, EntityType, MemorySpace>>
_preconditioner;
};
Expand All @@ -670,12 +672,13 @@ class HypreSemiStructPCG
{
public:
//! Base HYPRE semi-structured solver type.
using Base = HypreSemiStructuredSolver<Scalar, EntityType, MemorySpace>;
using base_type =
HypreSemiStructuredSolver<Scalar, EntityType, MemorySpace>;
//! Constructor
template <class ArrayLayout_t>
HypreSemiStructPCG( const ArrayLayout_t& layout, int n_vars,
const bool is_preconditioner = false )
: Base( layout, n_vars, is_preconditioner )
: base_type( layout, n_vars, is_preconditioner )
{
if ( is_preconditioner )
throw std::logic_error(
Expand Down Expand Up @@ -743,17 +746,15 @@ class HypreSemiStructPCG
this->checkHypreError( error );
}

void setupImpl( HYPRE_SStructMatrix A, HYPRE_SStructVector b,
HYPRE_SStructVector x ) override
void setupImpl() override
{
auto error = HYPRE_SStructPCGSetup( _solver, A, b, x );
auto error = HYPRE_SStructPCGSetup( _solver, _A, _b, _x );
this->checkHypreError( error );
}

void solveImpl( HYPRE_SStructMatrix A, HYPRE_SStructVector b,
HYPRE_SStructVector x ) override
void solveImpl() override
{
auto error = HYPRE_SStructPCGSolve( _solver, A, b, x );
auto error = HYPRE_SStructPCGSolve( _solver, _A, _b, _x );
this->checkHypreError( error );
}

Expand Down Expand Up @@ -787,6 +788,9 @@ class HypreSemiStructPCG

private:
HYPRE_SStructSolver _solver;
using base_type::_A;
using base_type::_b;
using base_type::_x;
};

//---------------------------------------------------------------------------//
Expand All @@ -797,12 +801,13 @@ class HypreSemiStructGMRES
{
public:
//! Base HYPRE semi-structured solver type.
using Base = HypreSemiStructuredSolver<Scalar, EntityType, MemorySpace>;
using base_type =
HypreSemiStructuredSolver<Scalar, EntityType, MemorySpace>;
//! Constructor
template <class ArrayLayout_t>
HypreSemiStructGMRES( const ArrayLayout_t& layout, int n_vars,
const bool is_preconditioner = false )
: Base( layout, n_vars, is_preconditioner )
: base_type( layout, n_vars, is_preconditioner )
{
if ( is_preconditioner )
throw std::logic_error(
Expand Down Expand Up @@ -867,17 +872,15 @@ class HypreSemiStructGMRES
this->checkHypreError( error );
}

void setupImpl( HYPRE_SStructMatrix A, HYPRE_SStructVector b,
HYPRE_SStructVector x ) override
void setupImpl() override
{
auto error = HYPRE_SStructGMRESSetup( _solver, A, b, x );
auto error = HYPRE_SStructGMRESSetup( _solver, _A, _b, _x );
this->checkHypreError( error );
}

void solveImpl( HYPRE_SStructMatrix A, HYPRE_SStructVector b,
HYPRE_SStructVector x ) override
void solveImpl() override
{
auto error = HYPRE_SStructGMRESSolve( _solver, A, b, x );
auto error = HYPRE_SStructGMRESSolve( _solver, _A, _b, _x );
this->checkHypreError( error );
}

Expand Down Expand Up @@ -911,6 +914,9 @@ class HypreSemiStructGMRES

private:
HYPRE_SStructSolver _solver;
using base_type::_A;
using base_type::_b;
using base_type::_x;
};

//---------------------------------------------------------------------------//
Expand All @@ -921,13 +927,14 @@ class HypreSemiStructBiCGSTAB
{
public:
//! Base HYPRE semi-structured solver type.
using Base = HypreSemiStructuredSolver<Scalar, EntityType, MemorySpace>;
using base_type =
HypreSemiStructuredSolver<Scalar, EntityType, MemorySpace>;
//! Constructor
template <class ArrayLayout_t>
HypreSemiStructBiCGSTAB( const ArrayLayout_t& layout,
const bool is_preconditioner = false,
int n_vars = 3 )
: Base( layout, n_vars, is_preconditioner )
: base_type( layout, n_vars, is_preconditioner )
{
if ( is_preconditioner )
throw std::logic_error(
Expand Down Expand Up @@ -985,17 +992,15 @@ class HypreSemiStructBiCGSTAB
this->checkHypreError( error );
}

void setupImpl( HYPRE_SStructMatrix A, HYPRE_SStructVector b,
HYPRE_SStructVector x ) override
void setupImpl() override
{
auto error = HYPRE_SStructBiCGSTABSetup( _solver, A, b, x );
auto error = HYPRE_SStructBiCGSTABSetup( _solver, _A, _b, _x );
this->checkHypreError( error );
}

void solveImpl( HYPRE_SStructMatrix A, HYPRE_SStructVector b,
HYPRE_SStructVector x ) override
void solveImpl() override
{
auto error = HYPRE_SStructBiCGSTABSolve( _solver, A, b, x );
auto error = HYPRE_SStructBiCGSTABSolve( _solver, _A, _b, _x );
this->checkHypreError( error );
}

Expand Down Expand Up @@ -1030,6 +1035,9 @@ class HypreSemiStructBiCGSTAB

private:
HYPRE_SStructSolver _solver;
using base_type::_A;
using base_type::_b;
using base_type::_x;
};

//---------------------------------------------------------------------------//
Expand All @@ -1040,13 +1048,14 @@ class HypreSemiStructDiagonal
{
public:
//! Base HYPRE semi-structured solver type.
using Base = HypreSemiStructuredSolver<Scalar, EntityType, MemorySpace>;
using base_type =
HypreSemiStructuredSolver<Scalar, EntityType, MemorySpace>;
//! Constructor
template <class ArrayLayout_t>
HypreSemiStructDiagonal( const ArrayLayout_t& layout,
const bool is_preconditioner = false,
int n_vars = 3 )
: Base( layout, n_vars, is_preconditioner )
: base_type( layout, n_vars, is_preconditioner )
{
if ( !is_preconditioner )
throw std::logic_error(
Expand Down Expand Up @@ -1082,15 +1091,13 @@ class HypreSemiStructDiagonal
"Diagonal preconditioner cannot be used as a solver" );
}

void setupImpl( HYPRE_SStructMatrix, HYPRE_SStructVector,
HYPRE_SStructVector ) override
void setupImpl() override
{
throw std::logic_error(
"Diagonal preconditioner cannot be used as a solver" );
}

void solveImpl( HYPRE_SStructMatrix, HYPRE_SStructVector,
HYPRE_SStructVector ) override
void solveImpl() override
{
throw std::logic_error(
"Diagonal preconditioner cannot be used as a solver" );
Expand Down
Loading

0 comments on commit 9cc3824

Please sign in to comment.