From ae36f72faa24eb71edcd4f5bb38b87be1c04eb1b Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Mon, 2 Feb 2026 15:10:33 +0000 Subject: [PATCH 1/6] Use primary context to reduce memory footprint. --- src/loch/_kernels.py | 322 ++++++++++----------------------- src/loch/_platforms/_base.py | 29 +-- src/loch/_platforms/_cuda.py | 60 ++---- src/loch/_platforms/_opencl.py | 26 --- src/loch/_sampler.py | 196 ++++++++++++-------- tests/test_energy.py | 90 +++++++++ 6 files changed, 323 insertions(+), 400 deletions(-) diff --git a/src/loch/_kernels.py b/src/loch/_kernels.py index 185f708..710b107 100644 --- a/src/loch/_kernels.py +++ b/src/loch/_kernels.py @@ -27,7 +27,7 @@ // Platform-specific definitions for CUDA and OpenCL compatibility #ifdef __OPENCL_VERSION__ #define KERNEL __kernel - #define DEVICE // OpenCL: file-scope variables visible to all kernels + #define DEVICE #define GLOBAL __global #define LOCAL __local #define GET_GLOBAL_ID(dim) get_global_id(dim) @@ -46,7 +46,7 @@ #pragma OPENCL EXTENSION cl_khr_fp64 : enable #else #define KERNEL extern "C" __global__ - #define DEVICE __device__ // CUDA: device memory (mutable) + #define DEVICE __device__ #define GLOBAL #define LOCAL __shared__ #define GET_GLOBAL_ID(dim) (threadIdx.x + blockIdx.x * blockDim.x) @@ -62,229 +62,18 @@ const int num_water_positions = 3 * num_points; const float prefactor = 332.0637090025476f; - // Reaction field parameters. - DEVICE float rf_dielectric; - DEVICE float rf_kappa; - DEVICE float rf_cutoff; - DEVICE float rf_correction; - - // Soft-core parameters. - DEVICE float coulomb_power; - DEVICE float shift_coulomb; - DEVICE float shift_delta; - - // Triclinic cell information. - DEVICE float cell_matrix[3][3]; - DEVICE float cell_matrix_inverse[3][3]; - DEVICE float M[3][3]; - - // Atom properties. - DEVICE float sigma[num_atoms]; - DEVICE float epsilon[num_atoms]; - DEVICE float charge[num_atoms]; - DEVICE float alpha[num_atoms]; - DEVICE float position[num_atoms * 3]; - DEVICE int is_ghost_water[num_atoms]; - DEVICE int is_ghost_fep[num_atoms]; - - // Water properties. - DEVICE float sigma_water[num_points]; - DEVICE float epsilon_water[num_points]; - DEVICE float charge_water[num_points]; - DEVICE int water_idx[num_waters]; - DEVICE int water_state[num_waters]; - #ifndef __OPENCL_VERSION__ extern "C" { #endif - // Intialisation of the cell information for periodic triclinic boxes. - KERNEL void setCellMatrix( - GLOBAL float* matrix, - GLOBAL float* matrix_inverse, - GLOBAL float* m) - { - for (int i = 0; i < 3; i++) - { - for (int j = 0; j < 3; j++) - { - cell_matrix[i][j] = matrix[i * 3 + j]; - cell_matrix_inverse[i][j] = matrix_inverse[i * 3 + j]; - M[i][j] = m[i * 3 + j]; - } - } - } - - // Set the reaction field parameters. - KERNEL void setReactionField(float cutoff, float dielectric) - { - rf_dielectric = dielectric; - rf_cutoff = cutoff; - const float rf_cutoff2 = cutoff * cutoff; - const float rf_cutoff3_inv = 1.0f / (rf_cutoff * rf_cutoff2); - rf_kappa = rf_cutoff3_inv * (dielectric - 1.0f) / (2.0f * dielectric + 1.0f); - rf_correction = (1.0 / rf_cutoff) + rf_kappa * rf_cutoff2; - } - - // Set the soft-core parameters. - KERNEL void setSoftCore(float power, float coulomb, float delta) - { - coulomb_power = power; - shift_coulomb = coulomb; - shift_delta = delta; - } - - // Set the properties of each atom. - KERNEL void setAtomProperties( - GLOBAL float* charges, - GLOBAL float* sigmas, - GLOBAL float* epsilons, - GLOBAL float* alphas, - GLOBAL int* ghost_water, - GLOBAL int* ghost_fep) - { - const int tidx = GET_GLOBAL_ID(0); - - if (tidx < num_atoms) - { - charge[tidx] = charges[tidx]; - sigma[tidx] = sigmas[tidx]; - epsilon[tidx] = epsilons[tidx]; - alpha[tidx] = alphas[tidx]; - is_ghost_water[tidx] = ghost_water[tidx]; - is_ghost_fep[tidx] = ghost_fep[tidx]; - } - } - - // Set the positions of each atom. - KERNEL void setAtomPositions(GLOBAL float* positions, float scale) - { - const int tidx = GET_GLOBAL_ID(0); - - if (tidx < num_atoms) - { - position[tidx * 3] = scale * positions[tidx * 3]; - position[tidx * 3 + 1] = scale * positions[tidx * 3 + 1]; - position[tidx * 3 + 2] = scale * positions[tidx * 3 + 2]; - } - } - - // Set the properties of each water atom. - KERNEL void setWaterProperties( - GLOBAL float* charges, - GLOBAL float* sigmas, - GLOBAL float* epsilons, - GLOBAL int* idx, - GLOBAL int* state) - { - for (int i = 0; i < num_points; i++) - { - charge_water[i] = charges[i]; - sigma_water[i] = sigmas[i]; - epsilon_water[i] = epsilons[i]; - } - - for (int i = 0; i < num_waters; i++) - { - water_idx[i] = idx[i]; - water_state[i] = state[i]; - } - } - - // Update a single water. - KERNEL void updateWater(int idx, int state, int is_insertion, GLOBAL float* new_position) - { - // Set the new state. - water_state[idx] = state; - - // Get the water oxygen index in the context. - int idx_context = water_idx[idx]; - - for (int i = 0; i < num_points; i++) - { - // Ghost water. - if (state == 0) - { - charge[idx_context + i] = 0.0f; - epsilon[idx_context + i] = 0.0f; - is_ghost_water[idx_context + i] = 1; - } - else - { - charge[idx_context + i] = charge_water[i]; - epsilon[idx_context + i] = epsilon_water[i]; - is_ghost_water[idx_context + i] = 0; - } - - // Update the position of the water. We don't use the state to determine - // whether an insertion is performed, since we don't need to update the - // positions when a deletion move is rejected, which would also set the - // state to 1. - if (is_insertion == 1) - { - position[3 * idx_context + 3 * i] = new_position[3 * i]; - position[3 * idx_context + 3 * i + 1] = new_position[3 * i + 1]; - position[3 * idx_context + 3 * i + 2] = new_position[3 * i + 2]; - } - } - } - - // Calculate the delta that needs to be subtracted from the interatomic distance - // so that the atoms are wrapped to the same periodic box. - DEVICE void wrapDelta(float* v0, float* v1, float* delta_box) - { - // Work out the positions of v0 and v1 in "box" space. - float v0_box[3]; - float v1_box[3]; - for (int i = 0; i < 3; i++) - { - v0_box[i] = 0.0f; - v1_box[i] = 0.0f; - - for (int j = 0; j < 3; j++) - { - v0_box[i] += cell_matrix_inverse[i][j] * v0[j]; - v1_box[i] += cell_matrix_inverse[i][j] * v1[j]; - } - } - - // Now work out the distance between v0 and v1 in "box" space. - for (int i = 0; i < 3; i++) - { - delta_box[i] = v1_box[i] - v0_box[i]; - } - - // Extract the integer and fractional parts of the distance. - int int_x = (int)delta_box[0]; - int int_y = (int)delta_box[1]; - int int_z = (int)delta_box[2]; - float frac_x = delta_box[0] - int_x; - float frac_y = delta_box[1] - int_y; - float frac_z = delta_box[2] - int_z; - - // Shift to the box (branchless). - int_x += (int)floorf(frac_x + 0.5f); - int_y += (int)floorf(frac_y + 0.5f); - int_z += (int)floorf(frac_z + 0.5f); - - // Calculate the shifts over the box vectors. - delta_box[0] = 0.0f; - delta_box[1] = 0.0f; - delta_box[2] = 0.0f; - for (int i = 0; i < 3; i++) - { - delta_box[0] += cell_matrix[i][0] * int_x; - delta_box[1] += cell_matrix[i][1] * int_y; - delta_box[2] += cell_matrix[i][2] * int_z; - } - } - // Calculate the distance between two atoms within the periodic box. DEVICE void distance2( float* v0, float* v1, - float* dist2) + float* dist2, + GLOBAL const float* cell_matrix_inverse, + GLOBAL const float* metric_matrix) { // Work out the positions of v0 and v1 in "box" space. float v0_box[3]; @@ -296,8 +85,8 @@ for (int j = 0; j < 3; j++) { - v0_box[i] += cell_matrix_inverse[i][j] * v0[j]; - v1_box[i] += cell_matrix_inverse[i][j] * v1[j]; + v0_box[i] += cell_matrix_inverse[i * 3 + j] * v0[j]; + v1_box[i] += cell_matrix_inverse[i * 3 + j] * v1[j]; } } @@ -331,7 +120,7 @@ for (int j = 0; j < 3; j++) { - delta_box[i] += M[i][j] * frac_dist[j]; + delta_box[i] += metric_matrix[i * 3 + j] * frac_dist[j]; } } *dist2 = frac_x * delta_box[0] + frac_y * delta_box[1] + frac_z * delta_box[2]; @@ -425,6 +214,56 @@ v[8] = fmaf(x[2][0], M[0][2], fmaf(x[2][1], M[1][2], fmaf(x[2][2], M[2][2], mean_M[2]))); } + // Update a single water. + KERNEL void updateWater( + int idx, + int state, + int is_insertion, + GLOBAL float* new_position, + GLOBAL float* position, + GLOBAL float* charge, + GLOBAL float* epsilon, + GLOBAL int* is_ghost_water, + GLOBAL int* water_state, + GLOBAL const int* water_idx, + GLOBAL const float* charge_water, + GLOBAL const float* epsilon_water) + { + // Set the new state. + water_state[idx] = state; + + // Get the water oxygen index in the context. + int idx_context = water_idx[idx]; + + for (int i = 0; i < num_points; i++) + { + // Ghost water. + if (state == 0) + { + charge[idx_context + i] = 0.0f; + epsilon[idx_context + i] = 0.0f; + is_ghost_water[idx_context + i] = 1; + } + else + { + charge[idx_context + i] = charge_water[i]; + epsilon[idx_context + i] = epsilon_water[i]; + is_ghost_water[idx_context + i] = 0; + } + + // Update the position of the water. We don't use the state to determine + // whether an insertion is performed, since we don't need to update the + // positions when a deletion move is rejected, which would also set the + // state to 1. + if (is_insertion == 1) + { + position[3 * idx_context + 3 * i] = new_position[3 * i]; + position[3 * idx_context + 3 * i + 1] = new_position[3 * i + 1]; + position[3 * idx_context + 3 * i + 2] = new_position[3 * i + 2]; + } + } + } + // Generate a random position and orientation within the GCMC sphere // for each trial insertion. KERNEL void generateWater( @@ -435,7 +274,8 @@ int is_target, GLOBAL float* randoms_rotation, GLOBAL float* randoms_position, - GLOBAL float* randoms_radius) + GLOBAL float* randoms_radius, + GLOBAL const float* cell_matrix) { // Work out the thread index. const int tidx = GET_GLOBAL_ID(0); @@ -505,7 +345,7 @@ xyz[i] = 0.0f; for (int j = 0; j < 3; j++) { - xyz[i] += r[j] * cell_matrix[i][j]; + xyz[i] += r[j] * cell_matrix[i * 3 + j]; } } } @@ -533,7 +373,26 @@ GLOBAL float* energy_lj, GLOBAL int* deletion_candidates, GLOBAL int* is_deletion, - int is_fep) + int is_fep, + GLOBAL const float* position, + GLOBAL const float* charge, + GLOBAL const float* sigma, + GLOBAL const float* epsilon, + GLOBAL const float* alpha, + GLOBAL const int* is_ghost_water, + GLOBAL const int* is_ghost_fep, + GLOBAL const float* sigma_water, + GLOBAL const float* epsilon_water, + GLOBAL const float* charge_water, + GLOBAL const int* water_idx, + GLOBAL const float* cell_matrix_inverse, + GLOBAL const float* metric_matrix, + float rf_cutoff, + float rf_kappa, + float rf_correction, + float sc_coulomb_power, + float sc_shift_coulomb, + float sc_shift_delta) { // Work out the atom index. const int idx_atom = GET_GLOBAL_ID(0); @@ -636,7 +495,7 @@ // Calculate the squared distance between the atoms. float r2; - distance2(v0, v1, &r2); + distance2(v0, v1, &r2, cell_matrix_inverse, metric_matrix); // The distance is within the cut-off. if (r2 < cutoff2) @@ -694,24 +553,24 @@ } // Compute the Lennard-Jones interaction. - const float delta_lj = shift_delta * a; + const float delta_lj = sc_shift_delta * a; const float s6 = powf(s, 6.0f) / powf((s * delta_lj) + (r * r), 3.0f); energy_lj[idx] += 4.0f * e * s6 * (s6 - 1.0f); // Compute the Coulomb power expression. float cpe; - if (coulomb_power == 0.0f) + if (sc_coulomb_power == 0.0f) { cpe = 1.0f; } else { - cpe = powf((1.0f - a), coulomb_power); + cpe = powf((1.0f - a), sc_coulomb_power); } // Compute the Coulomb interaction. energy_coul[idx] += (q0 * q1) * - ((cpe / sqrtf((shift_coulomb * shift_coulomb * a) + ((cpe / sqrtf((sc_shift_coulomb * sc_shift_coulomb * a) + (r * r))) + (rf_kappa * r2) - rf_correction); } @@ -797,7 +656,12 @@ KERNEL void findDeletionCandidates( GLOBAL int* candidates, GLOBAL float* target, - float radius) + float radius, + GLOBAL const float* position, + GLOBAL const int* water_idx, + GLOBAL const int* water_state, + GLOBAL const float* cell_matrix_inverse, + GLOBAL const float* metric_matrix) { const int tidx = GET_GLOBAL_ID(0); @@ -820,7 +684,7 @@ // Calculate the distance between the water and the target. float r2; - distance2(v, target, &r2); + distance2(v, target, &r2, cell_matrix_inverse, metric_matrix); // The water is within the GCMC sphere. Flag it as a candidate. if (r2 < radius * radius) diff --git a/src/loch/_platforms/_base.py b/src/loch/_platforms/_base.py index f192d76..1a291c8 100644 --- a/src/loch/_platforms/_base.py +++ b/src/loch/_platforms/_base.py @@ -91,9 +91,8 @@ def compile_kernels(self) -> _Dict[str, _Callable]: ------- dict Dictionary mapping kernel names to callable kernel functions. - Expected keys: 'cell', 'rf', 'softcore', 'atom_properties', - 'atom_positions', 'water_properties', 'update_water', 'deletion', - 'water', 'energy', 'acceptance'. + Expected keys: 'update_water', 'deletion', 'water', 'energy', + 'acceptance'. """ pass @@ -151,32 +150,12 @@ def from_gpu(self, buffer: _Any) -> _np.ndarray: """ pass - @_abstractmethod - def push_context(self): - """ - Push GPU context onto the context stack. - - For CUDA, this pushes the context onto the driver stack. - For OpenCL, this is a no-op as OpenCL doesn't use context stacking. - """ - pass - - @_abstractmethod - def pop_context(self): - """ - Pop GPU context from the context stack. - - For CUDA, this pops the context from the driver stack. - For OpenCL, this is a no-op as OpenCL doesn't use context stacking. - """ - pass - @_abstractmethod def cleanup(self): """ - Clean up GPU resources and detach context. + Clean up GPU resources and release context. - This method should release all GPU memory and detach the context. + This method should release all GPU memory and context references. It is called during shutdown to ensure proper resource cleanup. """ pass diff --git a/src/loch/_platforms/_cuda.py b/src/loch/_platforms/_cuda.py index e7e6ab7..f297cf3 100644 --- a/src/loch/_platforms/_cuda.py +++ b/src/loch/_platforms/_cuda.py @@ -23,7 +23,6 @@ CUDA platform backend implementation. """ -import atexit as _atexit import io as _io import sys as _sys from typing import Any as _Any, Callable as _Callable, Dict as _Dict @@ -42,7 +41,9 @@ class CUDAPlatform(_PlatformBackend): CUDA platform backend using PyCUDA. This backend wraps PyCUDA functionality to provide GPU-accelerated - GCMC sampling on NVIDIA GPUs. + GCMC sampling on NVIDIA GPUs. Uses the CUDA primary context for + compatibility with other CUDA libraries (e.g. OpenMM) sharing the + same device. """ def __init__( @@ -87,8 +88,6 @@ def __init__( When True, passes --use_fast_math to nvcc. Default: True (matches OpenMM defaults). """ - from pycuda.tools import make_default_context - # Initialize CUDA driver _cuda.init() @@ -100,9 +99,13 @@ def __init__( raise ValueError( f"'device' must be between 0 and {_cuda.Device.count() - 1}" ) - self._pycuda_context = _cuda.Device(device).make_context() + self._cuda_device = _cuda.Device(device) else: - self._pycuda_context = make_default_context() + self._cuda_device = _cuda.Device(0) + + # Use the primary context (shared with OpenMM and other CUDA users). + self._pycuda_context = self._cuda_device.retain_primary_context() + self._pycuda_context.push() self._device = self._pycuda_context.get_device() @@ -115,9 +118,6 @@ def __init__( self._nvcc = nvcc self._compiler_optimisations = compiler_optimisations - # Register cleanup - _atexit.register(self._cleanup_wrapper) - def compile_kernels(self) -> _Dict[str, _Callable]: """ Compile CUDA kernels and return callable functions. @@ -165,12 +165,6 @@ def compile_kernels(self) -> _Dict[str, _Callable]: # Extract kernel functions kernels = { - "cell": mod.get_function("setCellMatrix"), - "rf": mod.get_function("setReactionField"), - "softcore": mod.get_function("setSoftCore"), - "atom_properties": mod.get_function("setAtomProperties"), - "atom_positions": mod.get_function("setAtomPositions"), - "water_properties": mod.get_function("setWaterProperties"), "update_water": mod.get_function("updateWater"), "deletion": mod.get_function("findDeletionCandidates"), "water": mod.get_function("generateWater"), @@ -231,38 +225,16 @@ def from_gpu(self, buffer: _Any) -> _np.ndarray: """ return buffer.get() - def push_context(self): - """ - Push the CUDA context onto the context stack. - """ - self._pycuda_context.push() - - def pop_context(self): - """ - Pop the CUDA context from the context stack. - """ - self._pycuda_context.pop() - def cleanup(self): """ - Clean up CUDA resources and detach context. - """ - try: - self.pop_context() - except Exception: - pass - self._pycuda_context.detach() - self._pycuda_context = None - - def _cleanup_wrapper(self): + Clean up CUDA resources and release primary context reference. """ - Wrapper for cleanup to handle atexit registration. - """ - try: - if self._pycuda_context is not None: - self.cleanup() - except Exception: - pass + if self._pycuda_context is not None: + try: + self._pycuda_context.pop() + except Exception: + pass + self._pycuda_context = None @property def platform_name(self) -> str: diff --git a/src/loch/_platforms/_opencl.py b/src/loch/_platforms/_opencl.py index a4fe550..98bea7f 100644 --- a/src/loch/_platforms/_opencl.py +++ b/src/loch/_platforms/_opencl.py @@ -192,12 +192,6 @@ def wrapper(*args, **kwargs): # Extract and wrap kernel functions kernels = { - "cell": make_kernel_wrapper(program.setCellMatrix), - "rf": make_kernel_wrapper(program.setReactionField), - "softcore": make_kernel_wrapper(program.setSoftCore), - "atom_properties": make_kernel_wrapper(program.setAtomProperties), - "atom_positions": make_kernel_wrapper(program.setAtomPositions), - "water_properties": make_kernel_wrapper(program.setWaterProperties), "update_water": make_kernel_wrapper(program.updateWater), "deletion": make_kernel_wrapper(program.findDeletionCandidates), "water": make_kernel_wrapper(program.generateWater), @@ -258,30 +252,10 @@ def from_gpu(self, buffer: _Any) -> _np.ndarray: """ return buffer.get() - def push_context(self): - """ - Push context (no-op for OpenCL). - - OpenCL doesn't use context stacking like CUDA, so this method - does nothing. It's provided for API compatibility. - """ - pass - - def pop_context(self): - """ - Pop context (no-op for OpenCL). - - OpenCL doesn't use context stacking like CUDA, so this method - does nothing. It's provided for API compatibility. - """ - pass - def cleanup(self): """ Clean up OpenCL resources. """ - # OpenCL resources are automatically released when objects are deleted - # No explicit cleanup needed, but we'll clear references self._queue = None self._context = None diff --git a/src/loch/_sampler.py b/src/loch/_sampler.py index faec480..539dc12 100644 --- a/src/loch/_sampler.py +++ b/src/loch/_sampler.py @@ -751,18 +751,6 @@ def _get_non_ghost_waters(self) -> _np.ndarray: """Get indices of non-ghost waters (cached).""" return self._non_ghost_waters_cache - def push(self) -> None: - """ - Push the GPU context on top of the stack (CUDA only, no-op for OpenCL). - """ - self._backend.push_context() - - def pop(self) -> None: - """ - Pop the GPU context from the stack (CUDA only, no-op for OpenCL). - """ - self._backend.pop_context() - def system(self) -> _Any: """ Return the GCMC system. @@ -826,14 +814,10 @@ def set_box(self, system: _Any) -> None: self._get_box_information(self._space) ) - # Update the cell matrix information on the GPU. - self._kernels["cell"]( - self._cell_matrix, - self._cell_matrix_inverse, - self._M, - block=(1, 1, 1), - grid=(1, 1, 1), - ) + # Store cell matrices as GPU buffers (used as kernel arguments). + self._gpu_cell_matrix = self._cell_matrix + self._gpu_cell_matrix_inverse = self._cell_matrix_inverse + self._gpu_M = self._M def set_bulk_sampling_probability(self, probability: float) -> None: """ @@ -882,19 +866,19 @@ def delete_waters(self, context: _openmm.Context) -> None: self._get_target_position(positions).astype(_np.float32) ) - # Set the positions on the GPU. - self._kernels["atom_positions"]( - self._backend.to_gpu(positions.astype(_np.float32).flatten()), - _np.float32(1.0), - block=(self._num_threads, 1, 1), - grid=(self._atom_blocks, 1, 1), - ) + # Upload atom positions to GPU. + self._gpu_position = self._backend.to_gpu(_as_float32(positions).flatten()) # Find the non-ghost waters within the GCMC region. self._kernels["deletion"]( self._deletion_candidates, self._backend.to_gpu(target.astype(_np.float32)), _np.float32(self._radius.value()), + self._gpu_position, + self._gpu_water_idx, + self._gpu_water_state, + self._gpu_cell_matrix_inverse, + self._gpu_M, block=(self._num_threads, 1, 1), grid=(self._water_blocks, 1, 1), ) @@ -944,19 +928,19 @@ def num_waters(self) -> int: self._get_target_position(positions).astype(_np.float32) ) - # Set the positions on the GPU. - self._kernels["atom_positions"]( - self._backend.to_gpu(positions.astype(_np.float32).flatten()), - _np.float32(1.0), - block=(self._num_threads, 1, 1), - grid=(self._atom_blocks, 1, 1), - ) + # Upload atom positions to GPU. + self._gpu_position = self._backend.to_gpu(_as_float32(positions).flatten()) # Find the non-ghost waters within the GCMC region. self._kernels["deletion"]( self._deletion_candidates, self._backend.to_gpu(target.astype(_np.float32)), _np.float32(self._radius.value()), + self._gpu_position, + self._gpu_water_idx, + self._gpu_water_state, + self._gpu_cell_matrix_inverse, + self._gpu_M, block=(self._num_threads, 1, 1), grid=(self._water_blocks, 1, 1), ) @@ -1175,12 +1159,9 @@ def move(self, context: _openmm.Context) -> list[int]: _np.float32 ) - # Set the positions on the GPU. - self._kernels["atom_positions"]( - self._backend.to_gpu(_as_float32(positions_angstrom).flatten()), - _np.float32(1.0), - block=(self._num_threads, 1, 1), - grid=(self._atom_blocks, 1, 1), + # Upload atom positions to GPU. + self._gpu_position = self._backend.to_gpu( + _as_float32(positions_angstrom).flatten() ) # Work out the number of waters in the sampling volume. @@ -1189,6 +1170,11 @@ def move(self, context: _openmm.Context) -> list[int]: self._deletion_candidates, self._backend.to_gpu(_as_float32(target)), _np.float32(self._radius.value()), + self._gpu_position, + self._gpu_water_idx, + self._gpu_water_state, + self._gpu_cell_matrix_inverse, + self._gpu_M, block=(self._num_threads, 1, 1), grid=(self._water_blocks, 1, 1), ) @@ -1287,6 +1273,7 @@ def move(self, context: _openmm.Context) -> list[int]: randoms_rotation, randoms_position, randoms_radius, + self._gpu_cell_matrix, block=(self._num_threads, 1, 1), grid=(self._batch_blocks, 1, 1), ) @@ -1299,6 +1286,25 @@ def move(self, context: _openmm.Context) -> list[int]: candidates_gpu, is_deletion_gpu, _np.int32(self._is_fep), + self._gpu_position, + self._gpu_charge, + self._gpu_sigma, + self._gpu_epsilon, + self._gpu_alpha, + self._gpu_is_ghost_water, + self._gpu_is_ghost_fep, + self._gpu_sigma_water, + self._gpu_epsilon_water, + self._gpu_charge_water, + self._gpu_water_idx, + self._gpu_cell_matrix_inverse, + self._gpu_M, + self._rf_cutoff, + self._rf_kappa, + self._rf_correction, + self._sc_coulomb_power, + self._sc_shift_coulomb, + self._sc_shift_delta, block=(self._num_threads, 1, 1), grid=(self._atom_blocks, self._batch_size, 1), ) @@ -2017,46 +2023,44 @@ def _initialise_gpu_memory(self): # Pre-allocate zero target array for bulk sampling. self._zero_target_gpu = self._backend.to_gpu(_np.zeros(3, dtype=_np.float32)) - # Initialise the reaction field parameters. - self._kernels["rf"]( - _np.float32(self._cutoff.value()), - _np.float32(78.3), - block=(1, 1, 1), - grid=(1, 1, 1), + # Compute reaction field parameters on host. + cutoff_val = self._cutoff.value() + self._rf_cutoff = _np.float32(cutoff_val) + self._rf_kappa = _np.float32( + (78.3 - 1.0) / ((2.0 * 78.3 + 1.0) * cutoff_val**3) ) - - # Initialise the soft-core parameters. - if self._is_fep: - self._kernels["softcore"]( - _np.float32(self._coulomb_power), - _np.float32(self._shift_coulomb.value()), - _np.float32(self._shift_delta.value()), - block=(1, 1, 1), - grid=(1, 1, 1), - ) - - # Set the atomic properties. - self._kernels["atom_properties"]( - charges, - sigmas, - epsilons, - alphas, - self._backend.to_gpu(is_ghost_water.astype(_np.int32)), - is_ghost_fep, - block=(self._num_threads, 1, 1), - grid=(self._atom_blocks, 1, 1), + self._rf_correction = _np.float32( + 1.0 / cutoff_val + float(self._rf_kappa) * cutoff_val**2 ) - # Set the water properties. - self._kernels["water_properties"]( - charge_water, - sigma_water, - epsilon_water, - self._backend.to_gpu(self._water_indices.astype(_np.int32)), - self._backend.to_gpu(self._water_state.astype(_np.int32)), - block=(1, 1, 1), - grid=(1, 1, 1), + # Store soft-core parameters as scalars. + self._sc_coulomb_power = _np.float32(self._coulomb_power) + self._sc_shift_coulomb = _np.float32(self._shift_coulomb.value()) + self._sc_shift_delta = _np.float32(self._shift_delta.value()) + + # Store immutable per-atom buffers on GPU. + self._gpu_sigma = sigmas + self._gpu_epsilon = epsilons + self._gpu_charge = charges + self._gpu_alpha = alphas + self._gpu_is_ghost_water = self._backend.to_gpu( + is_ghost_water.astype(_np.int32) ) + self._gpu_is_ghost_fep = is_ghost_fep + + # Store immutable water property buffers on GPU. + self._gpu_charge_water = charge_water + self._gpu_sigma_water = sigma_water + self._gpu_epsilon_water = epsilon_water + self._gpu_water_idx = self._backend.to_gpu( + self._water_indices.astype(_np.int32) + ) + self._gpu_water_state = self._backend.to_gpu( + self._water_state.astype(_np.int32) + ) + + # Allocate mutable position buffer (will be filled before each move). + self._gpu_position = self._backend.empty((1, self._num_atoms * 3), _np.float32) # Initialise the memory to store the water positions. self._water_positions = self._backend.empty( @@ -2162,6 +2166,14 @@ def _accept_insertion( _np.int32(1), _np.int32(1), self._backend.to_gpu(water_positions.flatten().astype(_np.float32)), + self._gpu_position, + self._gpu_charge, + self._gpu_epsilon, + self._gpu_is_ghost_water, + self._gpu_water_state, + self._gpu_water_idx, + self._gpu_charge_water, + self._gpu_epsilon_water, block=(1, 1, 1), grid=(1, 1, 1), ) @@ -2223,6 +2235,14 @@ def _accept_deletion(self, idx, context): self._backend.to_gpu( _np.zeros((self._num_points, 3), dtype=_np.float32).flatten() ), + self._gpu_position, + self._gpu_charge, + self._gpu_epsilon, + self._gpu_is_ghost_water, + self._gpu_water_state, + self._gpu_water_idx, + self._gpu_charge_water, + self._gpu_epsilon_water, block=(1, 1, 1), grid=(1, 1, 1), ) @@ -2287,6 +2307,14 @@ def _reject_deletion(self, idx, context): self._backend.to_gpu( _np.zeros((self._num_points, 3), dtype=_np.float32).flatten() ), + self._gpu_position, + self._gpu_charge, + self._gpu_epsilon, + self._gpu_is_ghost_water, + self._gpu_water_state, + self._gpu_water_idx, + self._gpu_charge_water, + self._gpu_epsilon_water, block=(1, 1, 1), grid=(1, 1, 1), ) @@ -2381,6 +2409,14 @@ def _set_water_state(self, context, indices=None, states=None, force=False): self._backend.to_gpu( _np.zeros((self._num_points, 3), dtype=_np.float32).flatten() ), + self._gpu_position, + self._gpu_charge, + self._gpu_epsilon, + self._gpu_is_ghost_water, + self._gpu_water_state, + self._gpu_water_idx, + self._gpu_charge_water, + self._gpu_epsilon_water, block=(1, 1, 1), grid=(1, 1, 1), ) @@ -2419,6 +2455,14 @@ def _set_water_state(self, context, indices=None, states=None, force=False): self._backend.to_gpu( _np.zeros((self._num_points, 3), dtype=_np.float32).flatten() ), + self._gpu_position, + self._gpu_charge, + self._gpu_epsilon, + self._gpu_is_ghost_water, + self._gpu_water_state, + self._gpu_water_idx, + self._gpu_charge_water, + self._gpu_epsilon_water, block=(1, 1, 1), grid=(1, 1, 1), ) diff --git a/tests/test_energy.py b/tests/test_energy.py index 87cfcdc..4eaaf82 100644 --- a/tests/test_energy.py +++ b/tests/test_energy.py @@ -267,3 +267,93 @@ def test_platform_consistency(fixture, request): assert ( relative_diff < 0.001 ), f"Platform energies differ: CUDA={cuda_energy:.6f}, OpenCL={opencl_energy:.6f}, relative_diff={relative_diff:.6f}" + + +# Reference energy values captured with seed=42 on the original kernel implementation. +# These anchor the kernel output to exact values so that refactors (e.g. moving from +# __device__ static arrays to buffer arguments) can be validated. +_REFERENCE_ENERGIES = { + "water_box": { + "energy_coul": -9.45853172201302, + "energy_lj": 3.2191088, + }, + "bpti": { + "energy_coul": -15.377882774621897, + "energy_lj": -0.58867246, + }, +} + + +@pytest.mark.skipif( + "CUDA_VISIBLE_DEVICES" not in os.environ, + reason="Requires CUDA enabled GPU.", +) +@pytest.mark.parametrize("platform", ["cuda", "opencl"]) +@pytest.mark.parametrize("fixture", ["water_box", "bpti"]) +def test_energy_regression(fixture, platform, request): + """ + Test that kernel energy values are unchanged for a fixed random seed. + + This catches silent numerical changes introduced by kernel refactors. + """ + + # Get the fixture. + mols, reference = request.getfixturevalue(fixture) + + # Standard lambda schedule. + schedule = sr.cas.LambdaSchedule.standard_morph() + + # Set the lambda value. + lambda_value = 0.5 + + # Create a GCMC sampler with a fixed seed. + sampler = GCMCSampler( + mols, + cutoff_type="rf", + cutoff="10 A", + reference=reference, + lambda_schedule=schedule, + lambda_value=lambda_value, + log_level="debug", + ghost_file=None, + log_file=None, + test=True, + platform=platform, + seed=42, + ) + + # Create a dynamics object. + d = sampler.system().dynamics( + cutoff_type="rf", + cutoff="10 A", + temperature="298 K", + pressure=None, + constraint="h_bonds", + timestep="2 fs", + schedule=schedule, + lambda_value=lambda_value, + coulomb_power=sampler._coulomb_power, + shift_coulomb=str(sampler._shift_coulomb), + shift_delta=str(sampler._shift_delta), + platform=platform, + ) + + # Loop until we get an accepted insertion move. + is_accepted = False + while not is_accepted: + moves = sampler.move(d.context()) + if len(moves) > 0 and moves[0] == 0: + is_accepted = True + + # Get the energy components. + energy_coul = sampler._debug["energy_coul"] + energy_lj = sampler._debug["energy_lj"] + + # Check against reference values. + ref = _REFERENCE_ENERGIES[fixture] + assert math.isclose( + energy_coul, ref["energy_coul"], abs_tol=1e-4 + ), f"Coulomb energy changed: {energy_coul!r} != {ref['energy_coul']!r}" + assert math.isclose( + energy_lj, ref["energy_lj"], abs_tol=1e-4 + ), f"LJ energy changed: {energy_lj!r} != {ref['energy_lj']!r}" From 516db88432caee7bcb2ec5c59079797b1a7b8916 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Mon, 2 Feb 2026 15:45:32 +0000 Subject: [PATCH 2/6] Add kernel compilation cache. --- src/loch/_platforms/_base.py | 12 +++ src/loch/_platforms/_cuda.py | 129 +++++++++++++++++++----------- src/loch/_platforms/_opencl.py | 138 +++++++++++++++++++++++---------- src/loch/_sampler.py | 13 ++++ tests/test_compiler.py | 3 + 5 files changed, 210 insertions(+), 85 deletions(-) diff --git a/src/loch/_platforms/_base.py b/src/loch/_platforms/_base.py index 1a291c8..71296b2 100644 --- a/src/loch/_platforms/_base.py +++ b/src/loch/_platforms/_base.py @@ -173,6 +173,18 @@ def platform_name(self) -> str: """ pass + @property + def cache_hit(self) -> bool: + """ + Whether the last compile_kernels() call was a cache hit. + + Returns + ------- + bool + True if kernels were loaded from cache, False if freshly compiled. + """ + return getattr(self, "_cache_hit", False) + @property def compiler_log(self) -> str: """ diff --git a/src/loch/_platforms/_cuda.py b/src/loch/_platforms/_cuda.py index f297cf3..56fe42f 100644 --- a/src/loch/_platforms/_cuda.py +++ b/src/loch/_platforms/_cuda.py @@ -30,11 +30,16 @@ import numpy as _np import pycuda.driver as _cuda import pycuda.gpuarray as _gpuarray -from pycuda.compiler import SourceModule as _SourceModule +from pycuda.compiler import compile as _compile from .._kernels import code as _kernel_code from ._base import PlatformBackend as _PlatformBackend +# Module-level kernel compilation cache. Keyed on +# (device_index, num_points, num_batch, num_waters, num_atoms, compiler_optimisations). +_kernel_cache = {} +_cache_stats = {"hits": 0, "misses": 0} + class CUDAPlatform(_PlatformBackend): """ @@ -99,9 +104,10 @@ def __init__( raise ValueError( f"'device' must be between 0 and {_cuda.Device.count() - 1}" ) - self._cuda_device = _cuda.Device(device) + self._device_index = device else: - self._cuda_device = _cuda.Device(0) + self._device_index = 0 + self._cuda_device = _cuda.Device(self._device_index) # Use the primary context (shared with OpenMM and other CUDA users). self._pycuda_context = self._cuda_device.retain_primary_context() @@ -122,46 +128,69 @@ def compile_kernels(self) -> _Dict[str, _Callable]: """ Compile CUDA kernels and return callable functions. + Uses a module-level cache of compiled cubin bytes so that backends + with identical template parameters skip source compilation. + Returns ------- dict Dictionary mapping kernel names to callable kernel functions. """ - # Compile kernel module with template substitution. - # Suppress stderr but capture it for error reporting. - stderr_capture = _io.StringIO() - old_stderr = _sys.stderr - - # Build compiler options - options = [] - if self._compiler_optimisations: - options.append("--use_fast_math") - - try: - _sys.stderr = stderr_capture - mod = _SourceModule( - _kernel_code - % { - "NUM_POINTS": self._num_points, - "NUM_BATCH": self._num_batch, - "NUM_WATERS": self._num_waters, - "NUM_ATOMS": self._num_atoms, - }, - no_extern_c=True, - nvcc=self._nvcc, - options=options, - ) - except Exception as e: - stderr_output = stderr_capture.getvalue().strip() - error_msg = f"CUDA kernel compilation failed: {e}" - if stderr_output: - error_msg += f"\n{stderr_output}" - raise RuntimeError(error_msg) - finally: - _sys.stderr = old_stderr - - # Store any compiler warnings. - self._compiler_log = stderr_capture.getvalue().strip() + cache_key = ( + self._device_index, + self._num_points, + self._num_batch, + self._num_waters, + self._num_atoms, + self._compiler_optimisations, + ) + + if cache_key in _kernel_cache: + _cache_stats["hits"] += 1 + cubin = _kernel_cache[cache_key] + mod = _cuda.module_from_buffer(cubin) + self._compiler_log = "" + self._cache_hit = True + else: + _cache_stats["misses"] += 1 + + # Compile kernel source with template substitution. + # Suppress stderr but capture it for error reporting. + stderr_capture = _io.StringIO() + old_stderr = _sys.stderr + + options = [] + if self._compiler_optimisations: + options.append("--use_fast_math") + + source = _kernel_code % { + "NUM_POINTS": self._num_points, + "NUM_BATCH": self._num_batch, + "NUM_WATERS": self._num_waters, + "NUM_ATOMS": self._num_atoms, + } + + try: + _sys.stderr = stderr_capture + cubin = _compile( + source, + no_extern_c=True, + nvcc=self._nvcc, + options=options, + ) + except Exception as e: + stderr_output = stderr_capture.getvalue().strip() + error_msg = f"CUDA kernel compilation failed: {e}" + if stderr_output: + error_msg += f"\n{stderr_output}" + raise RuntimeError(error_msg) + finally: + _sys.stderr = old_stderr + + self._compiler_log = stderr_capture.getvalue().strip() + self._cache_hit = False + _kernel_cache[cache_key] = cubin + mod = _cuda.module_from_buffer(cubin) # Extract kernel functions kernels = { @@ -174,6 +203,18 @@ def compile_kernels(self) -> _Dict[str, _Callable]: return kernels + @staticmethod + def get_cache_stats(): + """Return kernel cache statistics as a dict with 'hits' and 'misses'.""" + return _cache_stats.copy() + + @staticmethod + def clear_cache(): + """Clear the kernel cache and reset statistics.""" + _kernel_cache.clear() + _cache_stats["hits"] = 0 + _cache_stats["misses"] = 0 + def to_gpu(self, array: _np.ndarray) -> _Any: """ Transfer a NumPy array to GPU memory. @@ -227,14 +268,12 @@ def from_gpu(self, buffer: _Any) -> _np.ndarray: def cleanup(self): """ - Clean up CUDA resources and release primary context reference. + Clean up CUDA resources. + + The primary context is intentionally not popped — it is shared with + OpenMM and other CUDA users and must remain current. """ - if self._pycuda_context is not None: - try: - self._pycuda_context.pop() - except Exception: - pass - self._pycuda_context = None + self._pycuda_context = None @property def platform_name(self) -> str: diff --git a/src/loch/_platforms/_opencl.py b/src/loch/_platforms/_opencl.py index 98bea7f..5f55b2a 100644 --- a/src/loch/_platforms/_opencl.py +++ b/src/loch/_platforms/_opencl.py @@ -35,6 +35,11 @@ from .._kernels import code as _kernel_code from ._base import PlatformBackend as _PlatformBackend +# Module-level kernel compilation cache. Stores compiled binaries keyed on +# (device_index, num_points, num_batch, num_waters, num_atoms, compiler_optimisations). +_kernel_cache = {} +_cache_stats = {"hits": 0, "misses": 0} + class OpenCLPlatform(_PlatformBackend): """ @@ -101,9 +106,10 @@ def __init__( raise ValueError("'device' must be of type 'int'") if device < 0 or device >= len(devices): raise ValueError(f"'device' must be between 0 and {len(devices) - 1}") - self._device = devices[device] + self._device_index = device else: - self._device = devices[0] + self._device_index = 0 + self._device = devices[self._device_index] # Create context and command queue self._context = _cl.Context([self._device]) @@ -121,62 +127,104 @@ def compile_kernels(self) -> _Dict[str, _Callable]: """ Compile OpenCL kernels and return callable functions. + Uses a module-level cache of compiled binaries so that backends with + identical template parameters skip source compilation. + Returns ------- dict Dictionary mapping kernel names to callable kernel functions. """ - # Substitute template parameters - kernel_source = _kernel_code % { - "NUM_POINTS": self._num_points, - "NUM_BATCH": self._num_batch, - "NUM_WATERS": self._num_waters, - "NUM_ATOMS": self._num_atoms, - } + cache_key = ( + self._device_index, + self._num_points, + self._num_batch, + self._num_waters, + self._num_atoms, + self._compiler_optimisations, + ) # Build compiler options build_options = [] if self._compiler_optimisations: build_options.extend(["-cl-mad-enable", "-cl-no-signed-zeros"]) - # Compile program, suppressing stderr and warnings but capturing for errors. - stderr_capture = _io.StringIO() - old_stderr = _sys.stderr - try: - _sys.stderr = stderr_capture - with _warnings.catch_warnings(): - _warnings.simplefilter("ignore") - program = _cl.Program(self._context, kernel_source).build( - options=build_options - ) - except _cl.RuntimeError as e: - stderr_output = stderr_capture.getvalue().strip() - error_msg = f"OpenCL kernel compilation failed: {e}" - if stderr_output: - error_msg += f"\n{stderr_output}" - raise RuntimeError(error_msg) - finally: - _sys.stderr = old_stderr - - # Capture the compiler log (including any warnings). - self._compiler_log = program.get_build_info( - self._device, _cl.program_build_info.LOG - ).strip() - - # Create kernel wrappers that match PyCUDA calling convention + if cache_key in _kernel_cache: + _cache_stats["hits"] += 1 + cached_binary = _kernel_cache[cache_key] + + # Create program from cached binary. + stderr_capture = _io.StringIO() + old_stderr = _sys.stderr + try: + _sys.stderr = stderr_capture + with _warnings.catch_warnings(): + _warnings.simplefilter("ignore") + program = _cl.Program( + self._context, [self._device], [cached_binary] + ) + program.build(options=build_options) + except _cl.RuntimeError as e: + stderr_output = stderr_capture.getvalue().strip() + error_msg = f"OpenCL kernel build from cached binary failed: {e}" + if stderr_output: + error_msg += f"\n{stderr_output}" + raise RuntimeError(error_msg) + finally: + _sys.stderr = old_stderr + + self._compiler_log = "" + self._cache_hit = True + else: + _cache_stats["misses"] += 1 + + # Substitute template parameters + kernel_source = _kernel_code % { + "NUM_POINTS": self._num_points, + "NUM_BATCH": self._num_batch, + "NUM_WATERS": self._num_waters, + "NUM_ATOMS": self._num_atoms, + } + + # Compile program from source, suppressing stderr and warnings. + stderr_capture = _io.StringIO() + old_stderr = _sys.stderr + try: + _sys.stderr = stderr_capture + with _warnings.catch_warnings(): + _warnings.simplefilter("ignore") + program = _cl.Program(self._context, kernel_source).build( + options=build_options + ) + except _cl.RuntimeError as e: + stderr_output = stderr_capture.getvalue().strip() + error_msg = f"OpenCL kernel compilation failed: {e}" + if stderr_output: + error_msg += f"\n{stderr_output}" + raise RuntimeError(error_msg) + finally: + _sys.stderr = old_stderr + + # Capture the compiler log (including any warnings). + self._compiler_log = program.get_build_info( + self._device, _cl.program_build_info.LOG + ).strip() + self._cache_hit = False + + # Cache the compiled binary. + _kernel_cache[cache_key] = program.get_info(_cl.program_info.BINARIES)[0] + + # Create kernel wrappers that match PyCUDA calling convention. # OpenCL kernels need (queue, global_size, local_size, *args) - # We'll wrap them to match CUDA's (args..., block, grid) signature + # We wrap them to match CUDA's (args..., block, grid) signature. def make_kernel_wrapper(kernel): def wrapper(*args, **kwargs): - # Extract block and grid from kwargs block = kwargs.get("block", (self._num_threads, 1, 1)) grid = kwargs.get("grid", (1, 1, 1)) - # Calculate global work size global_size = tuple(b * g for b, g in zip(block, grid)) local_size = block - # Convert PyOpenCL array objects to their .data buffers processed_args = [] for arg in args: if isinstance(arg, _cl_array.Array): @@ -184,13 +232,11 @@ def wrapper(*args, **kwargs): else: processed_args.append(arg) - # Execute kernel kernel(self._queue, global_size, local_size, *processed_args) self._queue.finish() return wrapper - # Extract and wrap kernel functions kernels = { "update_water": make_kernel_wrapper(program.updateWater), "deletion": make_kernel_wrapper(program.findDeletionCandidates), @@ -201,6 +247,18 @@ def wrapper(*args, **kwargs): return kernels + @staticmethod + def get_cache_stats(): + """Return kernel cache statistics as a dict with 'hits' and 'misses'.""" + return _cache_stats.copy() + + @staticmethod + def clear_cache(): + """Clear the kernel cache and reset statistics.""" + _kernel_cache.clear() + _cache_stats["hits"] = 0 + _cache_stats["misses"] = 0 + def to_gpu(self, array: _np.ndarray) -> _Any: """ Transfer a NumPy array to GPU memory. diff --git a/src/loch/_sampler.py b/src/loch/_sampler.py index 539dc12..76d10dd 100644 --- a/src/loch/_sampler.py +++ b/src/loch/_sampler.py @@ -751,6 +751,19 @@ def _get_non_ghost_waters(self) -> _np.ndarray: """Get indices of non-ghost waters (cached).""" return self._non_ghost_waters_cache + @property + def kernel_cache_hit(self) -> bool: + """ + Whether kernel compilation was satisfied from cache. + + Returns + ------- + + cache_hit: bool + True if kernels were loaded from cache, False if freshly compiled. + """ + return self._backend.cache_hit + def system(self) -> _Any: """ Return the GCMC system. diff --git a/tests/test_compiler.py b/tests/test_compiler.py index b369da3..8de255a 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -113,6 +113,9 @@ def test_compilation_error_raises_exception(self): nvcc=_get_nvcc(), ) + # Clear the kernel cache so the patched code is actually compiled. + CUDAPlatform.clear_cache() + # Patch kernel code directly in the cuda module (not the kernels module, # since it's already imported as _kernel_code at module load time). original_code = cuda_module._kernel_code From cfa01c211f4b078d21a93cb1b2c75cf97065ea66 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Mon, 2 Feb 2026 16:13:29 +0000 Subject: [PATCH 3/6] Still need to push/pop CUDA contexts. --- src/loch/_platforms/_base.py | 17 +++++++++++++++++ src/loch/_platforms/_cuda.py | 24 +++++++++++++++++++----- src/loch/_sampler.py | 8 ++++++++ 3 files changed, 44 insertions(+), 5 deletions(-) diff --git a/src/loch/_platforms/_base.py b/src/loch/_platforms/_base.py index 71296b2..1584a4f 100644 --- a/src/loch/_platforms/_base.py +++ b/src/loch/_platforms/_base.py @@ -150,6 +150,23 @@ def from_gpu(self, buffer: _Any) -> _np.ndarray: """ pass + def push_context(self): + """ + Make the GPU context current on the calling thread. + + For CUDA, this pushes the primary context onto the thread-local + context stack. For OpenCL, this is a no-op. + """ + pass + + def pop_context(self): + """ + Remove the GPU context from the calling thread's stack. + + For CUDA, this pops the primary context. For OpenCL, this is a no-op. + """ + pass + @_abstractmethod def cleanup(self): """ diff --git a/src/loch/_platforms/_cuda.py b/src/loch/_platforms/_cuda.py index 56fe42f..418d715 100644 --- a/src/loch/_platforms/_cuda.py +++ b/src/loch/_platforms/_cuda.py @@ -266,14 +266,28 @@ def from_gpu(self, buffer: _Any) -> _np.ndarray: """ return buffer.get() - def cleanup(self): + def push_context(self): + """ + Push the primary context onto the calling thread's context stack. """ - Clean up CUDA resources. + self._pycuda_context.push() - The primary context is intentionally not popped — it is shared with - OpenMM and other CUDA users and must remain current. + def pop_context(self): """ - self._pycuda_context = None + Pop the primary context from the calling thread's context stack. + """ + self._pycuda_context.pop() + + def cleanup(self): + """ + Clean up CUDA resources and pop the context pushed during __init__. + """ + if self._pycuda_context is not None: + try: + self._pycuda_context.pop() + except Exception: + pass + self._pycuda_context = None @property def platform_name(self) -> str: diff --git a/src/loch/_sampler.py b/src/loch/_sampler.py index 76d10dd..79744e9 100644 --- a/src/loch/_sampler.py +++ b/src/loch/_sampler.py @@ -751,6 +751,14 @@ def _get_non_ghost_waters(self) -> _np.ndarray: """Get indices of non-ghost waters (cached).""" return self._non_ghost_waters_cache + def push(self) -> None: + """Push the GPU context onto the calling thread's context stack.""" + self._backend.push_context() + + def pop(self) -> None: + """Pop the GPU context from the calling thread's context stack.""" + self._backend.pop_context() + @property def kernel_cache_hit(self) -> bool: """ From 7f5e4de3fecb12bd58385fdc43487da4316d7aa8 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Tue, 3 Feb 2026 11:10:26 +0000 Subject: [PATCH 4/6] Refactor kernels to remove compile-time constants. --- src/loch/_kernels.py | 26 ++++++-- src/loch/_platforms/_base.py | 12 ---- src/loch/_platforms/_cuda.py | 103 ++++++++--------------------- src/loch/_platforms/_opencl.py | 116 +++++++-------------------------- src/loch/_sampler.py | 28 ++++---- tests/test_compiler.py | 3 - 6 files changed, 85 insertions(+), 203 deletions(-) diff --git a/src/loch/_kernels.py b/src/loch/_kernels.py index 710b107..352775f 100644 --- a/src/loch/_kernels.py +++ b/src/loch/_kernels.py @@ -55,13 +55,12 @@ // Constants. const float pi = 3.14159265359f; - const int num_points = %(NUM_POINTS)s; - const int num_batch = %(NUM_BATCH)s; - const int num_atoms = %(NUM_ATOMS)s; - const int num_waters = %(NUM_WATERS)s; - const int num_water_positions = 3 * num_points; const float prefactor = 332.0637090025476f; + // Maximum number of atoms per water molecule (for stack array sizing). + #define MAX_POINTS 5 + #define MAX_WATER_POSITIONS (3 * MAX_POINTS) + #ifndef __OPENCL_VERSION__ extern "C" { @@ -216,6 +215,7 @@ // Update a single water. KERNEL void updateWater( + int num_points, int idx, int state, int is_insertion, @@ -267,6 +267,8 @@ // Generate a random position and orientation within the GCMC sphere // for each trial insertion. KERNEL void generateWater( + int num_points, + int num_batch, GLOBAL float* water_template, GLOBAL float* target, float radius, @@ -283,8 +285,10 @@ // Make sure we are within the number of waters. if (tidx < num_batch) { + const int num_water_positions = 3 * num_points; + // Translate the oxygen atom to the origin. - float water[num_water_positions]; + float water[MAX_WATER_POSITIONS]; water[0] = 0.0f; water[1] = 0.0f; water[2] = 0.0f; @@ -304,7 +308,7 @@ randoms_rotation[tidx * 3 + 2]); // Calculate the distance between the oxygen and the hydrogens. - float dh[num_points][3]; + float dh[MAX_POINTS][3]; for (int i = 0; i < num_points-1; i++) { dh[i][0] = water[(i+1)*3] - water[0]; @@ -368,6 +372,9 @@ // Compute the Lennard-Jones and reaction field Coulomb energy between // the water and the atoms. KERNEL void computeEnergy( + int num_points, + int num_batch, + int num_atoms, GLOBAL float* water_position, GLOBAL float* energy_coul, GLOBAL float* energy_lj, @@ -400,6 +407,8 @@ // Make sure we're in bounds. if (idx_atom < num_atoms) { + const int num_water_positions = 3 * num_points; + // Store the squared cut-off distance. const float cutoff2 = rf_cutoff * rf_cutoff; @@ -582,6 +591,8 @@ // Calculate whether each attempt is accepted. KERNEL void checkAcceptance( + int num_batch, + int num_atoms, int N, float exp_B, float exp_minus_B, @@ -654,6 +665,7 @@ // Find candidate waters for deletion. KERNEL void findDeletionCandidates( + int num_waters, GLOBAL int* candidates, GLOBAL float* target, float radius, diff --git a/src/loch/_platforms/_base.py b/src/loch/_platforms/_base.py index 1584a4f..1a720ad 100644 --- a/src/loch/_platforms/_base.py +++ b/src/loch/_platforms/_base.py @@ -190,18 +190,6 @@ def platform_name(self) -> str: """ pass - @property - def cache_hit(self) -> bool: - """ - Whether the last compile_kernels() call was a cache hit. - - Returns - ------- - bool - True if kernels were loaded from cache, False if freshly compiled. - """ - return getattr(self, "_cache_hit", False) - @property def compiler_log(self) -> str: """ diff --git a/src/loch/_platforms/_cuda.py b/src/loch/_platforms/_cuda.py index 418d715..7ad13cf 100644 --- a/src/loch/_platforms/_cuda.py +++ b/src/loch/_platforms/_cuda.py @@ -35,11 +35,6 @@ from .._kernels import code as _kernel_code from ._base import PlatformBackend as _PlatformBackend -# Module-level kernel compilation cache. Keyed on -# (device_index, num_points, num_batch, num_waters, num_atoms, compiler_optimisations). -_kernel_cache = {} -_cache_stats = {"hits": 0, "misses": 0} - class CUDAPlatform(_PlatformBackend): """ @@ -128,69 +123,39 @@ def compile_kernels(self) -> _Dict[str, _Callable]: """ Compile CUDA kernels and return callable functions. - Uses a module-level cache of compiled cubin bytes so that backends - with identical template parameters skip source compilation. - Returns ------- dict Dictionary mapping kernel names to callable kernel functions. """ - cache_key = ( - self._device_index, - self._num_points, - self._num_batch, - self._num_waters, - self._num_atoms, - self._compiler_optimisations, - ) - - if cache_key in _kernel_cache: - _cache_stats["hits"] += 1 - cubin = _kernel_cache[cache_key] - mod = _cuda.module_from_buffer(cubin) - self._compiler_log = "" - self._cache_hit = True - else: - _cache_stats["misses"] += 1 - - # Compile kernel source with template substitution. - # Suppress stderr but capture it for error reporting. - stderr_capture = _io.StringIO() - old_stderr = _sys.stderr - - options = [] - if self._compiler_optimisations: - options.append("--use_fast_math") - - source = _kernel_code % { - "NUM_POINTS": self._num_points, - "NUM_BATCH": self._num_batch, - "NUM_WATERS": self._num_waters, - "NUM_ATOMS": self._num_atoms, - } - - try: - _sys.stderr = stderr_capture - cubin = _compile( - source, - no_extern_c=True, - nvcc=self._nvcc, - options=options, - ) - except Exception as e: - stderr_output = stderr_capture.getvalue().strip() - error_msg = f"CUDA kernel compilation failed: {e}" - if stderr_output: - error_msg += f"\n{stderr_output}" - raise RuntimeError(error_msg) - finally: - _sys.stderr = old_stderr - - self._compiler_log = stderr_capture.getvalue().strip() - self._cache_hit = False - _kernel_cache[cache_key] = cubin - mod = _cuda.module_from_buffer(cubin) + # Compile kernel source. + # Suppress stderr but capture it for error reporting. + stderr_capture = _io.StringIO() + old_stderr = _sys.stderr + + options = [] + if self._compiler_optimisations: + options.append("--use_fast_math") + + try: + _sys.stderr = stderr_capture + cubin = _compile( + _kernel_code, + no_extern_c=True, + nvcc=self._nvcc, + options=options, + ) + except Exception as e: + stderr_output = stderr_capture.getvalue().strip() + error_msg = f"CUDA kernel compilation failed: {e}" + if stderr_output: + error_msg += f"\n{stderr_output}" + raise RuntimeError(error_msg) + finally: + _sys.stderr = old_stderr + + self._compiler_log = stderr_capture.getvalue().strip() + mod = _cuda.module_from_buffer(cubin) # Extract kernel functions kernels = { @@ -203,18 +168,6 @@ def compile_kernels(self) -> _Dict[str, _Callable]: return kernels - @staticmethod - def get_cache_stats(): - """Return kernel cache statistics as a dict with 'hits' and 'misses'.""" - return _cache_stats.copy() - - @staticmethod - def clear_cache(): - """Clear the kernel cache and reset statistics.""" - _kernel_cache.clear() - _cache_stats["hits"] = 0 - _cache_stats["misses"] = 0 - def to_gpu(self, array: _np.ndarray) -> _Any: """ Transfer a NumPy array to GPU memory. diff --git a/src/loch/_platforms/_opencl.py b/src/loch/_platforms/_opencl.py index 5f55b2a..b699344 100644 --- a/src/loch/_platforms/_opencl.py +++ b/src/loch/_platforms/_opencl.py @@ -35,11 +35,6 @@ from .._kernels import code as _kernel_code from ._base import PlatformBackend as _PlatformBackend -# Module-level kernel compilation cache. Stores compiled binaries keyed on -# (device_index, num_points, num_batch, num_waters, num_atoms, compiler_optimisations). -_kernel_cache = {} -_cache_stats = {"hits": 0, "misses": 0} - class OpenCLPlatform(_PlatformBackend): """ @@ -127,92 +122,39 @@ def compile_kernels(self) -> _Dict[str, _Callable]: """ Compile OpenCL kernels and return callable functions. - Uses a module-level cache of compiled binaries so that backends with - identical template parameters skip source compilation. - Returns ------- dict Dictionary mapping kernel names to callable kernel functions. """ - cache_key = ( - self._device_index, - self._num_points, - self._num_batch, - self._num_waters, - self._num_atoms, - self._compiler_optimisations, - ) - # Build compiler options build_options = [] if self._compiler_optimisations: build_options.extend(["-cl-mad-enable", "-cl-no-signed-zeros"]) - if cache_key in _kernel_cache: - _cache_stats["hits"] += 1 - cached_binary = _kernel_cache[cache_key] - - # Create program from cached binary. - stderr_capture = _io.StringIO() - old_stderr = _sys.stderr - try: - _sys.stderr = stderr_capture - with _warnings.catch_warnings(): - _warnings.simplefilter("ignore") - program = _cl.Program( - self._context, [self._device], [cached_binary] - ) - program.build(options=build_options) - except _cl.RuntimeError as e: - stderr_output = stderr_capture.getvalue().strip() - error_msg = f"OpenCL kernel build from cached binary failed: {e}" - if stderr_output: - error_msg += f"\n{stderr_output}" - raise RuntimeError(error_msg) - finally: - _sys.stderr = old_stderr - - self._compiler_log = "" - self._cache_hit = True - else: - _cache_stats["misses"] += 1 - - # Substitute template parameters - kernel_source = _kernel_code % { - "NUM_POINTS": self._num_points, - "NUM_BATCH": self._num_batch, - "NUM_WATERS": self._num_waters, - "NUM_ATOMS": self._num_atoms, - } - - # Compile program from source, suppressing stderr and warnings. - stderr_capture = _io.StringIO() - old_stderr = _sys.stderr - try: - _sys.stderr = stderr_capture - with _warnings.catch_warnings(): - _warnings.simplefilter("ignore") - program = _cl.Program(self._context, kernel_source).build( - options=build_options - ) - except _cl.RuntimeError as e: - stderr_output = stderr_capture.getvalue().strip() - error_msg = f"OpenCL kernel compilation failed: {e}" - if stderr_output: - error_msg += f"\n{stderr_output}" - raise RuntimeError(error_msg) - finally: - _sys.stderr = old_stderr - - # Capture the compiler log (including any warnings). - self._compiler_log = program.get_build_info( - self._device, _cl.program_build_info.LOG - ).strip() - self._cache_hit = False - - # Cache the compiled binary. - _kernel_cache[cache_key] = program.get_info(_cl.program_info.BINARIES)[0] + # Compile program from source, suppressing stderr and warnings. + stderr_capture = _io.StringIO() + old_stderr = _sys.stderr + try: + _sys.stderr = stderr_capture + with _warnings.catch_warnings(): + _warnings.simplefilter("ignore") + program = _cl.Program(self._context, _kernel_code).build( + options=build_options + ) + except _cl.RuntimeError as e: + stderr_output = stderr_capture.getvalue().strip() + error_msg = f"OpenCL kernel compilation failed: {e}" + if stderr_output: + error_msg += f"\n{stderr_output}" + raise RuntimeError(error_msg) + finally: + _sys.stderr = old_stderr + + # Capture the compiler log (including any warnings). + self._compiler_log = program.get_build_info( + self._device, _cl.program_build_info.LOG + ).strip() # Create kernel wrappers that match PyCUDA calling convention. # OpenCL kernels need (queue, global_size, local_size, *args) @@ -247,18 +189,6 @@ def wrapper(*args, **kwargs): return kernels - @staticmethod - def get_cache_stats(): - """Return kernel cache statistics as a dict with 'hits' and 'misses'.""" - return _cache_stats.copy() - - @staticmethod - def clear_cache(): - """Clear the kernel cache and reset statistics.""" - _kernel_cache.clear() - _cache_stats["hits"] = 0 - _cache_stats["misses"] = 0 - def to_gpu(self, array: _np.ndarray) -> _Any: """ Transfer a NumPy array to GPU memory. diff --git a/src/loch/_sampler.py b/src/loch/_sampler.py index 79744e9..dc90336 100644 --- a/src/loch/_sampler.py +++ b/src/loch/_sampler.py @@ -759,19 +759,6 @@ def pop(self) -> None: """Pop the GPU context from the calling thread's context stack.""" self._backend.pop_context() - @property - def kernel_cache_hit(self) -> bool: - """ - Whether kernel compilation was satisfied from cache. - - Returns - ------- - - cache_hit: bool - True if kernels were loaded from cache, False if freshly compiled. - """ - return self._backend.cache_hit - def system(self) -> _Any: """ Return the GCMC system. @@ -892,6 +879,7 @@ def delete_waters(self, context: _openmm.Context) -> None: # Find the non-ghost waters within the GCMC region. self._kernels["deletion"]( + _np.int32(self._num_waters), self._deletion_candidates, self._backend.to_gpu(target.astype(_np.float32)), _np.float32(self._radius.value()), @@ -954,6 +942,7 @@ def num_waters(self) -> int: # Find the non-ghost waters within the GCMC region. self._kernels["deletion"]( + _np.int32(self._num_waters), self._deletion_candidates, self._backend.to_gpu(target.astype(_np.float32)), _np.float32(self._radius.value()), @@ -1188,6 +1177,7 @@ def move(self, context: _openmm.Context) -> list[int]: # Work out the number of waters in the sampling volume. if not self._is_bulk: self._kernels["deletion"]( + _np.int32(self._num_waters), self._deletion_candidates, self._backend.to_gpu(_as_float32(target)), _np.float32(self._radius.value()), @@ -1286,6 +1276,8 @@ def move(self, context: _openmm.Context) -> list[int]: # Generate the random water positions and orientations. self._kernels["water"]( + _np.int32(self._num_points), + _np.int32(self._batch_size), template_positions, target_gpu, _np.float32(self._radius.value()), @@ -1301,6 +1293,9 @@ def move(self, context: _openmm.Context) -> list[int]: # Perform the energy calculation. self._kernels["energy"]( + _np.int32(self._num_points), + _np.int32(self._batch_size), + _np.int32(self._num_atoms), self._water_positions, self._energy_coul, self._energy_lj, @@ -1335,6 +1330,8 @@ def move(self, context: _openmm.Context) -> list[int]: # Check the acceptance for each trial state. self._kernels["acceptance"]( + _np.int32(self._batch_size), + _np.int32(self._num_atoms), _np.int32(self._N), _np.float32(exp_B), _np.float32(exp_minus_B), @@ -2183,6 +2180,7 @@ def _accept_insertion( # Update the state of the water on the GPU. self._kernels["update_water"]( + _np.int32(self._num_points), _np.int32(idx_water), _np.int32(1), _np.int32(1), @@ -2250,6 +2248,7 @@ def _accept_deletion(self, idx, context): # Update the state of the water on the GPU. self._kernels["update_water"]( + _np.int32(self._num_points), _np.int32(idx), _np.int32(0), _np.int32(0), @@ -2322,6 +2321,7 @@ def _reject_deletion(self, idx, context): # Update the state of the water on the GPU. self._kernels["update_water"]( + _np.int32(self._num_points), _np.int32(idx), _np.int32(1), _np.int32(0), @@ -2424,6 +2424,7 @@ def _set_water_state(self, context, indices=None, states=None, force=False): # Update the state of the water on the GPU. self._kernels["update_water"]( + _np.int32(self._num_points), _np.int32(idx), _np.int32(0), _np.int32(0), @@ -2470,6 +2471,7 @@ def _set_water_state(self, context, indices=None, states=None, force=False): # Update the state of the water on the GPU. self._kernels["update_water"]( + _np.int32(self._num_points), _np.int32(idx), _np.int32(1), _np.int32(0), diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 8de255a..b369da3 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -113,9 +113,6 @@ def test_compilation_error_raises_exception(self): nvcc=_get_nvcc(), ) - # Clear the kernel cache so the patched code is actually compiled. - CUDAPlatform.clear_cache() - # Patch kernel code directly in the cuda module (not the kernels module, # since it's already imported as _kernel_code at module load time). original_code = cuda_module._kernel_code From 40ee8ef2ea5c0055a4bacc02321890f7372275de Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Tue, 3 Feb 2026 11:14:10 +0000 Subject: [PATCH 5/6] Fix hard-coded water model in uniform_random_rotation kernel. --- src/loch/_kernels.py | 77 +++++++++++++++++++++----------------------- 1 file changed, 37 insertions(+), 40 deletions(-) diff --git a/src/loch/_kernels.py b/src/loch/_kernels.py index 352775f..ba20ebb 100644 --- a/src/loch/_kernels.py +++ b/src/loch/_kernels.py @@ -126,7 +126,7 @@ } // Perform a random rotation about a unit sphere. - DEVICE void uniform_random_rotation(float* v, float r0, float r1, float r2) + DEVICE void uniform_random_rotation(float* v, int num_points, float r0, float r1, float r2) { /* Adapted from: https://www.blopig.com/blog/2021/08/uniformly-sampled-3d-rotation-matrices/ @@ -167,50 +167,47 @@ // Now compute M = -(H @ R), i.e. rotate all points around the x axis. float M[3][3]; - M[0][0] = -(H[0][0] * R[0][0] + H[0][1] * R[1][0] + H[0][2] * R[2][0]); - M[0][1] = -(H[0][0] * R[0][1] + H[0][1] * R[1][1] + H[0][2] * R[2][1]); - M[0][2] = -(H[0][0] * R[0][2] + H[0][1] * R[1][2] + H[0][2] * R[2][2]); - M[1][0] = -(H[1][0] * R[0][0] + H[1][1] * R[1][0] + H[1][2] * R[2][0]); - M[1][1] = -(H[1][0] * R[0][1] + H[1][1] * R[1][1] + H[1][2] * R[2][1]); - M[1][2] = -(H[1][0] * R[0][2] + H[1][1] * R[1][2] + H[1][2] * R[2][2]); - M[2][0] = -(H[2][0] * R[0][0] + H[2][1] * R[1][0] + H[2][2] * R[2][0]); - M[2][1] = -(H[2][0] * R[0][1] + H[2][1] * R[1][1] + H[2][2] * R[2][1]); - M[2][2] = -(H[2][0] * R[0][2] + H[2][1] * R[1][2] + H[2][2] * R[2][2]); + for (int i = 0; i < 3; i++) + { + for (int j = 0; j < 3; j++) + { + M[i][j] = -(H[i][0] * R[0][j] + H[i][1] * R[1][j] + H[i][2] * R[2][j]); + } + } // Compute the mean coordinate of the water molecule. float mean_coord[3]; - mean_coord[0] = (v[0] + v[3] + v[6]) / 3.0f; - mean_coord[1] = (v[1] + v[4] + v[7]) / 3.0f; - mean_coord[2] = (v[2] + v[5] + v[8]) / 3.0f; + mean_coord[0] = 0.0f; + mean_coord[1] = 0.0f; + mean_coord[2] = 0.0f; + for (int i = 0; i < num_points; i++) + { + mean_coord[0] += v[i * 3]; + mean_coord[1] += v[i * 3 + 1]; + mean_coord[2] += v[i * 3 + 2]; + } + mean_coord[0] /= (float)num_points; + mean_coord[1] /= (float)num_points; + mean_coord[2] /= (float)num_points; // Precompute mean_coord @ M (avoids redundant calculations). float mean_M[3]; - mean_M[0] = fmaf(mean_coord[0], M[0][0], fmaf(mean_coord[1], M[1][0], mean_coord[2] * M[2][0])); - mean_M[1] = fmaf(mean_coord[0], M[0][1], fmaf(mean_coord[1], M[1][1], mean_coord[2] * M[2][1])); - mean_M[2] = fmaf(mean_coord[0], M[0][2], fmaf(mean_coord[1], M[1][2], mean_coord[2] * M[2][2])); - - // Now compute ((v - mean_coord) @ M) + mean_M. - float x[3][3]; - x[0][0] = v[0] - mean_coord[0]; - x[0][1] = v[1] - mean_coord[1]; - x[0][2] = v[2] - mean_coord[2]; - x[1][0] = v[3] - mean_coord[0]; - x[1][1] = v[4] - mean_coord[1]; - x[1][2] = v[5] - mean_coord[2]; - x[2][0] = v[6] - mean_coord[0]; - x[2][1] = v[7] - mean_coord[1]; - x[2][2] = v[8] - mean_coord[2]; - - // Compute the rotated coordinates using fma. - v[0] = fmaf(x[0][0], M[0][0], fmaf(x[0][1], M[1][0], fmaf(x[0][2], M[2][0], mean_M[0]))); - v[1] = fmaf(x[0][0], M[0][1], fmaf(x[0][1], M[1][1], fmaf(x[0][2], M[2][1], mean_M[1]))); - v[2] = fmaf(x[0][0], M[0][2], fmaf(x[0][1], M[1][2], fmaf(x[0][2], M[2][2], mean_M[2]))); - v[3] = fmaf(x[1][0], M[0][0], fmaf(x[1][1], M[1][0], fmaf(x[1][2], M[2][0], mean_M[0]))); - v[4] = fmaf(x[1][0], M[0][1], fmaf(x[1][1], M[1][1], fmaf(x[1][2], M[2][1], mean_M[1]))); - v[5] = fmaf(x[1][0], M[0][2], fmaf(x[1][1], M[1][2], fmaf(x[1][2], M[2][2], mean_M[2]))); - v[6] = fmaf(x[2][0], M[0][0], fmaf(x[2][1], M[1][0], fmaf(x[2][2], M[2][0], mean_M[0]))); - v[7] = fmaf(x[2][0], M[0][1], fmaf(x[2][1], M[1][1], fmaf(x[2][2], M[2][1], mean_M[1]))); - v[8] = fmaf(x[2][0], M[0][2], fmaf(x[2][1], M[1][2], fmaf(x[2][2], M[2][2], mean_M[2]))); + for (int j = 0; j < 3; j++) + { + mean_M[j] = fmaf(mean_coord[0], M[0][j], fmaf(mean_coord[1], M[1][j], mean_coord[2] * M[2][j])); + } + + // Compute ((v - mean_coord) @ M) + mean_M for each atom. + for (int i = 0; i < num_points; i++) + { + float dx = v[i * 3] - mean_coord[0]; + float dy = v[i * 3 + 1] - mean_coord[1]; + float dz = v[i * 3 + 2] - mean_coord[2]; + + v[i * 3] = fmaf(dx, M[0][0], fmaf(dy, M[1][0], fmaf(dz, M[2][0], mean_M[0]))); + v[i * 3 + 1] = fmaf(dx, M[0][1], fmaf(dy, M[1][1], fmaf(dz, M[2][1], mean_M[1]))); + v[i * 3 + 2] = fmaf(dx, M[0][2], fmaf(dy, M[1][2], fmaf(dz, M[2][2], mean_M[2]))); + } } // Update a single water. @@ -302,7 +299,7 @@ } // Rotate the water randomly using pre-generated randoms. - uniform_random_rotation(water, + uniform_random_rotation(water, num_points, randoms_rotation[tidx * 3], randoms_rotation[tidx * 3 + 1], randoms_rotation[tidx * 3 + 2]); From d4975e962b7f69006b5b0137e25d8c2d23c60db8 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Tue, 3 Feb 2026 12:43:08 +0000 Subject: [PATCH 6/6] Re-add simplified cache for compiled kernel code. --- src/loch/_platforms/_base.py | 12 +++++ src/loch/_platforms/_cuda.py | 78 +++++++++++++++++++---------- src/loch/_platforms/_opencl.py | 91 +++++++++++++++++++++++++--------- src/loch/_sampler.py | 13 +++++ tests/test_compiler.py | 3 ++ tests/test_energy.py | 87 ++++++++++++++++++++++++++++++++ 6 files changed, 234 insertions(+), 50 deletions(-) diff --git a/src/loch/_platforms/_base.py b/src/loch/_platforms/_base.py index 1a720ad..1584a4f 100644 --- a/src/loch/_platforms/_base.py +++ b/src/loch/_platforms/_base.py @@ -190,6 +190,18 @@ def platform_name(self) -> str: """ pass + @property + def cache_hit(self) -> bool: + """ + Whether the last compile_kernels() call was a cache hit. + + Returns + ------- + bool + True if kernels were loaded from cache, False if freshly compiled. + """ + return getattr(self, "_cache_hit", False) + @property def compiler_log(self) -> str: """ diff --git a/src/loch/_platforms/_cuda.py b/src/loch/_platforms/_cuda.py index 7ad13cf..4047115 100644 --- a/src/loch/_platforms/_cuda.py +++ b/src/loch/_platforms/_cuda.py @@ -35,6 +35,12 @@ from .._kernels import code as _kernel_code from ._base import PlatformBackend as _PlatformBackend +# Module-level kernel compilation cache. Keyed on +# (device_index, compiler_optimisations). Since the kernel source no longer +# depends on system-specific parameters, the same compiled binary can be +# reused across all samplers on a given device. +_kernel_cache = {} + class CUDAPlatform(_PlatformBackend): """ @@ -123,38 +129,51 @@ def compile_kernels(self) -> _Dict[str, _Callable]: """ Compile CUDA kernels and return callable functions. + Uses a module-level cache so that only the first sampler on a given + device pays the nvcc compilation cost. + Returns ------- dict Dictionary mapping kernel names to callable kernel functions. """ - # Compile kernel source. - # Suppress stderr but capture it for error reporting. - stderr_capture = _io.StringIO() - old_stderr = _sys.stderr - - options = [] - if self._compiler_optimisations: - options.append("--use_fast_math") - - try: - _sys.stderr = stderr_capture - cubin = _compile( - _kernel_code, - no_extern_c=True, - nvcc=self._nvcc, - options=options, - ) - except Exception as e: - stderr_output = stderr_capture.getvalue().strip() - error_msg = f"CUDA kernel compilation failed: {e}" - if stderr_output: - error_msg += f"\n{stderr_output}" - raise RuntimeError(error_msg) - finally: - _sys.stderr = old_stderr - - self._compiler_log = stderr_capture.getvalue().strip() + cache_key = (self._device_index, self._compiler_optimisations) + + if cache_key in _kernel_cache: + cubin = _kernel_cache[cache_key] + self._compiler_log = "" + self._cache_hit = True + else: + # Compile kernel source. + # Suppress stderr but capture it for error reporting. + stderr_capture = _io.StringIO() + old_stderr = _sys.stderr + + options = [] + if self._compiler_optimisations: + options.append("--use_fast_math") + + try: + _sys.stderr = stderr_capture + cubin = _compile( + _kernel_code, + no_extern_c=True, + nvcc=self._nvcc, + options=options, + ) + except Exception as e: + stderr_output = stderr_capture.getvalue().strip() + error_msg = f"CUDA kernel compilation failed: {e}" + if stderr_output: + error_msg += f"\n{stderr_output}" + raise RuntimeError(error_msg) + finally: + _sys.stderr = old_stderr + + self._compiler_log = stderr_capture.getvalue().strip() + self._cache_hit = False + _kernel_cache[cache_key] = cubin + mod = _cuda.module_from_buffer(cubin) # Extract kernel functions @@ -168,6 +187,11 @@ def compile_kernels(self) -> _Dict[str, _Callable]: return kernels + @staticmethod + def clear_cache(): + """Clear the kernel compilation cache.""" + _kernel_cache.clear() + def to_gpu(self, array: _np.ndarray) -> _Any: """ Transfer a NumPy array to GPU memory. diff --git a/src/loch/_platforms/_opencl.py b/src/loch/_platforms/_opencl.py index b699344..6395c43 100644 --- a/src/loch/_platforms/_opencl.py +++ b/src/loch/_platforms/_opencl.py @@ -35,6 +35,10 @@ from .._kernels import code as _kernel_code from ._base import PlatformBackend as _PlatformBackend +# Module-level kernel compilation cache. Keyed on +# (device_index, compiler_optimisations). Stores compiled program binaries. +_kernel_cache = {} + class OpenCLPlatform(_PlatformBackend): """ @@ -122,39 +126,75 @@ def compile_kernels(self) -> _Dict[str, _Callable]: """ Compile OpenCL kernels and return callable functions. + Uses a module-level cache so that only the first sampler on a given + device pays the compilation cost. + Returns ------- dict Dictionary mapping kernel names to callable kernel functions. """ + cache_key = (self._device_index, self._compiler_optimisations) + # Build compiler options build_options = [] if self._compiler_optimisations: build_options.extend(["-cl-mad-enable", "-cl-no-signed-zeros"]) - # Compile program from source, suppressing stderr and warnings. - stderr_capture = _io.StringIO() - old_stderr = _sys.stderr - try: - _sys.stderr = stderr_capture - with _warnings.catch_warnings(): - _warnings.simplefilter("ignore") - program = _cl.Program(self._context, _kernel_code).build( - options=build_options - ) - except _cl.RuntimeError as e: - stderr_output = stderr_capture.getvalue().strip() - error_msg = f"OpenCL kernel compilation failed: {e}" - if stderr_output: - error_msg += f"\n{stderr_output}" - raise RuntimeError(error_msg) - finally: - _sys.stderr = old_stderr - - # Capture the compiler log (including any warnings). - self._compiler_log = program.get_build_info( - self._device, _cl.program_build_info.LOG - ).strip() + if cache_key in _kernel_cache: + cached_binary = _kernel_cache[cache_key] + + # Create program from cached binary. + stderr_capture = _io.StringIO() + old_stderr = _sys.stderr + try: + _sys.stderr = stderr_capture + with _warnings.catch_warnings(): + _warnings.simplefilter("ignore") + program = _cl.Program( + self._context, [self._device], [cached_binary] + ) + program.build(options=build_options) + except _cl.RuntimeError as e: + stderr_output = stderr_capture.getvalue().strip() + error_msg = f"OpenCL kernel build from cached binary failed: {e}" + if stderr_output: + error_msg += f"\n{stderr_output}" + raise RuntimeError(error_msg) + finally: + _sys.stderr = old_stderr + + self._compiler_log = "" + self._cache_hit = True + else: + # Compile program from source, suppressing stderr and warnings. + stderr_capture = _io.StringIO() + old_stderr = _sys.stderr + try: + _sys.stderr = stderr_capture + with _warnings.catch_warnings(): + _warnings.simplefilter("ignore") + program = _cl.Program(self._context, _kernel_code).build( + options=build_options + ) + except _cl.RuntimeError as e: + stderr_output = stderr_capture.getvalue().strip() + error_msg = f"OpenCL kernel compilation failed: {e}" + if stderr_output: + error_msg += f"\n{stderr_output}" + raise RuntimeError(error_msg) + finally: + _sys.stderr = old_stderr + + # Capture the compiler log (including any warnings). + self._compiler_log = program.get_build_info( + self._device, _cl.program_build_info.LOG + ).strip() + + self._cache_hit = False + + # Cache the compiled binary. + _kernel_cache[cache_key] = program.get_info(_cl.program_info.BINARIES)[0] # Create kernel wrappers that match PyCUDA calling convention. # OpenCL kernels need (queue, global_size, local_size, *args) @@ -189,6 +229,11 @@ def wrapper(*args, **kwargs): return kernels + @staticmethod + def clear_cache(): + """Clear the kernel compilation cache.""" + _kernel_cache.clear() + def to_gpu(self, array: _np.ndarray) -> _Any: """ Transfer a NumPy array to GPU memory. diff --git a/src/loch/_sampler.py b/src/loch/_sampler.py index dc90336..893cc17 100644 --- a/src/loch/_sampler.py +++ b/src/loch/_sampler.py @@ -759,6 +759,19 @@ def pop(self) -> None: """Pop the GPU context from the calling thread's context stack.""" self._backend.pop_context() + @property + def kernel_cache_hit(self) -> bool: + """ + Whether kernel compilation was satisfied from cache. + + Returns + ------- + + cache_hit: bool + True if kernels were loaded from cache, False if freshly compiled. + """ + return self._backend.cache_hit + def system(self) -> _Any: """ Return the GCMC system. diff --git a/tests/test_compiler.py b/tests/test_compiler.py index b369da3..394b727 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -113,6 +113,9 @@ def test_compilation_error_raises_exception(self): nvcc=_get_nvcc(), ) + # Clear the cache so the patched code is actually compiled. + CUDAPlatform.clear_cache() + # Patch kernel code directly in the cuda module (not the kernels module, # since it's already imported as _kernel_code at module load time). original_code = cuda_module._kernel_code diff --git a/tests/test_energy.py b/tests/test_energy.py index 4eaaf82..64a806f 100644 --- a/tests/test_energy.py +++ b/tests/test_energy.py @@ -357,3 +357,90 @@ def test_energy_regression(fixture, platform, request): assert math.isclose( energy_lj, ref["energy_lj"], abs_tol=1e-4 ), f"LJ energy changed: {energy_lj!r} != {ref['energy_lj']!r}" + + +@pytest.mark.skipif( + "CUDA_VISIBLE_DEVICES" not in os.environ, + reason="Requires CUDA enabled GPU.", +) +@pytest.mark.parametrize("platform", ["cuda", "opencl"]) +def test_cached_kernel_correctness(platform, water_box): + """ + A second sampler using cached kernels must produce the same energies + as the first. + """ + + mols, reference = water_box + + schedule = sr.cas.LambdaSchedule.standard_morph() + + def _create_and_run(seed): + sampler = GCMCSampler( + mols, + cutoff_type="rf", + cutoff="10 A", + reference=reference, + lambda_schedule=schedule, + lambda_value=0.5, + log_level="debug", + ghost_file=None, + log_file=None, + test=True, + platform=platform, + seed=seed, + ) + + d = sampler.system().dynamics( + cutoff_type="rf", + cutoff="10 A", + temperature="298 K", + pressure=None, + constraint="h_bonds", + timestep="2 fs", + schedule=schedule, + lambda_value=0.5, + coulomb_power=sampler._coulomb_power, + shift_coulomb=str(sampler._shift_coulomb), + shift_delta=str(sampler._shift_delta), + platform=platform, + ) + + is_accepted = False + while not is_accepted: + moves = sampler.move(d.context()) + if len(moves) > 0 and moves[0] == 0: + is_accepted = True + + return sampler + + # Clear the cache so the first sampler compiles from source. + if platform == "cuda": + from loch._platforms._cuda import CUDAPlatform + + CUDAPlatform.clear_cache() + else: + from loch._platforms._opencl import OpenCLPlatform + + OpenCLPlatform.clear_cache() + + # First sampler compiles kernels, second uses the cache. + # Both use the same seed so random water positions are identical. + sampler1 = _create_and_run(seed=42) + sampler2 = _create_and_run(seed=42) + + # Verify cache behaviour. + assert not sampler1.kernel_cache_hit, "First sampler should compile from source" + assert sampler2.kernel_cache_hit, "Second sampler should use cached kernels" + + # Verify energy consistency. + energy1_coul = sampler1._debug["energy_coul"] + energy1_lj = sampler1._debug["energy_lj"] + energy2_coul = sampler2._debug["energy_coul"] + energy2_lj = sampler2._debug["energy_lj"] + + assert math.isclose( + energy1_coul, energy2_coul, abs_tol=1e-4 + ), f"Coulomb energy mismatch: {energy1_coul!r} vs {energy2_coul!r}" + assert math.isclose( + energy1_lj, energy2_lj, abs_tol=1e-4 + ), f"LJ energy mismatch: {energy1_lj!r} vs {energy2_lj!r}"