diff --git a/src/AmgXSolver.H b/src/AmgXSolver.H index 18a123d0f6299604e9102561513976076bd29e69..916cd5750c33dcd85f9c992db0498efbcd14378a 100644 --- a/src/AmgXSolver.H +++ b/src/AmgXSolver.H @@ -206,6 +206,31 @@ class AmgXSolver AmgXCSRMatrix& matrix ); + /** \brief Solve the linear system. + * + * \p p vector will be used as an initial guess and will be updated to the + * solution by the end of solving. + * + * For cases that use more MPI processes than the number of GPUs, this + * function will do data gathering before solving and data scattering + * after the solving. + * + * \param nLocalRows [in] The number of rows owned by this rank. + * \param p [in, out] The unknown vector. + * \param b [in] The RHS vector. + * \param matrix [in,out] The AmgX CSR matrix, A. + * + * \return PetscErrorCode. + */ + PetscErrorCode solve + ( + int nLocalRows, + Vec& p, + Vec& b, + AmgXCSRMatrix& matrix + ); + + /** \brief Get the number of iterations of the last solving. * * \param iter [out] Number of iterations. diff --git a/src/AmgXSolver.cu b/src/AmgXSolver.cu index 5c86a326f1ef8f1ad11ed55eec00782f90a84206..2f96a630782ea4683502d263f2cb7fa149dcba77 100644 --- a/src/AmgXSolver.cu +++ b/src/AmgXSolver.cu @@ -327,6 +327,29 @@ PetscErrorCode AmgXSolver::updateOperator PetscFunctionReturn(0); } +/* \implements AmgXSolver::solve */ +PetscErrorCode AmgXSolver::solve( + int nLocalRows, Vec& p, Vec& b, AmgXCSRMatrix& matrix) +{ + PetscFunctionBeginUser; + + PetscScalar* pscalar; + PetscScalar* bscalar; + + PetscInt ierr; + + // get pointers to the raw data of local vectors + ierr = VecGetArray(p, &pscalar); CHK; + ierr = VecGetArray(b, &bscalar); CHK; + + solve(nLocalRows, pscalar, bscalar, matrix); + + ierr = VecRestoreArray(p, &pscalar); CHK; + ierr = VecRestoreArray(b, &bscalar); CHK; + + PetscFunctionReturn(0); +} + /* \implements AmgXSolver::solve */ PetscErrorCode AmgXSolver::solve(