From e5d5379f7ee199f51fb2893749ff23a86bb22a2a Mon Sep 17 00:00:00 2001 From: Simone Bna <s.bn@cineca.it> Date: Wed, 18 Nov 2020 18:26:53 +0100 Subject: [PATCH] ENH: used Vec type instead of PetscScalar in the arguments of the solve function --- src/AmgXSolver.H | 2 +- src/solve.cu | 27 +++++++++++++++++---------- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/src/AmgXSolver.H b/src/AmgXSolver.H index 7ac8a56..29d3660 100644 --- a/src/AmgXSolver.H +++ b/src/AmgXSolver.H @@ -221,7 +221,7 @@ class AmgXSolver * * \return PetscErrorCode. */ - PetscErrorCode solve(PetscScalar *p, const PetscScalar *b, const int nRows); + PetscErrorCode solve(Vec &p, Vec &b, const int nRows); /** \brief Get the number of iterations of the last solving. diff --git a/src/solve.cu b/src/solve.cu index 55a33f8..f74e1a4 100644 --- a/src/solve.cu +++ b/src/solve.cu @@ -104,16 +104,23 @@ PetscErrorCode AmgXSolver::solve_real(Vec &p, Vec &b) /* \implements AmgXSolver::solve */ -PetscErrorCode AmgXSolver::solve(PetscScalar* p, const PetscScalar* b, const int nRows) +PetscErrorCode AmgXSolver::solve(Vec& p, Vec &b, const int nRows) { PetscFunctionBeginUser; int ierr; + PetscScalar* pscalar; + PetscScalar* bscalar; + + // get pointers to the raw data of local vectors + ierr = VecGetArray(p, &pscalar); CHK; + ierr = VecGetArray(b, &bscalar); CHK; + if (consolidationStatus == ConsolidationStatus::Device) { - CHECK(cudaMemcpy((void**)&pCons[rowDispls[myDevWorldRank]], p, sizeof(PetscScalar) * nRows, cudaMemcpyDefault)); - CHECK(cudaMemcpy((void**)&rhsCons[rowDispls[myDevWorldRank]], b, sizeof(PetscScalar) * nRows, cudaMemcpyDefault)); + CHECK(cudaMemcpy((void**)&pCons[rowDispls[myDevWorldRank]], pscalar, sizeof(PetscScalar) * nRows, cudaMemcpyDefault)); + CHECK(cudaMemcpy((void**)&rhsCons[rowDispls[myDevWorldRank]], bscalar, sizeof(PetscScalar) * nRows, cudaMemcpyDefault)); // Must synchronize here as device to device copies are non-blocking w.r.t host CHECK(cudaDeviceSynchronize()); @@ -122,8 +129,8 @@ PetscErrorCode AmgXSolver::solve(PetscScalar* p, const PetscScalar* b, const int else if (consolidationStatus == ConsolidationStatus::Host) { MPI_Request req[2]; - ierr = MPI_Igatherv(p, nRows, MPI_DOUBLE, &pCons[rowDispls[myDevWorldRank]], nRowsInDevWorld.data(), rowDispls.data(), MPI_DOUBLE, 0, devWorld, &req[0]); CHK; - ierr = MPI_Igatherv(b, nRows, MPI_DOUBLE, &rhsCons[rowDispls[myDevWorldRank]], nRowsInDevWorld.data(), rowDispls.data(), MPI_DOUBLE, 0, devWorld, &req[1]); CHK; + ierr = MPI_Igatherv(pscalar, nRows, MPI_DOUBLE, &pCons[rowDispls[myDevWorldRank]], nRowsInDevWorld.data(), rowDispls.data(), MPI_DOUBLE, 0, devWorld, &req[0]); CHK; + ierr = MPI_Igatherv(bscalar, nRows, MPI_DOUBLE, &rhsCons[rowDispls[myDevWorldRank]], nRowsInDevWorld.data(), rowDispls.data(), MPI_DOUBLE, 0, devWorld, &req[1]); CHK; MPI_Waitall(2, req, MPI_STATUSES_IGNORE); } @@ -132,8 +139,8 @@ PetscErrorCode AmgXSolver::solve(PetscScalar* p, const PetscScalar* b, const int // Upload potentially consolidated vectors to AmgX if (consolidationStatus == ConsolidationStatus::None) { - AMGX_vector_upload(AmgXP, nRows, 1, p); - AMGX_vector_upload(AmgXRHS, nRows, 1, b); + AMGX_vector_upload(AmgXP, nRows, 1, pscalar); + AMGX_vector_upload(AmgXRHS, nRows, 1, bscalar); } else { @@ -160,7 +167,7 @@ PetscErrorCode AmgXSolver::solve(PetscScalar* p, const PetscScalar* b, const int // Download data from device if (consolidationStatus == ConsolidationStatus::None) { - AMGX_vector_download(AmgXP, p); + AMGX_vector_download(AmgXP, pscalar); } else { @@ -178,7 +185,7 @@ PetscErrorCode AmgXSolver::solve(PetscScalar* p, const PetscScalar* b, const int // Must synchronise before each rank attempts to read from the consolidated solution ierr = MPI_Barrier(devWorld); CHK; - CHECK(cudaMemcpy((void **)p, &pCons[rowDispls[myDevWorldRank]], sizeof(PetscScalar) * nRows, cudaMemcpyDefault)); + CHECK(cudaMemcpy((void **)pscalar, &pCons[rowDispls[myDevWorldRank]], sizeof(PetscScalar) * nRows, cudaMemcpyDefault)); CHECK(cudaDeviceSynchronize()); } else if (consolidationStatus == ConsolidationStatus::Host) @@ -186,7 +193,7 @@ PetscErrorCode AmgXSolver::solve(PetscScalar* p, const PetscScalar* b, const int // Must synchronise before each rank attempts to read from the consolidated solution ierr = MPI_Barrier(devWorld); CHK; - ierr = MPI_Scatterv(&pCons[rowDispls[myDevWorldRank]], nRowsInDevWorld.data(), rowDispls.data(), MPI_DOUBLE, p, nRows, MPI_DOUBLE, 0, devWorld); CHK; + ierr = MPI_Scatterv(&pCons[rowDispls[myDevWorldRank]], nRowsInDevWorld.data(), rowDispls.data(), MPI_DOUBLE, pscalar, nRows, MPI_DOUBLE, 0, devWorld); CHK; } ierr = MPI_Barrier(globalCpuWorld); CHK; -- GitLab