From de8f34e44d312d05b78ec252f4d26693f9b5e067 Mon Sep 17 00:00:00 2001
From: Matt Martineau <mmartineau@nvidia.com>
Date: Tue, 25 May 2021 08:29:55 -0700
Subject: [PATCH] Merged consolidation and transformation to reduce the amount
 of work performed by ranks that do not own the GPU to purely memory copies.

---
 Make/files               |   1 -
 src/AmgXCSRMatrix.H      | 149 ++++++++-
 src/AmgXCSRMatrix.cu     | 645 ++++++++++++++++++++++++++++-----------
 src/AmgXConsolidation.cu | 490 -----------------------------
 src/AmgXMPIComms.cu      |  12 +-
 src/AmgXSolver.H         | 147 ++-------
 src/AmgXSolver.cu        | 286 ++++++++---------
 7 files changed, 753 insertions(+), 977 deletions(-)
 delete mode 100644 src/AmgXConsolidation.cu

diff --git a/Make/files b/Make/files
index 5662037..fe81f6d 100644
--- a/Make/files
+++ b/Make/files
@@ -1,5 +1,4 @@
 src/AmgXCSRMatrix.cu
-src/AmgXConsolidation.cu  
 src/AmgXMPIComms.cu
 src/AmgXSolver.cu
 
diff --git a/src/AmgXCSRMatrix.H b/src/AmgXCSRMatrix.H
index f015567..0257cf7 100644
--- a/src/AmgXCSRMatrix.H
+++ b/src/AmgXCSRMatrix.H
@@ -22,6 +22,28 @@
 
 #pragma once
 
+#include <vector>
+#include <mpi.h>
+#include <cuda_runtime.h>
+
+/** \brief A set of handles to the device data storing a consolidated CSR matrix. */
+struct ConsolidationHandles
+{
+    cudaIpcMemHandle_t rhsConsHandle;
+    cudaIpcMemHandle_t solConsHandle;
+    cudaIpcMemHandle_t rowIndicesConsHandle;
+    cudaIpcMemHandle_t colIndicesConsHandle;
+    cudaIpcMemHandle_t valuesConsHandle;
+};
+
+/** \brief Enumeration for the status of matrix consolidation for the solver.*/
+enum class ConsolidationStatus
+{
+    Uninitialised,
+    None,
+    Device
+};
+
 class AmgXCSRMatrix
 {
     public:
@@ -59,9 +81,13 @@ class AmgXCSRMatrix
             const double *extVals
         );
 
+        void initialiseComms(
+            MPI_Comm devWorld,
+            int gpuProc);
+
         const int* getColIndices() const
         {
-            return colIndices;
+            return colIndicesGlobal;
         }
 
         const int* getRowOffsets() const
@@ -74,6 +100,36 @@ class AmgXCSRMatrix
             return values;
         }
 
+        double* getPCons()
+        {
+            return pCons;
+        }
+
+        double* getRHSCons()
+        {
+            return rhsCons;
+        }
+
+        const int* getRowDispls() const
+        {
+            return rowDispls.data();
+        }
+
+        int getNConsRows() const
+        {
+            return nConsRows;
+        }
+
+        int getNConsNz() const
+        {
+            return nConsNz + nConsExtNz;
+        }
+
+        bool isConsolidated() const
+        {
+            return consolidationStatus == ConsolidationStatus::Device;
+        }
+
         // Discard elements of the matrix structure
         void discardStructure();
 
@@ -81,16 +137,87 @@ class AmgXCSRMatrix
         void finalise();
 
     private:
+
+        void initialiseConsolidation(
+            const int nLocalRows,
+            const int nLocalNz,
+            const int nInternalFaces,
+            const int nExtNz,
+            int*& rowIndicesTmp,
+            int*& colIndicesTmp);
+
+        void finaliseConsolidation();
+
         // CSR device data for AmgX matrix
-        int *colIndices;
-        int *rowOffsets;
-        double *values;
-
-        // Temporary storage for the permutation
-        double *valuesTmp;
-        
-        // Permutation array to convert between LDU and CSR
-        // Also possibly encodes sorting of columns
-        int *ldu2csrPerm;
+        int *colIndicesGlobal = nullptr;
+
+        int *rowOffsets = nullptr;
+
+        double *values = nullptr;
+
+        /** \brief Temporary storage for the permutation. */
+        double *valuesTmp = nullptr;
+
+        /** \brief The consolidated solution vector. */
+        double* pCons = nullptr;
+
+        /** \brief The consolidated right hand side vector. */
+        double* rhsCons = nullptr;
+
+        /** \brief A flag indicating the type of consolidation applied, if any.
+         * This will be consistent for all ranks within a devWorld. */
+        ConsolidationStatus consolidationStatus = ConsolidationStatus::Uninitialised;
+
+        /** \brief Permutation array to convert between LDU and CSR. */
+        /** Also possibly encodes sorting of columns. */
+        int *ldu2csrPerm = nullptr;
+
+        /** \brief The number of non-zeros consolidated from multiple ranks to a device.*/
+        int nConsNz = 0;
+
+        /** \brief The number of rows consolidated from multiple ranks to a device.*/
+        int nConsRows = 0;
+
+        /** \brief The number of internal faces consolidation. */
+        int nConsInternalFaces = 0;
+
+        /** \brief The number of non zeros consolidated. */
+        int nConsExtNz = 0;
+
+        /** \brief The number of rows per rank associated with a single device.*/
+        std::vector<int> nRowsInDevWorld {};
+
+        /** \brief The number of non-zeros per rank associated with a single device.*/
+        std::vector<int> nnzInDevWorld {};
+
+        /** \brief The number of internal faces per rank associated with a single device.*/
+        std::vector<int> nInternalFacesInDevWorld {};
+
+        /** \brief The row displacements per rank associated with a single device.*/
+        std::vector<int> rowDispls {};
+
+        /** \brief The internal face count displacements per rank associated with a single device.*/
+        std::vector<int> internalFacesDispls {};
+
+        /** \brief The non-zero displacements per rank associated with a single device.*/
+        std::vector<int> nzDispls {};
+
+        /** \brief The number of external non zeros per rank associated with a single device.*/
+        std::vector<int> nExtNzInDevWorld {};
+
+        /** \brief The external non zero displacements per rank associated with a single device.*/
+        std::vector<int> extNzDispls {};
+
+        /** \brief A communicator for processes sharing the same device. */
+        MPI_Comm devWorld = nullptr;
+
+        /** \brief A flag indicating if this process will send compute requests to a device. */
+        int gpuProc = MPI_UNDEFINED;
+
+        /** \brief Size of \ref AmgXSolver::devWorld "devWorld". */
+        int devWorldSize = 0;
+
+        /** \brief Rank in \ref AmgXSolver::devWorld "devWorld". */
+        int myDevWorldRank = 0;
 };
 
diff --git a/src/AmgXCSRMatrix.cu b/src/AmgXCSRMatrix.cu
index bb05e9d..05c5a76 100644
--- a/src/AmgXCSRMatrix.cu
+++ b/src/AmgXCSRMatrix.cu
@@ -27,6 +27,9 @@
 #include <thrust/scan.h>
 #include <thrust/sequence.h>
 #include <thrust/execution_policy.h>
+#include <numeric>
+
+#include <mpi.h>
 
 #define CHECK(call)                                              \
     {                                                            \
@@ -39,94 +42,213 @@
     }
 
 // Offset the column indices to transform from local to global
-__global__ void localToGlobalColIndices
-(
+__global__ void fixConsolidatedRowIndices(
+    const int nInternalFaces,
+    const int nifDisp,
+    const int nExtNz,
+    const int extDisp,
+    const int nConsRows,
+    const int nConsInternalFaces,
+    const int offset,
+    int *rowIndices)
+{
+    int i = threadIdx.x + blockIdx.x * blockDim.x;
+
+    if (i < nInternalFaces)
+    {
+        rowIndices[nConsRows + nifDisp + i] += offset;
+        rowIndices[nConsRows + nConsInternalFaces + nifDisp + i] += offset;
+    }
+
+    if (i < nExtNz)
+    {
+        rowIndices[nConsRows + 2 * nConsInternalFaces + extDisp + i] += offset;
+    }
+}
+
+// Offset the column indices to transform from local to global
+__global__ void localToGlobalColIndices(
     const int nnz,
-    const int nrows,
+    const int nLocalRows,
     const int nInternalFaces,
+    const int rowDisp,
+    const int nifDisp,
+    const int nConsRows,
+    const int nConsInternalFaces,
     const int diagIndexGlobal,
     const int lowOffGlobal,
     const int uppOffGlobal,
-    int *colIndices
-)
+    int *colIndicesGlobal)
 {
-    // Offset by global offset
-    for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < nnz; i += blockDim.x * gridDim.x)
-    {
-        int offset;
+    int i = threadIdx.x + blockIdx.x * blockDim.x;
 
-        // Depending upon the row, different offsets must be applied
-        if (i < nrows)
-        {
-            offset = diagIndexGlobal;
-        }
-        else if (i < nrows + nInternalFaces)
-        {
-            offset = uppOffGlobal;
-        }
-        else
-        {
-            offset = lowOffGlobal;
-        }
+    if (i < nLocalRows)
+    {
+        colIndicesGlobal[rowDisp + i] += diagIndexGlobal;
+    }
 
-        colIndices[i] += offset;
+    if (i < nInternalFaces)
+    {
+        colIndicesGlobal[nConsRows + nifDisp + i] += uppOffGlobal;
+        colIndicesGlobal[nConsRows + nConsInternalFaces + nifDisp + i] += lowOffGlobal;
     }
 }
 
 // Apply the pre-existing permutation to the values [and columns]
-__global__ void applyPermutation
-(
-    const int totalNnz,
+__global__ void applyPermutation(
+    const int nTotalNz,
     const int *perm,
     const int *colIndicesTmp,
     const double *valuesTmp,
-    int *colIndices,
+    int *colIndicesGlobal,
     double *values,
-    bool valuesOnly
-)
+    bool valuesOnly)
 {
-    // Permute col indices and values
-    for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < totalNnz; i += blockDim.x * gridDim.x)
-    {
-        int p = perm[i];
+    int i = threadIdx.x + blockIdx.x * blockDim.x;
+    if(i >= nTotalNz) return;
 
-        // In the values only case the column indices and row offsets remain fixed so no update
-        if (!valuesOnly)
-        {
-            colIndices[i] = colIndicesTmp[p];
-        }
+    int p = perm[i];
 
-        values[i] = valuesTmp[p];
+    // In the values only case the column indices and row offsets remain fixed so no update
+    if (!valuesOnly)
+    {
+        colIndicesGlobal[i] = colIndicesTmp[p];
     }
+
+    values[i] = valuesTmp[p];
 }
 
 // Flatten the row indices into the row offsets
-__global__ void createRowOffsets
-(
+__global__ void createRowOffsets(
     int nnz,
-    int nrows, 
     int *rowIndices,
-    int *rowOffsets
-)
+    int *rowOffsets)
 {
     for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < nnz; i += blockDim.x * gridDim.x)
     {
-        int j = rowIndices[i];
-        atomicAdd(&rowOffsets[j], 1);
+        atomicAdd(&rowOffsets[rowIndices[i]], 1);
     }
 }
 
