From e256565022d43d972f2f522daf9f68ff4e0b1470 Mon Sep 17 00:00:00 2001 From: Louis Tiao Date: Tue, 25 Nov 2025 12:31:01 -0800 Subject: [PATCH] Refine DataLoaderConfig. statuses_to_fit to STATUSES_EXPECTING_DATA Summary: This change updates `DataLoaderConfig.statuses_to_fit` to use `STATUSES_EXPECTING_DATA` plus `TrialStatus.CANDIDATE` instead of `NON_ABANDONED_STATUSES`. This is a semantic narrowing: instead of including all trial statuses except `ABANDONED` (which includes `CANDIDATE`, `STAGED`, `FAILED`, `STALE`, etc.), it now only includes trials that are expected to have reliable data (`RUNNING`, `COMPLETED`, `EARLY_STOPPED`) plus `CANDIDATE` explicitly. **Changes:** - Changes implementation from `NON_ABANDONED_STATUSES` to `{*STATUSES_EXPECTING_DATA, TrialStatus.CANDIDATE}` - Updates docstrings in both `statuses_to_fit` and `statuses_to_fit_map_metric` to accurately reflect the new behavior, explicitly listing the statuses: `RUNNING`, `COMPLETED`, `EARLY_STOPPED`, plus `CANDIDATE` Differential Revision: D83514647 --- ax/adapter/data_utils.py | 7 ++++--- ax/adapter/tests/test_data_utils.py | 11 ++++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/ax/adapter/data_utils.py b/ax/adapter/data_utils.py index be1d9d27531..b1f8f499957 100644 --- a/ax/adapter/data_utils.py +++ b/ax/adapter/data_utils.py @@ -27,7 +27,7 @@ from ax.core.map_data import MAP_KEY, MapData from ax.core.map_metric import MapMetric from ax.core.observation import Observation, ObservationData, ObservationFeatures -from ax.core.trial_status import NON_ABANDONED_STATUSES, TrialStatus +from ax.core.trial_status import STATUSES_EXPECTING_DATA, TrialStatus from ax.core.types import TParameterization from ax.exceptions.core import UnsupportedError from ax.utils.common.constants import Keys @@ -86,11 +86,12 @@ def __post_init__(self, fit_out_of_design: bool | None) -> None: def statuses_to_fit(self) -> set[TrialStatus]: """The data from trials in these statuses will be used to fit the model for non map metrics. Defaults to all trial statuses if - `fit_abandoned is True` and all statuses except ABANDONED, otherwise. + `fit_abandoned is True` and trials that are expected to have data + (RUNNING, COMPLETED, EARLY_STOPPED) plus CANDIDATE, otherwise. """ if self.fit_abandoned: return set(TrialStatus) - return NON_ABANDONED_STATUSES + return {*STATUSES_EXPECTING_DATA, TrialStatus.CANDIDATE} @property def statuses_to_fit_map_metric(self) -> set[TrialStatus]: diff --git a/ax/adapter/tests/test_data_utils.py b/ax/adapter/tests/test_data_utils.py index c682eed7699..9995b607552 100644 --- a/ax/adapter/tests/test_data_utils.py +++ b/ax/adapter/tests/test_data_utils.py @@ -16,7 +16,7 @@ from ax.core.data import Data from ax.core.map_data import MAP_KEY, MapData from ax.core.observation import Observation, ObservationData, ObservationFeatures -from ax.core.trial_status import NON_ABANDONED_STATUSES, TrialStatus +from ax.core.trial_status import STATUSES_EXPECTING_DATA, TrialStatus from ax.exceptions.core import UnsupportedError from ax.utils.common.constants import Keys from ax.utils.common.testutils import TestCase @@ -44,8 +44,13 @@ def test_data_loader_config(self) -> None: self.assertEqual(config.latest_rows_per_group, 1) self.assertIsNone(config.limit_rows_per_group) self.assertIsNone(config.limit_rows_per_metric) - self.assertEqual(config.statuses_to_fit, NON_ABANDONED_STATUSES) - self.assertEqual(config.statuses_to_fit_map_metric, NON_ABANDONED_STATUSES) + self.assertEqual( + config.statuses_to_fit, {*STATUSES_EXPECTING_DATA, TrialStatus.CANDIDATE} + ) + self.assertEqual( + config.statuses_to_fit_map_metric, + {*STATUSES_EXPECTING_DATA, TrialStatus.CANDIDATE}, + ) # Validation for latest / limit rows. with self.assertRaisesRegex(UnsupportedError, "must be None if either of"): DataLoaderConfig(latest_rows_per_group=1, limit_rows_per_metric=5)