Skip to content
Snippets Groups Projects
Commit e5d5379f authored by Simone Bna's avatar Simone Bna
Browse files

ENH: used Vec type instead of PetscScalar in the arguments of the solve function

parent 2562d871
No related branches found
No related tags found
No related merge requests found
......@@ -221,7 +221,7 @@ class AmgXSolver
*
* \return PetscErrorCode.
*/
PetscErrorCode solve(PetscScalar *p, const PetscScalar *b, const int nRows);
PetscErrorCode solve(Vec &p, Vec &b, const int nRows);
/** \brief Get the number of iterations of the last solving.
......
......@@ -104,16 +104,23 @@ PetscErrorCode AmgXSolver::solve_real(Vec &p, Vec &b)
/* \implements AmgXSolver::solve */
PetscErrorCode AmgXSolver::solve(PetscScalar* p, const PetscScalar* b, const int nRows)
PetscErrorCode AmgXSolver::solve(Vec& p, Vec &b, const int nRows)
{
PetscFunctionBeginUser;
int ierr;
PetscScalar* pscalar;
PetscScalar* bscalar;
// get pointers to the raw data of local vectors
ierr = VecGetArray(p, &pscalar); CHK;
ierr = VecGetArray(b, &bscalar); CHK;
if (consolidationStatus == ConsolidationStatus::Device)
{
CHECK(cudaMemcpy((void**)&pCons[rowDispls[myDevWorldRank]], p, sizeof(PetscScalar) * nRows, cudaMemcpyDefault));
CHECK(cudaMemcpy((void**)&rhsCons[rowDispls[myDevWorldRank]], b, sizeof(PetscScalar) * nRows, cudaMemcpyDefault));
CHECK(cudaMemcpy((void**)&pCons[rowDispls[myDevWorldRank]], pscalar, sizeof(PetscScalar) * nRows, cudaMemcpyDefault));
CHECK(cudaMemcpy((void**)&rhsCons[rowDispls[myDevWorldRank]], bscalar, sizeof(PetscScalar) * nRows, cudaMemcpyDefault));
// Must synchronize here as device to device copies are non-blocking w.r.t host
CHECK(cudaDeviceSynchronize());
......@@ -122,8 +129,8 @@ PetscErrorCode AmgXSolver::solve(PetscScalar* p, const PetscScalar* b, const int
else if (consolidationStatus == ConsolidationStatus::Host)
{
MPI_Request req[2];
ierr = MPI_Igatherv(p, nRows, MPI_DOUBLE, &pCons[rowDispls[myDevWorldRank]], nRowsInDevWorld.data(), rowDispls.data(), MPI_DOUBLE, 0, devWorld, &req[0]); CHK;
ierr = MPI_Igatherv(b, nRows, MPI_DOUBLE, &rhsCons[rowDispls[myDevWorldRank]], nRowsInDevWorld.data(), rowDispls.data(), MPI_DOUBLE, 0, devWorld, &req[1]); CHK;
ierr = MPI_Igatherv(pscalar, nRows, MPI_DOUBLE, &pCons[rowDispls[myDevWorldRank]], nRowsInDevWorld.data(), rowDispls.data(), MPI_DOUBLE, 0, devWorld, &req[0]); CHK;
ierr = MPI_Igatherv(bscalar, nRows, MPI_DOUBLE, &rhsCons[rowDispls[myDevWorldRank]], nRowsInDevWorld.data(), rowDispls.data(), MPI_DOUBLE, 0, devWorld, &req[1]); CHK;
MPI_Waitall(2, req, MPI_STATUSES_IGNORE);
}
......@@ -132,8 +139,8 @@ PetscErrorCode AmgXSolver::solve(PetscScalar* p, const PetscScalar* b, const int
// Upload potentially consolidated vectors to AmgX
if (consolidationStatus == ConsolidationStatus::None)
{
AMGX_vector_upload(AmgXP, nRows, 1, p);
AMGX_vector_upload(AmgXRHS, nRows, 1, b);
AMGX_vector_upload(AmgXP, nRows, 1, pscalar);
AMGX_vector_upload(AmgXRHS, nRows, 1, bscalar);
}
else
{
......@@ -160,7 +167,7 @@ PetscErrorCode AmgXSolver::solve(PetscScalar* p, const PetscScalar* b, const int
// Download data from device
if (consolidationStatus == ConsolidationStatus::None)
{
AMGX_vector_download(AmgXP, p);
AMGX_vector_download(AmgXP, pscalar);
}
else
{
......@@ -178,7 +185,7 @@ PetscErrorCode AmgXSolver::solve(PetscScalar* p, const PetscScalar* b, const int
// Must synchronise before each rank attempts to read from the consolidated solution
ierr = MPI_Barrier(devWorld); CHK;
CHECK(cudaMemcpy((void **)p, &pCons[rowDispls[myDevWorldRank]], sizeof(PetscScalar) * nRows, cudaMemcpyDefault));
CHECK(cudaMemcpy((void **)pscalar, &pCons[rowDispls[myDevWorldRank]], sizeof(PetscScalar) * nRows, cudaMemcpyDefault));
CHECK(cudaDeviceSynchronize());
}
else if (consolidationStatus == ConsolidationStatus::Host)
......@@ -186,7 +193,7 @@ PetscErrorCode AmgXSolver::solve(PetscScalar* p, const PetscScalar* b, const int
// Must synchronise before each rank attempts to read from the consolidated solution
ierr = MPI_Barrier(devWorld); CHK;
ierr = MPI_Scatterv(&pCons[rowDispls[myDevWorldRank]], nRowsInDevWorld.data(), rowDispls.data(), MPI_DOUBLE, p, nRows, MPI_DOUBLE, 0, devWorld); CHK;
ierr = MPI_Scatterv(&pCons[rowDispls[myDevWorldRank]], nRowsInDevWorld.data(), rowDispls.data(), MPI_DOUBLE, pscalar, nRows, MPI_DOUBLE, 0, devWorld); CHK;
}
ierr = MPI_Barrier(globalCpuWorld); CHK;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment