diff --git a/src/AmgXSolver.H b/src/AmgXSolver.H
index 7ac8a56d0a91ac1ebccc75707e18fbffc182536d..29d3660c820d77f69dfc9eed0135f3fe1577cb57 100644
--- a/src/AmgXSolver.H
+++ b/src/AmgXSolver.H
@@ -221,7 +221,7 @@ class AmgXSolver
          *
          * \return PetscErrorCode.
          */
-        PetscErrorCode solve(PetscScalar *p, const PetscScalar *b, const int nRows);
+        PetscErrorCode solve(Vec &p, Vec &b, const int nRows);
 
 
         /** \brief Get the number of iterations of the last solving.
diff --git a/src/solve.cu b/src/solve.cu
index 55a33f885f673983de3f480cb9a1aabdf33c27f3..f74e1a4ac0a13859a60f353d5f56f6149c1080a0 100644
--- a/src/solve.cu
+++ b/src/solve.cu
@@ -104,16 +104,23 @@ PetscErrorCode AmgXSolver::solve_real(Vec &p, Vec &b)
 
 
 /* \implements AmgXSolver::solve */
-PetscErrorCode AmgXSolver::solve(PetscScalar* p, const PetscScalar* b, const int nRows)
+PetscErrorCode AmgXSolver::solve(Vec& p, Vec &b, const int nRows)
 {
     PetscFunctionBeginUser;
 
     int ierr;
 
+    PetscScalar* pscalar;
+    PetscScalar* bscalar;
+
+    // get pointers to the raw data of local vectors
+    ierr = VecGetArray(p, &pscalar); CHK;
+    ierr = VecGetArray(b, &bscalar); CHK;
+
     if (consolidationStatus == ConsolidationStatus::Device)
     {
-        CHECK(cudaMemcpy((void**)&pCons[rowDispls[myDevWorldRank]], p, sizeof(PetscScalar) * nRows, cudaMemcpyDefault));
-        CHECK(cudaMemcpy((void**)&rhsCons[rowDispls[myDevWorldRank]], b, sizeof(PetscScalar) * nRows, cudaMemcpyDefault));
+        CHECK(cudaMemcpy((void**)&pCons[rowDispls[myDevWorldRank]], pscalar, sizeof(PetscScalar) * nRows, cudaMemcpyDefault));
+        CHECK(cudaMemcpy((void**)&rhsCons[rowDispls[myDevWorldRank]], bscalar, sizeof(PetscScalar) * nRows, cudaMemcpyDefault));
 
         // Must synchronize here as device to device copies are non-blocking w.r.t host
         CHECK(cudaDeviceSynchronize());
@@ -122,8 +129,8 @@ PetscErrorCode AmgXSolver::solve(PetscScalar* p, const PetscScalar* b, const int
     else if (consolidationStatus == ConsolidationStatus::Host)
     {
         MPI_Request req[2];
-        ierr = MPI_Igatherv(p, nRows, MPI_DOUBLE, &pCons[rowDispls[myDevWorldRank]], nRowsInDevWorld.data(), rowDispls.data(), MPI_DOUBLE, 0, devWorld, &req[0]); CHK;
-        ierr = MPI_Igatherv(b, nRows, MPI_DOUBLE, &rhsCons[rowDispls[myDevWorldRank]], nRowsInDevWorld.data(), rowDispls.data(), MPI_DOUBLE, 0, devWorld, &req[1]); CHK;
+        ierr = MPI_Igatherv(pscalar, nRows, MPI_DOUBLE, &pCons[rowDispls[myDevWorldRank]], nRowsInDevWorld.data(), rowDispls.data(), MPI_DOUBLE, 0, devWorld, &req[0]); CHK;
+        ierr = MPI_Igatherv(bscalar, nRows, MPI_DOUBLE, &rhsCons[rowDispls[myDevWorldRank]], nRowsInDevWorld.data(), rowDispls.data(), MPI_DOUBLE, 0, devWorld, &req[1]); CHK;
         MPI_Waitall(2, req, MPI_STATUSES_IGNORE);
     }
 
@@ -132,8 +139,8 @@ PetscErrorCode AmgXSolver::solve(PetscScalar* p, const PetscScalar* b, const int
         // Upload potentially consolidated vectors to AmgX
         if (consolidationStatus == ConsolidationStatus::None)
         {
-            AMGX_vector_upload(AmgXP, nRows, 1, p);
-            AMGX_vector_upload(AmgXRHS, nRows, 1, b);
+            AMGX_vector_upload(AmgXP, nRows, 1, pscalar);
+            AMGX_vector_upload(AmgXRHS, nRows, 1, bscalar);
         }
         else
         {
@@ -160,7 +167,7 @@ PetscErrorCode AmgXSolver::solve(PetscScalar* p, const PetscScalar* b, const int
         // Download data from device
         if (consolidationStatus == ConsolidationStatus::None)
         {
-            AMGX_vector_download(AmgXP, p);
+            AMGX_vector_download(AmgXP, pscalar);
         }
         else
         {
@@ -178,7 +185,7 @@ PetscErrorCode AmgXSolver::solve(PetscScalar* p, const PetscScalar* b, const int
         // Must synchronise before each rank attempts to read from the consolidated solution
         ierr = MPI_Barrier(devWorld); CHK;
 
-        CHECK(cudaMemcpy((void **)p, &pCons[rowDispls[myDevWorldRank]], sizeof(PetscScalar) * nRows, cudaMemcpyDefault));
+        CHECK(cudaMemcpy((void **)pscalar, &pCons[rowDispls[myDevWorldRank]], sizeof(PetscScalar) * nRows, cudaMemcpyDefault));
         CHECK(cudaDeviceSynchronize());
     }
     else if (consolidationStatus == ConsolidationStatus::Host)
@@ -186,7 +193,7 @@ PetscErrorCode AmgXSolver::solve(PetscScalar* p, const PetscScalar* b, const int
         // Must synchronise before each rank attempts to read from the consolidated solution
         ierr = MPI_Barrier(devWorld); CHK;
 
-        ierr = MPI_Scatterv(&pCons[rowDispls[myDevWorldRank]], nRowsInDevWorld.data(), rowDispls.data(), MPI_DOUBLE, p, nRows, MPI_DOUBLE, 0, devWorld); CHK;
+        ierr = MPI_Scatterv(&pCons[rowDispls[myDevWorldRank]], nRowsInDevWorld.data(), rowDispls.data(), MPI_DOUBLE, pscalar, nRows, MPI_DOUBLE, 0, devWorld); CHK;
     }
 
     ierr = MPI_Barrier(globalCpuWorld); CHK;