Conversation
Co-authored-by: drewoldag <47493171+drewoldag@users.noreply.github.com>
Co-authored-by: drewoldag <47493171+drewoldag@users.noreply.github.com>
Co-authored-by: drewoldag <47493171+drewoldag@users.noreply.github.com>
drewoldag
left a comment
There was a problem hiding this comment.
This looks good to me.
There was a problem hiding this comment.
Pull request overview
This PR adds fine-tuning / transfer learning support by introducing a model_weights_file key under the [train] config section. Users can now point to pre-trained weights that will be loaded before training begins (using only model parameters, not optimizer state), which is distinct from resume (full checkpoint restore). An early ValueError is raised if both are set simultaneously.
Changes:
- New
model_weights_file = falseconfig key in[train]with descriptive comment - Mutual-exclusivity validation and conditional weight loading (before
create_trainer) added toTrain.run() - Three new tests covering the conflict error, a full fine-tuning run, and the default value
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
src/hyrax/hyrax_default_config.toml |
Adds model_weights_file = false key with comment in [train] section |
src/hyrax/verbs/train.py |
Adds conflict validation, conditional weight loading with colorama logging in Train.run() |
tests/hyrax/test_train.py |
Adds three tests: conflict raises ValueError, successful fine-tuning run, default value is False |
You can also share your feedback on Copilot code review. Take the survey.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #752 +/- ##
==========================================
+ Coverage 64.66% 64.70% +0.04%
==========================================
Files 61 61
Lines 5881 5888 +7
==========================================
+ Hits 3803 3810 +7
Misses 2078 2078 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Co-authored-by: drewoldag <47493171+drewoldag@users.noreply.github.com>
Click here to view all benchmarks. |
aritraghsh09
left a comment
There was a problem hiding this comment.
I like the simple, clean implementation!
Change Description
Adds
model_weights_fileto the[train]config to allow loading pre-trained weights before training begins — enabling fine-tuning and transfer learning workflows. This is distinct fromresume, which restores a full checkpoint (optimizer state, epoch counter).Solution Description
hyrax_default_config.toml— Addedmodel_weights_file = falsedirectly after theresumekey in[train], with a concise 2-line comment clarifying its purpose and mutual exclusivity withresume. The key is placed adjacent toresumeto make their relationship clear.train.py— Two additions toTrain.run():ValueErrorif bothresumeandmodel_weights_fileare set simultaneously (fails before any dataset loading or directory creation).load_model_weights(config, model, "train")aftersetup_model()but beforecreate_trainer()— critical to avoid key mismatches fromidist.auto_modelwrapping. The two colorama-styled log lines confirming the weights path and fine-tuning mode are emitted afterload_model_weightssucceeds, ensuring no misleading output if loading fails.test_train.py— Three new tests: conflict raisesValueError, successful fine-tuning run, default value isFalse.Usage:
Setting both
resumeandmodel_weights_fileraises immediately:Code Quality
Original prompt
This section details on the original issue you should resolve
<issue_title>Enable training starting with a pre-trained model</issue_title>
<issue_description>Feature request
The user should be able to provide the path to a model weights file via the config that Hyrax would instruct Hyrax to load the model weights using the existing
load_model_weightsfunction.The key in the config should be under the
traintable and should be calledmodel_weights_file.It's value should be
falseby default.It should have a descriptive comment that states that if it's value is not false, it should be a path to a model weights file.
An error should be raised quickly if the user specifies both a resume (checkpoint path) in the config and a model_weights_file in the config.
There should be some logging associated with using a pre-trained weights file if one is provided to ensure that the user is aware of what is happening. </issue_description>
<agent_instructions>@copilot the second comment in the issue is the plan created by the PM agent. Please follow that plan for implementing this feature.</agent_instructions>
Comments on the Issue (you are @copilot in this section)
@drewoldag ## Plan: Add `model_weights_file` to `[train]` ConfigTL;DR: Add a
model_weights_filekey to the[train]config section so users can load pre-trained model weights (for fine-tuning / transfer learning) separately from full checkpoint resuming (resume). The two keys serve distinct purposes—resumerestores the full training state (optimizer, scheduler, epoch counter), whilemodel_weights_fileloads only model parameters and starts training fresh. An earlyValueErrorprevents setting both simultaneously.Steps
Add config key in src/hyrax/hyrax_default_config.toml — insert
model_weights_file = falsein the[train]table (after the existingresumeentry). Comment should read something like:Add early validation in src/hyrax/verbs/train.py — at the top of
Train.run(), immediately afterconfig = self.config, add a check: if bothconfig["train"]["resume"]andconfig["train"]["model_weights_file"]are truthy, raise aValueErrorwith a clear message explaining the difference and asking the user to pick one. This runs before any expensive dataset loading or directory creation.Load pre-trained weights in src/hyrax/verbs/train.py — after
model = setup_model(config, dataset["train"])(line ~72) and beforecreate_trainer(...)(line ~141). Ifconfig["train"]["model_weights_file"]is truthy, call the existingload_model_weights(config, model, "train")from src/hyrax/models/model_utils.py. This must happen beforecreate_trainerbecausecreate_trainerwraps the model withidist.auto_model(distributed wrapper), which can alter parameter key names. Loading weights into the un-wrapped model avoids key mismatches.Add logging — in
Train.run(), after theload_model_weightscall succeeds, log a colorama-styled message (matching the existing pattern like{Style.BRIGHT}{Fore.BLACK}{Back.GREEN}...{Style.RESET_ALL}) saying something like:This makes it unmistakable to the user what's happening and how it differs from
resume.Handle the
load_model_weightsfallback behavior — the existing function in model_utils.py falls back to auto-discovering the most recent training results whenmodel_weights_fileis falsy. Since step 3 only calls the function when the value is truthy, this fallback won't trigger. No changes toload_model_weightsare needed.Add tests in tests/hyrax/test_train.py — three new test functions:
test_train_raises_on_resume_and_model_weights_file— set bothconfig["train"]["resume"]andconfig["train"]["model_weights_file"]to non-false values, callh.train(), assertValueErroris raised with an appropriate message. Follow the existingloopback_hyraxfixture pattern.test_train_with_pretrained_weights— run a first training to produce a weights file, then setconfig["train"]["model_weights_file"]to that weights file path, run training again, assert it completes successfully. Verify the model was loaded from the sp...🔒 GitHub Advanced Security automatically protects Copilot coding agent pull requests. You can protect all pull requests by enabling Advanced Security for your repositories. Learn more about Advanced Security.