From 0cafdb188206da9685883cfc38f3fb8e6e6380f3 Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Mon, 15 Dec 2025 12:45:40 -0800 Subject: [PATCH] manual gc in llama3 Signed-off-by: Peter St. John --- .../llama3_native_te/hydra_config/defaults.yaml | 1 + .../recipes/llama3_native_te/perf_logger.py | 14 ++++++++++++++ .../recipes/llama3_native_te/train_fsdp2.py | 4 ---- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml index a4e34fd1df..cb1e56c1b9 100644 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml @@ -74,6 +74,7 @@ logger: profiler: enabled: false + gc_interval: 1_000 # Run garbage collection every 1000 steps schedule: wait: 10 warmup: 10 diff --git a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py index bbdfdd5987..ddde87a57e 100644 --- a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc import logging import time from pathlib import Path @@ -71,6 +72,11 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig): self.previous_step_time = time.perf_counter() self._profiler = None + # Manually control garbage collection for cleaner profiling. + self._gc_interval = args.profiler.gc_interval + gc.disable() + self._run_garbage_collection() + if self._dist_config.is_main_process(): # Log the entire args object to wandb for experiment tracking and reproducibility. self._wandb_run = wandb.init(**args.wandb, config=self._run_config) @@ -134,6 +140,9 @@ def log_step( if self._dist_config.local_rank == 0: logger.info(", ".join([f"{k.split('/')[1]}: {v:.3g}" for k, v in metrics.items()])) + if (step + 1) % self._gc_interval == 0: + self._run_garbage_collection() + def finish(self): """Finish the logger and close the progress bar.""" if self._profiler is not None: @@ -145,6 +154,11 @@ def finish(self): wandb.finish() self._progress_bar.close() + def _run_garbage_collection(self): + """Run garbage collection.""" + gc.collect() + torch.cuda.empty_cache() + def setup_profiler(args: DictConfig, wandb_run: wandb.Run): """Setup a basic torch profiler for the experiment. diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py index 2915b4af7f..7741f267e4 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import gc import logging from contextlib import nullcontext from pathlib import Path @@ -128,9 +127,6 @@ def main(args: DictConfig) -> float | None: perf_logger = PerfLogger(dist_config, args) - gc.collect() - torch.cuda.empty_cache() - # Training loop logger.info(f"Starting training loop from step {start_step} to {args.num_train_steps}") step = start_step