+void AmgXCSRMatrix::initialiseComms(
+    MPI_Comm devWorld,
+    int gpuProc)
+{
+    this->devWorld = devWorld;
+    this->gpuProc = gpuProc;
+
+    MPI_Comm_rank(this->devWorld, &myDevWorldRank);
+    MPI_Comm_size(this->devWorld, &devWorldSize);
+}
+
+// Initialises the consolidation feature
+void AmgXCSRMatrix::initialiseConsolidation(
+    const int nLocalRows,
+    const int nLocalNz,
+    const int nInternalFaces,
+    const int nExtNz,
+    int*& rowIndicesTmp,
+    int*& colIndicesTmp)
+{
+    // Consolidation has been previously used, must deallocate the structures
+    if (consolidationStatus != ConsolidationStatus::Uninitialised)
+    {
+        finaliseConsolidation();
+    }
+
+    // Check if only one rank is associated with the device of devWorld, and exit
+    // early since no consolidation is done in this case:
+    if (devWorldSize == 1)
+    {
+        // This value will be the same for all ranks within devWorld
+        consolidationStatus = ConsolidationStatus::None;
+
+        // Allocate data only
+        CHECK(cudaMalloc((void **)&rowIndicesTmp, (nLocalNz + nExtNz) * sizeof(int)));
+        CHECK(cudaMalloc((void **)&colIndicesTmp, (nLocalNz + nExtNz) * sizeof(int)));
+        CHECK(cudaMalloc((void **)&valuesTmp, (nLocalNz + nExtNz) * sizeof(double)));
+        return;
+    }
+
+    nRowsInDevWorld.resize(devWorldSize);
+    nnzInDevWorld.resize(devWorldSize);
+    nInternalFacesInDevWorld.resize(devWorldSize);
+    nExtNzInDevWorld.resize(devWorldSize);
+
+    rowDispls.resize(devWorldSize + 1, 0);
+    nzDispls.resize(devWorldSize + 1, 0);
+    internalFacesDispls.resize(devWorldSize + 1, 0);
+    extNzDispls.resize(devWorldSize + 1, 0);
+
+    // Fetch to all the number of local rows and non zeros on each rank
+    MPI_Request reqs[4] = { MPI_REQUEST_NULL};
+    MPI_Iallgather(&nLocalRows, 1, MPI_INT, nRowsInDevWorld.data(), 1, MPI_INT, devWorld, &reqs[0]);
+    MPI_Iallgather(&nLocalNz, 1, MPI_INT, nnzInDevWorld.data(), 1, MPI_INT, devWorld, &reqs[1]);
+    MPI_Iallgather(&nInternalFaces, 1, MPI_INT, nInternalFacesInDevWorld.data(), 1, MPI_INT, devWorld, &reqs[2]);
+    MPI_Iallgather(&nExtNz, 1, MPI_INT, nExtNzInDevWorld.data(), 1, MPI_INT, devWorld, &reqs[3]);
+    MPI_Waitall(4, reqs, MPI_STATUSES_IGNORE);
+
+    // Calculate consolidate number of rows, non-zeros, and calculate row, non-zero displacements
+    std::partial_sum(nRowsInDevWorld.begin(), nRowsInDevWorld.end(), rowDispls.begin() + 1);
+    std::partial_sum(nnzInDevWorld.begin(), nnzInDevWorld.end(), nzDispls.begin() + 1);
+    std::partial_sum(nInternalFacesInDevWorld.begin(), nInternalFacesInDevWorld.end(), internalFacesDispls.begin() + 1);
+    std::partial_sum(nExtNzInDevWorld.begin(), nExtNzInDevWorld.end(), extNzDispls.begin() + 1);
+
+    nConsNz = nzDispls[devWorldSize];
+    nConsRows = rowDispls[devWorldSize];
+    nConsInternalFaces = internalFacesDispls[devWorldSize];
+    nConsExtNz = extNzDispls[devWorldSize];
+
+    // Consolidate the CSR matrix data from multiple ranks sharing a single GPU to a
+    // root rank, allowing multiple ranks per GPU. This allows overdecomposing the problem
+    // when there are more CPU cores than GPUs, without the inefficiences of performing
+    // the linear solve on multiple separate domains.
+    ConsolidationHandles handles;
+
+    // The data is already on the GPU so consolidate there
+    if (gpuProc == 0)
+    {
+        // We are consolidating data that already exists on the GPU
+        CHECK(cudaMalloc((void **)&rhsCons, sizeof(double) * nConsRows));
+        CHECK(cudaMalloc((void **)&pCons, sizeof(double) * nConsRows));
+        CHECK(cudaMalloc((void **)&rowIndicesTmp, sizeof(int) * (nConsNz + nConsExtNz)));
+        CHECK(cudaMalloc((void **)&colIndicesTmp, sizeof(int) * (nConsNz + nConsExtNz)));
+        CHECK(cudaMalloc((void **)&valuesTmp, sizeof(double) * (nConsNz + nConsExtNz)));
+
+        CHECK(cudaIpcGetMemHandle(&handles.rhsConsHandle, rhsCons));
+        CHECK(cudaIpcGetMemHandle(&handles.solConsHandle, pCons));
+        CHECK(cudaIpcGetMemHandle(&handles.rowIndicesConsHandle, rowIndicesTmp));
+        CHECK(cudaIpcGetMemHandle(&handles.colIndicesConsHandle, colIndicesTmp));
+        CHECK(cudaIpcGetMemHandle(&handles.valuesConsHandle, valuesTmp));
+    }
+
+    MPI_Bcast(&handles, sizeof(ConsolidationHandles), MPI_BYTE, 0, devWorld);
+
+    // Open memory handles to the consolidated matrix data owned by the gpu owning process
+    if (gpuProc == MPI_UNDEFINED)
+    {
+        CHECK(cudaIpcOpenMemHandle((void **)&rhsCons, handles.rhsConsHandle, cudaIpcMemLazyEnablePeerAccess));
+        CHECK(cudaIpcOpenMemHandle((void **)&pCons, handles.solConsHandle, cudaIpcMemLazyEnablePeerAccess));
+        CHECK(cudaIpcOpenMemHandle((void **)&rowIndicesTmp, handles.rowIndicesConsHandle, cudaIpcMemLazyEnablePeerAccess));
+        CHECK(cudaIpcOpenMemHandle((void **)&colIndicesTmp, handles.colIndicesConsHandle, cudaIpcMemLazyEnablePeerAccess));
+        CHECK(cudaIpcOpenMemHandle((void **)&valuesTmp, handles.valuesConsHandle, cudaIpcMemLazyEnablePeerAccess));
+    }
+
+    // This value will be the same for all ranks within devWorld
+    consolidationStatus = ConsolidationStatus::Device;
+}
+
 // Perform the conversion between an LDU matrix and a CSR matrix, possibly distributed
 void AmgXCSRMatrix::setValuesLDU
 (
-    int nrows,
+    int nLocalRows,
     int nInternalFaces,
     int diagIndexGlobal,
     int lowOffGlobal,
     int uppOffGlobal,
     const int *upperAddr,
     const int *lowerAddr,
-    const int extNnz,
+    const int nExtNz,
     const int *extRow,
     const int *extCol,
     const double *diagVals,
@@ -136,179 +258,340 @@ void AmgXCSRMatrix::setValuesLDU
 )
 {
     // Determine the local non-zeros from the internal faces
-    int localNnz = nrows + 2 * nInternalFaces;
+    int nLocalNz = nLocalRows + 2 * nInternalFaces;
+    int *rowIndicesTmp;
+    int *permTmp;
+    int *colIndicesTmp;
 
-    // Add external non-zeros (communicated halo entries)
-    int totalNnz = localNnz + extNnz;
+    initialiseConsolidation(nLocalRows, nLocalNz, nInternalFaces, nExtNz, rowIndicesTmp, colIndicesTmp);
 
-    // Generate unpermuted index list [0, ..., totalNnz-1]
-    int *permTmp;
-    CHECK(cudaMalloc(&permTmp, sizeof(int) * totalNnz));
-    thrust::sequence(thrust::device, permTmp, permTmp + totalNnz, 0);
+    int nTotalNz = 0;
+    int nRows = 0;
 
-    // Fill rowIndicesTmp with [0, ..., n-1], lowerAddr, upperAddr, (extAddr)
-    int *rowIndicesTmp;
-    CHECK(cudaMalloc(&rowIndicesTmp, sizeof(int) * totalNnz));
-    CHECK(cudaMemcpy(rowIndicesTmp, permTmp, nrows * sizeof(int), cudaMemcpyDefault));
-    CHECK(cudaMemcpy(rowIndicesTmp + nrows, lowerAddr, nInternalFaces * sizeof(int), cudaMemcpyDefault));
-    CHECK(cudaMemcpy(rowIndicesTmp + nrows + nInternalFaces, upperAddr, nInternalFaces * sizeof(int), cudaMemcpyDefault));
-    if (extNnz > 0)
+    std::vector<int> diagIndexGlobalAll(devWorldSize);
+    std::vector<int> lowOffGlobalAll(devWorldSize);
+    std::vector<int> uppOffGlobalAll(devWorldSize);
+
+    switch (consolidationStatus)
     {
-        CHECK(cudaMemcpy(rowIndicesTmp + localNnz, extRow, extNnz * sizeof(int), cudaMemcpyDefault));
-    }
 
-    // Make space for the row indices and stored permutation
-    int *rowIndices;
-    CHECK(cudaMalloc(&rowIndices, sizeof(int) * totalNnz));
-    CHECK(cudaMalloc(&ldu2csrPerm, sizeof(int) * totalNnz));
-
-    // Sort the row indices and store results in the permutation
-    void *tempStorage = NULL;
-    size_t tempStorageBytes = 0;
-    cub::DeviceRadixSort::SortPairs(tempStorage, tempStorageBytes, rowIndicesTmp, rowIndices, permTmp, ldu2csrPerm, totalNnz);
-    CHECK(cudaMalloc(&tempStorage, tempStorageBytes));
-    cub::DeviceRadixSort::SortPairs(tempStorage, tempStorageBytes, rowIndicesTmp, rowIndices, permTmp, ldu2csrPerm, totalNnz);
-    CHECK(cudaFree(permTmp));
-    CHECK(cudaFree(tempStorage));
-
-    // Make space for the row offsets
-    CHECK(cudaMalloc(&rowOffsets, sizeof(int) * (nrows + 1)));
-    CHECK(cudaMemset(rowOffsets, 0, sizeof(int) * (nrows + 1)));
-
-    // XXX Taking the non zero per row data from the host could be more
-    // efficient, experiment with this in the future
-    //cudaMemcpy(rowOffsets, nz_per_row, sizeof(int) * nrows, cudaMemcpyDefault);
-    //thrust::exclusive_scan(thrust::device, rowOffsets, rowOffsets + nrows + 1, rowOffsets);
-
-    // Convert the row indices into offsets
-    constexpr int nthreads = 128;
-    int nblocks = totalNnz / nthreads + 1;
-    createRowOffsets<<<nblocks, nthreads>>>(totalNnz, nrows, rowIndices, rowOffsets);
-    thrust::exclusive_scan(thrust::device, rowOffsets, rowOffsets + nrows + 1, rowOffsets);
-    CHECK(cudaFree(rowIndices));
-
-    // Fill rowIndicesTmp with diagVals, upperVals, lowerVals, (extVals)
-    CHECK(cudaMalloc(&valuesTmp, totalNnz * sizeof(double)));
-    CHECK(cudaMemcpy(valuesTmp, diagVals, nrows * sizeof(double), cudaMemcpyDefault));
-    CHECK(cudaMemcpy(valuesTmp + nrows, upperVals, nInternalFaces * sizeof(double), cudaMemcpyDefault));
-    // symmetric matrices
-    if (lowerVals == upperVals)
+    case ConsolidationStatus::None:
+    {
+        nTotalNz = nLocalNz + nExtNz;
+        nRows = nLocalRows;
+
+        // Generate unpermuted index list [0, ..., nTotalNz-1]
+        CHECK(cudaMalloc(&permTmp, sizeof(int) * nTotalNz));
+        thrust::sequence(thrust::device, permTmp, permTmp + nTotalNz, 0);
+
+        // Fill rowIndicesTmp with [0, ..., n-1], lowerAddr, upperAddr, (extAddr)
+        thrust::sequence(thrust::device, rowIndicesTmp, rowIndicesTmp + nRows, 0);
+        CHECK(cudaMemcpy(rowIndicesTmp + nRows, lowerAddr, nInternalFaces * sizeof(int), cudaMemcpyDefault));
+        CHECK(cudaMemcpy(rowIndicesTmp + nRows + nInternalFaces, upperAddr, nInternalFaces * sizeof(int), cudaMemcpyDefault));
+        if (nExtNz > 0)
+        {
+            CHECK(cudaMemcpy(rowIndicesTmp + nLocalNz, extRow, nExtNz * sizeof(int), cudaMemcpyDefault));
+        }
+
+        // Concat [0, ..., n-1], upperAddr, lowerAddr (note switched) into column indices
+        thrust::sequence(thrust::device, colIndicesTmp, colIndicesTmp + nRows, 0);
+        CHECK(cudaMemcpy(colIndicesTmp + nRows, rowIndicesTmp + nRows + nInternalFaces, nInternalFaces * sizeof(int), cudaMemcpyDefault));
+        CHECK(cudaMemcpy(colIndicesTmp + nRows + nInternalFaces, rowIndicesTmp + nRows, nInternalFaces * sizeof(int), cudaMemcpyDefault));
+        if (nExtNz > 0)
+        {
+            CHECK(cudaMemcpy(colIndicesTmp + nLocalNz, extCol, nExtNz * sizeof(int), cudaMemcpyDefault));
+        }
+
+        // Fill valuesTmp with diagVals, upperVals, lowerVals, (extVals)
+        CHECK(cudaMemcpy(valuesTmp, diagVals, nRows * sizeof(double), cudaMemcpyDefault));
+        CHECK(cudaMemcpy(valuesTmp + nRows, upperVals, nInternalFaces * sizeof(double), cudaMemcpyDefault));
+        CHECK(cudaMemcpy(valuesTmp + nRows + nInternalFaces, lowerVals, nInternalFaces * sizeof(double), cudaMemcpyDefault));
+        if (nExtNz > 0)
+        {
+            CHECK(cudaMemcpy(valuesTmp + nLocalNz, extVals, nExtNz * sizeof(double), cudaMemcpyDefault));
+        }
+        break;
+    }
+    case ConsolidationStatus::Device:
     {
-        CHECK(cudaMemcpy(valuesTmp + nrows + nInternalFaces, 
-                         valuesTmp + nrows, 
-                         nInternalFaces * sizeof(double), 
-                         cudaMemcpyDefault));    
+        nTotalNz = nConsNz + nConsExtNz;
+        nRows = nConsRows;
+
+        // Copy the data to the consolidation buffer
+        // Fill rowIndicesTmp with lowerAddr, upperAddr, (extAddr)
+        CHECK(cudaMemcpy(rowIndicesTmp + nConsRows + internalFacesDispls[myDevWorldRank], lowerAddr, nInternalFaces * sizeof(int), cudaMemcpyDefault));
+        CHECK(cudaMemcpy(rowIndicesTmp + nConsRows + nConsInternalFaces + internalFacesDispls[myDevWorldRank], upperAddr, nInternalFaces * sizeof(int), cudaMemcpyDefault));
+
+        // Fill valuesTmp with diagVals, upperVals, lowerVals, (extVals)
+        CHECK(cudaMemcpy(valuesTmp + rowDispls[myDevWorldRank], diagVals, nLocalRows * sizeof(double), cudaMemcpyDefault));
+        CHECK(cudaMemcpy(valuesTmp + nConsRows + internalFacesDispls[myDevWorldRank], upperVals, nInternalFaces * sizeof(double), cudaMemcpyDefault));
+        CHECK(cudaMemcpy(valuesTmp + nConsRows + nConsInternalFaces + internalFacesDispls[myDevWorldRank], lowerVals, nInternalFaces * sizeof(double), cudaMemcpyDefault));
+        if (nExtNz > 0)
+        {
+            CHECK(cudaMemcpy(rowIndicesTmp + nConsNz + extNzDispls[myDevWorldRank], extRow, nExtNz * sizeof(int), cudaMemcpyDefault));
+            CHECK(cudaMemcpy(colIndicesTmp + nConsNz + extNzDispls[myDevWorldRank], extCol, nExtNz * sizeof(int), cudaMemcpyDefault));
+            CHECK(cudaMemcpy(valuesTmp + nConsNz + extNzDispls[myDevWorldRank], extVals, nExtNz * sizeof(double), cudaMemcpyDefault));
+        }
+
+        // cudaMemcpy does not block the host in the cases above, device to device copies,
+        // so sychronize with device to ensure operation is complete. Barrier on all devWorld
+        // ranks to ensure full arrays are populated before the root process uses the data.
+        CHECK(cudaDeviceSynchronize());
+        int ierr = MPI_Barrier(devWorld);
+
+        if (gpuProc == 0)
+        {
+            // Generate unpermuted index list [0, ..., nTotalNz]
+            CHECK(cudaMalloc(&permTmp, sizeof(int) * nTotalNz));
+            thrust::sequence(thrust::device, permTmp, permTmp + nTotalNz, 0);
+
+            // Concat [0, ..., n-1], upperAddr, lowerAddr (note switched) into column indices
+            thrust::sequence(thrust::device, rowIndicesTmp, rowIndicesTmp + nConsRows, 0);
+            CHECK(cudaMemcpy(colIndicesTmp + nConsRows, rowIndicesTmp + nConsRows + nConsInternalFaces, nConsInternalFaces * sizeof(int), cudaMemcpyDefault));
+            CHECK(cudaMemcpy(colIndicesTmp + nConsRows + nConsInternalFaces, rowIndicesTmp + nConsRows, nConsInternalFaces * sizeof(int), cudaMemcpyDefault));
+
+            for (int i = 0; i < devWorldSize; ++i)
+            {
+                int nrows = nRowsInDevWorld[i];
+                int rowDisp = rowDispls[i];
+
+                thrust::sequence(thrust::device, colIndicesTmp + rowDisp, colIndicesTmp + rowDisp + nrows, 0);
+
+                // Skip first as offset 0
+                if (i > 0)
+                {
+                    int nif = nInternalFacesInDevWorld[i];
+                    int nifDisp = internalFacesDispls[i];
+                    int extDisp = extNzDispls[i];
+                    int nenz = nExtNzInDevWorld[i];
+
+                    // Adjust rowIndices so that they are correct for the consolidated matrix
+                    int nthreads = 128;
+                    int nblocks = nif / nthreads + 1;
+                    fixConsolidatedRowIndices<<<nblocks, nthreads>>>(nif, nifDisp, nenz, extDisp, nConsRows, nConsInternalFaces, rowDisp, rowIndicesTmp);
+                }
+            }
+        }
+        else
+        {
+            // Close IPC handles and deallocate for consolidation
+            CHECK(cudaIpcCloseMemHandle(rowIndicesTmp));
+            CHECK(cudaIpcCloseMemHandle(colIndicesTmp));
+        }
+
+        MPI_Request reqs[3] = { MPI_REQUEST_NULL };
+        MPI_Igather(&diagIndexGlobal, 1, MPI_INT, diagIndexGlobalAll.data(), 1, MPI_INT, 0, devWorld, &reqs[0]);
+        MPI_Igather(&lowOffGlobal, 1, MPI_INT, lowOffGlobalAll.data(), 1, MPI_INT, 0, devWorld, &reqs[1]);
+        MPI_Igather(&uppOffGlobal, 1, MPI_INT, uppOffGlobalAll.data(), 1, MPI_INT, 0, devWorld, &reqs[2]);
+        MPI_Waitall(3, reqs, MPI_STATUSES_IGNORE);
+
+        break;
     }
-    else
+    case ConsolidationStatus::Uninitialised:
     {
-        CHECK(cudaMemcpy(valuesTmp + nrows + nInternalFaces, 
-                         lowerVals, 
-                         nInternalFaces * sizeof(double), 
-                         cudaMemcpyDefault));
+        fprintf(stderr, "Attempting to consolidate before consolidation is initialised.\n");
+        break;
     }
-    if (extNnz > 0)
+    default:
     {
-        CHECK(cudaMemcpy(valuesTmp + localNnz, 
-                         extVals, 
-                         extNnz * sizeof(double), 
-                         cudaMemcpyDefault));
+        fprintf(stderr, "Incorrect consolidation status set.\n");
+        break;
+    }
     }
 
