Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/bout/globalindexer.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public:

int localSize = size();
MPI_Comm comm =
std::is_same_v<T, FieldPerp> ? fieldmesh->getXcomm() : BoutComm::get();
std::is_same_v<T, FieldPerp> ? fieldmesh->getXZcomm() : BoutComm::get();
fieldmesh->getMpi().MPI_Scan(&localSize, &globalEnd, 1, MPI_INT, MPI_SUM, comm);
globalEnd--;
int counter = globalStart = globalEnd - size() + 1;
Expand Down
6 changes: 3 additions & 3 deletions include/bout/hypre_interface.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ public:
explicit HypreVector(IndexerPtr<T> indConverter) : indexConverter(indConverter) {
Mesh& mesh = *indConverter->getMesh();
const MPI_Comm comm =
std::is_same_v<T, FieldPerp> ? mesh.getXcomm() : BoutComm::get();
std::is_same_v<T, FieldPerp> ? mesh.getXZcomm() : BoutComm::get();

HYPRE_BigInt jlower = indConverter->getGlobalStart();
HYPRE_BigInt jupper = jlower + indConverter->size() - 1; // inclusive end
Expand Down Expand Up @@ -380,7 +380,7 @@ public:
: hypre_matrix(new HYPRE_IJMatrix, MatrixDeleter{}), index_converter(indConverter) {
Mesh* mesh = indConverter->getMesh();
const MPI_Comm comm =
std::is_same_v<T, FieldPerp> ? mesh->getXcomm() : BoutComm::get();
std::is_same_v<T, FieldPerp> ? mesh->getXZcomm() : BoutComm::get();
parallel_transform = &mesh->getCoordinates()->getParallelTransform();

ilower = indConverter->getGlobalStart();
Expand Down Expand Up @@ -812,7 +812,7 @@ public:
"values are: gmres, bicgstab, pcg")
.withDefault(HYPRE_SOLVER_TYPE::bicgstab);

comm = std::is_same_v<T, FieldPerp> ? mesh.getXcomm() : BoutComm::get();
comm = std::is_same_v<T, FieldPerp> ? mesh.getXZcomm() : BoutComm::get();

auto print_level =
options["hypre_print_level"]
Expand Down
9 changes: 4 additions & 5 deletions include/bout/mesh.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
*
* Interface for mesh classes. Contains standard variables and useful
* routines.
*
*
* Changelog
* =========
*
* 2014-12 Ben Dudson <bd512@york.ac.uk>
* * Removing coordinate system into separate
* Coordinates class
* * Adding index derivative functions from derivs.cxx
*
*
* 2010-06 Ben Dudson, Sean Farley
* * Initial version, adapted from GridData class
* * Incorporates code from topology.cpp and Communicator
Expand All @@ -20,7 +20,7 @@
* Copyright 2010-2025 BOUT++ contributors
*
* Contact: Ben Dudson, dudson2@llnl.gov
*
*
* This file is part of BOUT++.
*
* BOUT++ is free software: you can redistribute it and/or modify
Expand Down Expand Up @@ -58,8 +58,6 @@ class Mesh;
#include "bout/sys/range.hxx" // RangeIterator
#include "bout/unused.hxx"

#include "mpi.h"

#include <map>
#include <memory>
#include <optional>
Expand Down Expand Up @@ -405,6 +403,7 @@ public:
} ///< Return communicator containing all processors in X
virtual MPI_Comm getXcomm(int jy) const = 0; ///< Return X communicator
virtual MPI_Comm getYcomm(int jx) const = 0; ///< Return Y communicator
virtual MPI_Comm getXZcomm() const = 0; ///< Communicator in X-Z

