Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions ax/adapter/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ax.core.map_data import MAP_KEY, MapData
from ax.core.map_metric import MapMetric
from ax.core.observation import Observation, ObservationData, ObservationFeatures
from ax.core.trial_status import NON_ABANDONED_STATUSES, TrialStatus
from ax.core.trial_status import STATUSES_EXPECTING_DATA, TrialStatus
from ax.core.types import TParameterization
from ax.exceptions.core import UnsupportedError
from ax.utils.common.constants import Keys
Expand Down Expand Up @@ -86,11 +86,12 @@ def __post_init__(self, fit_out_of_design: bool | None) -> None:
def statuses_to_fit(self) -> set[TrialStatus]:
"""The data from trials in these statuses will be used to fit the model
for non map metrics. Defaults to all trial statuses if
`fit_abandoned is True` and all statuses except ABANDONED, otherwise.
`fit_abandoned is True` and trials that are expected to have data
(RUNNING, COMPLETED, EARLY_STOPPED) plus CANDIDATE, otherwise.
"""
if self.fit_abandoned:
return set(TrialStatus)
return NON_ABANDONED_STATUSES
return {*STATUSES_EXPECTING_DATA, TrialStatus.CANDIDATE}

@property
def statuses_to_fit_map_metric(self) -> set[TrialStatus]:
Expand Down
11 changes: 8 additions & 3 deletions ax/adapter/tests/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ax.core.data import Data
from ax.core.map_data import MAP_KEY, MapData
from ax.core.observation import Observation, ObservationData, ObservationFeatures
from ax.core.trial_status import NON_ABANDONED_STATUSES, TrialStatus
from ax.core.trial_status import STATUSES_EXPECTING_DATA, TrialStatus
from ax.exceptions.core import UnsupportedError
from ax.utils.common.constants import Keys
from ax.utils.common.testutils import TestCase
Expand Down Expand Up @@ -44,8 +44,13 @@ def test_data_loader_config(self) -> None:
self.assertEqual(config.latest_rows_per_group, 1)
self.assertIsNone(config.limit_rows_per_group)
self.assertIsNone(config.limit_rows_per_metric)
self.assertEqual(config.statuses_to_fit, NON_ABANDONED_STATUSES)
self.assertEqual(config.statuses_to_fit_map_metric, NON_ABANDONED_STATUSES)
self.assertEqual(
config.statuses_to_fit, {*STATUSES_EXPECTING_DATA, TrialStatus.CANDIDATE}
)
self.assertEqual(
config.statuses_to_fit_map_metric,
{*STATUSES_EXPECTING_DATA, TrialStatus.CANDIDATE},
)
# Validation for latest / limit rows.
with self.assertRaisesRegex(UnsupportedError, "must be None if either of"):
DataLoaderConfig(latest_rows_per_group=1, limit_rows_per_metric=5)
Expand Down