-    // Concat [0, ..., n-1], upperAddr, lowerAddr (note switched) into column indices
-    int *colIndicesTmp;
-    CHECK(cudaMalloc(&colIndicesTmp, totalNnz * sizeof(int)));
-    CHECK(cudaMemcpy(colIndicesTmp, rowIndicesTmp, nrows * sizeof(int), cudaMemcpyDefault));
-    CHECK(cudaMemcpy(colIndicesTmp + nrows, rowIndicesTmp + nrows + nInternalFaces, nInternalFaces * sizeof(int), cudaMemcpyDefault));
-    CHECK(cudaMemcpy(colIndicesTmp + nrows + nInternalFaces, rowIndicesTmp + nrows, nInternalFaces * sizeof(int), cudaMemcpyDefault));
-    if (extNnz > 0)
+    if (gpuProc == 0)
     {
-        CHECK(cudaMemcpy(colIndicesTmp + localNnz, extCol, extNnz * sizeof(int), cudaMemcpyDefault));
-    }
+        // Make space for the row indices and stored permutation
+        int *rowIndices;
+        CHECK(cudaMalloc(&rowIndices, sizeof(int) * nTotalNz));
+        CHECK(cudaMalloc(&ldu2csrPerm, sizeof(int) * nTotalNz));
+
+        cub::DoubleBuffer<int> d_keys(rowIndicesTmp, rowIndices);
+        cub::DoubleBuffer<int> d_values(permTmp, ldu2csrPerm);
+
+        // Sort the row indices and store results in the permutation
+        void *tempStorage = NULL;
+        size_t tempStorageBytes = 0;
+        cub::DeviceRadixSort::SortPairs(tempStorage, tempStorageBytes, d_keys, d_values, nTotalNz);
+        CHECK(cudaMalloc(&tempStorage, tempStorageBytes));
+        cub::DeviceRadixSort::SortPairs(tempStorage, tempStorageBytes, d_keys, d_values, nTotalNz);
+        if(tempStorageBytes > 0)
+        {
+            CHECK(cudaFree(tempStorage));
+        }
+
+        // Fetch the invalid pointers from the CUB ping pong buffers and de-alloc
+        CHECK(cudaFree(d_keys.Alternate()));
+        CHECK(cudaFree(d_values.Alternate()));
+
+        // Fetch the correct pointers from the CUB ping pong buffers
+        rowIndices = d_keys.Current();
+        ldu2csrPerm = d_values.Current();
+
+        // Make space for the row offsets
+        CHECK(cudaMalloc(&rowOffsets, sizeof(int) * (nRows + 1)));
+        CHECK(cudaMemset(rowOffsets, 0, sizeof(int) * (nRows + 1)));
 
