From a0d8d215deb7535bc99c613cc4510eb20c8795f7 Mon Sep 17 00:00:00 2001
From: Simone Bna <s.bn@cineca.it>
Date: Thu, 27 May 2021 14:45:31 +0200
Subject: [PATCH] ENH: added new solve API

---
 src/AmgXSolver.H  | 25 +++++++++++++++++++++++++
 src/AmgXSolver.cu | 23 +++++++++++++++++++++++
 2 files changed, 48 insertions(+)

diff --git a/src/AmgXSolver.H b/src/AmgXSolver.H
index 18a123d..916cd57 100644
--- a/src/AmgXSolver.H
+++ b/src/AmgXSolver.H
@@ -206,6 +206,31 @@ class AmgXSolver
             AmgXCSRMatrix& matrix
         );
 
+	/** \brief Solve the linear system.
+         *
+         * \p p vector will be used as an initial guess and will be updated to the
+         * solution by the end of solving.
+         *
+         * For cases that use more MPI processes than the number of GPUs, this
+         * function will do data gathering before solving and data scattering
+         * after the solving.
+         *
+         * \param nLocalRows [in] The number of rows owned by this rank.
+         * \param p [in, out] The unknown vector.
+         * \param b [in] The RHS vector.
+         * \param matrix [in,out] The AmgX CSR matrix, A.
+         *
+         * \return PetscErrorCode.
+         */
+        PetscErrorCode solve
+        (
+            int nLocalRows,
+            Vec& p,
+            Vec& b,
+            AmgXCSRMatrix& matrix
+        );
+
+
         /** \brief Get the number of iterations of the last solving.
          *
          * \param iter [out] Number of iterations.
diff --git a/src/AmgXSolver.cu b/src/AmgXSolver.cu
index 5c86a32..2f96a63 100644
--- a/src/AmgXSolver.cu
+++ b/src/AmgXSolver.cu
@@ -327,6 +327,29 @@ PetscErrorCode AmgXSolver::updateOperator
     PetscFunctionReturn(0);
 }
 
+/* \implements AmgXSolver::solve */
+PetscErrorCode AmgXSolver::solve(
+    int nLocalRows, Vec& p, Vec& b, AmgXCSRMatrix& matrix)
+{
+    PetscFunctionBeginUser;
+
+    PetscScalar* pscalar;
+    PetscScalar* bscalar;
+
+    PetscInt ierr;
+
+    // get pointers to the raw data of local vectors
+    ierr = VecGetArray(p, &pscalar); CHK;
+    ierr = VecGetArray(b, &bscalar); CHK;
+
+    solve(nLocalRows, pscalar, bscalar, matrix);
+
+    ierr = VecRestoreArray(p, &pscalar); CHK;
+    ierr = VecRestoreArray(b, &bscalar); CHK;
+
+    PetscFunctionReturn(0);
+}
+
 
 /* \implements AmgXSolver::solve */
 PetscErrorCode AmgXSolver::solve(
-- 
GitLab