diff --git a/Make/files b/Make/files
index 1c1778abfde22c3a42cf233dd4efce7cd049012a..56620373550d52fc0ae688a8540852b2fe0f38cf 100644
--- a/Make/files
+++ b/Make/files
@@ -1,7 +1,6 @@
-src/FOAM2CSR.cu
-src/consolidate.cu
+src/AmgXCSRMatrix.cu
+src/AmgXConsolidation.cu  
+src/AmgXMPIComms.cu
 src/AmgXSolver.cu
-src/init.cu
-src/setA.cu
 
 LIB = $(FOAM_MODULE_LIBBIN)/libfoam2csr
diff --git a/src/FOAM2CSR.H b/src/AmgXCSRMatrix.H
similarity index 99%
rename from src/FOAM2CSR.H
rename to src/AmgXCSRMatrix.H
index 58bf62c604cf1438573c0936fe1fc208b072149b..82268d8cff1ff7e75bc5f3d6c1b3f80c3412a8e6 100644
--- a/src/FOAM2CSR.H
+++ b/src/AmgXCSRMatrix.H
@@ -22,7 +22,7 @@
 
 #pragma once
 
-struct FOAM2CSR
+struct AmgXCSRMatrix
 {
     // CSR device data for AmgX matrix
     int *colIndices;
diff --git a/src/FOAM2CSR.cu b/src/AmgXCSRMatrix.cu
similarity index 98%
rename from src/FOAM2CSR.cu
rename to src/AmgXCSRMatrix.cu
index dd7b5a65cd67ed355b04e55843889acb1ceac02b..6db5131d9e48f047f9e8fa187a0f9e508ba42b9f 100644
--- a/src/FOAM2CSR.cu
+++ b/src/AmgXCSRMatrix.cu
@@ -20,7 +20,8 @@
  * DEALINGS IN THE SOFTWARE.
  */
 
-#include <FOAM2CSR.H>
+#include <AmgXCSRMatrix.H>
+
 #include <cuda.h>
 #include <cub/cub.cuh>
 #include <thrust/scan.h>
@@ -127,7 +128,7 @@ __global__ void createRowOffsets(
 }
 
 // Updates the values based on the previously determined permutation
-void FOAM2CSR::updateValues(
+void AmgXCSRMatrix::updateValues(
     const int nrows,
     const int nInternalFaces,
     const int extNnz,
@@ -158,7 +159,7 @@ void FOAM2CSR::updateValues(
 }
 
 // Perform the conversion between an LDU matrix and a CSR matrix, possibly distributed
-void FOAM2CSR::convertLDU2CSR(
+void AmgXCSRMatrix::convertLDU2CSR(
     int nrows,
     int nInternalFaces,
     int diagIndexGlobal,
@@ -265,12 +266,12 @@ void FOAM2CSR::convertLDU2CSR(
 
 // XXX Should implement an early abandonment of the
 // unnecessary data for capacity optimisation
-void FOAM2CSR::discardStructure()
+void AmgXCSRMatrix::discardStructure()
 {
 }
 
 // Deallocate remaining storage
-void FOAM2CSR::finalise()
+void AmgXCSRMatrix::finalise()
 {
     if(rowOffsets != nullptr)
         CHECK(cudaFree(rowOffsets));
@@ -283,3 +284,4 @@ void FOAM2CSR::finalise()
     if(ldu2csrPerm != nullptr)
         CHECK(cudaFree(ldu2csrPerm));
 }
+
diff --git a/src/consolidate.cu b/src/AmgXConsolidation.cu
similarity index 99%
rename from src/consolidate.cu
rename to src/AmgXConsolidation.cu
index 853a965d02329d5e9a3009bdfa686356002e23dc..87e7681b65013b753449a367ad576b500bdc755a 100644
--- a/src/consolidate.cu
+++ b/src/AmgXConsolidation.cu
@@ -18,7 +18,7 @@
  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
  * DEALINGS IN THE SOFTWARE.
- * \file consolidate.cu
+ * \file AmgXConsolidation.cu
  * \brief Definition of member functions related to matrix consolidation.
  * \author Matt Martineau (mmartineau@nvidia.com)
  * \date 2020-07-31
diff --git a/src/AmgXMPIComms.cu b/src/AmgXMPIComms.cu
new file mode 100644
index 0000000000000000000000000000000000000000..148b2553da58a6f6e86aff4239673a72cf2fe714
--- /dev/null
+++ b/src/AmgXMPIComms.cu
@@ -0,0 +1,156 @@
+/**
+ * \file AmgXMPIComms.cu
+ * \brief ***.
+ * \author Pi-Yueh Chuang (pychuang@gwu.edu)
+ * \date 2015-09-01
+ * \copyright Copyright (c) 2015-2019 Pi-Yueh Chuang, Lorena A. Barba.
+ *            This project is released under MIT License.
+ */
+
+// AmgXWrapper
+# include "AmgXSolver.H"
+
+
+/* \implements AmgXSolver::initMPIcomms */
+PetscErrorCode AmgXSolver::initMPIcomms(const MPI_Comm &comm)
+{
+    PetscErrorCode      ierr;
+
+    PetscFunctionBeginUser;
+
+    // duplicate the global communicator
+    ierr = MPI_Comm_dup(comm, &globalCpuWorld); CHK;
+    ierr = MPI_Comm_set_name(globalCpuWorld, "globalCpuWorld"); CHK;
+
+    // get size and rank for global communicator
+    ierr = MPI_Comm_size(globalCpuWorld, &globalSize); CHK;
+    ierr = MPI_Comm_rank(globalCpuWorld, &myGlobalRank); CHK;
+
+
+    // Get the communicator for processors on the same node (local world)
+    ierr = MPI_Comm_split_type(globalCpuWorld,
+            MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &localCpuWorld); CHK;
+    ierr = MPI_Comm_set_name(localCpuWorld, "localCpuWorld"); CHK;
+
+    // get size and rank for local communicator
+    ierr = MPI_Comm_size(localCpuWorld, &localSize); CHK;
+    ierr = MPI_Comm_rank(localCpuWorld, &myLocalRank); CHK;
+
+
+    // set up the variable nDevs
+    ierr = setDeviceCount(); CHK;
+
+
+    // set up corresponding ID of the device used by each local process
+    ierr = setDeviceIDs(); CHK;
+    ierr = MPI_Barrier(globalCpuWorld); CHK;
+
+
+    // split the global world into a world involved in AmgX and a null world
+    ierr = MPI_Comm_split(globalCpuWorld, gpuProc, 0, &gpuWorld); CHK;
+
+    // get size and rank for the communicator corresponding to gpuWorld
+    if (gpuWorld != MPI_COMM_NULL)
+    {
+        ierr = MPI_Comm_set_name(gpuWorld, "gpuWorld"); CHK;
+        ierr = MPI_Comm_size(gpuWorld, &gpuWorldSize); CHK;
+        ierr = MPI_Comm_rank(gpuWorld, &myGpuWorldRank); CHK;
+    }
+    else // for those can not communicate with GPU devices
+    {
+        gpuWorldSize = MPI_UNDEFINED;
+        myGpuWorldRank = MPI_UNDEFINED;
+    }
+
+    // split local world into worlds corresponding to each CUDA device
+    ierr = MPI_Comm_split(localCpuWorld, devID, 0, &devWorld); CHK;
+    ierr = MPI_Comm_set_name(devWorld, "devWorld"); CHK;
+
+    // get size and rank for the communicator corresponding to myWorld
+    ierr = MPI_Comm_size(devWorld, &devWorldSize); CHK;
+    ierr = MPI_Comm_rank(devWorld, &myDevWorldRank); CHK;
+
+    ierr = MPI_Barrier(globalCpuWorld); CHK;
+
+    return 0;
+}
+
+
+/* \implements AmgXSolver::setDeviceCount */
+PetscErrorCode AmgXSolver::setDeviceCount()
+{
+    PetscErrorCode      ierr;
+
+    PetscFunctionBeginUser;
+
+    // get the number of devices that AmgX solvers can use
+    switch (mode)
+    {
+        case AMGX_mode_dDDI: // for GPU cases, nDevs is the # of local GPUs
+        case AMGX_mode_dDFI: // for GPU cases, nDevs is the # of local GPUs
+        case AMGX_mode_dFFI: // for GPU cases, nDevs is the # of local GPUs
+            // get the number of total cuda devices
+            CHECK(cudaGetDeviceCount(&nDevs));
+            ierr = PetscPrintf(localCpuWorld, "Number of GPU devices :: %d \n", nDevs); CHK;
+
+            // Check whether there is at least one CUDA device on this node
+            if (nDevs == 0) SETERRQ1(MPI_COMM_WORLD, PETSC_ERR_SUP_SYS,
+                    "There is no CUDA device on the node %s !\n", nodeName.c_str());
+            break;
+        case AMGX_mode_hDDI: // for CPU cases, nDevs is the # of local processes
+        case AMGX_mode_hDFI: // for CPU cases, nDevs is the # of local processes
+        case AMGX_mode_hFFI: // for CPU cases, nDevs is the # of local processes
+        default:
+            nDevs = localSize;
+            break;
+    }
+
+    PetscFunctionReturn(0);
+}
+
+
+/* \implements AmgXSolver::setDeviceIDs */
+PetscErrorCode AmgXSolver::setDeviceIDs()
+{
+    PetscFunctionBeginUser;
+
+    PetscErrorCode      ierr;
+
+    // set the ID of device that each local process will use
+    if (nDevs == localSize) // # of the devices and local precosses are the same
+    {
+        devID = myLocalRank;
+        gpuProc = 0;
+    }
+    else if (nDevs > localSize) // there are more devices than processes
+    {
+        ierr = PetscPrintf(localCpuWorld, "CUDA devices on the node %s "
+                "are more than the MPI processes launched. Only %d CUDA "
+                "devices will be used.\n", nodeName.c_str(), localSize); CHK;
+
+        devID = myLocalRank;
+        gpuProc = 0;
+    }
+    else // there more processes than devices
+    {
+        int     nBasic = localSize / nDevs,
+                nRemain = localSize % nDevs;
+
+        if (myLocalRank < (nBasic+1)*nRemain)
+        {
+            devID = myLocalRank / (nBasic + 1);
+            if (myLocalRank % (nBasic + 1) == 0)  gpuProc = 0;
+        }
+        else
+        {
+            devID = (myLocalRank - (nBasic+1)*nRemain) / nBasic + nRemain;
+            if ((myLocalRank - (nBasic+1)*nRemain) % nBasic == 0) gpuProc = 0;
+        }
+    }
+
+    // Set the device for each rank
+    cudaSetDevice(devID);
+
+    PetscFunctionReturn(0);
+}
+
diff --git a/src/AmgXSolver.H b/src/AmgXSolver.H
index 707057302968f97ede8b9fd2ddfb7f26e0644b6f..bd6b69cd741b23cb842fa26e1d1970ea13ff4eb6 100644
--- a/src/AmgXSolver.H
+++ b/src/AmgXSolver.H
@@ -23,7 +23,6 @@
 # include <amgx_c.h>
 
 // PETSc
-# include <petscmat.h>
 # include <petscvec.h>
 
 
@@ -92,7 +91,6 @@ class AmgXSolver
         /** \brief Default constructor. */
         AmgXSolver() = default;
 
-
         /** \brief Construct a AmgXSolver instance.
          *
          * \param comm [in] MPI communicator.
@@ -102,11 +100,9 @@ class AmgXSolver
         AmgXSolver(const MPI_Comm &comm,
                 const std::string &modeStr, const std::string &cfgFile);
 
-
         /** \brief Destructor. */
         ~AmgXSolver();
 
-
         /** \brief Initialize a AmgXSolver instance.
          *
          * \param comm [in] MPI communicator.
@@ -118,7 +114,6 @@ class AmgXSolver
         PetscErrorCode initialize(const MPI_Comm &comm,
                 const std::string &modeStr, const std::string &cfgFile);
 
-
         /** \brief Finalize this instance.
          *
          * This function destroys AmgX data. When there are more than one
@@ -129,23 +124,6 @@ class AmgXSolver
          */
         PetscErrorCode finalize();
 
-
-        /** \brief Set up the matrix used by AmgX.
-         *
-         * This function will automatically convert PETSc matrix to AmgX matrix.
-         * If the number of MPI processes is higher than the number of available
-         * GPU devices, we also redistribute the matrix in this function and
-         * upload redistributed one to GPUs.
-         *
-         * Note: currently we can only handle AIJ/MPIAIJ format.
-         *
-         * \param A [in] A PETSc Mat.
-         *
-         * \return PetscErrorCode.
-         */
-        PetscErrorCode setA(const Mat &A);
-
-
         /** \brief Set up the matrix used by AmgX.
          *
          * This function sets up the AmgX matrix from the provided CSR data
@@ -171,7 +149,6 @@ class AmgXSolver
             const PetscScalar* values,
             const PetscInt* partData);
 
-
         /** \brief Re-sets up an existing AmgX matrix.
          *
          * Replaces the matrix coefficients with the provided values and performs
@@ -188,7 +165,6 @@ class AmgXSolver
             const PetscInt nLocalNz,
             const PetscScalar* values);
 
-
         /** \brief Solve the linear system.
          *
          * \p p vector will be used as an initial guess and will be updated to the
@@ -206,7 +182,6 @@ class AmgXSolver
          */
         PetscErrorCode solve(Vec &p, Vec &b, const int nRows);
 
-
         /** \brief Get the number of iterations of the last solving.
          *
          * \param iter [out] Number of iterations.
@@ -215,7 +190,6 @@ class AmgXSolver
          */
         PetscErrorCode getIters(int &iter);
 
-
         /** \brief Get the residual at a specific iteration during the last solving.
          *
          * \param iter [in] Target iteration.
@@ -243,8 +217,6 @@ class AmgXSolver
         std::string             nodeName;
 
 
-
-
         /** \brief Number of local GPU devices used by AmgX.*/
         PetscMPIInt             nDevs;
 
@@ -291,8 +263,6 @@ class AmgXSolver
         PetscMPIInt             myDevWorldRank;
 
 
-
-
         /** \brief A parameter used by AmgX. */
         int                     ring;
 
@@ -322,6 +292,48 @@ class AmgXSolver
          */
         static AMGX_resources_handle   rsrc;
 
+
+        /** \brief Enumeration for the status of matrix consolidation for the solver.*/
+        enum class ConsolidationStatus {
+            Uninitialized,
+            None,
+            Host,
+            Device
+        } consolidationStatus;
+
+        /** \brief The number of non-zeros consolidated from multiple ranks to a device.*/
+        int nConsNz = 0;
+
+        /** \brief The number of rows consolidated from multiple ranks to a device.*/
+        int nConsRows = 0;
+
+        /** \brief The row offsets consolidated onto a single device.*/
+        PetscInt *rowOffsetsCons = nullptr;
+
+        /** \brief The global column indices consolidated onto a single device.*/
+        PetscInt *colIndicesGlobalCons = nullptr;
+
+        /** \brief The values consolidated onto a single device.*/
+        PetscScalar* valuesCons = nullptr;
+
+        /** \brief The solution vector consolidated onto a single device.*/
+        PetscScalar* pCons = nullptr;
+
+        /** \brief The RHS vector consolidated onto a single device.*/
+        PetscScalar* rhsCons = nullptr;
+
+        /** \brief The number of rows per rank associated with a single device.*/
+        std::vector<int> nRowsInDevWorld;
+
+        /** \brief The number of non-zeros per rank associated with a single device.*/
+        std::vector<int> nnzInDevWorld;
+
+        /** \brief The row displacements per rank associated with a single device.*/
+        std::vector<int> rowDispls;
+
+        /** \brief The non-zero displacements per rank associated with a single device.*/
+        std::vector<int> nzDispls;
+
         /** \brief Initialize consolidation, if required. Allocates space to hold
          * consolidated CSR matrix and solution/RHS vectors on the root rank and,
          * if device pointer consolidation, allocates and opens IPC memory handles.
@@ -384,63 +396,6 @@ class AmgXSolver
          */
         PetscErrorCode finalizeConsolidation();
 
-        /** \brief Enumeration for the status of matrix consolidation for the solver.*/
-        enum class ConsolidationStatus {
-            Uninitialized,
-            None,
-            Host,
-            Device
-        } consolidationStatus;
-
-        /** \brief The number of non-zeros consolidated from multiple ranks to a device.*/
-        int nConsNz = 0;
-
-        /** \brief The number of rows consolidated from multiple ranks to a device.*/
-        int nConsRows = 0;
-
-        /** \brief The row offsets consolidated onto a single device.*/
-        PetscInt *rowOffsetsCons = nullptr;
-
-        /** \brief The global column indices consolidated onto a single device.*/
-        PetscInt *colIndicesGlobalCons = nullptr;
-
-        /** \brief The values consolidated onto a single device.*/
-        PetscScalar* valuesCons = nullptr;
-
-        /** \brief The solution vector consolidated onto a single device.*/
-        PetscScalar* pCons = nullptr;
-
-        /** \brief The RHS vector consolidated onto a single device.*/
-        PetscScalar* rhsCons = nullptr;
-
-        /** \brief The number of rows per rank associated with a single device.*/
-        std::vector<int> nRowsInDevWorld;
-
-        /** \brief The number of non-zeros per rank associated with a single device.*/
-        std::vector<int> nnzInDevWorld;
-
-        /** \brief The row displacements per rank associated with a single device.*/
-        std::vector<int> rowDispls;
-
-        /** \brief The non-zero displacements per rank associated with a single device.*/
-        std::vector<int> nzDispls;
-
-        /** \brief A VecScatter for gathering/scattering between original PETSc
-         *         Vec and temporary redistributed PETSc Vec.*/
-        VecScatter              scatterLhs = nullptr;
-
-        /** \brief A VecScatter for gathering/scattering between original PETSc
-         *         Vec and temporary redistributed PETSc Vec.*/
-        VecScatter              scatterRhs = nullptr;
-
-        /** \brief A temporary PETSc Vec holding redistributed unknowns. */
-        Vec                     redistLhs = nullptr;
-
-        /** \brief A temporary PETSc Vec holding redistributed RHS. */
-        Vec                     redistRhs = nullptr;
-
-
-
 
         /** \brief Set AmgX solver mode based on the user-provided string.
          *
@@ -451,21 +406,18 @@ class AmgXSolver
          */
         PetscErrorCode setMode(const std::string &modeStr);
 
-
         /** \brief Get the number of GPU devices on this computing node.
          *
          * \return PetscErrorCode.
          */
         PetscErrorCode setDeviceCount();
 
-
         /** \brief Set the ID of the corresponding GPU used by this process.
          *
          * \return PetscErrorCode.
          */
         PetscErrorCode setDeviceIDs();
 
-
         /** \brief Initialize all MPI communicators.
          *
          * The \p comm provided will be duplicated and saved to the
@@ -476,7 +428,6 @@ class AmgXSolver
          */
         PetscErrorCode initMPIcomms(const MPI_Comm &comm);
 
-
         /** \brief Perform necessary initialization of AmgX.
          *
          * This function initializes AmgX for current instance. Based on
@@ -487,100 +438,6 @@ class AmgXSolver
          * \return PetscErrorCode.
          */
         PetscErrorCode initAmgX(const std::string &cfgFile);
-
-
-        /** \brief Get IS for the row indices that processes in
-         *      \ref AmgXSolver::gpuWorld "gpuWorld" will held.
-         *
-         * \param A [in] PETSc matrix.
-         * \param devIS [out] PETSc IS.
-         * \return PetscErrorCode.
-         */
-        PetscErrorCode getDevIS(const Mat &A, IS &devIS);
-
-
-        /** \brief Get local sequential PETSc Mat of the redistributed matrix.
-         *
-         * \param A [in] Original PETSc Mat.
-         * \param devIS [in] PETSc IS representing redistributed row indices.
-         * \param localA [out] Local sequential redistributed matrix.
-         * \return PetscErrorCode.
-         */
-        PetscErrorCode getLocalA(const Mat &A, const IS &devIS, Mat &localA);
-
-
-        /** \brief Redistribute matrix.
-         *
-         * \param A [in] Original PETSc Mat object.
-         * \param devIS [in] PETSc IS representing redistributed rows.
-         * \param newA [out] Redistributed matrix.
-         * \return PetscErrorCode.
-         */
-        PetscErrorCode redistMat(const Mat &A, const IS &devIS, Mat &newA);
-
-
-        /** \brief Get \ref AmgXSolver::scatterLhs "scatterLhs" and
-         *      \ref AmgXSolver::scatterRhs "scatterRhs".
-         *
-         * \param A1 [in] Original PETSc Mat object.
-         * \param A2 [in] Redistributed PETSc Mat object.
-         * \param devIS [in] PETSc IS representing redistributed row indices.
-         * \return PetscErrorCode.
-         */
-        PetscErrorCode getVecScatter(const Mat &A1, const Mat &A2, const IS &devIS);
-
-
-        /** \brief Get data of compressed row layout of local sparse matrix.
-         *
-         * \param localA [in] Sequential local redistributed PETSc Mat.
-         * \param localN [out] Number of local rows.
-         * \param row [out] Row vector in compressed row layout.
-         * \param col [out] Column vector in compressed row layout.
-         * \param data [out] Data vector in compressed row layout.
-         * \return PetscErrorCode.
-         */
-        PetscErrorCode getLocalMatRawData(const Mat &localA,
-                PetscInt &localN, std::vector<PetscInt> &row,
-                std::vector<PetscInt64> &col, std::vector<PetscScalar> &data);
-
-
-        /** \brief Destroy the sequential local redistributed matrix.
-         *
-         * \param A [in] Original PETSc Mat.
-         * \param localA [in, out] Local matrix, returning null pointer.
-         * \return PetscErrorCode.
-         */
-        PetscErrorCode destroyLocalA(const Mat &A, Mat &localA);
-
-        /** \brief Check whether the global matrix distribution is contiguous
-         *
-         * If the global matrix is distributed such that contiguous chunks of rows are
-         * distributed over the individual ranks in ascending rank order, the partition vector
-         * has trivial structure (i.e. [0, ..., 0, 1, ..., 1, ..., R-1, ..., R-1] for R ranks) and
-         * its calculation can be skipped since all information is available to AmgX through
-         * the number of ranks and the partition *offsets* (i.e. how many rows are on each rank).
-         *
-         * \param devIS [in] PETSc IS representing redistributed row indices.
-         * \param isContiguous [out] Whether the global matrix is contiguously distributed.
-         * \param partOffsets [out] If contiguous, holds the partition offsets for all R ranks
-         * \return PetscErrorCode.
-         */
-        PetscErrorCode checkForContiguousPartitioning(
-            const IS &devIS, PetscBool &isContiguous, std::vector<PetscInt> &partOffsets);
-
-        /** \brief Get partition data required by AmgX.
-         *
-         * \param devIS [in] PETSc IS representing redistributed row indices.
-         * \param N [in] Total number of rows in global matrix.
-         * \param partData [out] Partition data, either explicit vector or offsets.
-         * \param usesOffsets [out] If PETSC_TRUE, partitioning is contiguous and partData contains
-         *      partition offsets, see checkForContiguousPartitioning(). Otherwise, contains explicit
-         *      partition vector.
-         * \return PetscErrorCode.
-         */
-        PetscErrorCode getPartData(const IS &devIS,
-                const PetscInt &N, std::vector<PetscInt> &partData, PetscBool &usesOffsets);
-
 };
 
 #endif
diff --git a/src/AmgXSolver.cu b/src/AmgXSolver.cu
index c628917f8cd40e1e57b6ad0c1f70febde0f82c82..c92ee06962a7b89008c2d4b3f572c709f65320c3 100644
--- a/src/AmgXSolver.cu
+++ b/src/AmgXSolver.cu
@@ -7,11 +7,9 @@
  *            This project is released under MIT License.
  */
 
-
 // AmgXWrapper
 # include "AmgXSolver.H"
 
-
 // initialize AmgXSolver::count to 0
 int AmgXSolver::count = 0;
 
@@ -19,6 +17,295 @@ int AmgXSolver::count = 0;
 AMGX_resources_handle AmgXSolver::rsrc = nullptr;
 
 
+/* \implements AmgXSolver::AmgXSolver */
+AmgXSolver::AmgXSolver(const MPI_Comm &comm,
+        const std::string &modeStr, const std::string &cfgFile)
+{
+    initialize(comm, modeStr, cfgFile);
+}
+
+
+/* \implements AmgXSolver::~AmgXSolver */
+AmgXSolver::~AmgXSolver()
+{
+    if (isInitialized) finalize();
+}
+
+
+/* \implements AmgXSolver::initialize */
+PetscErrorCode AmgXSolver::initialize(const MPI_Comm &comm,
+        const std::string &modeStr, const std::string &cfgFile)
+{
+    PetscErrorCode      ierr;
+
+    PetscFunctionBeginUser;
+
+    // if this instance has already been initialized, skip
+    if (isInitialized) SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE,
+            "This AmgXSolver instance has been initialized on this process.");
+
+    // increase the number of AmgXSolver instances
+    count += 1;
+
+    // get the name of this node
+    int     len;
+    char    name[MPI_MAX_PROCESSOR_NAME];
+    ierr = MPI_Get_processor_name(name, &len); CHK;
+    nodeName = name;
+
+    // get the mode of AmgX solver
+    ierr = setMode(modeStr); CHK;
+
+    // initialize communicators and corresponding information
+    ierr = initMPIcomms(comm); CHK;
+
+    // only processes in gpuWorld are required to initialize AmgX
+    if (gpuProc == 0)
+    {
+        ierr = initAmgX(cfgFile); CHK;
+    }
+
+    // a bool indicating if this instance is initialized
+    isInitialized = true;
+
+    consolidationStatus = ConsolidationStatus::Uninitialized;
+
+    PetscFunctionReturn(0);
+}
+
+
+/* \implements AmgXSolver::setMode */
+PetscErrorCode AmgXSolver::setMode(const std::string &modeStr)
+{
+    PetscFunctionBeginUser;
+
+    if (modeStr == "dDDI")
+        mode = AMGX_mode_dDDI;
+    else if (modeStr == "dDFI")
+        mode = AMGX_mode_dDFI;
+    else if (modeStr == "dFFI")
+        mode = AMGX_mode_dFFI;
+    else if (modeStr[0] == 'h')
+        SETERRQ1(MPI_COMM_WORLD, PETSC_ERR_ARG_WRONG,
+                "CPU mode, %s, is not supported in this wrapper!",
+                modeStr.c_str());
+    else
+        SETERRQ1(MPI_COMM_WORLD, PETSC_ERR_ARG_WRONG,
+                "%s is not an available mode! Available modes are: "
+                "dDDI, dDFI, dFFI.\n", modeStr.c_str());
+
+    PetscFunctionReturn(0);
+}
+
+
+/* \implements AmgXSolver::initAmgX */
+PetscErrorCode AmgXSolver::initAmgX(const std::string &cfgFile)
+{
+    PetscFunctionBeginUser;
+
+    // only the first instance (AmgX solver) is in charge of initializing AmgX
+    if (count == 1)
+    {
+        // initialize AmgX
+        AMGX_SAFE_CALL(AMGX_initialize());
+
+        // intialize AmgX plugings
+        AMGX_SAFE_CALL(AMGX_initialize_plugins());
+
+        // only the master process can output something on the screen
+        AMGX_SAFE_CALL(AMGX_register_print_callback(
+                    [](const char *msg, int length)->void
+                    {PetscPrintf(PETSC_COMM_WORLD, "%s", msg);}));
+
+        // let AmgX to handle errors returned
+        AMGX_SAFE_CALL(AMGX_install_signal_handler());
+    }
+
+    // create an AmgX configure object
+    AMGX_SAFE_CALL(AMGX_config_create_from_file(&cfg, cfgFile.c_str()));
+
+    // let AmgX handle returned error codes internally
+    AMGX_SAFE_CALL(AMGX_config_add_parameters(&cfg, "exception_handling=1"));
+
+    // create an AmgX resource object, only the first instance is in charge
+    if (count == 1) AMGX_resources_create(&rsrc, cfg, &gpuWorld, 1, &devID);
+
+    // create AmgX vector object for unknowns and RHS
+    AMGX_vector_create(&AmgXP, rsrc, mode);
+    AMGX_vector_create(&AmgXRHS, rsrc, mode);
+
+    // create AmgX matrix object for unknowns and RHS
+    AMGX_matrix_create(&AmgXA, rsrc, mode);
+
+    // create an AmgX solver object
+    AMGX_solver_create(&solver, rsrc, mode, cfg);
+
+    // obtain the default number of rings based on current configuration
+    AMGX_config_get_default_number_of_rings(cfg, &ring);
+
+    PetscFunctionReturn(0);
+}
+
+
+/* \implements AmgXSolver::finalize */
+PetscErrorCode AmgXSolver::finalize()
+{
+    PetscErrorCode      ierr;
+
+    PetscFunctionBeginUser;
+
+    // skip if this instance has not been initialized
+    if (! isInitialized)
+    {
+        ierr = PetscPrintf(PETSC_COMM_WORLD,
+                "This AmgXWrapper has not been initialized. "
+                "Please initialize it before finalization.\n"); CHK;
+
+        PetscFunctionReturn(0);
+    }
+
+    // only processes using GPU are required to destroy AmgX content
+    if (gpuProc == 0)
+    {
+        // destroy solver instance
+        AMGX_solver_destroy(solver);
+
+        // destroy matrix instance
+        AMGX_matrix_destroy(AmgXA);
+
+        // destroy RHS and unknown vectors
+        AMGX_vector_destroy(AmgXP);
+        AMGX_vector_destroy(AmgXRHS);
+
+        // only the last instance need to destroy resource and finalizing AmgX
+        if (count == 1)
+        {
+            AMGX_resources_destroy(rsrc);
+            AMGX_SAFE_CALL(AMGX_config_destroy(cfg));
+
+            AMGX_SAFE_CALL(AMGX_finalize_plugins());
+            AMGX_SAFE_CALL(AMGX_finalize());
+        }
+        else
+        {
+            AMGX_config_destroy(cfg);
+        }
+
+        // destroy gpuWorld
+        ierr = MPI_Comm_free(&gpuWorld); CHK;
+    }
+
+    finalizeConsolidation();
+
+    // re-set necessary variables in case users want to reuse
+    // the variable of this instance for a new instance
+    gpuProc = MPI_UNDEFINED;
+    ierr = MPI_Comm_free(&globalCpuWorld); CHK;
+    ierr = MPI_Comm_free(&localCpuWorld); CHK;
+    ierr = MPI_Comm_free(&devWorld); CHK;
+
+    // decrease the number of instances
+    count -= 1;
+
+    // change status
+    isInitialized = false;
+
+    PetscFunctionReturn(0);
+}
+
+
+/* \implements AmgXSolver::setA */
+PetscErrorCode AmgXSolver::setA(
+    const PetscInt nGlobalRows,
+    const PetscInt nLocalRows,
+    const PetscInt nLocalNz,
+    const PetscInt* rowOffsets,
+    const PetscInt* colIndicesGlobal,
+    const PetscScalar* values,
+    const PetscInt* partData)
+{
+    PetscFunctionBeginUser;
+
+    // Merge the distributed matrix for MPI processes sharing a GPU
+    consolidateMatrix(nLocalRows, nLocalNz, rowOffsets, colIndicesGlobal, values);
+
+    int ierr;
+
+    // upload matrix A to AmgX
+    if (gpuWorld != MPI_COMM_NULL)
+    {
+        ierr = MPI_Barrier(gpuWorld); CHK;
+
+        if (consolidationStatus == ConsolidationStatus::None)
+        {
+            AMGX_matrix_upload_all_global_32(
+                AmgXA, nGlobalRows, nLocalRows, nLocalNz,
+                1, 1, rowOffsets, colIndicesGlobal, values,
+                nullptr, ring, ring, partData);
+        }
+        else
+        {
+            AMGX_matrix_upload_all_global_32(
+                AmgXA, nGlobalRows, nConsRows, nConsNz,
+                1, 1, rowOffsetsCons, colIndicesGlobalCons, valuesCons,
+                nullptr, ring, ring, partData);
+
+            // The rowOffsets and colIndices are no longer needed
+            freeConsStructure();
+        }
+
+        // bind the matrix A to the solver
+        ierr = MPI_Barrier(gpuWorld); CHK;
+        AMGX_solver_setup(solver, AmgXA);
+
+        // connect (bind) vectors to the matrix
+        AMGX_vector_bind(AmgXP, AmgXA);
+        AMGX_vector_bind(AmgXRHS, AmgXA);
+    }
+    ierr = MPI_Barrier(globalCpuWorld); CHK;
+
+    PetscFunctionReturn(0);
+}
+
+
+/* \implements AmgXSolver::updateA */
+PetscErrorCode AmgXSolver::updateA(
+    const PetscInt nLocalRows,
+    const PetscInt nLocalNz,
+    const PetscScalar* values)
+{
+    PetscFunctionBeginUser;
+
+    // Merges the values from multiple MPI processes sharing a single GPU
+    reconsolidateValues(nLocalNz, values);
+
+    int ierr;
+    // Replace the coefficients for the CSR matrix A within AmgX
+    if (gpuWorld != MPI_COMM_NULL)
+    {
+        ierr = MPI_Barrier(gpuWorld); CHK;
+
+        if (consolidationStatus == ConsolidationStatus::None)
+        {
+            AMGX_matrix_replace_coefficients(AmgXA, nLocalRows, nLocalNz, values, nullptr);
+        }
+        else
+        {
+            AMGX_matrix_replace_coefficients(AmgXA, nConsRows, nConsNz, valuesCons, nullptr);
+        }
+
+        ierr = MPI_Barrier(gpuWorld); CHK;
+
+        // Re-setup the solver (a reduced overhead setup that accounts for consistent matrix structure)
+        AMGX_solver_resetup(solver, AmgXA);
+    }
+
+    ierr = MPI_Barrier(globalCpuWorld); CHK;
+
+    PetscFunctionReturn(0);
+}
+
+
 /* \implements AmgXSolver::solve */
 PetscErrorCode AmgXSolver::solve(Vec& p, Vec &b, const int nRows)
 {
@@ -171,30 +458,6 @@ PetscErrorCode AmgXSolver::solve(Vec& p, Vec &b, const int nRows)
 }
 
 
-/* \implements AmgXSolver::setMode */
-PetscErrorCode AmgXSolver::setMode(const std::string &modeStr)
-{
-    PetscFunctionBeginUser;
-
-    if (modeStr == "dDDI")
-        mode = AMGX_mode_dDDI;
-    else if (modeStr == "dDFI")
-        mode = AMGX_mode_dDFI;
-    else if (modeStr == "dFFI")
-        mode = AMGX_mode_dFFI;
-    else if (modeStr[0] == 'h')
-        SETERRQ1(MPI_COMM_WORLD, PETSC_ERR_ARG_WRONG,
-                "CPU mode, %s, is not supported in this wrapper!",
-                modeStr.c_str());
-    else
-        SETERRQ1(MPI_COMM_WORLD, PETSC_ERR_ARG_WRONG,
-                "%s is not an available mode! Available modes are: "
-                "dDDI, dDFI, dFFI.\n", modeStr.c_str());
-
-    PetscFunctionReturn(0);
-}
-
-
 /* \implements AmgXSolver::getIters */
 PetscErrorCode AmgXSolver::getIters(int &iter)
 {
diff --git a/src/init.cu b/src/init.cu
deleted file mode 100644
index d6afcd7578c7db434e907a9a2343c2d47b5f1352..0000000000000000000000000000000000000000
--- a/src/init.cu
+++ /dev/null
@@ -1,336 +0,0 @@
-/**
- * \file init.cpp
- * \brief Definition of some member functions of the class AmgXSolver.
- * \author Pi-Yueh Chuang (pychuang@gwu.edu)
- * \date 2016-01-08
- * \copyright Copyright (c) 2015-2019 Pi-Yueh Chuang, Lorena A. Barba.
- *            This project is released under MIT License.
- */
-
-
-// CUDA
-# include <cuda_runtime.h>
-
-// AmgXSolver
-# include "AmgXSolver.H"
-# include <iostream>
-
-/* \implements AmgXSolver::AmgXSolver */
-AmgXSolver::AmgXSolver(const MPI_Comm &comm,
-        const std::string &modeStr, const std::string &cfgFile)
-{
-    initialize(comm, modeStr, cfgFile);
-}
-
-
-/* \implements AmgXSolver::~AmgXSolver */
-AmgXSolver::~AmgXSolver()
-{
-    if (isInitialized) finalize();
-}
-
-
-/* \implements AmgXSolver::initialize */
-PetscErrorCode AmgXSolver::initialize(const MPI_Comm &comm,
-        const std::string &modeStr, const std::string &cfgFile)
-{
-    PetscErrorCode      ierr;
-
-    PetscFunctionBeginUser;
-
-    // if this instance has already been initialized, skip
-    if (isInitialized) SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE,
-            "This AmgXSolver instance has been initialized on this process.");
-
-    // increase the number of AmgXSolver instances
-    count += 1;
-
-    // get the name of this node
-    int     len;
-    char    name[MPI_MAX_PROCESSOR_NAME];
-    ierr = MPI_Get_processor_name(name, &len); CHK;
-    nodeName = name;
-
-    // get the mode of AmgX solver
-    ierr = setMode(modeStr); CHK;
-
-    // initialize communicators and corresponding information
-    ierr = initMPIcomms(comm); CHK;
-
-    // only processes in gpuWorld are required to initialize AmgX
-    if (gpuProc == 0)
-    {
-        ierr = initAmgX(cfgFile); CHK;
-    }
-
-    // a bool indicating if this instance is initialized
-    isInitialized = true;
-
-    consolidationStatus = ConsolidationStatus::Uninitialized;
-
-    PetscFunctionReturn(0);
-}
-
-
-/* \implements AmgXSolver::initMPIcomms */
-PetscErrorCode AmgXSolver::initMPIcomms(const MPI_Comm &comm)
-{
-    PetscErrorCode      ierr;
-
-    PetscFunctionBeginUser;
-
-    // duplicate the global communicator
-    ierr = MPI_Comm_dup(comm, &globalCpuWorld); CHK;
-    ierr = MPI_Comm_set_name(globalCpuWorld, "globalCpuWorld"); CHK;
-
-    // get size and rank for global communicator
-    ierr = MPI_Comm_size(globalCpuWorld, &globalSize); CHK;
-    ierr = MPI_Comm_rank(globalCpuWorld, &myGlobalRank); CHK;
-
-
-    // Get the communicator for processors on the same node (local world)
-    ierr = MPI_Comm_split_type(globalCpuWorld,
-            MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &localCpuWorld); CHK;
-    ierr = MPI_Comm_set_name(localCpuWorld, "localCpuWorld"); CHK;
-
-    // get size and rank for local communicator
-    ierr = MPI_Comm_size(localCpuWorld, &localSize); CHK;
-    ierr = MPI_Comm_rank(localCpuWorld, &myLocalRank); CHK;
-
-
-    // set up the variable nDevs
-    ierr = setDeviceCount(); CHK;
-
-
-    // set up corresponding ID of the device used by each local process
-    ierr = setDeviceIDs(); CHK;
-    ierr = MPI_Barrier(globalCpuWorld); CHK;
-
-
-    // split the global world into a world involved in AmgX and a null world
-    ierr = MPI_Comm_split(globalCpuWorld, gpuProc, 0, &gpuWorld); CHK;
-
-    // get size and rank for the communicator corresponding to gpuWorld
-    if (gpuWorld != MPI_COMM_NULL)
-    {
-        ierr = MPI_Comm_set_name(gpuWorld, "gpuWorld"); CHK;
-        ierr = MPI_Comm_size(gpuWorld, &gpuWorldSize); CHK;
-        ierr = MPI_Comm_rank(gpuWorld, &myGpuWorldRank); CHK;
-    }
-    else // for those can not communicate with GPU devices
-    {
-        gpuWorldSize = MPI_UNDEFINED;
-        myGpuWorldRank = MPI_UNDEFINED;
-    }
-
-    // split local world into worlds corresponding to each CUDA device
-    ierr = MPI_Comm_split(localCpuWorld, devID, 0, &devWorld); CHK;
-    ierr = MPI_Comm_set_name(devWorld, "devWorld"); CHK;
-
-    // get size and rank for the communicator corresponding to myWorld
-    ierr = MPI_Comm_size(devWorld, &devWorldSize); CHK;
-    ierr = MPI_Comm_rank(devWorld, &myDevWorldRank); CHK;
-
-    ierr = MPI_Barrier(globalCpuWorld); CHK;
-
-    return 0;
-}
-
-
-/* \implements AmgXSolver::setDeviceCount */
-PetscErrorCode AmgXSolver::setDeviceCount()
-{
-    PetscFunctionBeginUser;
-
-    // get the number of devices that AmgX solvers can use
-    switch (mode)
-    {
-        case AMGX_mode_dDDI: // for GPU cases, nDevs is the # of local GPUs
-        case AMGX_mode_dDFI: // for GPU cases, nDevs is the # of local GPUs
-        case AMGX_mode_dFFI: // for GPU cases, nDevs is the # of local GPUs
-            // get the number of total cuda devices
-            CHECK(cudaGetDeviceCount(&nDevs));
-            std::cout << "Number of GPU devices  :: " << nDevs << std::endl;
-
-            // Check whether there is at least one CUDA device on this node
-            if (nDevs == 0) SETERRQ1(MPI_COMM_WORLD, PETSC_ERR_SUP_SYS,
-                    "There is no CUDA device on the node %s !\n", nodeName.c_str());
-            break;
-        case AMGX_mode_hDDI: // for CPU cases, nDevs is the # of local processes
-        case AMGX_mode_hDFI: // for CPU cases, nDevs is the # of local processes
-        case AMGX_mode_hFFI: // for CPU cases, nDevs is the # of local processes
-        default:
-            nDevs = localSize;
-            break;
-    }
-
-    PetscFunctionReturn(0);
-}
-
-
-/* \implements AmgXSolver::setDeviceIDs */
-PetscErrorCode AmgXSolver::setDeviceIDs()
-{
-    PetscFunctionBeginUser;
-
-    PetscErrorCode      ierr;
-
-    // set the ID of device that each local process will use
-    if (nDevs == localSize) // # of the devices and local precosses are the same
-    {
-        devID = myLocalRank;
-        gpuProc = 0;
-    }
-    else if (nDevs > localSize) // there are more devices than processes
-    {
-        ierr = PetscPrintf(localCpuWorld, "CUDA devices on the node %s "
-                "are more than the MPI processes launched. Only %d CUDA "
-                "devices will be used.\n", nodeName.c_str(), localSize); CHK;
-
-        devID = myLocalRank;
-        gpuProc = 0;
-    }
-    else // there more processes than devices
-    {
-        int     nBasic = localSize / nDevs,
-                nRemain = localSize % nDevs;
-
-        if (myLocalRank < (nBasic+1)*nRemain)
-        {
-            devID = myLocalRank / (nBasic + 1);
-            if (myLocalRank % (nBasic + 1) == 0)  gpuProc = 0;
-        }
-        else
-        {
-            devID = (myLocalRank - (nBasic+1)*nRemain) / nBasic + nRemain;
-            if ((myLocalRank - (nBasic+1)*nRemain) % nBasic == 0) gpuProc = 0;
-        }
-    }
-
-    // Set the device for each rank
-    cudaSetDevice(devID);
-
-    PetscFunctionReturn(0);
-}
-
-
-/* \implements AmgXSolver::initAmgX */
-PetscErrorCode AmgXSolver::initAmgX(const std::string &cfgFile)
-{
-    PetscFunctionBeginUser;
-
-    // only the first instance (AmgX solver) is in charge of initializing AmgX
-    if (count == 1)
-    {
-        // initialize AmgX
-        AMGX_SAFE_CALL(AMGX_initialize());
-
-        // intialize AmgX plugings
-        AMGX_SAFE_CALL(AMGX_initialize_plugins());
-
-        // only the master process can output something on the screen
-        AMGX_SAFE_CALL(AMGX_register_print_callback(
-                    [](const char *msg, int length)->void
-                    {PetscPrintf(PETSC_COMM_WORLD, "%s", msg);}));
-
-        // let AmgX to handle errors returned
-        AMGX_SAFE_CALL(AMGX_install_signal_handler());
-    }
-
-    // create an AmgX configure object
-    AMGX_SAFE_CALL(AMGX_config_create_from_file(&cfg, cfgFile.c_str()));
-
-    // let AmgX handle returned error codes internally
-    AMGX_SAFE_CALL(AMGX_config_add_parameters(&cfg, "exception_handling=1"));
-
-    // create an AmgX resource object, only the first instance is in charge
-    if (count == 1) AMGX_resources_create(&rsrc, cfg, &gpuWorld, 1, &devID);
-
-    // create AmgX vector object for unknowns and RHS
-    AMGX_vector_create(&AmgXP, rsrc, mode);
-    AMGX_vector_create(&AmgXRHS, rsrc, mode);
-
-    // create AmgX matrix object for unknowns and RHS
-    AMGX_matrix_create(&AmgXA, rsrc, mode);
-
-    // create an AmgX solver object
-    AMGX_solver_create(&solver, rsrc, mode, cfg);
-
-    // obtain the default number of rings based on current configuration
-    AMGX_config_get_default_number_of_rings(cfg, &ring);
-
-    PetscFunctionReturn(0);
-}
-
-
-/* \implements AmgXSolver::finalize */
-PetscErrorCode AmgXSolver::finalize()
-{
-    PetscErrorCode      ierr;
-
-    PetscFunctionBeginUser;
-
-    // skip if this instance has not been initialized
-    if (! isInitialized)
-    {
-        ierr = PetscPrintf(PETSC_COMM_WORLD,
-                "This AmgXWrapper has not been initialized. "
-                "Please initialize it before finalization.\n"); CHK;
-
-        PetscFunctionReturn(0);
-    }
-
-    // only processes using GPU are required to destroy AmgX content
-    if (gpuProc == 0)
-    {
-        // destroy solver instance
-        AMGX_solver_destroy(solver);
-
-        // destroy matrix instance
-        AMGX_matrix_destroy(AmgXA);
-
-        // destroy RHS and unknown vectors
-        AMGX_vector_destroy(AmgXP);
-        AMGX_vector_destroy(AmgXRHS);
-
-        // only the last instance need to destroy resource and finalizing AmgX
-        if (count == 1)
-        {
-            AMGX_resources_destroy(rsrc);
-            AMGX_SAFE_CALL(AMGX_config_destroy(cfg));
-
-            AMGX_SAFE_CALL(AMGX_finalize_plugins());
-            AMGX_SAFE_CALL(AMGX_finalize());
-        }
-        else
-        {
-            AMGX_config_destroy(cfg);
-        }
-
-        // destroy gpuWorld
-        ierr = MPI_Comm_free(&gpuWorld); CHK;
-    }
-
-    finalizeConsolidation();
-
-    // destroy PETSc objects
-    ierr = VecScatterDestroy(&scatterLhs); CHK;
-    ierr = VecScatterDestroy(&scatterRhs); CHK;
-    ierr = VecDestroy(&redistLhs); CHK;
-    ierr = VecDestroy(&redistRhs); CHK;
-
-    // re-set necessary variables in case users want to reuse
-    // the variable of this instance for a new instance
-    gpuProc = MPI_UNDEFINED;
-    ierr = MPI_Comm_free(&globalCpuWorld); CHK;
-    ierr = MPI_Comm_free(&localCpuWorld); CHK;
-    ierr = MPI_Comm_free(&devWorld); CHK;
-
-    // decrease the number of instances
-    count -= 1;
-
-    // change status
-    isInitialized = false;
-
-    PetscFunctionReturn(0);
-}
diff --git a/src/setA.cu b/src/setA.cu
deleted file mode 100644
index 1007e11e92522436ba6f288698f3f96aebd2ecba..0000000000000000000000000000000000000000
--- a/src/setA.cu
+++ /dev/null
@@ -1,490 +0,0 @@
-/**
- * \file setA.cpp
- * \brief Definition of member functions regarding setting A in AmgXSolver.
- * \author Pi-Yueh Chuang (pychuang@gwu.edu)
- * \date 2016-01-08
- * \copyright Copyright (c) 2015-2019 Pi-Yueh Chuang, Lorena A. Barba.
- * \copyright Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
- *            This project is released under MIT License.
- */
-
-
-// STD
-# include <cstring>
-# include <algorithm>
-#include  <iostream>
-
-// AmgXSolver
-# include "AmgXSolver.H"
-
-
-/* \implements AmgXSolver::setA */
-PetscErrorCode AmgXSolver::setA(const Mat &A)
-{
-    PetscFunctionBeginUser;
-
-    PetscErrorCode      ierr;
-
-    Mat                 localA;
-
-    IS                  devIS;
-
-    PetscInt            nGlobalRows,
-                        nLocalRows;
-    PetscBool           usesOffsets;
-
-    std::vector<PetscInt>       row;
-    std::vector<PetscInt64>     col;
-    std::vector<PetscScalar>    data;
-    std::vector<PetscInt>       partData;
-
-
-    // get number of rows in global matrix
-    ierr = MatGetSize(A, &nGlobalRows, nullptr); CHK;
-
-    // get the row indices of redistributed matrix owned by processes in gpuWorld
-    ierr = getDevIS(A, devIS); CHK;
-
-    // get sequential local portion of redistributed matrix
-    ierr = getLocalA(A, devIS, localA); CHK;
-
-    // get compressed row layout of the local Mat
-    ierr = getLocalMatRawData(localA, nLocalRows, row, col, data); CHK;
-
-    // destroy local matrix
-    ierr = destroyLocalA(A, localA); CHK;
-
-    // get a partition vector required by AmgX
-    ierr = getPartData(devIS, nGlobalRows, partData, usesOffsets); CHK;
-
-
-    // upload matrix A to AmgX
-    if (gpuWorld != MPI_COMM_NULL)
-    {
-        ierr = MPI_Barrier(gpuWorld); CHK;
-        // offsets need to be 64 bit, since we use 64 bit column indices
-        std::vector<PetscInt64> offsets;
-
-        AMGX_distribution_handle dist;
-        AMGX_distribution_create(&dist, cfg);
-        if (usesOffsets) {
-            offsets.assign(partData.begin(), partData.end());
-            AMGX_distribution_set_partition_data(dist, AMGX_DIST_PARTITION_OFFSETS, offsets.data());
-        } else {
-            AMGX_distribution_set_partition_data(dist, AMGX_DIST_PARTITION_VECTOR, partData.data());
-        }
-
-        AMGX_matrix_upload_distributed(
-                AmgXA, nGlobalRows, nLocalRows, row[nLocalRows],
-                1, 1, row.data(), col.data(), data.data(),
-                nullptr, dist);
-        AMGX_distribution_destroy(dist);
-
-        // bind the matrix A to the solver
-        ierr = MPI_Barrier(gpuWorld); CHK;
-        AMGX_solver_setup(solver, AmgXA);
-
-        // connect (bind) vectors to the matrix
-        AMGX_vector_bind(AmgXP, AmgXA);
-        AMGX_vector_bind(AmgXRHS, AmgXA);
-    }
-    ierr = MPI_Barrier(globalCpuWorld); CHK;
-
-    // destroy temporary PETSc objects
-    ierr = ISDestroy(&devIS); CHK;
-
-    PetscFunctionReturn(0);
-}
-
-
-/* \implements AmgXSolver::getDevIS */
-PetscErrorCode AmgXSolver::getDevIS(const Mat &A, IS &devIS)
-{
-    PetscFunctionBeginUser;
-
-    PetscErrorCode      ierr;
-    IS                  tempIS;
-
-    // get index sets of A locally owned by each process
-    // note that devIS is now a serial IS on each process
-    ierr = MatGetOwnershipIS(A, &devIS, nullptr); CHK;
-
-    // concatenate index sets that belong to the same devWorld
-    // note that now devIS is a parallel IS of communicator devWorld
-    ierr = ISOnComm(devIS, devWorld, PETSC_USE_POINTER, &tempIS); CHK;
-    ierr = ISDestroy(&devIS); CHK;
-
-    // all gather in order to have all indices belong to a devWorld on the
-    // leading rank of that devWorld. THIS IS NOT EFFICIENT!!
-    // note that now devIS is again a serial IS on each process
-    ierr = ISAllGather(tempIS, &devIS); CHK;
-    ierr = ISDestroy(&tempIS); CHK;
-
-    // empty devIS on ranks other than the leading ranks in each devWorld
-    if (myDevWorldRank != 0)
-        ierr = ISGeneralSetIndices(devIS, 0, nullptr, PETSC_COPY_VALUES); CHK;
-
-    // devIS is not guaranteed to be sorted. We sort it here.
-    ierr = ISSort(devIS); CHK;
-
-    PetscFunctionReturn(0);
-}
-
-
-/* \implements AmgXSolver::getLocalA */
-PetscErrorCode AmgXSolver::getLocalA(const Mat &A, const IS &devIS, Mat &localA)
-{
-    PetscFunctionBeginUser;
-
-    PetscErrorCode      ierr;
-    MatType             type;
-
-    // get the Mat type
-    ierr = MatGetType(A, &type); CHK;
-
-    // check whether the Mat type is supported
-    if (std::strcmp(type, MATSEQAIJ) == 0) // sequential AIJ
-    {
-        // make localA point to the same memory space as A does
-        localA = A;
-    }
-    else if (std::strcmp(type, MATMPIAIJ) == 0)
-    {
-        Mat                 tempA;
-
-        // redistribute matrix and also get corresponding scatters.
-        ierr = redistMat(A, devIS, tempA); CHK;
-
-        // get local matrix from redistributed matrix
-        ierr = MatMPIAIJGetLocalMat(tempA, MAT_INITIAL_MATRIX, &localA); CHK;
-
-        // destroy redistributed matrix
-        if (tempA == A)
-        {
-            tempA = nullptr;
-        }
-        else
-        {
-            ierr = MatDestroy(&tempA); CHK;
-        }
-    }
-    else
-    {
-        SETERRQ1(globalCpuWorld, PETSC_ERR_ARG_WRONG,
-                "Mat type %s is not supported!\n", type);
-    }
-
-    PetscFunctionReturn(0);
-}
-
-
-/* \implements AmgXSolver::redistMat */
-PetscErrorCode AmgXSolver::redistMat(const Mat &A, const IS &devIS, Mat &newA)
-{
-    PetscFunctionBeginUser;
-
-    PetscErrorCode      ierr;
-
-    if (gpuWorldSize == globalSize) // no redistributation required
-    {
-        newA = A;
-    }
-    else
-    {
-        IS      is;
-
-        // re-set the communicator of devIS to globalCpuWorld
-        ierr = ISOnComm(devIS, globalCpuWorld, PETSC_USE_POINTER, &is); CHK;
-
-        // redistribute the matrix A to newA
-        ierr = MatGetSubMatrix(A, is, is, MAT_INITIAL_MATRIX, &newA); CHK;
-
-        // get VecScatters between original data layout and the new one
-        ierr = getVecScatter(A, newA, is); CHK;
-
-        // destroy the temporary IS
-        ierr = ISDestroy(&is); CHK;
-    }
-
-    PetscFunctionReturn(0);
-}
-
-
-/* \implements AmgXSolver::getVecScatter */
-PetscErrorCode AmgXSolver::getVecScatter(
-        const Mat &A1, const Mat &A2, const IS &devIS)
-{
-    PetscFunctionBeginUser;
-
-    PetscErrorCode      ierr;
-
-    Vec                 tempLhs;
-    Vec                 tempRhs;
-
-    ierr = MatCreateVecs(A1, &tempLhs, &tempRhs); CHK;
-    ierr = MatCreateVecs(A2, &redistLhs, &redistRhs); CHK;
-
-    ierr = VecScatterCreate(tempLhs, devIS, redistLhs, devIS, &scatterLhs); CHK;
-    ierr = VecScatterCreate(tempRhs, devIS, redistRhs, devIS, &scatterRhs); CHK;
-
-    ierr = VecDestroy(&tempRhs); CHK;
-    ierr = VecDestroy(&tempLhs); CHK;
-
-    PetscFunctionReturn(0);
-}
-
-
-/* \implements AmgXSolver::getLocalMatRawData */
-PetscErrorCode AmgXSolver::getLocalMatRawData(const Mat &localA,
-        PetscInt &localN, std::vector<PetscInt> &row,
-        std::vector<PetscInt64> &col, std::vector<PetscScalar> &data)
-{
-    PetscFunctionBeginUser;
-
-    PetscErrorCode      ierr;
-
-    const PetscInt      *rawCol,
-                        *rawRow;
-
-    PetscScalar         *rawData;
-
-    PetscInt            rawN;
-
-    PetscBool           done;
-
-    // get row and column indices in compressed row format
-    ierr = MatGetRowIJ(localA, 0, PETSC_FALSE, PETSC_FALSE,
-            &rawN, &rawRow, &rawCol, &done); CHK;
-
-    // rawN will be returned by MatRestoreRowIJ, so we have to copy it
-    localN = rawN;
-
-    // check if the function worked
-    if (! done)
-        SETERRQ(globalCpuWorld, PETSC_ERR_SIG, "MatGetRowIJ did not work!");
-
-    // get data
-    ierr = MatSeqAIJGetArray(localA, &rawData); CHK;
-
-    // copy values to STL vector. Note: there is an implicit conversion from
-    // PetscInt to PetscInt64 for the column vector
-    col.assign(rawCol, rawCol+rawRow[localN]);
-    row.assign(rawRow, rawRow+localN+1);
-    data.assign(rawData, rawData+rawRow[localN]);
-
-
-    // return ownership of memory space to PETSc
-    ierr = MatRestoreRowIJ(localA, 0, PETSC_FALSE, PETSC_FALSE,
-            &rawN, &rawRow, &rawCol, &done); CHK;
-
-    // check if the function worked
-    if (! done)
-        SETERRQ(globalCpuWorld, PETSC_ERR_SIG, "MatRestoreRowIJ did not work!");
-
-    // return ownership of memory space to PETSc
-    ierr = MatSeqAIJRestoreArray(localA, &rawData); CHK;
-
-    PetscFunctionReturn(0);
-}
-
-
-/* \implements AmgXSolver::destroyLocalA */
-PetscErrorCode AmgXSolver::destroyLocalA(const Mat &A, Mat &localA)
-{
-    PetscFunctionBeginUser;
-
-    PetscErrorCode      ierr;
-
-    MatType             type;
-
-    // Get the Mat type
-    ierr = MatGetType(A, &type); CHK;
-
-    // when A is sequential, we can not destroy the memory space
-    if (std::strcmp(type, MATSEQAIJ) == 0)
-    {
-        localA = nullptr;
-    }
-    // for parallel case, localA can be safely destroyed
-    else if (std::strcmp(type, MATMPIAIJ) == 0)
-    {
-        ierr = MatDestroy(&localA); CHK;
-    }
-
-    PetscFunctionReturn(0);
-}
-
-/* \implements AmgXSolver::checkForContiguousPartitioning */
-PetscErrorCode AmgXSolver::checkForContiguousPartitioning(
-    const IS &devIS, PetscBool &isContiguous, std::vector<PetscInt> &partOffsets)
-{
-    PetscFunctionBeginUser;
-    PetscErrorCode      ierr;
-    PetscBool sorted;
-    PetscInt ismax= -2; // marker for "unsorted", allows to check after global sort
-
-    ierr = ISSorted(devIS, &sorted); CHK;
-    if (sorted)
-    {
-        ierr = ISGetMinMax(devIS, NULL, &ismax); CHK;
-    }
-    partOffsets.resize(gpuWorldSize);
-    ++ismax; // add 1 to allow reusing gathered ismax values as partition offsets for AMGX
-    MPI_Allgather(&ismax, 1, MPIU_INT, &partOffsets[0], 1, MPIU_INT, gpuWorld);
-    bool all_sorted = std::is_sorted(partOffsets.begin(), partOffsets.end())
-                        && partOffsets[0] != -1;
-    if (all_sorted)
-    {
-        partOffsets.insert(partOffsets.begin(), 0); // partition 0 always starts at 0
-        isContiguous = PETSC_TRUE;
-    }
-    else
-    {
-        isContiguous = PETSC_FALSE;
-    }
-    PetscFunctionReturn(0);
-}
-
-
-/* \implements AmgXSolver::getPartData */
-PetscErrorCode AmgXSolver::getPartData(
-        const IS &devIS, const PetscInt &N, std::vector<PetscInt> &partData, PetscBool &usesOffsets)
-{
-    PetscFunctionBeginUser;
-
-    PetscErrorCode      ierr;
-
-    VecScatter          scatter;
-    Vec                 tempMPI,
-                        tempSEQ;
-
-    PetscInt            n;
-    PetscScalar         *tempPartVec;
-
-    ierr = ISGetLocalSize(devIS, &n); CHK;
-
-    if (gpuWorld != MPI_COMM_NULL)
-    {
-        // check if sorted/contiguous, then we can skip expensive scatters
-        checkForContiguousPartitioning(devIS, usesOffsets, partData);
-        if (!usesOffsets)
-        {
-            ierr = VecCreateMPI(gpuWorld, n, N, &tempMPI); CHK;
-
-            IS      is;
-            ierr = ISOnComm(devIS, gpuWorld, PETSC_USE_POINTER, &is); CHK;
-            ierr = VecISSet(tempMPI, is, (PetscScalar) myGpuWorldRank); CHK;
-            ierr = ISDestroy(&is); CHK;
-
-            ierr = VecScatterCreateToAll(tempMPI, &scatter, &tempSEQ); CHK;
-            ierr = VecScatterBegin(scatter,
-                    tempMPI, tempSEQ, INSERT_VALUES, SCATTER_FORWARD); CHK;
-            ierr = VecScatterEnd(scatter,
-                    tempMPI, tempSEQ, INSERT_VALUES, SCATTER_FORWARD); CHK;
-            ierr = VecScatterDestroy(&scatter); CHK;
-            ierr = VecDestroy(&tempMPI); CHK;
-
-            ierr = VecGetArray(tempSEQ, &tempPartVec); CHK;
-
-            partData.assign(tempPartVec, tempPartVec+N);
-
-            ierr = VecRestoreArray(tempSEQ, &tempPartVec); CHK;
-
-            ierr = VecDestroy(&tempSEQ); CHK;
-        }
-    }
-    ierr = MPI_Barrier(globalCpuWorld); CHK;
-
-    PetscFunctionReturn(0);
-}
-
-
-/* \implements AmgXSolver::setA */
-PetscErrorCode AmgXSolver::setA(
-    const PetscInt nGlobalRows,
-    const PetscInt nLocalRows,
-    const PetscInt nLocalNz,
-    const PetscInt* rowOffsets,
-    const PetscInt* colIndicesGlobal,
-    const PetscScalar* values,
-    const PetscInt* partData)
-{
-    PetscFunctionBeginUser;
-
-    // Merge the distributed matrix for MPI processes sharing a GPU
-    consolidateMatrix(nLocalRows, nLocalNz, rowOffsets, colIndicesGlobal, values);
-
-    int ierr;
-
-    // upload matrix A to AmgX
-    if (gpuWorld != MPI_COMM_NULL)
-    {
-        ierr = MPI_Barrier(gpuWorld); CHK;
-
-        if (consolidationStatus == ConsolidationStatus::None)
-        {
-            AMGX_matrix_upload_all_global_32(
-                AmgXA, nGlobalRows, nLocalRows, nLocalNz,
-                1, 1, rowOffsets, colIndicesGlobal, values,
-                nullptr, ring, ring, partData);
-        }
-        else
-        {
-            AMGX_matrix_upload_all_global_32(
-                AmgXA, nGlobalRows, nConsRows, nConsNz,
-                1, 1, rowOffsetsCons, colIndicesGlobalCons, valuesCons,
-                nullptr, ring, ring, partData);
-
-            // The rowOffsets and colIndices are no longer needed
-            freeConsStructure();
-        }
-
-        // bind the matrix A to the solver
-        ierr = MPI_Barrier(gpuWorld); CHK;
-        AMGX_solver_setup(solver, AmgXA);
-
-        // connect (bind) vectors to the matrix
-        AMGX_vector_bind(AmgXP, AmgXA);
-        AMGX_vector_bind(AmgXRHS, AmgXA);
-    }
-    ierr = MPI_Barrier(globalCpuWorld); CHK;
-
-    PetscFunctionReturn(0);
-}
-
-/* \implements AmgXSolver::updateA */
-PetscErrorCode AmgXSolver::updateA(
-    const PetscInt nLocalRows,
-    const PetscInt nLocalNz,
-    const PetscScalar* values)
-{
-    PetscFunctionBeginUser;
-
-    // Merges the values from multiple MPI processes sharing a single GPU
-    reconsolidateValues(nLocalNz, values);
-
-    int ierr;
-    // Replace the coefficients for the CSR matrix A within AmgX
-    if (gpuWorld != MPI_COMM_NULL)
-    {
-        ierr = MPI_Barrier(gpuWorld); CHK;
-
-        if (consolidationStatus == ConsolidationStatus::None)
-        {
-            AMGX_matrix_replace_coefficients(AmgXA, nLocalRows, nLocalNz, values, nullptr);
-        }
-        else
-        {
-            AMGX_matrix_replace_coefficients(AmgXA, nConsRows, nConsNz, valuesCons, nullptr);
-        }
-
-        ierr = MPI_Barrier(gpuWorld); CHK;
-
-        // Re-setup the solver (a reduced overhead setup that accounts for consistent matrix structure)
-        AMGX_solver_resetup(solver, AmgXA);
-    }
-
-    ierr = MPI_Barrier(globalCpuWorld); CHK;
-
-    PetscFunctionReturn(0);
-}