-    CHECK(cudaFree(rowIndicesTmp));
+        // Convert the row indices into offsets
+        constexpr int nthreads = 128;
+        int nblocks = nTotalNz / nthreads + 1;
+        createRowOffsets<<<nblocks, nthreads>>>(nTotalNz, rowIndices, rowOffsets);
+        thrust::exclusive_scan(thrust::device, rowOffsets, rowOffsets + nRows + 1, rowOffsets);
+        CHECK(cudaFree(rowIndices));
+
+        // Transform the local column indices to global column indices
+        if (isConsolidated())
+        {
+            for (int i = 0; i < devWorldSize; ++i)
+            {
+                nblocks = nnzInDevWorld[i] / nthreads + 1;
+                localToGlobalColIndices<<<nblocks, nthreads>>>(nnzInDevWorld[i], nRowsInDevWorld[i], nInternalFacesInDevWorld[i],
+                                                               rowDispls[i], internalFacesDispls[i], nConsRows,
+                                                               nConsInternalFaces, diagIndexGlobalAll[i],
+                                                               lowOffGlobalAll[i], uppOffGlobalAll[i], colIndicesTmp);
+            }
+        }
+        else
+        {
+            nblocks = nLocalNz / nthreads + 1;
+            localToGlobalColIndices<<<nblocks, nthreads>>>(nLocalNz, nLocalRows, nInternalFaces, 0, 0,
+                                                           nLocalRows, nInternalFaces, diagIndexGlobal,
+                                                           lowOffGlobal, uppOffGlobal, colIndicesTmp);
+        }
 
-    // Construct the global column indices
-    nblocks = localNnz / nthreads + 1;
-    localToGlobalColIndices<<<nblocks, nthreads>>>(localNnz, nrows, nInternalFaces, diagIndexGlobal, lowOffGlobal, uppOffGlobal, colIndicesTmp);
+        // Allocate space to store the permuted column indices and values
+        CHECK(cudaMalloc(&colIndicesGlobal, sizeof(int) * nTotalNz));
+        CHECK(cudaMalloc(&values, sizeof(double) * nTotalNz));
 
-    // Allocate space to store the permuted column indices and values
-    CHECK(cudaMalloc(&colIndices, sizeof(int) * totalNnz));
-    CHECK(cudaMalloc(&values, sizeof(double) * totalNnz));
+        // Swap column indices based on the pre-determined permutation
+        nblocks = nTotalNz / nthreads + 1;
+        applyPermutation<<<nblocks, nthreads>>>(nTotalNz, ldu2csrPerm, colIndicesTmp, valuesTmp, colIndicesGlobal, values, false);
 
-    // Swap column indices based on the pre-determined permutation
-    nblocks = totalNnz / nthreads + 1;
-    applyPermutation<<<nblocks, nthreads>>>(totalNnz, ldu2csrPerm, colIndicesTmp, valuesTmp, colIndices, values, false);
-    CHECK(cudaFree(colIndicesTmp));
+        CHECK(cudaFree(colIndicesTmp));
+    }
 }
 
 // Updates the values based on the previously determined permutation
 void AmgXCSRMatrix::updateValues
 (
-    const int nrows,
+    const int nLocalRows,
     const int nInternalFaces,
-    const int extNnz,
-    const double *diagVal,
-    const double *uppVal,
-    const double *lowVal,
-    const double *extVal
+    const int nExtNz,
+    const double *diagVals,
+    const double *uppVals,
+    const double *lowVals,
+    const double *extVals
 )
 {
-    // Determine the local non-zeros from the internal faces
-    int localNnz = nrows + 2 * nInternalFaces;
-
     // Add external non-zeros (communicated halo entries)
-    int totalNnz = localNnz + extNnz;
+    int nTotalNz;
 
-    // Copy the values in [ diag, upper, lower, (external) ]
-    CHECK(cudaMemcpy(valuesTmp, diagVal, sizeof(double) * nrows, cudaMemcpyDefault));
-    CHECK(cudaMemcpy(valuesTmp + nrows, uppVal, sizeof(double) * nInternalFaces, cudaMemcpyDefault));
-    // symmetric matrices
-    if (lowVal == uppVal)
+    if (isConsolidated())
     {
-        CHECK(cudaMemcpy(valuesTmp + nrows + nInternalFaces, 
-                         valuesTmp + nrows, 
-                         sizeof(double) * nInternalFaces, 
-                         cudaMemcpyDefault));
+        nTotalNz = (nConsNz + nConsExtNz);
+
+        // Fill valuesTmp with diagVals, upperVals, lowerVals, (extVals)
+        CHECK(cudaMemcpy(valuesTmp + rowDispls[myDevWorldRank], diagVals, nLocalRows * sizeof(double), cudaMemcpyDefault));
+        CHECK(cudaMemcpy(valuesTmp + nConsRows + internalFacesDispls[myDevWorldRank], uppVals, nInternalFaces * sizeof(double), cudaMemcpyDefault));
+        CHECK(cudaMemcpy(valuesTmp + nConsRows + nConsInternalFaces + internalFacesDispls[myDevWorldRank], lowVals, nInternalFaces * sizeof(double), cudaMemcpyDefault));
+
+        if (nExtNz > 0)
+        {
+            CHECK(cudaMemcpy(valuesTmp + nConsNz + extNzDispls[myDevWorldRank], extVals, nExtNz * sizeof(double), cudaMemcpyDefault));
+        }
+
+        // Ensure that all ranks associated with a device have completed prior to the subsequent permutation
+        CHECK(cudaDeviceSynchronize());
+        MPI_Barrier(devWorld);
     }
     else
     {
-        CHECK(cudaMemcpy(valuesTmp + nrows + nInternalFaces, 
-                         lowVal, 
-                         sizeof(double) * nInternalFaces, 
-                         cudaMemcpyDefault));
-    }
-    if (extNnz > 0)
-    {
-        CHECK(cudaMemcpy(valuesTmp + localNnz, 
-                         extVal, 
-                         sizeof(double) * extNnz, 
-                         cudaMemcpyDefault));
+        int nLocalNz = nLocalRows + 2 * nInternalFaces;
+        nTotalNz = nLocalNz + nExtNz;
+
+        // Copy the values in [ diag, upper, lower, (external) ]
+        CHECK(cudaMemcpy(valuesTmp, diagVals, sizeof(double) * nLocalRows, cudaMemcpyDefault));
+        CHECK(cudaMemcpy(valuesTmp + nLocalRows, uppVals, sizeof(double) * nInternalFaces, cudaMemcpyDefault));
+        CHECK(cudaMemcpy(valuesTmp + nLocalRows + nInternalFaces, lowVals, sizeof(double) * nInternalFaces, cudaMemcpyDefault));
+        if (nExtNz > 0)
+        {
+            CHECK(cudaMemcpy(valuesTmp + nLocalNz, extVals, sizeof(double) * nExtNz, cudaMemcpyDefault));
+        }
     }
 
-    constexpr int nthreads = 128;
-    int nblocks = totalNnz / nthreads + 1;
-    applyPermutation<<<nblocks, nthreads>>>(totalNnz, ldu2csrPerm, nullptr, valuesTmp, nullptr, values, true);
-    CHECK(cudaStreamSynchronize(0));
-}
+    if (gpuProc == 0)
+    {
+        constexpr int nthreads = 128;
+        int nblocks = nTotalNz / nthreads + 1;
+        applyPermutation<<<nblocks, nthreads>>>(nTotalNz, ldu2csrPerm, nullptr, valuesTmp, nullptr, values, true);
 
-// XXX Should implement an early abandonment of the
-// unnecessary data for capacity optimisation
-void AmgXCSRMatrix::discardStructure()
-{
+        // Sync to ensure API errors are caught within the API code and avoid any 
+        // issues if users are subsequently using non-blocking streams.
+        CHECK(cudaDeviceSynchronize());
+    }
 }
 
 // Deallocate remaining storage
 void AmgXCSRMatrix::finalise()
 {
-    if(rowOffsets != nullptr)
+    switch (consolidationStatus)
+    {
+
+    case ConsolidationStatus::None:
+    {
+        CHECK(cudaFree(ldu2csrPerm));
         CHECK(cudaFree(rowOffsets));
-    if(colIndices != nullptr)
-        CHECK(cudaFree(colIndices));
-    if(values != nullptr)
+        CHECK(cudaFree(colIndicesGlobal));
         CHECK(cudaFree(values));
-    if(valuesTmp != nullptr)
-        CHECK(cudaFree(valuesTmp));
-    if(ldu2csrPerm != nullptr)
-        CHECK(cudaFree(ldu2csrPerm));
+        break;
+    }
+    case ConsolidationStatus::Uninitialised:
+    {
+        // Consolidation is not required for uninitialised
+        return;
+    }
+    case ConsolidationStatus::Device:
+    {
+        if (gpuProc == 0)
+        {
+            // Deallocate the CSR matrix values, solution and RHS
+            CHECK(cudaFree(pCons));
+            CHECK(cudaFree(rhsCons));
+            CHECK(cudaFree(valuesTmp));
+            CHECK(cudaFree(ldu2csrPerm));
+            CHECK(cudaFree(rowOffsets));
+            CHECK(cudaFree(colIndicesGlobal));
+            CHECK(cudaFree(values));
+        }
+        else
+        {
+            // Close the remaining IPC memory handles
+            CHECK(cudaIpcCloseMemHandle(pCons));
+            CHECK(cudaIpcCloseMemHandle(rhsCons));
+            CHECK(cudaIpcCloseMemHandle(valuesTmp));
+        }
+        break;
+    }
+    default:
+    {
+        fprintf(stderr,
+                "Incorrect consolidation status set.\n");
+        break;
+    }
+    }
+
+    // Free the local GPU partitioning structures
+    if (isConsolidated())
+    {
+        nRowsInDevWorld.clear();
+        nnzInDevWorld.clear();
+        rowDispls.clear();
+        nzDispls.clear();
+    }
+
+    consolidationStatus = ConsolidationStatus::Uninitialised;
 }
 
