diff --git a/hypnotoad/core/mesh.py b/hypnotoad/core/mesh.py index b5221f60..d26f32d2 100644 --- a/hypnotoad/core/mesh.py +++ b/hypnotoad/core/mesh.py @@ -3220,9 +3220,11 @@ def smoothnl(self, varname): if change < 1.0e-3: break - def plotGridCellEdges(self, ax=None, **kwargs): + def plotGridCellEdges(self, ax=None, exclude_penalty=True, **kwargs): """ Plot lines between cell corners + + exclude_penalty Exclude regions where penalty mask > 0.99? """ from matplotlib import pyplot from cycler import cycle @@ -3235,25 +3237,31 @@ def plotGridCellEdges(self, ax=None, **kwargs): for region in self.regions.values(): c = next(colors) label = region.myID - for i in range(region.nx + 1): - ax.plot( - region.Rxy.corners[i, :], - region.Zxy.corners[i, :], - c=c, - label=label, - **kwargs, - ) - label = None - label = region.myID - for j in range(region.ny + 1): - ax.plot( - region.Rxy.corners[:, j], - region.Zxy.corners[:, j], - c=c, - label=None, - **kwargs, - ) - label = None + + for i in range(region.nx): + for j in range(region.ny): + if exclude_penalty and region.penalty_mask[i, j] > 0.99: + continue + # Plot cell edges around (i,j) + ax.plot( + [ + region.Rxy.corners[i, j], + region.Rxy.corners[i, j + 1], + region.Rxy.corners[i + 1, j + 1], + region.Rxy.corners[i + 1, j], + ], + [ + region.Zxy.corners[i, j], + region.Zxy.corners[i, j + 1], + region.Zxy.corners[i + 1, j + 1], + region.Zxy.corners[i + 1, j], + ], + c=c, + label=label, + **kwargs, + ) + label = None + return ax def plotPenaltyMask(self, ax=None, **kwargs): from matplotlib import pyplot @@ -3283,8 +3291,14 @@ def plotPenaltyMask(self, ax=None, **kwargs): "k", alpha=penalty, ) + return ax + + def plotGridLines(self, ax=None, exclude_penalty=True, **kwargs): + """ + Plot lines through cell centers - def plotGridLines(self, ax=None, **kwargs): + exclude_penalty Exclude regions where penalty mask > 0.99? + """ from matplotlib import pyplot from cycler import cycle @@ -3298,20 +3312,36 @@ def plotGridLines(self, ax=None, **kwargs): for region in self.regions.values(): c = next(colors) label = region.myID + jmin = 0 + jmax = region.ny - 1 for i in range(region.nx): + if exclude_penalty: + jwhere = numpy.argwhere(region.penalty_mask[i, :] < 0.5) + if len(jwhere) == 0: + continue + jmin = jwhere[0][0] + jmax = jwhere[-1][0] ax.plot( - region.Rxy.centre[i, :], - region.Zxy.centre[i, :], + region.Rxy.centre[i, jmin : (jmax + 1)], + region.Zxy.centre[i, jmin : (jmax + 1)], c=c, label=label, **kwargs, ) label = None label = region.myID + imin = 0 + imax = region.nx - 1 for j in range(region.ny): + if exclude_penalty: + iwhere = numpy.argwhere(region.penalty_mask[:, j] < 0.5) + if len(iwhere) == 0: + continue + imin = iwhere[0][0] + imax = iwhere[-1][0] ax.plot( - region.Rxy.centre[:, j], - region.Zxy.centre[:, j], + region.Rxy.centre[imin : (imax + 1), j], + region.Zxy.centre[imin : (imax + 1), j], c=c, label=None, **kwargs, @@ -3319,6 +3349,7 @@ def plotGridLines(self, ax=None, **kwargs): label = None l = fig.legend() l.set_draggable(True) + return ax def plotPoints( self,