Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/somd2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,9 @@
# Store the sire version.
from sire import __version__ as _sire_version
from sire import __revisionid__ as _sire_revisionid

# Store the ghostly version.
from ghostly import __version__ as _ghostly_version

# Store the loch version.
from loch import __version__ as _loch_version
12 changes: 11 additions & 1 deletion src/somd2/runner/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,20 @@ def __init__(self, system, config):
self._perturbed_box = None

# Log the versions of somd2 and sire.
from somd2 import __version__, _sire_version, _sire_revisionid
from somd2 import (
__version__,
_sire_version,
_sire_revisionid,
_ghostly_version,
_loch_version,
)

_logger.info(f"somd2 version: {__version__}")
_logger.info(f"sire version: {_sire_version}+{_sire_revisionid}")
if self._config.ghost_modifications:
_logger.info(f"ghostly version: {_ghostly_version}")
if self._config.gcmc:
_logger.info(f"loch version: {_loch_version}")

# Flag whether frames are being saved.
if (
Expand Down
91 changes: 54 additions & 37 deletions src/somd2/runner/_repex.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,30 +215,31 @@ def _create_dynamics(
# Initialise the dynamics object list.
self._dynamics = []

# A set of visited device indices.
devices = set()
# Per-device memory tracking for estimation.
device_mem = {}

# Determine whether there is a remainder in the number of replicas.
# Work out how many replicas are assigned to each device.
# Replicas are assigned round-robin, so the first (num_replicas % num_gpus)
# devices get one extra replica.
base = floor(num_replicas / num_gpus)
remainder = num_replicas % num_gpus

# Store the number of contexts for each device. The last device will
# have remainder contexts, while all others have
contexts_per_device = num_replicas * [floor(num_replicas / num_gpus)]

# Set the last device to have the remainder contexts.
contexts_per_device[-1] = remainder
contexts_per_device = [
base + (1 if i < remainder else 0) for i in range(num_gpus)
]

# Create the dynamics objects in serial.
for i, (lam, scale) in enumerate(zip(lambdas, rest2_scale_factors)):
# Work out the device index.
device = i % num_gpus

# If we've not seen this device before then get the memory statistics
# prior to creating the dynamics object and GCMC sampler.
if device not in devices:
used_mem_before, free_mem_before, total_mem = self._check_device_memory(
device
)
# Record baseline memory before the first replica on this device.
if device not in device_mem:
used_before, _, total_mem = self._check_device_memory(device)
device_mem[device] = {
"before": used_before,
"total": total_mem,
"count": 0,
}

# This is a restart, get the system for this replica.
if isinstance(system, list):
Expand Down Expand Up @@ -321,19 +322,43 @@ def _create_dynamics(
# Append the dynamics object.
self._dynamics.append(dynamics)

# Check the memory footprint for this device.
if not device in devices:
# Add the device to the set of visited devices.
devices.add(device)
# Track memory footprint for this device.
info = device_mem[device]
info["count"] += 1
num_contexts = contexts_per_device[device]

# Get the current memory usage.
used_mem, free_mem, total_mem = self._check_device_memory(device)
# Estimate memory after the first or second replica.
if info["count"] == 1:
used_mem, _, _ = self._check_device_memory(device)
info["after_first"] = used_mem

# Work out the memory used by this dynamics object and GCMC sampler.
mem_used = used_mem - used_mem_before
if num_contexts == 1:
# Only one replica on this device, use actual measurement.
est_total = used_mem
else:
# Wait for the second replica to get the marginal cost.
est_total = None

elif info["count"] == 2:
used_mem, _, _ = self._check_device_memory(device)
# The first replica includes one-time context overhead.
# The marginal cost of subsequent replicas is the difference
# between the second and first.
first_cost = info["after_first"] - info["before"]
marginal_cost = used_mem - info["after_first"]
est_total = (
info["before"] + first_cost + marginal_cost * (num_contexts - 1)
)
_logger.info(
f"Memory per replica on device {device}: "
f"first = {first_cost / (1024**2):.0f} MiB, "
f"marginal = {marginal_cost / (1024**2):.0f} MiB"
)
else:
est_total = None

# Work out the estimated total after all replicas have been created.
est_total = mem_used * contexts_per_device[device] + used_mem_before
if est_total is not None:
total_mem = info["total"]

# If this exceeds the total memory, raise an error.
if est_total > total_mem:
Expand Down Expand Up @@ -562,18 +587,10 @@ def _check_device_memory(device_index=0):

pynvml.nvmlInit()

# Find matching device by name
device_count = pynvml.nvmlDeviceGetCount()
for i in range(device_count):
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
name = pynvml.nvmlDeviceGetName(handle)

if name in device.name or device.name in name:
memory = pynvml.nvmlDeviceGetMemoryInfo(handle)
pynvml.nvmlShutdown()
return (memory.used, memory.free, memory.total)

handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
memory = pynvml.nvmlDeviceGetMemoryInfo(handle)
pynvml.nvmlShutdown()
return (memory.used, memory.free, memory.total)
except Exception as e:
msg = f"Could not get NVIDIA GPU memory info for device {device_index}: {e}"
_logger.error(msg)
Expand Down
Loading