+void AmgXCSRMatrix::finaliseConsolidation()
+{
+}
diff --git a/src/AmgXConsolidation.cu b/src/AmgXConsolidation.cu
deleted file mode 100644
index 6cb9d93..0000000
--- a/src/AmgXConsolidation.cu
+++ /dev/null
@@ -1,490 +0,0 @@
-/*
- * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- *
- * Permission is hereby granted, free of charge, to any person obtaining a
- * copy of this software and associated documentation files (the "Software"),
- * to deal in the Software without restriction, including without limitation
- * the rights to use, copy, modify, merge, publish, distribute, sublicense,
- * and/or sell copies of the Software, and to permit persons to whom the
- * Software is furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in
- * all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
- * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
- * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
- * DEALINGS IN THE SOFTWARE.
- * \file AmgXConsolidation.cu
- * \brief Definition of member functions related to matrix consolidation.
- * \author Matt Martineau (mmartineau@nvidia.com)
- * \date 2020-07-31
- */
-
-#include <AmgXSolver.H>
-
-#include <numeric>
-
-/*
-    Changes local row offsets to describe the consolidated row space on
-    the root rank.
-*/
-__global__ void fixConsolidatedRowOffsets
-(
-    int nLocalRows, 
-    int offset, 
-    int* rowOffsets)
-{
-    for(int i = threadIdx.x + blockIdx.x*blockDim.x; i < nLocalRows; i += blockDim.x*gridDim.x)
-    {
-        rowOffsets[i] += offset;
-    }
-}
-
-// A set of handles to the device data storing a consolidated CSR matrix
-struct ConsolidationHandles
-{
-    cudaIpcMemHandle_t rhsConsHandle;
-    cudaIpcMemHandle_t solConsHandle;
-    cudaIpcMemHandle_t rowOffsetsConsHandle;
-    cudaIpcMemHandle_t colIndicesConsHandle;
-    cudaIpcMemHandle_t valuesConsHandle;
-};
-
-/* \implements AmgXSolver::initializeConsolidation */
-PetscErrorCode AmgXSolver::initializeConsolidation
-(
-    const PetscInt nLocalRows,
-    const PetscInt nLocalNz,
-    const PetscScalar* values
-)
-{
-    PetscFunctionBeginUser;
-
-    // Check if multiple ranks are associated with a device
-    if (devWorldSize == 1)
-    {
-        consolidationStatus = ConsolidationStatus::None;
-        PetscFunctionReturn(0);
-    }
-
-    nRowsInDevWorld.resize(devWorldSize);
-    nnzInDevWorld.resize(devWorldSize);
-    rowDispls.resize(devWorldSize+1, 0);
-    nzDispls.resize(devWorldSize+1, 0);
-
-    // Fetch to all the number of local rows on each rank
-    MPI_Request req[2];
-    int ierr = MPI_Iallgather(
-        &nLocalRows, 
-        1,
-        MPI_INT, 
-        nRowsInDevWorld.data(), 
-        1, 
-        MPI_INT, 
-        devWorld, 
-        &req[0]); CHK;
-
-    // Fetch to all the number of non zeros on each rank
-    ierr = MPI_Iallgather(
-        &nLocalNz, 
-        1, 
-        MPI_INT, 
-        nnzInDevWorld.data(), 
-        1, 
-        MPI_INT, 
-        devWorld, 
-        &req[1]); CHK;
-    MPI_Waitall(2, req, MPI_STATUSES_IGNORE);
-
-    // Calculate consolidate number of rows, non-zeros, and calculate row, non-zero displacements
-    nConsNz = std::accumulate(nnzInDevWorld.begin(), nnzInDevWorld.end(), 0);
-    nConsRows = std::accumulate(nRowsInDevWorld.begin(), nRowsInDevWorld.end(), 0);
-    std::partial_sum(nRowsInDevWorld.begin(), nRowsInDevWorld.end(), rowDispls.begin()+1);
-    std::partial_sum(nnzInDevWorld.begin(), nnzInDevWorld.end(), nzDispls.begin()+1);
-
-    // Consolidate the CSR matrix data from multiple ranks sharing a single GPU to a
-    // root rank, allowing multiple ranks per GPU. This allows overdecomposing the problem
-    // when there are more CPU cores than GPUs, without the inefficiences of performing
-    // the linear solve on multiple separate domains.
-    // If the data is a device pointer then use IPC handles to perform the intra-GPU
-    // copies from the allocations of different processes operating the same GPU.
-    // Opening the handles as an initialization step means it is not necessary to
-    // repeatedly call cudaIpcOpenMemHandle, which can be expensive.
-    cudaPointerAttributes att;
-    cudaError_t err = cudaPointerGetAttributes(&att, values);
-    if (err != cudaErrorInvalidValue && att.type == cudaMemoryTypeDevice)
-    {
-        ConsolidationHandles handles;
-        // The data is already on the GPU so consolidate there
-        if (gpuProc == 0)
-        {
-            // We are consolidating data that already exists on the GPU
-            CHECK(cudaMalloc((void**)&rhsCons, sizeof(PetscScalar) * nConsRows));
-            CHECK(cudaMalloc((void**)&pCons, sizeof(PetscScalar) * nConsRows));
-            CHECK(cudaMalloc((void**)&rowOffsetsCons, sizeof(PetscInt) * (nConsRows+1)));
-            CHECK(cudaMalloc((void**)&colIndicesGlobalCons, sizeof(PetscInt) * nConsNz));
-            CHECK(cudaMalloc((void**)&valuesCons, sizeof(PetscScalar) * nConsNz));
-
-            CHECK(cudaIpcGetMemHandle(&handles.rhsConsHandle, rhsCons));
-            CHECK(cudaIpcGetMemHandle(&handles.solConsHandle, pCons));
-            CHECK(cudaIpcGetMemHandle(&handles.rowOffsetsConsHandle, rowOffsetsCons));
-            CHECK(cudaIpcGetMemHandle(&handles.colIndicesConsHandle, colIndicesGlobalCons));
-            CHECK(cudaIpcGetMemHandle(&handles.valuesConsHandle, valuesCons));
-        }
-
-        MPI_Bcast(&handles, sizeof(ConsolidationHandles), MPI_BYTE, 0, devWorld);
-
-        if(gpuProc == MPI_UNDEFINED)
-        {
-            CHECK(cudaIpcOpenMemHandle((void**)&rhsCons, handles.rhsConsHandle, cudaIpcMemLazyEnablePeerAccess));
-            CHECK(cudaIpcOpenMemHandle((void**)&pCons, handles.solConsHandle, cudaIpcMemLazyEnablePeerAccess));
-            CHECK(cudaIpcOpenMemHandle((void**)&rowOffsetsCons, handles.rowOffsetsConsHandle, cudaIpcMemLazyEnablePeerAccess));
-            CHECK(cudaIpcOpenMemHandle((void**)&colIndicesGlobalCons, handles.colIndicesConsHandle, cudaIpcMemLazyEnablePeerAccess));
-            CHECK(cudaIpcOpenMemHandle((void**)&valuesCons, handles.valuesConsHandle, cudaIpcMemLazyEnablePeerAccess));
-        }
-
-        consolidationStatus = ConsolidationStatus::Device;
-    }
-    else
-    {
-        if (gpuProc == 0)
-        {
-            // The data is already on the CPU so consolidate there
-            rowOffsetsCons = new PetscInt[nConsRows+1];
-            colIndicesGlobalCons = new PetscInt[nConsNz];
-            valuesCons = new PetscScalar[nConsNz];
-            rhsCons = new PetscScalar[nConsRows];
-            pCons = new PetscScalar[nConsRows];
-        }
-
-        consolidationStatus = ConsolidationStatus::Host;
-    }
-
-    PetscFunctionReturn(0);
-}
-
-/* \implements AmgXSolver::consolidateMatrix */
-PetscErrorCode AmgXSolver::consolidateMatrix(
-    const PetscInt nLocalRows,
-    const PetscInt nLocalNz,
-    const PetscInt* rowOffsets,
-    const PetscInt* colIndicesGlobal,
-    const PetscScalar* values)
-{
-    PetscFunctionBeginUser;
-
-    // Consolidation has been previously used, must deallocate the structures
-    if (consolidationStatus != ConsolidationStatus::Uninitialized)
-    {
-        // XXX Would the maintainers be happy to include a warning message here,
-        // that makes it clear updateA should be preferentially adopted by developers
-        // if the sparsity pattern does not change? This would avoid many costs,
-        // including re-consolidation costs.
-
-        finalizeConsolidation();
-    }
-
-    // Allocate space for the structures required to consolidate
-    initializeConsolidation(nLocalRows, nLocalNz, values);
-
-    switch(consolidationStatus)
-    {
-
-    case ConsolidationStatus::None:
-    {
-        // Consolidation is not required
-        PetscFunctionReturn(0);
-    }
-    case ConsolidationStatus::Uninitialized:
-    {
-        SETERRQ(MPI_COMM_WORLD, PETSC_ERR_SUP_SYS,
-                "Attempting to consolidate before consolidation is initialized.\n");
-        break;
-    }
-    case ConsolidationStatus::Device:
-    {
-        // Copy the data to the consolidation buffer
-        CHECK(cudaMemcpy(&rowOffsetsCons[rowDispls[myDevWorldRank]], rowOffsets, sizeof(PetscInt) * nLocalRows, cudaMemcpyDefault));
-        CHECK(cudaMemcpy(&colIndicesGlobalCons[nzDispls[myDevWorldRank]], colIndicesGlobal, sizeof(PetscInt) * nLocalNz, cudaMemcpyDefault));
-        CHECK(cudaMemcpy(&valuesCons[nzDispls[myDevWorldRank]], values, sizeof(PetscScalar) * nLocalNz, cudaMemcpyDefault));
-
-        // cudaMemcpy does not block the host in the cases above, device to device copies,
-        // so sychronize with device to ensure operation is complete. Barrier on all devWorld
-        // ranks to ensure full arrays are populated before the root process uses the data.
-        CHECK(cudaDeviceSynchronize());
-        int ierr = MPI_Barrier(devWorld); CHK;
-
-        if (gpuProc == 0)
-        {
-            // Adjust merged row offsets so that they are correct for the consolidated matrix
-            for (int i = 1; i < devWorldSize; ++i)
-            {
-                int nthreads = 128;
-                int nblocks = nRowsInDevWorld[i] / nthreads + 1;
-                fixConsolidatedRowOffsets<<<nblocks, nthreads>>>(nRowsInDevWorld[i], nzDispls[i], &rowOffsetsCons[rowDispls[i]]);
-            }
-
-            // Manually add the last entry of the rowOffsets list, which is the
-            // number of non-zeros in the CSR matrix
-            CHECK(cudaMemcpy(&rowOffsetsCons[nConsRows], &nConsNz, sizeof(int), cudaMemcpyDefault));
-        }
-        else
-        {
-            // Close IPC handles and deallocate for consolidation
-            CHECK(cudaIpcCloseMemHandle(rowOffsetsCons));
-            CHECK(cudaIpcCloseMemHandle(colIndicesGlobalCons));
-        }
-
-        CHECK(cudaDeviceSynchronize());
-        break;
-    }
-    case ConsolidationStatus::Host:
-    {
-        // Gather the matrix data to the root rank for consolidation
-        MPI_Request req[3];
-        int ierr = MPI_Igatherv(
-            rowOffsets, 
-            nLocalRows, 
-            MPI_INT, 
-            rowOffsetsCons, 
-            nRowsInDevWorld.data(), 
-            rowDispls.data(), 
-            MPI_INT,
-            0, 
-            devWorld, 
-            &req[0]); CHK;
-        ierr = MPI_Igatherv(
-            colIndicesGlobal, 
-            nLocalNz, 
-            MPI_INT, 
-            colIndicesGlobalCons, 
-            nnzInDevWorld.data(), 
-            nzDispls.data(), 
-            MPI_INT, 
-            0, 
-            devWorld, 
-            &req[1]); CHK;
-        ierr = MPI_Igatherv(
-            values, 
-            nLocalNz, 
-            MPI_DOUBLE, 
-            valuesCons, 
-            nnzInDevWorld.data(), 
-            nzDispls.data(), 
-            MPI_DOUBLE, 
-            0, 
-            devWorld, 
-            &req[2]); CHK;
-        MPI_Waitall(3, req, MPI_STATUSES_IGNORE);
-
-        if (gpuProc == 0)
-        {
-            // Adjust merged row offsets so that they are correct for the consolidated matrix
-            for (int j = 1; j < devWorldSize; ++j)
-            {
-                for(int i = 0; i < nRowsInDevWorld[j]; ++i)
-                {
-                    rowOffsetsCons[rowDispls[j] + i] += nzDispls[j];
-                }
-            }
-
-            // Manually add the last entry of the rowOffsets list, which is the
-            // number of non-zeros in the CSR matrix
-            rowOffsetsCons[nConsRows] = nConsNz;
-        }
-
-        break;
-    }
-    default:
-    {
-        SETERRQ(MPI_COMM_WORLD, PETSC_ERR_SUP_SYS,
-                "Incorrect consolidation status set.\n");
-        break;
-    }
-
-    }
-
-    PetscFunctionReturn(0);
-}
-
-/* \implements AmgXSolver::reconsolidateValues */
-PetscErrorCode AmgXSolver::reconsolidateValues(
-    const PetscInt nLocalNz,
-    const PetscScalar* values)
-{
-    PetscFunctionBeginUser;
-
-    switch (consolidationStatus)
-    {
-
-    case ConsolidationStatus::None:
-    {
-        // Consolidation is not required
-        PetscFunctionReturn(0);
-    }
-    case ConsolidationStatus::Uninitialized:
-    {
-        SETERRQ(MPI_COMM_WORLD, PETSC_ERR_SUP_SYS,
-                "Attempting to re-consolidate before consolidation is initialized.\n");
-        break;
-    }
-    case ConsolidationStatus::Device:
-    {
-        CHECK(cudaDeviceSynchronize());
-        int ierr = MPI_Barrier(devWorld); CHK;
-
-        // The data is already on the GPU so consolidate there
-        CHECK(cudaMemcpy(&valuesCons[nzDispls[myDevWorldRank]], values, sizeof(PetscScalar) * nLocalNz, cudaMemcpyDefault));
-
-        CHECK(cudaDeviceSynchronize());
-        ierr = MPI_Barrier(devWorld); CHK;
-
-        break;
-    }
-    case ConsolidationStatus::Host:
-    {
-        // Gather the matrix values to the root rank for consolidation
-        int ierr = MPI_Gatherv(
-            values, 
-            nLocalNz, 
-            MPI_DOUBLE, 
-            valuesCons, 
-            nnzInDevWorld.data(), 
-            nzDispls.data(), 
-            MPI_DOUBLE, 
-            0, 
-            devWorld); CHK;
-        break;
-    }
-    default:
-    {
-        SETERRQ(MPI_COMM_WORLD, PETSC_ERR_SUP_SYS,
-                "Incorrect consolidation status set.\n");
-        break;
-    }
-
-    }
-
-    PetscFunctionReturn(0);
-}
-
-/* \implements AmgXSolver::freeConsStructure */
-PetscErrorCode AmgXSolver::freeConsStructure()
-{
-    PetscFunctionBeginUser;
-
-    // Only the root rank maintains a consolidated structure
-    if(gpuProc == MPI_UNDEFINED)
-    {
-        PetscFunctionReturn(0);
-    }
-
-    switch(consolidationStatus)
-    {
-
-    case ConsolidationStatus::None:
-    {
-        // Consolidation is not required
-        PetscFunctionReturn(0);
-    }
-    case ConsolidationStatus::Uninitialized:
-    {
-        SETERRQ(MPI_COMM_WORLD, PETSC_ERR_SUP_SYS,
-                 "Attempting to free consolidation structures before consolidation is initialized.\n");
-        break;
-    }
-    case ConsolidationStatus::Device:
-    {
-        // Free the device allocated consolidated CSR matrix structure
-        CHECK(cudaFree(rowOffsetsCons));
-        CHECK(cudaFree(colIndicesGlobalCons));
-        break;
-    }
-    case ConsolidationStatus::Host:
-    {
-        // Free the host allocated consolidated CSR matrix structure
-        delete[] rowOffsetsCons;
-        delete[] colIndicesGlobalCons;
-        break;
-    }
-    default:
-    {
-        SETERRQ(MPI_COMM_WORLD, PETSC_ERR_SUP_SYS,
-                "Incorrect consolidation status set.\n");
-        break;
-    }
-
-    }
-
-    PetscFunctionReturn(0);
-}
-
-/* \implements AmgXSolver::finalizeConsolidation */
-PetscErrorCode AmgXSolver::finalizeConsolidation()
-{
-    PetscFunctionBeginUser;
-
-    switch(consolidationStatus)
-    {
-
-    case ConsolidationStatus::None:
-    case ConsolidationStatus::Uninitialized:
-    {
-        // Consolidation is not required or uninitialized
-        PetscFunctionReturn(0);
-    }
-    case ConsolidationStatus::Device:
-    {
-        if (gpuProc == 0)
-        {
-            // Deallocate the CSR matrix values, solution and RHS
-            CHECK(cudaFree(valuesCons));
-            CHECK(cudaFree(pCons));
-            CHECK(cudaFree(rhsCons));
-        }
-        else
-        {
-            // Close the remaining IPC memory handles
-            CHECK(cudaIpcCloseMemHandle(valuesCons));
-            CHECK(cudaIpcCloseMemHandle(pCons));
-            CHECK(cudaIpcCloseMemHandle(rhsCons));
-        }
-        break;
-    }
-    case ConsolidationStatus::Host:
-    {
-        if(gpuProc == 0)
-        {
-            delete[] valuesCons;
-            delete[] pCons;
-            delete[] rhsCons;
-        }
-        break;
-    }
-    default:
-    {
-        SETERRQ(MPI_COMM_WORLD, PETSC_ERR_SUP_SYS,
-                "Incorrect consolidation status set.\n");
-        break;
-    }
-
-    }
-
-    // Free the local GPU partitioning structures
-    if(consolidationStatus == ConsolidationStatus::Device || consolidationStatus == ConsolidationStatus::Host)
-    {
-        nRowsInDevWorld.clear();
-        nnzInDevWorld.clear();
-        rowDispls.clear();
-        nzDispls.clear();
-    }
-
-    consolidationStatus = ConsolidationStatus::Uninitialized;
-
-    PetscFunctionReturn(0);
-}
-
diff --git a/src/AmgXMPIComms.cu b/src/AmgXMPIComms.cu
index 148b255..8d2afe5 100644
--- a/src/AmgXMPIComms.cu
+++ b/src/AmgXMPIComms.cu
@@ -2,8 +2,10 @@
  * \file AmgXMPIComms.cu
  * \brief ***.
  * \author Pi-Yueh Chuang (pychuang@gwu.edu)
