diff --git a/include/bout/globalindexer.hxx b/include/bout/globalindexer.hxx index e756ead3b2..bd4203a092 100644 --- a/include/bout/globalindexer.hxx +++ b/include/bout/globalindexer.hxx @@ -86,7 +86,7 @@ public: int localSize = size(); MPI_Comm comm = - std::is_same_v ? fieldmesh->getXcomm() : BoutComm::get(); + std::is_same_v ? fieldmesh->getXZcomm() : BoutComm::get(); fieldmesh->getMpi().MPI_Scan(&localSize, &globalEnd, 1, MPI_INT, MPI_SUM, comm); globalEnd--; int counter = globalStart = globalEnd - size() + 1; diff --git a/include/bout/hypre_interface.hxx b/include/bout/hypre_interface.hxx index 65d8875a9a..189e91d159 100644 --- a/include/bout/hypre_interface.hxx +++ b/include/bout/hypre_interface.hxx @@ -159,7 +159,7 @@ public: explicit HypreVector(IndexerPtr indConverter) : indexConverter(indConverter) { Mesh& mesh = *indConverter->getMesh(); const MPI_Comm comm = - std::is_same_v ? mesh.getXcomm() : BoutComm::get(); + std::is_same_v ? mesh.getXZcomm() : BoutComm::get(); HYPRE_BigInt jlower = indConverter->getGlobalStart(); HYPRE_BigInt jupper = jlower + indConverter->size() - 1; // inclusive end @@ -380,7 +380,7 @@ public: : hypre_matrix(new HYPRE_IJMatrix, MatrixDeleter{}), index_converter(indConverter) { Mesh* mesh = indConverter->getMesh(); const MPI_Comm comm = - std::is_same_v ? mesh->getXcomm() : BoutComm::get(); + std::is_same_v ? mesh->getXZcomm() : BoutComm::get(); parallel_transform = &mesh->getCoordinates()->getParallelTransform(); ilower = indConverter->getGlobalStart(); @@ -812,7 +812,7 @@ public: "values are: gmres, bicgstab, pcg") .withDefault(HYPRE_SOLVER_TYPE::bicgstab); - comm = std::is_same_v ? mesh.getXcomm() : BoutComm::get(); + comm = std::is_same_v ? mesh.getXZcomm() : BoutComm::get(); auto print_level = options["hypre_print_level"] diff --git a/include/bout/mesh.hxx b/include/bout/mesh.hxx index a1ed6a9011..d4f7f7ac7f 100644 --- a/include/bout/mesh.hxx +++ b/include/bout/mesh.hxx @@ -3,7 +3,7 @@ * * Interface for mesh classes. Contains standard variables and useful * routines. - * + * * Changelog * ========= * @@ -11,7 +11,7 @@ * * Removing coordinate system into separate * Coordinates class * * Adding index derivative functions from derivs.cxx - * + * * 2010-06 Ben Dudson, Sean Farley * * Initial version, adapted from GridData class * * Incorporates code from topology.cpp and Communicator @@ -20,7 +20,7 @@ * Copyright 2010-2025 BOUT++ contributors * * Contact: Ben Dudson, dudson2@llnl.gov - * + * * This file is part of BOUT++. * * BOUT++ is free software: you can redistribute it and/or modify @@ -58,8 +58,6 @@ class Mesh; #include "bout/sys/range.hxx" // RangeIterator #include "bout/unused.hxx" -#include "mpi.h" - #include #include #include @@ -405,6 +403,7 @@ public: } ///< Return communicator containing all processors in X virtual MPI_Comm getXcomm(int jy) const = 0; ///< Return X communicator virtual MPI_Comm getYcomm(int jx) const = 0; ///< Return Y communicator + virtual MPI_Comm getXZcomm() const = 0; ///< Communicator in X-Z /// Return pointer to the mesh's MPI Wrapper object MpiWrapper& getMpi() { return *mpi; } diff --git a/include/bout/petsc_interface.hxx b/include/bout/petsc_interface.hxx index 5239039f0c..2ce71d0549 100644 --- a/include/bout/petsc_interface.hxx +++ b/include/bout/petsc_interface.hxx @@ -79,7 +79,7 @@ inline MPI_Comm getComm([[maybe_unused]] const T& field) { template <> inline MPI_Comm getComm([[maybe_unused]] const FieldPerp& field) { - return field.getMesh()->getXcomm(); + return field.getMesh()->getXZcomm(); } template @@ -293,7 +293,7 @@ public: PetscMatrix(IndexerPtr indConverter, bool preallocate = true) : matrix(new Mat()), indexConverter(indConverter), pt(&indConverter->getMesh()->getCoordinates()->getParallelTransform()) { - MPI_Comm comm = std::is_same_v ? indConverter->getMesh()->getXcomm() + MPI_Comm comm = std::is_same_v ? indConverter->getMesh()->getXZcomm() : BoutComm::get(); const int size = indexConverter->size(); diff --git a/src/invert/laplace/impls/petsc/petsc_laplace.cxx b/src/invert/laplace/impls/petsc/petsc_laplace.cxx index af65a1cd95..89ba25405b 100644 --- a/src/invert/laplace/impls/petsc/petsc_laplace.cxx +++ b/src/invert/laplace/impls/petsc/petsc_laplace.cxx @@ -140,7 +140,7 @@ LaplacePetsc::LaplacePetsc(Options* opt, const CELL_LOC loc, Mesh* mesh_in, [[maybe_unused]] Solver* solver) : Laplacian(opt, loc, mesh_in), A(0.0, mesh_in), C1(1.0, mesh_in), C2(1.0, mesh_in), D(1.0, mesh_in), Ex(0.0, mesh_in), Ez(0.0, mesh_in), issetD(false), issetC(false), - issetE(false), comm(localmesh->getXcomm()), + issetE(false), comm(localmesh->getXZcomm()), opts(opt == nullptr ? &(Options::root()["laplace"]) : opt), // WARNING: only a few of these options actually make sense: see the // PETSc documentation to work out which they are (possibly diff --git a/src/invert/laplacexz/impls/petsc/laplacexz-petsc.cxx b/src/invert/laplacexz/impls/petsc/laplacexz-petsc.cxx index f8cb52b3da..3e55262b6d 100644 --- a/src/invert/laplacexz/impls/petsc/laplacexz-petsc.cxx +++ b/src/invert/laplacexz/impls/petsc/laplacexz-petsc.cxx @@ -155,7 +155,7 @@ LaplaceXZpetsc::LaplaceXZpetsc(Mesh* m, Options* opt, const CELL_LOC loc) .withDefault("petsc"); // Get MPI communicator - MPI_Comm comm = localmesh->getXcomm(); + MPI_Comm comm = localmesh->getXZcomm(); // Local size int localN = (localmesh->xend - localmesh->xstart + 1) * (localmesh->LocalNz); diff --git a/src/mesh/impls/bout/boutmesh.cxx b/src/mesh/impls/bout/boutmesh.cxx index ea1b6a41b2..fea88c536e 100644 --- a/src/mesh/impls/bout/boutmesh.cxx +++ b/src/mesh/impls/bout/boutmesh.cxx @@ -52,6 +52,7 @@ #include #include +#include #include #include @@ -106,6 +107,9 @@ BoutMesh::~BoutMesh() { if (comm_outer != MPI_COMM_NULL) { MPI_Comm_free(&comm_outer); } + if (comm_xz != MPI_COMM_NULL) { + MPI_Comm_free(&comm_xz); + } } BoutMesh::YDecompositionIndices @@ -665,10 +669,43 @@ int BoutMesh::load() { return 0; } +namespace { +auto make_XZ_communicator(const BoutMesh& mesh, MPI_Group group_world) -> MPI_Comm { + std::vector ranks; + + const int yp = mesh.getYProcIndex(); + + // All processors with the same Y index + for (int xp = 0; xp < mesh.getNXPE(); ++xp) { + for (int zp = 0; zp < mesh.getNZPE(); ++zp) { + ranks.push_back(mesh.getProcIndex(xp, yp, zp)); + } + } + MPI_Group group{}; + if (MPI_Group_incl(group_world, static_cast(ranks.size()), ranks.data(), &group) + != MPI_SUCCESS) { + throw BoutException("Could not create X-Z communication group for ranks {}", + fmt::join(ranks, ", ")); + } + + MPI_Comm comm_xz{}; + if (MPI_Comm_create(BoutComm::get(), group, &comm_xz) != MPI_SUCCESS) { + throw BoutException("Could not create X-Z communicator for yp={} (xind={}, yind={}, " + "zind={}) ranks={}", + yp, mesh.getXProcIndex(), mesh.getYProcIndex(), + mesh.getZProcIndex(), fmt::join(ranks, ", ")); + } + + return comm_xz; +} +} // namespace + void BoutMesh::createCommunicators() { MPI_Group group_world{}; MPI_Comm_group(BoutComm::get(), &group_world); // Get the entire group + comm_xz = make_XZ_communicator(*this, group_world); + ////////////////////////////////////////////////////// /// Communicator in X @@ -1038,7 +1075,9 @@ void BoutMesh::createXBoundaries() { } } -int BoutMesh::getProcIndex(int X, int Y, int Z) const { return Y * NXPE + X; } +int BoutMesh::getProcIndex(int X, int Y, int Z) const { + return (((Z * NYPE) + Y) * NXPE) + X; +} void BoutMesh::createYBoundaries() { if (MYG <= 0) { @@ -2218,9 +2257,9 @@ void BoutMesh::topology() { } for (int i = 0; i < limiter_count; ++i) { - int const yind = limiter_yinds[i]; - int const xstart = limiter_xstarts[i]; - int const xend = limiter_xends[i]; + const int yind = limiter_yinds[i]; + const int xstart = limiter_xstarts[i]; + const int xend = limiter_xends[i]; output_info.write("Adding a limiter between y={} and {}. X indices {} to {}\n", yind, yind + 1, xstart, xend); add_target(yind, xstart, xend); diff --git a/src/mesh/impls/bout/boutmesh.hxx b/src/mesh/impls/bout/boutmesh.hxx index b42bf325b5..49eded2ae8 100644 --- a/src/mesh/impls/bout/boutmesh.hxx +++ b/src/mesh/impls/bout/boutmesh.hxx @@ -107,6 +107,7 @@ public: MPI_Comm getXcomm(int UNUSED(jy)) const override { return comm_x; } /// Return communicator containing all processors in Y MPI_Comm getYcomm(int xpos) const override; + MPI_Comm getXZcomm() const override { return comm_xz; } /// Is local X index \p jx periodic in Y? /// @@ -455,6 +456,8 @@ private: /// Communicator containing all processors in X MPI_Comm comm_x{MPI_COMM_NULL}; + /// Communicator for all processors in an XZ plane + MPI_Comm comm_xz{MPI_COMM_NULL}; ////////////////////////////////////////////////// // Surface communications diff --git a/src/mesh/mesh.cxx b/src/mesh/mesh.cxx index 829c150a15..31b3c08889 100644 --- a/src/mesh/mesh.cxx +++ b/src/mesh/mesh.cxx @@ -516,7 +516,7 @@ int Mesh::globalStartIndex2D() { int Mesh::globalStartIndexPerp() { int localSize = localSizePerp(); int cumulativeSize = 0; - mpi->MPI_Scan(&localSize, &cumulativeSize, 1, MPI_INT, MPI_SUM, getXcomm()); + mpi->MPI_Scan(&localSize, &cumulativeSize, 1, MPI_INT, MPI_SUM, getXZcomm()); return cumulativeSize - localSize; } diff --git a/tests/unit/fake_mesh.hxx b/tests/unit/fake_mesh.hxx index aa0220609b..6dbbd6200b 100644 --- a/tests/unit/fake_mesh.hxx +++ b/tests/unit/fake_mesh.hxx @@ -139,6 +139,7 @@ public: } MPI_Comm getXcomm(int UNUSED(jy)) const override { return BoutComm::get(); } MPI_Comm getYcomm(int UNUSED(jx)) const override { return BoutComm::get(); } + MPI_Comm getXZcomm() const override { return BoutComm::get(); } // Periodic Y int ix_separatrix{1000000}; // separatrix index