/// Return pointer to the mesh's MPI Wrapper object
MpiWrapper& getMpi() { return *mpi; }
Expand Down
4 changes: 2 additions & 2 deletions include/bout/petsc_interface.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ inline MPI_Comm getComm([[maybe_unused]] const T& field) {

template <>
inline MPI_Comm getComm([[maybe_unused]] const FieldPerp& field) {
return field.getMesh()->getXcomm();
return field.getMesh()->getXZcomm();
}

template <class T>
Expand Down Expand Up @@ -293,7 +293,7 @@ public:
PetscMatrix(IndexerPtr<T> indConverter, bool preallocate = true)
: matrix(new Mat()), indexConverter(indConverter),
pt(&indConverter->getMesh()->getCoordinates()->getParallelTransform()) {
MPI_Comm comm = std::is_same_v<T, FieldPerp> ? indConverter->getMesh()->getXcomm()
MPI_Comm comm = std::is_same_v<T, FieldPerp> ? indConverter->getMesh()->getXZcomm()
: BoutComm::get();
const int size = indexConverter->size();

Expand Down
2 changes: 1 addition & 1 deletion src/invert/laplace/impls/petsc/petsc_laplace.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ LaplacePetsc::LaplacePetsc(Options* opt, const CELL_LOC loc, Mesh* mesh_in,
[[maybe_unused]] Solver* solver)
: Laplacian(opt, loc, mesh_in), A(0.0, mesh_in), C1(1.0, mesh_in), C2(1.0, mesh_in),
D(1.0, mesh_in), Ex(0.0, mesh_in), Ez(0.0, mesh_in), issetD(false), issetC(false),
issetE(false), comm(localmesh->getXcomm()),
issetE(false), comm(localmesh->getXZcomm()),
opts(opt == nullptr ? &(Options::root()["laplace"]) : opt),
// WARNING: only a few of these options actually make sense: see the
// PETSc documentation to work out which they are (possibly
Expand Down
2 changes: 1 addition & 1 deletion src/invert/laplacexz/impls/petsc/laplacexz-petsc.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ LaplaceXZpetsc::LaplaceXZpetsc(Mesh* m, Options* opt, const CELL_LOC loc)
.withDefault("petsc");

// Get MPI communicator
MPI_Comm comm = localmesh->getXcomm();
MPI_Comm comm = localmesh->getXZcomm();

// Local size
int localN = (localmesh->xend - localmesh->xstart + 1) * (localmesh->LocalNz);
Expand Down
47 changes: 43 additions & 4 deletions src/mesh/impls/bout/boutmesh.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
#include <bout/utils.hxx>

#include <fmt/format.h>
#include <fmt/ranges.h>

#include <algorithm>
#include <cmath>
Expand Down Expand Up @@ -106,6 +107,9 @@ BoutMesh::~BoutMesh() {
if (comm_outer != MPI_COMM_NULL) {
MPI_Comm_free(&comm_outer);
}
if (comm_xz != MPI_COMM_NULL) {
MPI_Comm_free(&comm_xz);
}
}

BoutMesh::YDecompositionIndices
Expand Down Expand Up @@ -665,10 +669,43 @@ int BoutMesh::load() {
return 0;
}

namespace {
auto make_XZ_communicator(const BoutMesh& mesh, MPI_Group group_world) -> MPI_Comm {
std::vector<int> ranks;

const int yp = mesh.getYProcIndex();

// All processors with the same Y index
for (int xp = 0; xp < mesh.getNXPE(); ++xp) {
for (int zp = 0; zp < mesh.getNZPE(); ++zp) {
ranks.push_back(mesh.getProcIndex(xp, yp, zp));
}
}
MPI_Group group{};
if (MPI_Group_incl(group_world, static_cast<int>(ranks.size()), ranks.data(), &group)
!= MPI_SUCCESS) {
throw BoutException("Could not create X-Z communication group for ranks {}",
fmt::join(ranks, ", "));
}

MPI_Comm comm_xz{};
if (MPI_Comm_create(BoutComm::get(), group, &comm_xz) != MPI_SUCCESS) {
throw BoutException("Could not create X-Z communicator for yp={} (xind={}, yind={}, "
"zind={}) ranks={}",
yp, mesh.getXProcIndex(), mesh.getYProcIndex(),
mesh.getZProcIndex(), fmt::join(ranks, ", "));
}

return comm_xz;
}
} // namespace

void BoutMesh::createCommunicators() {
MPI_Group group_world{};
MPI_Comm_group(BoutComm::get(), &group_world); // Get the entire group

comm_xz = make_XZ_communicator(*this, group_world);

//////////////////////////////////////////////////////
/// Communicator in X

Expand Down Expand Up @@ -1038,7 +1075,9 @@ void BoutMesh::createXBoundaries() {
}
}

int BoutMesh::getProcIndex(int X, int Y, int Z) const { return Y * NXPE + X; }
int BoutMesh::getProcIndex(int X, int Y, int Z) const {
return (((Z * NYPE) + Y) * NXPE) + X;
}

void BoutMesh::createYBoundaries() {
if (MYG <= 0) {
Expand Down Expand Up @@ -2218,9 +2257,9 @@ void BoutMesh::topology() {
}

for (int i = 0; i < limiter_count; ++i) {
int const yind = limiter_yinds[i];
int const xstart = limiter_xstarts[i];
int const xend = limiter_xends[i];
const int yind = limiter_yinds[i];
const int xstart = limiter_xstarts[i];
const int xend = limiter_xends[i];
output_info.write("Adding a limiter between y={} and {}. X indices {} to {}\n",
yind, yind + 1, xstart, xend);
add_target(yind, xstart, xend);
Expand Down
3 changes: 3 additions & 0 deletions src/mesh/impls/bout/boutmesh.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ public:
MPI_Comm getXcomm(int UNUSED(jy)) const override { return comm_x; }
/// Return communicator containing all processors in Y
MPI_Comm getYcomm(int xpos) const override;
MPI_Comm getXZcomm() const override { return comm_xz; }

/// Is local X index \p jx periodic in Y?
///
Expand Down Expand Up @@ -455,6 +456,8 @@ private:

/// Communicator containing all processors in X
MPI_Comm comm_x{MPI_COMM_NULL};
/// Communicator for all processors in an XZ plane
MPI_Comm comm_xz{MPI_COMM_NULL};

//////////////////////////////////////////////////
// Surface communications
Expand Down
2 changes: 1 addition & 1 deletion src/mesh/mesh.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ int Mesh::globalStartIndex2D() {
int Mesh::globalStartIndexPerp() {
int localSize = localSizePerp();
int cumulativeSize = 0;
mpi->MPI_Scan(&localSize, &cumulativeSize, 1, MPI_INT, MPI_SUM, getXcomm());
mpi->MPI_Scan(&localSize, &cumulativeSize, 1, MPI_INT, MPI_SUM, getXZcomm());
return cumulativeSize - localSize;
}

Expand Down
1 change: 1 addition & 0 deletions tests/unit/fake_mesh.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ public:
}
MPI_Comm getXcomm(int UNUSED(jy)) const override { return BoutComm::get(); }
MPI_Comm getYcomm(int UNUSED(jx)) const override { return BoutComm::get(); }
MPI_Comm getXZcomm() const override { return BoutComm::get(); }

// Periodic Y
int ix_separatrix{1000000}; // separatrix index
Expand Down
Loading