+ * \author Matt Martineau (mmartineau@nvidia.com)
  * \date 2015-09-01
  * \copyright Copyright (c) 2015-2019 Pi-Yueh Chuang, Lorena A. Barba.
+ * \copyright Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.
  *            This project is released under MIT License.
  */
 
@@ -14,10 +16,10 @@
 /* \implements AmgXSolver::initMPIcomms */
 PetscErrorCode AmgXSolver::initMPIcomms(const MPI_Comm &comm)
 {
-    PetscErrorCode      ierr;
-
     PetscFunctionBeginUser;
 
+    PetscErrorCode      ierr;
+
     // duplicate the global communicator
     ierr = MPI_Comm_dup(comm, &globalCpuWorld); CHK;
     ierr = MPI_Comm_set_name(globalCpuWorld, "globalCpuWorld"); CHK;
@@ -72,17 +74,17 @@ PetscErrorCode AmgXSolver::initMPIcomms(const MPI_Comm &comm)
 
     ierr = MPI_Barrier(globalCpuWorld); CHK;
 
-    return 0;
+    PetscFunctionReturn(0);
 }
 
 
 /* \implements AmgXSolver::setDeviceCount */
 PetscErrorCode AmgXSolver::setDeviceCount()
 {
-    PetscErrorCode      ierr;
-
     PetscFunctionBeginUser;
 
+    PetscErrorCode      ierr;
+
     // get the number of devices that AmgX solvers can use
     switch (mode)
     {
diff --git a/src/AmgXSolver.H b/src/AmgXSolver.H
index 43ecb92..18a123d 100644
--- a/src/AmgXSolver.H
+++ b/src/AmgXSolver.H
@@ -2,9 +2,10 @@
  * \file AmgXSolver.hpp
  * \brief Definition of class AmgXSolver.
  * \author Pi-Yueh Chuang (pychuang@gwu.edu)
+ * \author Matt Martineau (mmartineau@nvidia.com)
  * \date 2015-09-01
  * \copyright Copyright (c) 2015-2019 Pi-Yueh Chuang, Lorena A. Barba.
- * \copyright Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
+ * \copyright Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.
  *            This project is released under MIT License.
  */
 
@@ -25,6 +26,8 @@
 // PETSc
 # include <petscvec.h>
 
+#include "AmgXCSRMatrix.H"
+
 
 /** \brief A macro to check the returned CUDA error code.
  *
@@ -100,7 +103,7 @@ class AmgXSolver
         AmgXSolver
         (
             const MPI_Comm &comm,
-            const std::string &modeStr, 
+            const std::string &modeStr,
             const std::string &cfgFile
         );
 
@@ -118,10 +121,16 @@ class AmgXSolver
         PetscErrorCode initialize
         (
             const MPI_Comm &comm,
-            const std::string &modeStr, 
+            const std::string &modeStr,
             const std::string &cfgFile
         );
 
+
+        PetscErrorCode initialiseMatrixComms
+        (
+            AmgXCSRMatrix& matrix
+        );
+
         /** \brief Finalize this instance.
          *
          * This function destroys AmgX data. When there are more than one
@@ -143,7 +152,6 @@ class AmgXSolver
          * \param rowOffsets [in] The local CSR matrix row offsets.
          * \param colIndicesGlobal [in] The global CSR matrix column indices.
          * \param values [in] The local CSR matrix values.
-         * \param partData [in] Array of length nGlobalRows containing the rank
          * id of the owning rank for each row.
          *
          * \return PetscErrorCode.
@@ -153,10 +161,7 @@ class AmgXSolver
             const PetscInt nLocalRows,
             const PetscInt nGlobalRows,
             const PetscInt nLocalNz,
-            const PetscInt* rowOffsets,
-            const PetscInt* colIndicesGlobal,
-            const PetscScalar* values,
-            const PetscInt* partData
+            AmgXCSRMatrix& matrix
         );
 
         /** \brief Re-sets up an existing AmgX matrix.
@@ -174,7 +179,7 @@ class AmgXSolver
         (
             const PetscInt nLocalRows,
             const PetscInt nLocalNz,
-            const PetscScalar* values
+            AmgXCSRMatrix& matrix
         );
 
         /** \brief Solve the linear system.
@@ -186,16 +191,19 @@ class AmgXSolver
          * function will do data gathering before solving and data scattering
          * after the solving.
          *
-         * \param p [in, out] The unknown array.
-         * \param b [in] The RHS array.
-         * \param nRows [in] The number of rows in this rank.
+         * \param nLocalRows [in] The number of rows owned by this rank.
+         * \param pscalar [in, out] The unknown array.
+         * \param bscalar [in] The RHS array.
+         * \param matrix [in,out] The AmgX CSR matrix, A.
          *
          * \return PetscErrorCode.
          */
         PetscErrorCode solve
         (
-            Vec &p, 
-            Vec &b
+            int nLocalRows,
+            PetscScalar* pscalar,
+            PetscScalar* bscalar,
+            AmgXCSRMatrix& matrix
         );
 
         /** \brief Get the number of iterations of the last solving.
@@ -218,7 +226,7 @@ class AmgXSolver
          */
         PetscErrorCode getResidual
         (
-            const int &iter, 
+            const int &iter,
             double &res
         );
 
@@ -234,12 +242,11 @@ class AmgXSolver
         static int              count;
 
         /** \brief A flag indicating if this instance has been initialized. */
-        bool                    isInitialized = false;
+        bool                    isInitialised = false;
 
         /** \brief The name of the node that this MPI process belongs to. */
         std::string             nodeName;
 
-
         /** \brief Number of local GPU devices used by AmgX.*/
         PetscMPIInt             nDevs;
 
@@ -285,7 +292,6 @@ class AmgXSolver
         /** \brief Rank in \ref AmgXSolver::devWorld "devWorld". */
         PetscMPIInt             myDevWorldRank;
 
-
         /** \brief A parameter used by AmgX. */
         int                     ring;
 
@@ -315,111 +321,6 @@ class AmgXSolver
          */
         static AMGX_resources_handle   rsrc;
 
-
-        /** \brief Enumeration for the status of matrix consolidation for the solver.*/
-        enum class ConsolidationStatus {
-            Uninitialized,
-            None,
-            Host,
-            Device
-        } consolidationStatus;
-
-        /** \brief The number of non-zeros consolidated from multiple ranks to a device.*/
-        int nConsNz = 0;
-
-        /** \brief The number of rows consolidated from multiple ranks to a device.*/
-        int nConsRows = 0;
-
-        /** \brief The row offsets consolidated onto a single device.*/
-        PetscInt *rowOffsetsCons = nullptr;
-
-        /** \brief The global column indices consolidated onto a single device.*/
-        PetscInt *colIndicesGlobalCons = nullptr;
-
-        /** \brief The values consolidated onto a single device.*/
-        PetscScalar* valuesCons = nullptr;
-
-        /** \brief The solution vector consolidated onto a single device.*/
-        PetscScalar* pCons = nullptr;
-
-        /** \brief The RHS vector consolidated onto a single device.*/
-        PetscScalar* rhsCons = nullptr;
-
-        /** \brief The number of rows per rank associated with a single device.*/
-        std::vector<int> nRowsInDevWorld;
-
-        /** \brief The number of non-zeros per rank associated with a single device.*/
-        std::vector<int> nnzInDevWorld;
-
-        /** \brief The row displacements per rank associated with a single device.*/
-        std::vector<int> rowDispls;
-
-        /** \brief The non-zero displacements per rank associated with a single device.*/
-        std::vector<int> nzDispls;
-
-        /** \brief Initialize consolidation, if required. Allocates space to hold
-         * consolidated CSR matrix and solution/RHS vectors on the root rank and,
-         * if device pointer consolidation, allocates and opens IPC memory handles.
-         * Calculates and stores consolidated sizes and displacements.
-         *
-         * \param nLocalRows [in] The number of rows owned by this rank.
-         * \param nLocalNz [in] The number of non-zeros owned by this rank.
-         * \param values [in] The values of the CSR matrix A.
-         *
-         * \return PetscErrorCode.
-         */
-        PetscErrorCode initializeConsolidation(
-            const PetscInt nLocalRows,
-            const PetscInt nLocalNz,
-            const PetscScalar *values);
-
-        /** \brief Realise consolidation, if required. This copies data from multiple ranks
-         * that share a single GPU, so a single consolidated CSR matrix can be passed to
-         * AmgX. As such, users can efficiently execute with more MPI ranks than GPUs,
-         * potentially improving performance where un-accelerated components of an application
-         * require CPU resources. CUDA IPC is leveraged to avoid buffering data onto the host,
-         * rather enabling fast intra-GPU copies if the local CSR matrices are already on the GPU.
-         *
-         * \param nLocalRows [in] The number of rows owned by this rank.
-         * \param nLocalNz [in] The number of non-zeros owned by this rank.
-         * \param rowOffsets [in] The row offsets of the CSR matrix A.
-         * \param colIndicesGlobal [in] The global column indices of the CSR matrix A.
-         * \param values [in] The values of the CSR matrix A.
-         *
-         * \return PetscErrorCode.
-         */
-        PetscErrorCode consolidateMatrix(
-            const PetscInt nLocalRows,
-            const PetscInt nLocalNz,
-            const PetscInt *rowOffsets,
-            const PetscInt *colIndicesGlobal,
-            const PetscScalar *values);
-
-        /** \brief Re-consolidates the values of the CSR matrix from multiple ranks.
-         *
-         * \param nLocalNz [in] The number of non-zeros owned by this rank.
-         * \param values [in] The values of the CSR matrix A.
-         *
-         * \return PetscErrorCode.
-         */
-        PetscErrorCode reconsolidateValues(
-            const PetscInt nLocalNz,
-            const PetscScalar *values);
-
-        /** \brief De-allocates consolidated data structures, if any, after the AmgX matrix
-         * has been constructed and the duplicate data becomes redundant.
-         *
-         * \return PetscErrorCode.
-         */
-        PetscErrorCode freeConsStructure();
-
-        /** \brief De-allocates all consolidated matrix structures.
-         *
-         * \return PetscErrorCode.
-         */
-        PetscErrorCode finalizeConsolidation();
-
-
         /** \brief Set AmgX solver mode based on the user-provided string.
          *
          * Available modes are: dDDI, dDFI, dFFI, hDDI, hDFI, hFFI.
diff --git a/src/AmgXSolver.cu b/src/AmgXSolver.cu
index 99717df..5c86a32 100644
--- a/src/AmgXSolver.cu
+++ b/src/AmgXSolver.cu
@@ -2,13 +2,17 @@
  * \file AmgXSolver.cpp
  * \brief Definition of member functions of the class AmgXSolver.
  * \author Pi-Yueh Chuang (pychuang@gwu.edu)
+ * \author Matt Martineau (mmartineau@nvidia.com)
  * \date 2015-09-01
  * \copyright Copyright (c) 2015-2019 Pi-Yueh Chuang, Lorena A. Barba.
+ * \copyright Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.
  *            This project is released under MIT License.
  */
 
 // AmgXWrapper
-# include "AmgXSolver.H"
+#include "AmgXSolver.H"
+#include <numeric>
+#include <limits>
 
 // initialize AmgXSolver::count to 0
 int AmgXSolver::count = 0;
@@ -28,7 +32,7 @@ AmgXSolver::AmgXSolver(const MPI_Comm &comm,
 /* \implements AmgXSolver::~AmgXSolver */
 AmgXSolver::~AmgXSolver()
 {
-    if (isInitialized) finalize();
+    if (isInitialised) finalize();
 }
 
 
@@ -36,12 +40,12 @@ AmgXSolver::~AmgXSolver()
 PetscErrorCode AmgXSolver::initialize(const MPI_Comm &comm,
         const std::string &modeStr, const std::string &cfgFile)
 {
-    PetscErrorCode      ierr;
-
     PetscFunctionBeginUser;
 
+    PetscErrorCode      ierr;
+
     // if this instance has already been initialized, skip
-    if (isInitialized) SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE,
+    if (isInitialised) SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE,
             "This AmgXSolver instance has been initialized on this process.");
 
     // increase the number of AmgXSolver instances
@@ -66,13 +70,20 @@ PetscErrorCode AmgXSolver::initialize(const MPI_Comm &comm,
     }
 
     // a bool indicating if this instance is initialized
-    isInitialized = true;
-
-    consolidationStatus = ConsolidationStatus::Uninitialized;
+    isInitialised = true;
 
     PetscFunctionReturn(0);
 }
 
+PetscErrorCode AmgXSolver::initialiseMatrixComms(
+    AmgXCSRMatrix& matrix)
+{
+    PetscFunctionBeginUser;
+
+    matrix.initialiseComms(devWorld, gpuProc);
+
+    PetscFunctionReturn(0);
+}
 
 /* \implements AmgXSolver::setMode */
 PetscErrorCode AmgXSolver::setMode(const std::string &modeStr)
@@ -146,22 +157,20 @@ PetscErrorCode AmgXSolver::initAmgX(const std::string &cfgFile)
     PetscFunctionReturn(0);
 }
 
-
 /* \implements AmgXSolver::finalize */
 PetscErrorCode AmgXSolver::finalize()
 {
-    PetscErrorCode      ierr;
-
     PetscFunctionBeginUser;
 
-    // skip if this instance has not been initialized
-    if (! isInitialized)
-    {
-        ierr = PetscPrintf(PETSC_COMM_WORLD,
-                "This AmgXWrapper has not been initialized. "
-                "Please initialize it before finalization.\n"); CHK;
+    PetscErrorCode      ierr;
 
-        PetscFunctionReturn(0);
+    // skip if this instance has not been initialised
+    if (!isInitialised)
+    {
+        fprintf(stderr,
+                "This AmgXWrapper has not been initialised. "
+                "Please initialise it before finalization.\n");
+        PetscFunctionReturn(1);
     }
 
     // only processes using GPU are required to destroy AmgX content
@@ -195,8 +204,6 @@ PetscErrorCode AmgXSolver::finalize()
         ierr = MPI_Comm_free(&gpuWorld); CHK;
     }
 
-    finalizeConsolidation();
-
     // re-set necessary variables in case users want to reuse
     // the variable of this instance for a new instance
     gpuProc = MPI_UNDEFINED;
@@ -208,28 +215,43 @@ PetscErrorCode AmgXSolver::finalize()
     count -= 1;
 
     // change status
-    isInitialized = false;
+    isInitialised = false;
 
     PetscFunctionReturn(0);
 }
 
-
 /* \implements AmgXSolver::setOperator */
 PetscErrorCode AmgXSolver::setOperator
 (
     const PetscInt nLocalRows,
     const PetscInt nGlobalRows,
     const PetscInt nLocalNz,
-    const PetscInt* rowOffsets,
-    const PetscInt* colIndicesGlobal,
-    const PetscScalar* values,
-    const PetscInt* partData
+    AmgXCSRMatrix& matrix
 )
 {
     PetscFunctionBeginUser;
 
-    // Merge the distributed matrix for MPI processes sharing a GPU
-    consolidateMatrix(nLocalRows, nLocalNz, rowOffsets, colIndicesGlobal, values);
+    // Check the matrix size is not larger than tolerated by AmgX
+    if(nGlobalRows > std::numeric_limits<int>::max())
+    {
+        fprintf(stderr,
+                "AmgX does not support a global number of rows greater than "
+                "what can be stored in 32 bits (nGlobalRows = %d).\n",
+                nGlobalRows);
+        PetscFunctionReturn(1);
+    }
+
+    const int nRows = (matrix.isConsolidated()) ? matrix.getNConsRows() : nLocalRows;
+    const int nNz = (matrix.isConsolidated()) ? matrix.getNConsNz() : nLocalNz;
+
+    if (nNz > std::numeric_limits<int>::max())
+    {
+        fprintf(stderr,
+                "AmgX does not support non-zeros per (consolidated) rank greater than"
+                "what can be stored in 32 bits (nLocalNz = %d).\n",
+                nNz);
+        PetscFunctionReturn(1);
+    }
 
     int ierr;
 
@@ -238,32 +260,39 @@ PetscErrorCode AmgXSolver::setOperator
     {
         ierr = MPI_Barrier(gpuWorld); CHK;
 
-        if (consolidationStatus == ConsolidationStatus::None)
-        {
-            AMGX_matrix_upload_all_global_32(
-                AmgXA, nGlobalRows, nLocalRows, nLocalNz,
-                1, 1, rowOffsets, colIndicesGlobal, values,
-                nullptr, ring, ring, partData);
-        }
-        else
-        {
-            AMGX_matrix_upload_all_global_32(
-                AmgXA, nGlobalRows, nConsRows, nConsNz,
-                1, 1, rowOffsetsCons, colIndicesGlobalCons, valuesCons,
-                nullptr, ring, ring, partData);
+        AMGX_distribution_handle dist;
+        AMGX_distribution_create(&dist, cfg);
 
-            // The rowOffsets and colIndices are no longer needed
-            freeConsStructure();
-        }
+        // Must persist until after we call upload
+        std::vector<int> offsets(gpuWorldSize + 1, 0);
+
+        // Determine the number of rows per GPU
+        std::vector<int> nRowsPerGPU(gpuWorldSize);
+        ierr = MPI_Allgather(&nRows, 1, MPI_INT, nRowsPerGPU.data(), 1, MPI_INT, gpuWorld); CHK;
+
+        // Calculate the global offsets
+        std::partial_sum(nRowsPerGPU.begin(), nRowsPerGPU.end(), offsets.begin() + 1);
+
+        AMGX_distribution_set_partition_data(
+            dist, AMGX_DIST_PARTITION_OFFSETS, offsets.data());
+
+        // Set the column indices size, 32- / 64-bit
+        AMGX_distribution_set_32bit_colindices(dist, true);
+
+        AMGX_matrix_upload_distributed(
+            AmgXA, nGlobalRows, nRows, nNz, 1, 1, matrix.getRowOffsets(),
+            matrix.getColIndices(), matrix.getValues(), nullptr, dist);
+
+        AMGX_distribution_destroy(dist);
 
         // bind the matrix A to the solver
-        ierr = MPI_Barrier(gpuWorld); CHK;
         AMGX_solver_setup(solver, AmgXA);
 
         // connect (bind) vectors to the matrix
         AMGX_vector_bind(AmgXP, AmgXA);
         AMGX_vector_bind(AmgXRHS, AmgXA);
     }
+
     ierr = MPI_Barrier(globalCpuWorld); CHK;
 
     PetscFunctionReturn(0);
@@ -275,30 +304,19 @@ PetscErrorCode AmgXSolver::updateOperator
 (
     const PetscInt nLocalRows,
     const PetscInt nLocalNz,
-    const PetscScalar* values
+    AmgXCSRMatrix& matrix
 )
 {
     PetscFunctionBeginUser;
 
-    // Merges the values from multiple MPI processes sharing a single GPU
-    reconsolidateValues(nLocalNz, values);
+    const int nRows = (matrix.isConsolidated()) ? matrix.getNConsRows() : nLocalRows;
+    const int nNz = (matrix.isConsolidated()) ? matrix.getNConsNz() : nLocalNz;
 
     int ierr;
     // Replace the coefficients for the CSR matrix A within AmgX
     if (gpuWorld != MPI_COMM_NULL)
     {
-        ierr = MPI_Barrier(gpuWorld); CHK;
-
-        if (consolidationStatus == ConsolidationStatus::None)
-        {
-            AMGX_matrix_replace_coefficients(AmgXA, nLocalRows, nLocalNz, values, nullptr);
-        }
-        else
-        {
-            AMGX_matrix_replace_coefficients(AmgXA, nConsRows, nConsNz, valuesCons, nullptr);
-        }
-
-        ierr = MPI_Barrier(gpuWorld); CHK;
+        AMGX_matrix_replace_coefficients(AmgXA, nRows, nNz, matrix.getValues(), nullptr);
 
         // Re-setup the solver (a reduced overhead setup that accounts for consistent matrix structure)
         AMGX_solver_resetup(solver, AmgXA);
@@ -311,88 +329,45 @@ PetscErrorCode AmgXSolver::updateOperator
 
 
 /* \implements AmgXSolver::solve */
-PetscErrorCode AmgXSolver::solve(Vec& p, Vec &b)
+PetscErrorCode AmgXSolver::solve(
+    int nLocalRows, PetscScalar* pscalar, PetscScalar* bscalar, AmgXCSRMatrix& matrix)
 {
     PetscFunctionBeginUser;
 
-    int ierr;
+    PetscScalar* p;
+    PetscScalar* b;
+    PetscInt nRows;
 
-    PetscScalar* pscalar;
-    PetscScalar* bscalar;
+    PetscInt ierr;
+    if (matrix.isConsolidated())
+    {
+        p = matrix.getPCons();
+        b = matrix.getRHSCons();
 
-    // get pointers to the raw data of local vectors
-    ierr = VecGetArray(p, &pscalar); CHK;
-    ierr = VecGetArray(b, &bscalar); CHK;
+        const int* rowDispls = matrix.getRowDispls();
+        CHECK(cudaMemcpy((void **)&p[rowDispls[myDevWorldRank]], pscalar, sizeof(PetscScalar) * nLocalRows, cudaMemcpyDefault));
+        CHECK(cudaMemcpy((void **)&b[rowDispls[myDevWorldRank]], bscalar, sizeof(PetscScalar) * nLocalRows, cudaMemcpyDefault));
 
-    PetscInt nRows;
-    ierr = VecGetLocalSize(p, &nRows); CHK;
+        // Override the number of rows as the consolidated number of rows
+        nRows = matrix.getNConsRows();
 
-    if (consolidationStatus == ConsolidationStatus::Device)
-    {
-        CHECK(
-            cudaMemcpy(
-                (void**)&pCons[rowDispls[myDevWorldRank]],
-                pscalar, 
-                sizeof(PetscScalar) * nRows, 
-                cudaMemcpyDefault
-            )
-        );
-        CHECK(
-            cudaMemcpy(
-                (void**)&rhsCons[rowDispls[myDevWorldRank]], 
-                bscalar, 
-                sizeof(PetscScalar) * nRows, 
-                cudaMemcpyDefault
-            )
-        );
-
-        // Must synchronize here as device to device copies are non-blocking w.r.t host
+        // Sync as cudaMemcpy to IPC buffers so device to device copies, which are non-blocking w.r.t host
+        // All ranks in devWorld have the same value for isConsolidated
         CHECK(cudaDeviceSynchronize());
         ierr = MPI_Barrier(devWorld); CHK;
     }
-    else if (consolidationStatus == ConsolidationStatus::Host)
+    else
     {
-        MPI_Request req[2];
-        ierr = MPI_Igatherv(
-            pscalar, 
-            nRows, 
-            MPI_DOUBLE, 
-            &pCons[rowDispls[myDevWorldRank]], 
-            nRowsInDevWorld.data(), 
-            rowDispls.data(), 
-            MPI_DOUBLE, 
-            0, 
-            devWorld, 
-            &req[0]
-        ); CHK;
-        ierr = MPI_Igatherv(
-            bscalar, 
-            nRows, 
-            MPI_DOUBLE, 
-            &rhsCons[rowDispls[myDevWorldRank]], 
-            nRowsInDevWorld.data(), 
-            rowDispls.data(), 
-            MPI_DOUBLE, 
-            0, 
-            devWorld, 
-            &req[1]
-        ); CHK;
-        MPI_Waitall(2, req, MPI_STATUSES_IGNORE);
+        p = pscalar;
+        b = bscalar;
+        nRows = nLocalRows;
     }
 
     if (gpuWorld != MPI_COMM_NULL)
     {
         // Upload potentially consolidated vectors to AmgX
-        if (consolidationStatus == ConsolidationStatus::None)
-        {
-            AMGX_vector_upload(AmgXP, nRows, 1, pscalar);
-            AMGX_vector_upload(AmgXRHS, nRows, 1, bscalar);
-        }
-        else
-        {
-            AMGX_vector_upload(AmgXP, nConsRows, 1, pCons);
-            AMGX_vector_upload(AmgXRHS, nConsRows, 1, rhsCons);
-        }
+        AMGX_vector_upload(AmgXP, nRows, 1, p);
+        AMGX_vector_upload(AmgXRHS, nRows, 1, b);
 
         ierr = MPI_Barrier(gpuWorld); CHK;
 
@@ -400,63 +375,42 @@ PetscErrorCode AmgXSolver::solve(Vec& p, Vec &b)
         AMGX_solver_solve(solver, AmgXRHS, AmgXP);
 
         // Get the status of the solver
-        AMGX_SOLVE_STATUS   status;
+        AMGX_SOLVE_STATUS status;
         AMGX_solver_get_status(solver, &status);
 
         // Check whether the solver successfully solved the problem
         if (status != AMGX_SOLVE_SUCCESS)
-                SETERRQ1(globalCpuWorld,
-                        PETSC_ERR_CONV_FAILED, "AmgX solver failed to solve the system! "
-                                                "The error code is %d.\n",
-                        status);
-
-        // Download data from device
-        if (consolidationStatus == ConsolidationStatus::None)
         {
-            AMGX_vector_download(AmgXP, pscalar);
+            fprintf(stderr, "AmgX solver failed to solve the system! "
+                            "The error code is %d.\n",
+                    status);
         }
-        else
-        {
-            AMGX_vector_download(AmgXP, pCons);
 
-            // AMGX_vector_download invokes a device to device copy here, so it is essential that
+        // Download data from device
+        AMGX_vector_download(AmgXP, p);
+
+        if(matrix.isConsolidated())
+        {
+            // AMGX_vector_download invokes a device to device copy, so it is essential that
             // the root rank blocks the host before other ranks copy from the consolidated solution
             CHECK(cudaDeviceSynchronize());
         }
     }
 
-    // If the matrix is consolidated, scatter the
-    if (consolidationStatus == ConsolidationStatus::Device)
+    // If the matrix is consolidated, scatter the solution
+    if (matrix.isConsolidated())
     {
         // Must synchronise before each rank attempts to read from the consolidated solution
         ierr = MPI_Barrier(devWorld); CHK;
 
-        CHECK(
-            cudaMemcpy(
-                (void **)pscalar, 
-                &pCons[rowDispls[myDevWorldRank]], 
-                sizeof(PetscScalar) * nRows, 
-                cudaMemcpyDefault
-            )
-        );
-        CHECK(cudaDeviceSynchronize());
-    }
-    else if (consolidationStatus == ConsolidationStatus::Host)
-    {
-        // Must synchronise before each rank attempts to read from the consolidated solution
-        ierr = MPI_Barrier(devWorld); CHK;
+        const int* rowDispls = matrix.getRowDispls();
 
-        ierr = MPI_Scatterv(
-            &pCons[rowDispls[myDevWorldRank]], 
-            nRowsInDevWorld.data(), 
-            rowDispls.data(), 
-            MPI_DOUBLE, 
-            pscalar, 
-            nRows, 
-            MPI_DOUBLE, 
-            0, 
-            devWorld
-        ); CHK;
+        // Ranks copy the portion of the solution they own into their rank-local buffers
+        CHECK(cudaMemcpy((void **)pscalar, &p[rowDispls[myDevWorldRank]], sizeof(PetscScalar) * nLocalRows, cudaMemcpyDefault));
+
+        // Sync as cudaMemcpy to IPC buffers so device to device copies, which are non-blocking w.r.t host
+        // All ranks in devWorld have the same value for isConsolidated
+        CHECK(cudaDeviceSynchronize());
     }
 
     ierr = MPI_Barrier(globalCpuWorld); CHK;
-- 
GitLab