diff --git a/src/somd2/__init__.py b/src/somd2/__init__.py index 6a45080..ac5eef1 100644 --- a/src/somd2/__init__.py +++ b/src/somd2/__init__.py @@ -37,3 +37,9 @@ # Store the sire version. from sire import __version__ as _sire_version from sire import __revisionid__ as _sire_revisionid + +# Store the ghostly version. +from ghostly import __version__ as _ghostly_version + +# Store the loch version. +from loch import __version__ as _loch_version diff --git a/src/somd2/runner/_base.py b/src/somd2/runner/_base.py index c401fc9..cbf8f43 100644 --- a/src/somd2/runner/_base.py +++ b/src/somd2/runner/_base.py @@ -117,10 +117,20 @@ def __init__(self, system, config): self._perturbed_box = None # Log the versions of somd2 and sire. - from somd2 import __version__, _sire_version, _sire_revisionid + from somd2 import ( + __version__, + _sire_version, + _sire_revisionid, + _ghostly_version, + _loch_version, + ) _logger.info(f"somd2 version: {__version__}") _logger.info(f"sire version: {_sire_version}+{_sire_revisionid}") + if self._config.ghost_modifications: + _logger.info(f"ghostly version: {_ghostly_version}") + if self._config.gcmc: + _logger.info(f"loch version: {_loch_version}") # Flag whether frames are being saved. if ( diff --git a/src/somd2/runner/_repex.py b/src/somd2/runner/_repex.py index 34303f1..c6025a9 100644 --- a/src/somd2/runner/_repex.py +++ b/src/somd2/runner/_repex.py @@ -215,30 +215,31 @@ def _create_dynamics( # Initialise the dynamics object list. self._dynamics = [] - # A set of visited device indices. - devices = set() + # Per-device memory tracking for estimation. + device_mem = {} - # Determine whether there is a remainder in the number of replicas. + # Work out how many replicas are assigned to each device. + # Replicas are assigned round-robin, so the first (num_replicas % num_gpus) + # devices get one extra replica. + base = floor(num_replicas / num_gpus) remainder = num_replicas % num_gpus - - # Store the number of contexts for each device. The last device will - # have remainder contexts, while all others have - contexts_per_device = num_replicas * [floor(num_replicas / num_gpus)] - - # Set the last device to have the remainder contexts. - contexts_per_device[-1] = remainder + contexts_per_device = [ + base + (1 if i < remainder else 0) for i in range(num_gpus) + ] # Create the dynamics objects in serial. for i, (lam, scale) in enumerate(zip(lambdas, rest2_scale_factors)): # Work out the device index. device = i % num_gpus - # If we've not seen this device before then get the memory statistics - # prior to creating the dynamics object and GCMC sampler. - if device not in devices: - used_mem_before, free_mem_before, total_mem = self._check_device_memory( - device - ) + # Record baseline memory before the first replica on this device. + if device not in device_mem: + used_before, _, total_mem = self._check_device_memory(device) + device_mem[device] = { + "before": used_before, + "total": total_mem, + "count": 0, + } # This is a restart, get the system for this replica. if isinstance(system, list): @@ -321,19 +322,43 @@ def _create_dynamics( # Append the dynamics object. self._dynamics.append(dynamics) - # Check the memory footprint for this device. - if not device in devices: - # Add the device to the set of visited devices. - devices.add(device) + # Track memory footprint for this device. + info = device_mem[device] + info["count"] += 1 + num_contexts = contexts_per_device[device] - # Get the current memory usage. - used_mem, free_mem, total_mem = self._check_device_memory(device) + # Estimate memory after the first or second replica. + if info["count"] == 1: + used_mem, _, _ = self._check_device_memory(device) + info["after_first"] = used_mem - # Work out the memory used by this dynamics object and GCMC sampler. - mem_used = used_mem - used_mem_before + if num_contexts == 1: + # Only one replica on this device, use actual measurement. + est_total = used_mem + else: + # Wait for the second replica to get the marginal cost. + est_total = None + + elif info["count"] == 2: + used_mem, _, _ = self._check_device_memory(device) + # The first replica includes one-time context overhead. + # The marginal cost of subsequent replicas is the difference + # between the second and first. + first_cost = info["after_first"] - info["before"] + marginal_cost = used_mem - info["after_first"] + est_total = ( + info["before"] + first_cost + marginal_cost * (num_contexts - 1) + ) + _logger.info( + f"Memory per replica on device {device}: " + f"first = {first_cost / (1024**2):.0f} MiB, " + f"marginal = {marginal_cost / (1024**2):.0f} MiB" + ) + else: + est_total = None - # Work out the estimated total after all replicas have been created. - est_total = mem_used * contexts_per_device[device] + used_mem_before + if est_total is not None: + total_mem = info["total"] # If this exceeds the total memory, raise an error. if est_total > total_mem: @@ -562,18 +587,10 @@ def _check_device_memory(device_index=0): pynvml.nvmlInit() - # Find matching device by name - device_count = pynvml.nvmlDeviceGetCount() - for i in range(device_count): - handle = pynvml.nvmlDeviceGetHandleByIndex(i) - name = pynvml.nvmlDeviceGetName(handle) - - if name in device.name or device.name in name: - memory = pynvml.nvmlDeviceGetMemoryInfo(handle) - pynvml.nvmlShutdown() - return (memory.used, memory.free, memory.total) - + handle = pynvml.nvmlDeviceGetHandleByIndex(device_index) + memory = pynvml.nvmlDeviceGetMemoryInfo(handle) pynvml.nvmlShutdown() + return (memory.used, memory.free, memory.total) except Exception as e: msg = f"Could not get NVIDIA GPU memory info for device {device_index}: {e}" _logger.error(msg)