diff --git a/ax/adapter/data_utils.py b/ax/adapter/data_utils.py index be1d9d27531..b1f8f499957 100644 --- a/ax/adapter/data_utils.py +++ b/ax/adapter/data_utils.py @@ -27,7 +27,7 @@ from ax.core.map_data import MAP_KEY, MapData from ax.core.map_metric import MapMetric from ax.core.observation import Observation, ObservationData, ObservationFeatures -from ax.core.trial_status import NON_ABANDONED_STATUSES, TrialStatus +from ax.core.trial_status import STATUSES_EXPECTING_DATA, TrialStatus from ax.core.types import TParameterization from ax.exceptions.core import UnsupportedError from ax.utils.common.constants import Keys @@ -86,11 +86,12 @@ def __post_init__(self, fit_out_of_design: bool | None) -> None: def statuses_to_fit(self) -> set[TrialStatus]: """The data from trials in these statuses will be used to fit the model for non map metrics. Defaults to all trial statuses if - `fit_abandoned is True` and all statuses except ABANDONED, otherwise. + `fit_abandoned is True` and trials that are expected to have data + (RUNNING, COMPLETED, EARLY_STOPPED) plus CANDIDATE, otherwise. """ if self.fit_abandoned: return set(TrialStatus) - return NON_ABANDONED_STATUSES + return {*STATUSES_EXPECTING_DATA, TrialStatus.CANDIDATE} @property def statuses_to_fit_map_metric(self) -> set[TrialStatus]: diff --git a/ax/adapter/tests/test_data_utils.py b/ax/adapter/tests/test_data_utils.py index c682eed7699..9995b607552 100644 --- a/ax/adapter/tests/test_data_utils.py +++ b/ax/adapter/tests/test_data_utils.py @@ -16,7 +16,7 @@ from ax.core.data import Data from ax.core.map_data import MAP_KEY, MapData from ax.core.observation import Observation, ObservationData, ObservationFeatures -from ax.core.trial_status import NON_ABANDONED_STATUSES, TrialStatus +from ax.core.trial_status import STATUSES_EXPECTING_DATA, TrialStatus from ax.exceptions.core import UnsupportedError from ax.utils.common.constants import Keys from ax.utils.common.testutils import TestCase @@ -44,8 +44,13 @@ def test_data_loader_config(self) -> None: self.assertEqual(config.latest_rows_per_group, 1) self.assertIsNone(config.limit_rows_per_group) self.assertIsNone(config.limit_rows_per_metric) - self.assertEqual(config.statuses_to_fit, NON_ABANDONED_STATUSES) - self.assertEqual(config.statuses_to_fit_map_metric, NON_ABANDONED_STATUSES) + self.assertEqual( + config.statuses_to_fit, {*STATUSES_EXPECTING_DATA, TrialStatus.CANDIDATE} + ) + self.assertEqual( + config.statuses_to_fit_map_metric, + {*STATUSES_EXPECTING_DATA, TrialStatus.CANDIDATE}, + ) # Validation for latest / limit rows. with self.assertRaisesRegex(UnsupportedError, "must be None if either of"): DataLoaderConfig(latest_rows_per_group=1, limit_rows_per_metric=5)