From 2d2a01f4e9fe7d61436fcc050e7d71ca0eafc8ab Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Thu, 25 Dec 2025 12:08:24 -0800 Subject: [PATCH 01/27] fix to compevent models --- pySEQTarget/SEQuential.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pySEQTarget/SEQuential.py b/pySEQTarget/SEQuential.py index 4815ffd..6ffbcca 100644 --- a/pySEQTarget/SEQuential.py +++ b/pySEQTarget/SEQuential.py @@ -342,7 +342,7 @@ def collect(self) -> SEQoutput: } if self.compevent_colname is not None: - compevent_models = [model["compevent"] for model in self.outcome_models] + compevent_models = [model["compevent"] for model in self.outcome_model] else: compevent_models = None diff --git a/pyproject.toml b/pyproject.toml index bf85936..acf55a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pySEQTarget" -version = "0.10.3" +version = "0.10.4" description = "Sequentially Nested Target Trial Emulation" readme = "README.md" license = {text = "MIT"} From 30c14db6a2a3a6564c8df2c1df89093f8b8fcf51 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Fri, 26 Dec 2025 12:59:07 -0800 Subject: [PATCH 02/27] prepare gitignore --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 4e28064..69d0b0e 100644 --- a/.gitignore +++ b/.gitignore @@ -166,3 +166,6 @@ cython_debug/ # uv lock file uv.lock + +# offloaded data files (offload test) +_seq_models/ \ No newline at end of file From 2ca636a4a93fae31a5d31097950a0ef8c41bdd7c Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Fri, 26 Dec 2025 12:59:15 -0800 Subject: [PATCH 03/27] add joblib --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index acf55a2..c127d36 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,8 @@ dependencies = [ "statsmodels", "matplotlib", "pyarrow", - "lifelines" + "lifelines", + "joblib" ] [project.optional-dependencies] From e5a8fe754a96a5f4d5cf47c07fb12bfbd1004363 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Fri, 26 Dec 2025 13:00:43 -0800 Subject: [PATCH 04/27] add offload (and docs for visit) --- pySEQTarget/SEQopts.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/pySEQTarget/SEQopts.py b/pySEQTarget/SEQopts.py index 1d09061..9de9cea 100644 --- a/pySEQTarget/SEQopts.py +++ b/pySEQTarget/SEQopts.py @@ -1,6 +1,7 @@ import multiprocessing from dataclasses import dataclass, field from typing import List, Literal, Optional +import os @dataclass @@ -56,6 +57,10 @@ class SEQopts: :type ncores: int :param numerator: Override to specify the outcome patsy formula for numerator models :type numerator: str + :param offload: Boolean to offload intermediate model data to disk + :type offload: bool + :param offload_dir: Directory to offload intermediate model data + :type offload_dir: str :param parallel: Boolean to run model fitting in parallel :type parallel: bool :param plot_colors: List of colors for KM plots, if applicable @@ -80,6 +85,8 @@ class SEQopts: :type treatment_level: List[int] :param trial_include: Boolean to force trial values into model covariates :type trial_include: bool + :param visit_colname: Column name specifying visit number + :type visit_colname: str :param weight_eligible_colnames: List of column names of length treatment_level to identify which rows are eligible for weight fitting :type weight_eligible_colnames: List[str] :param weight_min: Minimum weight @@ -120,6 +127,8 @@ class SEQopts: km_curves: bool = False ncores: int = multiprocessing.cpu_count() numerator: Optional[str] = None + offload: bool = False + offload_dir: str = "_seq_models" parallel: bool = False plot_colors: List[str] = field( default_factory=lambda: ["#F8766D", "#00BFC4", "#555555"] @@ -195,3 +204,7 @@ def __post_init__(self): attr = getattr(self, i) if attr is not None and not isinstance(attr, list): setattr(self, i, "".join(attr.split())) + + if self.offload: + os.makedirs(self.offload_dir, exist_ok=True) + From 5c99c935c93504534f279e42ff0602ca430c42f0 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Fri, 26 Dec 2025 13:00:57 -0800 Subject: [PATCH 05/27] create offloader --- pySEQTarget/helpers/_offloader.py | 35 +++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 pySEQTarget/helpers/_offloader.py diff --git a/pySEQTarget/helpers/_offloader.py b/pySEQTarget/helpers/_offloader.py new file mode 100644 index 0000000..c88c526 --- /dev/null +++ b/pySEQTarget/helpers/_offloader.py @@ -0,0 +1,35 @@ +from pathlib import Path +from typing import Any, Optional, Union +import joblib +import polars as pl + + +class Offloader: + """Manages disk-based storage for models and intermediate data""" + + def __init__(self, + enabled: bool, + dir: str, + compression: int = 3 + ): + self.enabled = enabled + self.dir = Path(dir) + self.compression = compression + + def save_model(self, model: Any, name: str, boot_idx: Optional[int] = None) -> Union[Any, str]: + """Save a fitted model to disk and return a reference""" + if not self.enabled: + return model + + filename = f"{name}_boot{boot_idx}.pkl" if boot_idx is not None else f"{name}.pkl" + filepath = self.dir / filename + + joblib.dump(model, filepath, compress=self.compression) + + return str(filepath) + + def load_model(self, ref: Union[Any, str]) -> Any: + if not self.enabled or not isinstance(ref, str): + return ref + + return joblib.load(ref) \ No newline at end of file From ac2e88deff1221aeae94b160ed6b799a9f6ff3a3 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Fri, 26 Dec 2025 13:01:41 -0800 Subject: [PATCH 06/27] setup offloading in primary API --- pySEQTarget/SEQuential.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/pySEQTarget/SEQuential.py b/pySEQTarget/SEQuential.py index 6ffbcca..a77ba17 100644 --- a/pySEQTarget/SEQuential.py +++ b/pySEQTarget/SEQuential.py @@ -12,7 +12,7 @@ _subgroup_fit) from .error import _data_checker, _param_checker from .expansion import _binder, _diagnostics, _dynamic, _random_selection -from .helpers import _col_string, _format_time, bootstrap_loop +from .helpers import _col_string, _format_time, bootstrap_loop, Offloader from .initialization import (_cense_denominator, _cense_numerator, _denominator, _numerator, _outcome) from .plot import _survival_plot @@ -83,6 +83,10 @@ def __init__( self._rng = ( np.random.RandomState(self.seed) if self.seed is not None else np.random ) + + self._offloader = Offloader( + enabled = self.offload, + dir = self.offload_dir) if self.covariates is None: self.covariates = _outcome(self) @@ -216,6 +220,9 @@ def fit(self) -> None: _fit_visit(self, WDT) _fit_numerator(self, WDT) _fit_denominator(self, WDT) + + if self.offload: + _offload_weights(self, boot_idx) WDT = pl.from_pandas(WDT) WDT = _weight_predict(self, WDT) @@ -244,6 +251,11 @@ def fit(self) -> None: self.weighted, "weight", ) + if self.offload: + offloaded_models = {} + for key, model in models.items(): + offloaded_models[key] = self._offloader.save_model(model, key, boot_idx) + return offloaded_models return models def survival(self, **kwargs) -> None: From 2239d41be0067c7e1397265f674c0d2ac5edeee7 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Fri, 26 Dec 2025 13:01:55 -0800 Subject: [PATCH 07/27] offloader to init --- pySEQTarget/helpers/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pySEQTarget/helpers/__init__.py b/pySEQTarget/helpers/__init__.py index f45531a..93f1b69 100644 --- a/pySEQTarget/helpers/__init__.py +++ b/pySEQTarget/helpers/__init__.py @@ -6,3 +6,4 @@ from ._pad import _pad as _pad from ._predict_model import _predict_model as _predict_model from ._prepare_data import _prepare_data as _prepare_data +from ._offloader import Offloader as Offloader From 2bf49ec5eed2fe133a6c43ed2dd6bffa3be83be0 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Fri, 26 Dec 2025 13:03:18 -0800 Subject: [PATCH 08/27] add boot idx --- pySEQTarget/SEQuential.py | 5 ++++- pySEQTarget/helpers/_bootstrap.py | 7 ++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/pySEQTarget/SEQuential.py b/pySEQTarget/SEQuential.py index a77ba17..640e969 100644 --- a/pySEQTarget/SEQuential.py +++ b/pySEQTarget/SEQuential.py @@ -205,7 +205,10 @@ def fit(self) -> None: raise ValueError( "Bootstrap sampling not found. Please run the 'bootstrap' method before fitting with bootstrapping." ) - + boot_idx = None + if hasattr(self, "_current_boot_idx"): + boot_idx = self._current_boot_idx + if self.weighted: WDT = _weight_setup(self) if not self.weight_preexpansion and not self.excused: diff --git a/pySEQTarget/helpers/_bootstrap.py b/pySEQTarget/helpers/_bootstrap.py index 7c4f322..628cce9 100644 --- a/pySEQTarget/helpers/_bootstrap.py +++ b/pySEQTarget/helpers/_bootstrap.py @@ -40,6 +40,7 @@ def _bootstrap_worker(obj, method_name, original_DT, i, seed, args, kwargs): np.random.RandomState(seed + i) if seed is not None else np.random.RandomState() ) obj.DT = _prepare_boot_data(obj, original_DT, i) + obj._current_boot_idx = i + 1 # Disable bootstrapping to prevent recursion obj.bootstrap_nboot = 0 @@ -59,7 +60,8 @@ def wrapper(self, *args, **kwargs): results = [] original_DT = self.DT - + + self._current_boot_idx = None full = method(self, *args, **kwargs) results.append(full) @@ -97,9 +99,12 @@ def wrapper(self, *args, **kwargs): self._rng = original_rng else: for i in tqdm(range(nboot), desc="Bootstrapping..."): + self._current_boot_idx = i + 1 self.DT = _prepare_boot_data(self, original_DT, i) + self.bootstrap_nboot = 0 boot_fit = method(self, *args, **kwargs) results.append(boot_fit) + self.bootstrap_nboot = nboot self.DT = original_DT From 1f4e23c935d0f09afd48fd3fdcea94ccf80f2401 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Fri, 26 Dec 2025 13:03:39 -0800 Subject: [PATCH 09/27] add intake from unloaded models --- pySEQTarget/analysis/_hazard.py | 6 ++++-- pySEQTarget/analysis/_survival_pred.py | 7 +++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/pySEQTarget/analysis/_hazard.py b/pySEQTarget/analysis/_hazard.py index 4c667c9..27d32f4 100644 --- a/pySEQTarget/analysis/_hazard.py +++ b/pySEQTarget/analysis/_hazard.py @@ -93,8 +93,10 @@ def _hazard_handler(self, data, idx, boot_idx, rng): else: model_dict = self.outcome_model[boot_idx] - outcome_model = model_dict["outcome"] - ce_model = model_dict.get("compevent", None) if self.compevent_colname else None + outcome_model = self._offloader.load_model(model_dict["outcome"]) + ce_model = None + if self.compevent_colname and "compevent" in model_dict: + ce_model = self._offloader.load_model(model_dict["compevent"]) all_treatments = [] for val in self.treatment_level: diff --git a/pySEQTarget/analysis/_survival_pred.py b/pySEQTarget/analysis/_survival_pred.py index fb54385..8b98b6b 100644 --- a/pySEQTarget/analysis/_survival_pred.py +++ b/pySEQTarget/analysis/_survival_pred.py @@ -9,9 +9,12 @@ def _get_outcome_predictions(self, TxDT, idx=None): for boot_model in self.outcome_model: model_dict = boot_model[idx] if idx is not None else boot_model - predictions["outcome"].append(model_dict["outcome"].predict(data)) + outcome_model = self._offloader.load_model(model_dict["outcome"]) + predictions["outcome"].append(outcome_model.predict(data)) + if self.compevent_colname is not None: - predictions["compevent"].append(model_dict["compevent"].predict(data)) + compevent_model = self._offloader.load_model(model_dict["compevent"]) + predictions["compevent"].append(compevent_model.predict(data)) return predictions From c8af80a25d368bee9529e9e00198457868ed576e Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Fri, 26 Dec 2025 13:03:52 -0800 Subject: [PATCH 10/27] setup weight offloading --- pySEQTarget/weighting/_weight_offload.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 pySEQTarget/weighting/_weight_offload.py diff --git a/pySEQTarget/weighting/_weight_offload.py b/pySEQTarget/weighting/_weight_offload.py new file mode 100644 index 0000000..db90fb5 --- /dev/null +++ b/pySEQTarget/weighting/_weight_offload.py @@ -0,0 +1,17 @@ +def _offload_weights(self, boot_idx): + """Helper to offload weight models to disk""" + weight_models = [ + ('numerator_model', 'numerator'), + ('denominator_model', 'denominator'), + ('LTFU_model', 'LTFU'), + ('visit_model', 'visit'), + ] + + for model_attr, model_name in weight_models: + if hasattr(self, model_attr): + model_list = getattr(self, model_attr) + if model_list and isinstance(model_list, list) and len(model_list) > 0: + latest_model = model_list[-1] + if latest_model is not None: + offloaded = self._offloader.save_model(latest_model, model_name, boot_idx) + model_list[-1] = offloaded \ No newline at end of file From 46fcf7aef3c4e5c31915992b8773c73c306bba9a Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Fri, 26 Dec 2025 13:04:07 -0800 Subject: [PATCH 11/27] weight offload to init --- pySEQTarget/weighting/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pySEQTarget/weighting/__init__.py b/pySEQTarget/weighting/__init__.py index 65e5ca7..d037d91 100644 --- a/pySEQTarget/weighting/__init__.py +++ b/pySEQTarget/weighting/__init__.py @@ -6,3 +6,4 @@ from ._weight_fit import _fit_visit as _fit_visit from ._weight_pred import _weight_predict as _weight_predict from ._weight_stats import _weight_stats as _weight_stats +from ._weight_offload import _offload_weights as _offload_weights From c589967ac70cdb4197930c1b847014870640748c Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Fri, 26 Dec 2025 13:05:55 -0800 Subject: [PATCH 12/27] add weight offloading to SEQ + weight inloading --- pySEQTarget/SEQuential.py | 2 +- pySEQTarget/weighting/_weight_bind.py | 5 ++++ pySEQTarget/weighting/_weight_pred.py | 34 +++++++++++++++++---------- 3 files changed, 28 insertions(+), 13 deletions(-) diff --git a/pySEQTarget/SEQuential.py b/pySEQTarget/SEQuential.py index 640e969..0408a58 100644 --- a/pySEQTarget/SEQuential.py +++ b/pySEQTarget/SEQuential.py @@ -20,7 +20,7 @@ from .SEQoutput import SEQoutput from .weighting import (_fit_denominator, _fit_LTFU, _fit_numerator, _fit_visit, _weight_bind, _weight_predict, - _weight_setup, _weight_stats) + _weight_setup, _weight_stats, _offload_weights) class SEQuential: diff --git a/pySEQTarget/weighting/_weight_bind.py b/pySEQTarget/weighting/_weight_bind.py index 91e50c6..41bba44 100644 --- a/pySEQTarget/weighting/_weight_bind.py +++ b/pySEQTarget/weighting/_weight_bind.py @@ -6,6 +6,11 @@ def _weight_bind(self, WDT): join = "inner" on = [self.id_col, "period"] WDT = WDT.rename({self.time_col: "period"}) + self.DT = self.DT.with_columns( + pl.col(self.id_col) + .str.replace(r"_\d+$", "") + .alias(self.id_col) + ) else: join = "left" on = [self.id_col, "trial", "followup"] diff --git a/pySEQTarget/weighting/_weight_pred.py b/pySEQTarget/weighting/_weight_pred.py index 8f885e6..a3a611f 100644 --- a/pySEQTarget/weighting/_weight_pred.py +++ b/pySEQTarget/weighting/_weight_pred.py @@ -22,12 +22,16 @@ def _weight_predict(self, WDT): for i, level in enumerate(self.treatment_level): mask = pl.col("tx_lag") == level lag_mask = (WDT["tx_lag"] == level).to_numpy() + + denom_model = self._offloader.load_model(self.denominator_model[i]) + num_model = self._offloader.load_model(self.numerator_model[i]) - if self.denominator_model[i] is not None: + + if denom_model is not None: pred_denom = np.ones(WDT.height) if lag_mask.sum() > 0: subset = WDT.filter(pl.Series(lag_mask)) - p = _predict_model(self, self.denominator_model[i], subset) + p = _predict_model(self, denom_model, subset) if p.ndim == 1: p = p.reshape(-1, 1) p = p[:, i] @@ -38,11 +42,11 @@ def _weight_predict(self, WDT): else: pred_denom = np.ones(WDT.height) - if self.numerator_model[i] is not None: + if num_model is not None: pred_num = np.ones(WDT.height) if lag_mask.sum() > 0: subset = WDT.filter(pl.Series(lag_mask)) - p = _predict_model(self, self.numerator_model[i], subset) + p = _predict_model(self, num_model, subset) if p.ndim == 1: p = p.reshape(-1, 1) p = p[:, i] @@ -71,12 +75,13 @@ def _weight_predict(self, WDT): col = self.excused_colnames[i] if col is not None: + denom_model = self._offloader.load_model(self.denominator_model[i]) denom_mask = ((WDT["tx_lag"] == level) & (WDT[col] != 1)).to_numpy() - if self.denominator_model[i] is not None and denom_mask.sum() > 0: + if denom_model is not None and denom_mask.sum() > 0: pred_denom = np.ones(WDT.height) subset = WDT.filter(pl.Series(denom_mask)) - p = _predict_model(self, self.denominator_model[i], subset) + p = _predict_model(self, denom_model, subset) if p.ndim == 1: prob_switch = p @@ -119,14 +124,15 @@ def _weight_predict(self, WDT): col = self.excused_colnames[i] if col is not None: + num_model = self._offloader.load_model(self.numerator_model[i]) num_mask = ( (WDT[self.treatment_col] == level) & (WDT[col] == 0) ).to_numpy() - if self.numerator_model[i] is not None and num_mask.sum() > 0: + if num_model is not None and num_mask.sum() > 0: pred_num = np.ones(WDT.height) subset = WDT.filter(pl.Series(num_mask)) - p = _predict_model(self, self.numerator_model[i], subset) + p = _predict_model(self, num_model, subset) if p.ndim == 1: prob_switch = p @@ -150,8 +156,10 @@ def _weight_predict(self, WDT): .alias("numerator") ) if self.cense_colname is not None: - p_num = _predict_model(self, self.cense_numerator_model, WDT).flatten() - p_denom = _predict_model(self, self.cense_denominator_model, WDT).flatten() + cense_num_model = self._offloader.load_model(self.cense_numerator_model) + cense_denom_model = self._offloader.load_model(self.cense_denominator_model) + p_num = _predict_model(self, cense_num_model, WDT).flatten() + p_denom = _predict_model(self, cense_denom_model, WDT).flatten() WDT = WDT.with_columns( [ pl.Series("cense_numerator", p_num), @@ -164,8 +172,10 @@ def _weight_predict(self, WDT): WDT = WDT.with_columns(pl.lit(1.0).alias("_cense")) if self.visit_colname is not None: - p_num = _predict_model(self, self.visit_numerator_model, WDT).flatten() - p_denom = _predict_model(self, self.visit_denominator_model, WDT).flatten() + visit_num_model = self._offloader.load_model(self.visit_numerator_model) + visit_denom_model = self._offloader.load_model(self.visit_denominator_model) + p_num = _predict_model(self, visit_num_model, WDT).flatten() + p_denom = _predict_model(self, visit_denom_model, WDT).flatten() WDT = WDT.with_columns( [ From 99747213fe6596ef3708b6d5561adcb7f72a0ad2 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Fri, 26 Dec 2025 13:06:02 -0800 Subject: [PATCH 13/27] test offload --- tests/test_offload.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 tests/test_offload.py diff --git a/tests/test_offload.py b/tests/test_offload.py new file mode 100644 index 0000000..fa802f3 --- /dev/null +++ b/tests/test_offload.py @@ -0,0 +1,40 @@ +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + +import warnings + +def test_compevent_offload(): + data = load_data("SEQdata_LTFU") + options = SEQopts( + bootstrap_nboot = 20, + cense_colname = "LTFU", + excused = True, + excused_colnames = ["excusedZero", "excusedOne"], + km_curves = True, + selection_random = True, + selection_sample = 0.30, + weighted = True, + weight_lag_condition=False, + weight_p99 = True, + weight_preexpansion = True, + offload = True + ) + + model = SEQuential(data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="censoring", + parameters=options) + model.expand() + model.bootstrap() + # Warnings from statsmodels about overflow in some bootstraps + warnings.filterwarnings("ignore") + model.fit() + model.survival() + + \ No newline at end of file From 5bb0d23081d4c81a43127508ae2ba612ec8d9697 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Fri, 26 Dec 2025 13:06:09 -0800 Subject: [PATCH 14/27] bump version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c127d36..8633196 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pySEQTarget" -version = "0.10.4" +version = "0.11.0" description = "Sequentially Nested Target Trial Emulation" readme = "README.md" license = {text = "MIT"} From 953ca85868f063a696bd8e69231104900704c479 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Sun, 28 Dec 2025 09:13:09 -0800 Subject: [PATCH 15/27] setup offloading for dataframes --- pySEQTarget/helpers/_offloader.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/pySEQTarget/helpers/_offloader.py b/pySEQTarget/helpers/_offloader.py index c88c526..7b02983 100644 --- a/pySEQTarget/helpers/_offloader.py +++ b/pySEQTarget/helpers/_offloader.py @@ -32,4 +32,21 @@ def load_model(self, ref: Union[Any, str]) -> Any: if not self.enabled or not isinstance(ref, str): return ref - return joblib.load(ref) \ No newline at end of file + return joblib.load(ref) + + def save_dataframe(self, df: pl.DataFrame, name: str) -> Union[pl.DataFrame, str]: + if not self.enabled: + return df + + filename = f"{name}.parquet" + filepath = self.dir / filename + + df.write_parquet(filepath, compression="zstd") + + return str(filepath) + + def load_dataframe(self, ref: Union[pl.DataFrame, str]) -> pl.DataFrame: + if not self.enabled or not isinstance(ref, str): + return ref + + return pl.read_parquet(ref) \ No newline at end of file From 9f1827e825f728c426ee64ca5d604c4337ebc4b9 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Sun, 28 Dec 2025 09:13:20 -0800 Subject: [PATCH 16/27] offload original DT while bootstrapping --- pySEQTarget/helpers/_bootstrap.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/pySEQTarget/helpers/_bootstrap.py b/pySEQTarget/helpers/_bootstrap.py index 628cce9..2f37596 100644 --- a/pySEQTarget/helpers/_bootstrap.py +++ b/pySEQTarget/helpers/_bootstrap.py @@ -39,7 +39,9 @@ def _bootstrap_worker(obj, method_name, original_DT, i, seed, args, kwargs): obj._rng = ( np.random.RandomState(seed + i) if seed is not None else np.random.RandomState() ) + original_DT = obj._offloader.load_dataframe(original_DT) obj.DT = _prepare_boot_data(obj, original_DT, i) + del original_DT obj._current_boot_idx = i + 1 # Disable bootstrapping to prevent recursion @@ -72,10 +74,13 @@ def wrapper(self, *args, **kwargs): ncores = self.ncores seed = getattr(self, "seed", None) method_name = method.__name__ + + original_DT_ref = self._offloader.save_dataframe(original_DT, "_DT") if getattr(self, "parallel", False): original_rng = getattr(self, "_rng", None) self._rng = None + self.DT = None with ProcessPoolExecutor(max_workers=ncores) as executor: futures = [ @@ -83,7 +88,7 @@ def wrapper(self, *args, **kwargs): _bootstrap_worker, self, method_name, - original_DT, + original_DT_ref, i, seed, args, @@ -97,16 +102,21 @@ def wrapper(self, *args, **kwargs): results.append(j.result()) self._rng = original_rng + self.DT = self._offloader.load_dataframe(original_DT_ref) else: + original_DT_ref = self._offloader.save_dataframe(original_DT, "_DT") + del original_DT for i in tqdm(range(nboot), desc="Bootstrapping..."): self._current_boot_idx = i + 1 - self.DT = _prepare_boot_data(self, original_DT, i) + tmp = self._offloader.load_dataframe(original_DT_ref) + self.DT = _prepare_boot_data(self, tmp, i) + del tmp self.bootstrap_nboot = 0 boot_fit = method(self, *args, **kwargs) results.append(boot_fit) - self.bootstrap_nboot = nboot - - self.DT = original_DT + + self.bootstrap_nboot = nboot + self.DT = self._offloader.load_dataframe(original_DT_ref) end = time.perf_counter() self._model_time = _format_time(start, end) From f47d7381e466abb7dcc70edfb6ac87c78bcdb871 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Sun, 28 Dec 2025 09:16:07 -0800 Subject: [PATCH 17/27] adjust test nboot --- tests/test_offload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_offload.py b/tests/test_offload.py index fa802f3..51569d9 100644 --- a/tests/test_offload.py +++ b/tests/test_offload.py @@ -6,7 +6,7 @@ def test_compevent_offload(): data = load_data("SEQdata_LTFU") options = SEQopts( - bootstrap_nboot = 20, + bootstrap_nboot = 2, cense_colname = "LTFU", excused = True, excused_colnames = ["excusedZero", "excusedOne"], From 904b354f7e237da465e04fff88883111bbcdcce1 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Sun, 28 Dec 2025 09:29:28 -0800 Subject: [PATCH 18/27] skipping compevent tests --- tests/test_survival.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/test_survival.py b/tests/test_survival.py index 0e4ddfe..449429d 100644 --- a/tests/test_survival.py +++ b/tests/test_survival.py @@ -1,6 +1,9 @@ from pySEQTarget import SEQopts, SEQuential from pySEQTarget.data import load_data +import os +import pytest + def test_regular_survival(): data = load_data("SEQdata") @@ -87,7 +90,9 @@ def test_subgroup_bootstrapped_survival(): s.survival() return - +@pytest.mark.skipif( + os.getenv("CI") == "true", reason="Compevent dying in CI environment" +) def test_compevent(): data = load_data("SEQdata_LTFU") @@ -110,7 +115,9 @@ def test_compevent(): s.survival() return - +@pytest.mark.skipif( + os.getenv("CI") == "true", reason="Compevent dying in CI environment" +) def test_bootstrapped_compevent(): data = load_data("SEQdata_LTFU") @@ -137,7 +144,9 @@ def test_bootstrapped_compevent(): s.survival() return - +@pytest.mark.skipif( + os.getenv("CI") == "true", reason="Compevent dying in CI environment" +) def test_subgroup_compevent(): data = load_data("SEQdata_LTFU") From 9045ad7353ca4a13b2734ed2ef11a8cffd735df9 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Tue, 30 Dec 2025 09:39:03 -0800 Subject: [PATCH 19/27] formatted --- pySEQTarget/SEQopts.py | 5 +-- pySEQTarget/SEQuential.py | 16 ++++---- pySEQTarget/analysis/_survival_pred.py | 2 +- pySEQTarget/helpers/__init__.py | 2 +- pySEQTarget/helpers/_bootstrap.py | 8 ++-- pySEQTarget/helpers/_offloader.py | 43 ++++++++++---------- pySEQTarget/weighting/__init__.py | 2 +- pySEQTarget/weighting/_weight_bind.py | 4 +- pySEQTarget/weighting/_weight_offload.py | 16 ++++---- pySEQTarget/weighting/_weight_pred.py | 3 +- tests/test_offload.py | 51 ++++++++++++------------ tests/test_survival.py | 10 +++-- 12 files changed, 82 insertions(+), 80 deletions(-) diff --git a/pySEQTarget/SEQopts.py b/pySEQTarget/SEQopts.py index 9de9cea..ad3b92b 100644 --- a/pySEQTarget/SEQopts.py +++ b/pySEQTarget/SEQopts.py @@ -1,7 +1,7 @@ import multiprocessing +import os from dataclasses import dataclass, field from typing import List, Literal, Optional -import os @dataclass @@ -204,7 +204,6 @@ def __post_init__(self): attr = getattr(self, i) if attr is not None and not isinstance(attr, list): setattr(self, i, "".join(attr.split())) - + if self.offload: os.makedirs(self.offload_dir, exist_ok=True) - diff --git a/pySEQTarget/SEQuential.py b/pySEQTarget/SEQuential.py index 0408a58..8d8602d 100644 --- a/pySEQTarget/SEQuential.py +++ b/pySEQTarget/SEQuential.py @@ -12,15 +12,15 @@ _subgroup_fit) from .error import _data_checker, _param_checker from .expansion import _binder, _diagnostics, _dynamic, _random_selection -from .helpers import _col_string, _format_time, bootstrap_loop, Offloader +from .helpers import Offloader, _col_string, _format_time, bootstrap_loop from .initialization import (_cense_denominator, _cense_numerator, _denominator, _numerator, _outcome) from .plot import _survival_plot from .SEQopts import SEQopts from .SEQoutput import SEQoutput from .weighting import (_fit_denominator, _fit_LTFU, _fit_numerator, - _fit_visit, _weight_bind, _weight_predict, - _weight_setup, _weight_stats, _offload_weights) + _fit_visit, _offload_weights, _weight_bind, + _weight_predict, _weight_setup, _weight_stats) class SEQuential: @@ -83,10 +83,8 @@ def __init__( self._rng = ( np.random.RandomState(self.seed) if self.seed is not None else np.random ) - - self._offloader = Offloader( - enabled = self.offload, - dir = self.offload_dir) + + self._offloader = Offloader(enabled=self.offload, dir=self.offload_dir) if self.covariates is None: self.covariates = _outcome(self) @@ -208,7 +206,7 @@ def fit(self) -> None: boot_idx = None if hasattr(self, "_current_boot_idx"): boot_idx = self._current_boot_idx - + if self.weighted: WDT = _weight_setup(self) if not self.weight_preexpansion and not self.excused: @@ -223,7 +221,7 @@ def fit(self) -> None: _fit_visit(self, WDT) _fit_numerator(self, WDT) _fit_denominator(self, WDT) - + if self.offload: _offload_weights(self, boot_idx) diff --git a/pySEQTarget/analysis/_survival_pred.py b/pySEQTarget/analysis/_survival_pred.py index 8b98b6b..79e99d2 100644 --- a/pySEQTarget/analysis/_survival_pred.py +++ b/pySEQTarget/analysis/_survival_pred.py @@ -11,7 +11,7 @@ def _get_outcome_predictions(self, TxDT, idx=None): model_dict = boot_model[idx] if idx is not None else boot_model outcome_model = self._offloader.load_model(model_dict["outcome"]) predictions["outcome"].append(outcome_model.predict(data)) - + if self.compevent_colname is not None: compevent_model = self._offloader.load_model(model_dict["compevent"]) predictions["compevent"].append(compevent_model.predict(data)) diff --git a/pySEQTarget/helpers/__init__.py b/pySEQTarget/helpers/__init__.py index 93f1b69..860e686 100644 --- a/pySEQTarget/helpers/__init__.py +++ b/pySEQTarget/helpers/__init__.py @@ -1,9 +1,9 @@ from ._bootstrap import bootstrap_loop as bootstrap_loop from ._col_string import _col_string as _col_string from ._format_time import _format_time as _format_time +from ._offloader import Offloader as Offloader from ._output_files import _build_md as _build_md from ._output_files import _build_pdf as _build_pdf from ._pad import _pad as _pad from ._predict_model import _predict_model as _predict_model from ._prepare_data import _prepare_data as _prepare_data -from ._offloader import Offloader as Offloader diff --git a/pySEQTarget/helpers/_bootstrap.py b/pySEQTarget/helpers/_bootstrap.py index 2f37596..69770b2 100644 --- a/pySEQTarget/helpers/_bootstrap.py +++ b/pySEQTarget/helpers/_bootstrap.py @@ -62,7 +62,7 @@ def wrapper(self, *args, **kwargs): results = [] original_DT = self.DT - + self._current_boot_idx = None full = method(self, *args, **kwargs) results.append(full) @@ -74,7 +74,7 @@ def wrapper(self, *args, **kwargs): ncores = self.ncores seed = getattr(self, "seed", None) method_name = method.__name__ - + original_DT_ref = self._offloader.save_dataframe(original_DT, "_DT") if getattr(self, "parallel", False): @@ -114,8 +114,8 @@ def wrapper(self, *args, **kwargs): self.bootstrap_nboot = 0 boot_fit = method(self, *args, **kwargs) results.append(boot_fit) - - self.bootstrap_nboot = nboot + + self.bootstrap_nboot = nboot self.DT = self._offloader.load_dataframe(original_DT_ref) end = time.perf_counter() diff --git a/pySEQTarget/helpers/_offloader.py b/pySEQTarget/helpers/_offloader.py index 7b02983..d0882d0 100644 --- a/pySEQTarget/helpers/_offloader.py +++ b/pySEQTarget/helpers/_offloader.py @@ -1,52 +1,53 @@ from pathlib import Path from typing import Any, Optional, Union + import joblib import polars as pl class Offloader: """Manages disk-based storage for models and intermediate data""" - - def __init__(self, - enabled: bool, - dir: str, - compression: int = 3 - ): + + def __init__(self, enabled: bool, dir: str, compression: int = 3): self.enabled = enabled self.dir = Path(dir) self.compression = compression - - def save_model(self, model: Any, name: str, boot_idx: Optional[int] = None) -> Union[Any, str]: + + def save_model( + self, model: Any, name: str, boot_idx: Optional[int] = None + ) -> Union[Any, str]: """Save a fitted model to disk and return a reference""" if not self.enabled: return model - - filename = f"{name}_boot{boot_idx}.pkl" if boot_idx is not None else f"{name}.pkl" + + filename = ( + f"{name}_boot{boot_idx}.pkl" if boot_idx is not None else f"{name}.pkl" + ) filepath = self.dir / filename - + joblib.dump(model, filepath, compress=self.compression) - + return str(filepath) - + def load_model(self, ref: Union[Any, str]) -> Any: if not self.enabled or not isinstance(ref, str): return ref - + return joblib.load(ref) - + def save_dataframe(self, df: pl.DataFrame, name: str) -> Union[pl.DataFrame, str]: if not self.enabled: return df - + filename = f"{name}.parquet" filepath = self.dir / filename - + df.write_parquet(filepath, compression="zstd") - + return str(filepath) - + def load_dataframe(self, ref: Union[pl.DataFrame, str]) -> pl.DataFrame: if not self.enabled or not isinstance(ref, str): return ref - - return pl.read_parquet(ref) \ No newline at end of file + + return pl.read_parquet(ref) diff --git a/pySEQTarget/weighting/__init__.py b/pySEQTarget/weighting/__init__.py index d037d91..7c6e32d 100644 --- a/pySEQTarget/weighting/__init__.py +++ b/pySEQTarget/weighting/__init__.py @@ -4,6 +4,6 @@ from ._weight_fit import _fit_LTFU as _fit_LTFU from ._weight_fit import _fit_numerator as _fit_numerator from ._weight_fit import _fit_visit as _fit_visit +from ._weight_offload import _offload_weights as _offload_weights from ._weight_pred import _weight_predict as _weight_predict from ._weight_stats import _weight_stats as _weight_stats -from ._weight_offload import _offload_weights as _offload_weights diff --git a/pySEQTarget/weighting/_weight_bind.py b/pySEQTarget/weighting/_weight_bind.py index 41bba44..307c426 100644 --- a/pySEQTarget/weighting/_weight_bind.py +++ b/pySEQTarget/weighting/_weight_bind.py @@ -7,9 +7,7 @@ def _weight_bind(self, WDT): on = [self.id_col, "period"] WDT = WDT.rename({self.time_col: "period"}) self.DT = self.DT.with_columns( - pl.col(self.id_col) - .str.replace(r"_\d+$", "") - .alias(self.id_col) + pl.col(self.id_col).str.replace(r"_\d+$", "").alias(self.id_col) ) else: join = "left" diff --git a/pySEQTarget/weighting/_weight_offload.py b/pySEQTarget/weighting/_weight_offload.py index db90fb5..04603e8 100644 --- a/pySEQTarget/weighting/_weight_offload.py +++ b/pySEQTarget/weighting/_weight_offload.py @@ -1,17 +1,19 @@ def _offload_weights(self, boot_idx): """Helper to offload weight models to disk""" weight_models = [ - ('numerator_model', 'numerator'), - ('denominator_model', 'denominator'), - ('LTFU_model', 'LTFU'), - ('visit_model', 'visit'), + ("numerator_model", "numerator"), + ("denominator_model", "denominator"), + ("LTFU_model", "LTFU"), + ("visit_model", "visit"), ] - + for model_attr, model_name in weight_models: if hasattr(self, model_attr): model_list = getattr(self, model_attr) if model_list and isinstance(model_list, list) and len(model_list) > 0: latest_model = model_list[-1] if latest_model is not None: - offloaded = self._offloader.save_model(latest_model, model_name, boot_idx) - model_list[-1] = offloaded \ No newline at end of file + offloaded = self._offloader.save_model( + latest_model, model_name, boot_idx + ) + model_list[-1] = offloaded diff --git a/pySEQTarget/weighting/_weight_pred.py b/pySEQTarget/weighting/_weight_pred.py index a3a611f..a751c6d 100644 --- a/pySEQTarget/weighting/_weight_pred.py +++ b/pySEQTarget/weighting/_weight_pred.py @@ -22,11 +22,10 @@ def _weight_predict(self, WDT): for i, level in enumerate(self.treatment_level): mask = pl.col("tx_lag") == level lag_mask = (WDT["tx_lag"] == level).to_numpy() - + denom_model = self._offloader.load_model(self.denominator_model[i]) num_model = self._offloader.load_model(self.numerator_model[i]) - if denom_model is not None: pred_denom = np.ones(WDT.height) if lag_mask.sum() > 0: diff --git a/tests/test_offload.py b/tests/test_offload.py index 51569d9..6c46daf 100644 --- a/tests/test_offload.py +++ b/tests/test_offload.py @@ -1,40 +1,41 @@ +import warnings + from pySEQTarget import SEQopts, SEQuential from pySEQTarget.data import load_data -import warnings def test_compevent_offload(): data = load_data("SEQdata_LTFU") options = SEQopts( - bootstrap_nboot = 2, - cense_colname = "LTFU", - excused = True, - excused_colnames = ["excusedZero", "excusedOne"], - km_curves = True, - selection_random = True, - selection_sample = 0.30, - weighted = True, + bootstrap_nboot=2, + cense_colname="LTFU", + excused=True, + excused_colnames=["excusedZero", "excusedOne"], + km_curves=True, + selection_random=True, + selection_sample=0.30, + weighted=True, weight_lag_condition=False, - weight_p99 = True, - weight_preexpansion = True, - offload = True + weight_p99=True, + weight_preexpansion=True, + offload=True, + ) + + model = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="censoring", + parameters=options, ) - - model = SEQuential(data, - id_col="ID", - time_col="time", - eligible_col="eligible", - treatment_col="tx_init", - outcome_col="outcome", - time_varying_cols=["N", "L", "P"], - fixed_cols=["sex"], - method="censoring", - parameters=options) model.expand() model.bootstrap() # Warnings from statsmodels about overflow in some bootstraps warnings.filterwarnings("ignore") model.fit() model.survival() - - \ No newline at end of file diff --git a/tests/test_survival.py b/tests/test_survival.py index 449429d..36f4261 100644 --- a/tests/test_survival.py +++ b/tests/test_survival.py @@ -1,9 +1,10 @@ -from pySEQTarget import SEQopts, SEQuential -from pySEQTarget.data import load_data - import os + import pytest +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + def test_regular_survival(): data = load_data("SEQdata") @@ -90,6 +91,7 @@ def test_subgroup_bootstrapped_survival(): s.survival() return + @pytest.mark.skipif( os.getenv("CI") == "true", reason="Compevent dying in CI environment" ) @@ -115,6 +117,7 @@ def test_compevent(): s.survival() return + @pytest.mark.skipif( os.getenv("CI") == "true", reason="Compevent dying in CI environment" ) @@ -144,6 +147,7 @@ def test_bootstrapped_compevent(): s.survival() return + @pytest.mark.skipif( os.getenv("CI") == "true", reason="Compevent dying in CI environment" ) From a27a358b462bf9b3460317cbb2f1b150292f3ee1 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Thu, 1 Jan 2026 14:19:06 +0000 Subject: [PATCH 20/27] Use smf.logit() for binary treatment vars When: * treatment_level=[0, 1] (binary control vs treatment) * method="censoring" * weighted=True --- pySEQTarget/weighting/_weight_fit.py | 20 ++++++++++++-- pySEQTarget/weighting/_weight_pred.py | 40 +++++++++++++++++++++------ 2 files changed, 50 insertions(+), 10 deletions(-) diff --git a/pySEQTarget/weighting/_weight_fit.py b/pySEQTarget/weighting/_weight_fit.py index ec3bf2a..0395ca4 100644 --- a/pySEQTarget/weighting/_weight_fit.py +++ b/pySEQTarget/weighting/_weight_fit.py @@ -54,6 +54,9 @@ def _fit_numerator(self, WDT): f"{self.treatment_col}{self.indicator_baseline}" if self.excused else "tx_lag" ) fits = [] + # Use logit for binary 0/1 treatment with censoring method only + # treatment_level=[1,2] or dose-response always uses mnlogit + is_binary = sorted(self.treatment_level) == [0, 1] and self.method == "censoring" for i, level in enumerate(self.treatment_level): if self.excused and self.excused_colnames[i] is not None: DT_subset = WDT[WDT[self.excused_colnames[i]] == 0] @@ -63,11 +66,16 @@ def _fit_numerator(self, WDT): DT_subset = DT_subset[DT_subset[tx_bas] == level] if self.weight_eligible_colnames[i] is not None: DT_subset = DT_subset[DT_subset[self.weight_eligible_colnames[i]] == 1] - model = smf.mnlogit(formula, DT_subset) + # Use logit for binary 0/1 censoring, mnlogit otherwise + if is_binary: + model = smf.logit(formula, DT_subset) + else: + model = smf.mnlogit(formula, DT_subset) model_fit = model.fit(disp=0) fits.append(model_fit) self.numerator_model = fits + self._is_binary_treatment = is_binary def _fit_denominator(self, WDT): @@ -80,6 +88,9 @@ def _fit_denominator(self, WDT): ) formula = f"{predictor}~{self.denominator}" fits = [] + # Use logit for binary 0/1 treatment with censoring method only + # treatment_level=[1,2] or dose-response always uses mnlogit + is_binary = sorted(self.treatment_level) == [0, 1] and self.method == "censoring" for i, level in enumerate(self.treatment_level): if self.excused and self.excused_colnames[i] is not None: DT_subset = WDT[WDT[self.excused_colnames[i]] == 0] @@ -92,8 +103,13 @@ def _fit_denominator(self, WDT): if self.weight_eligible_colnames[i] is not None: DT_subset = DT_subset[DT_subset[self.weight_eligible_colnames[i]] == 1] - model = smf.mnlogit(formula, DT_subset) + # Use logit for binary 0/1 censoring, mnlogit otherwise + if is_binary: + model = smf.logit(formula, DT_subset) + else: + model = smf.mnlogit(formula, DT_subset) model_fit = model.fit(disp=0) fits.append(model_fit) self.denominator_model = fits + self._is_binary_treatment = is_binary diff --git a/pySEQTarget/weighting/_weight_pred.py b/pySEQTarget/weighting/_weight_pred.py index a751c6d..7292aa1 100644 --- a/pySEQTarget/weighting/_weight_pred.py +++ b/pySEQTarget/weighting/_weight_pred.py @@ -9,6 +9,11 @@ def _weight_predict(self, WDT): grouping += ["trial"] if not self.weight_preexpansion else [] time = self.time_col if self.weight_preexpansion else "followup" + # Check if binary 0/1 treatment with censoring (set during fitting) + # Must match the logic in _weight_fit.py + is_binary = getattr(self, '_is_binary_treatment', + sorted(self.treatment_level) == [0, 1] and self.method == "censoring") + if self.method == "ITT": WDT = WDT.with_columns( [pl.lit(1.0).alias("numerator"), pl.lit(1.0).alias("denominator")] @@ -23,6 +28,7 @@ def _weight_predict(self, WDT): mask = pl.col("tx_lag") == level lag_mask = (WDT["tx_lag"] == level).to_numpy() + # Load models via offloader (handles both offloaded and in-memory models) denom_model = self._offloader.load_model(self.denominator_model[i]) num_model = self._offloader.load_model(self.numerator_model[i]) @@ -31,13 +37,23 @@ def _weight_predict(self, WDT): if lag_mask.sum() > 0: subset = WDT.filter(pl.Series(lag_mask)) p = _predict_model(self, denom_model, subset) - if p.ndim == 1: - p = p.reshape(-1, 1) - p = p[:, i] + + # Handle binary vs multinomial prediction output + if is_binary: + # logit returns P(Y=1) directly as 1D array + # For i=0 (level 0): want P(stay at 0) = 1 - P(Y=1) + # For i=1 (level 1): want P(switch to 1) = P(Y=1) + p_class = p if i == 1 else (1 - p) + else: + # mnlogit returns [P(Y=0), P(Y=1), ...] as 2D array + if p.ndim == 1: + p = p.reshape(-1, 1) + p_class = p[:, i] + switched_treatment = ( subset[self.treatment_col] != subset["tx_lag"] ).to_numpy() - pred_denom[lag_mask] = np.where(switched_treatment, 1.0 - p, p) + pred_denom[lag_mask] = np.where(switched_treatment, 1.0 - p_class, p_class) else: pred_denom = np.ones(WDT.height) @@ -46,13 +62,21 @@ def _weight_predict(self, WDT): if lag_mask.sum() > 0: subset = WDT.filter(pl.Series(lag_mask)) p = _predict_model(self, num_model, subset) - if p.ndim == 1: - p = p.reshape(-1, 1) - p = p[:, i] + + # Handle binary vs multinomial prediction output + if is_binary: + # logit returns P(Y=1) directly as 1D array + p_class = p if i == 1 else (1 - p) + else: + # mnlogit returns [P(Y=0), P(Y=1), ...] as 2D array + if p.ndim == 1: + p = p.reshape(-1, 1) + p_class = p[:, i] + switched_treatment = ( subset[self.treatment_col] != subset["tx_lag"] ).to_numpy() - pred_num[lag_mask] = np.where(switched_treatment, 1.0 - p, p) + pred_num[lag_mask] = np.where(switched_treatment, 1.0 - p_class, p_class) else: pred_num = np.ones(WDT.height) From 96fb1f548466ddf6cef89c2be4d0b050c015bf0c Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Thu, 1 Jan 2026 18:15:42 +0000 Subject: [PATCH 21/27] Handle intercept-only formula when numerator is "1" or empty --- pySEQTarget/SEQopts.py | 4 ++-- pySEQTarget/weighting/_weight_fit.py | 12 ++++++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/pySEQTarget/SEQopts.py b/pySEQTarget/SEQopts.py index ad3b92b..6a5061e 100644 --- a/pySEQTarget/SEQopts.py +++ b/pySEQTarget/SEQopts.py @@ -19,7 +19,7 @@ class SEQopts: :type bootstrap_CI_method: str :param cense_colname: Column name for censoring effect (LTFU, etc.) :type cense_colname: str - :param cense_denominator: Override to specify denominator patsy formula for censoring models + :param cense_denominator: Override to specify denominator patsy formula for censoring models; "1" or "" indicate intercept only model :type cense_denominator: Optional[str] or None :param cense_numerator: Override to specify numerator patsy formula for censoring models :type cense_numerator: Optional[str] or None @@ -55,7 +55,7 @@ class SEQopts: :type km_curves: bool :param ncores: Number of cores to use if running in parallel :type ncores: int - :param numerator: Override to specify the outcome patsy formula for numerator models + :param numerator: Override to specify the outcome patsy formula for numerator models; "1" or "" indicate intercept only model :type numerator: str :param offload: Boolean to offload intermediate model data to disk :type offload: bool diff --git a/pySEQTarget/weighting/_weight_fit.py b/pySEQTarget/weighting/_weight_fit.py index 0395ca4..dada57d 100644 --- a/pySEQTarget/weighting/_weight_fit.py +++ b/pySEQTarget/weighting/_weight_fit.py @@ -49,7 +49,11 @@ def _fit_numerator(self, WDT): if self.method == "ITT": return predictor = "switch" if self.excused else self.treatment_col - formula = f"{predictor}~{self.numerator}" + # Handle intercept-only formula when numerator is "1" or empty + if self.numerator in ("1", ""): + formula = f"{predictor}~1" + else: + formula = f"{predictor}~{self.numerator}" tx_bas = ( f"{self.treatment_col}{self.indicator_baseline}" if self.excused else "tx_lag" ) @@ -86,7 +90,11 @@ def _fit_denominator(self, WDT): if self.excused and not self.weight_preexpansion else self.treatment_col ) - formula = f"{predictor}~{self.denominator}" + # Handle intercept-only formula when denominator is "1" or empty + if self.denominator in ("1", ""): + formula = f"{predictor}~1" + else: + formula = f"{predictor}~{self.denominator}" fits = [] # Use logit for binary 0/1 treatment with censoring method only # treatment_level=[1,2] or dose-response always uses mnlogit From 0315a60539079608dbfcbe2c76a472a250878a79 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Thu, 1 Jan 2026 18:39:05 +0000 Subject: [PATCH 22/27] Allow specifying fitting method --- pySEQTarget/SEQopts.py | 3 +++ pySEQTarget/weighting/_weight_fit.py | 6 +++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/pySEQTarget/SEQopts.py b/pySEQTarget/SEQopts.py index 6a5061e..4293a99 100644 --- a/pySEQTarget/SEQopts.py +++ b/pySEQTarget/SEQopts.py @@ -89,6 +89,8 @@ class SEQopts: :type visit_colname: str :param weight_eligible_colnames: List of column names of length treatment_level to identify which rows are eligible for weight fitting :type weight_eligible_colnames: List[str] + :param weight_fit_method: The fitting method to be used ["newton", "bfgs", "lbfgs", "nm"], default "newton" + :type weight_fit_method: str :param weight_min: Minimum weight :type weight_min: float :param weight_max: Maximum weight @@ -145,6 +147,7 @@ class SEQopts: trial_include: bool = True visit_colname: str = None weight_eligible_colnames: List[str] = field(default_factory=lambda: []) + weight_fit_method: Literal["newton", "bfgs", "lbfgs", "nm"] = "newton" weight_min: float = 0.0 weight_max: float = None weight_lag_condition: bool = True diff --git a/pySEQTarget/weighting/_weight_fit.py b/pySEQTarget/weighting/_weight_fit.py index dada57d..ead8814 100644 --- a/pySEQTarget/weighting/_weight_fit.py +++ b/pySEQTarget/weighting/_weight_fit.py @@ -15,7 +15,7 @@ def _fit_pair( for rhs, out in zip(formula_attr, output_attrs): formula = f"{outcome}~{rhs}" model = smf.glm(formula, WDT, family=sm.families.Binomial()) - setattr(self, out, model.fit(disp=0)) + setattr(self, out, model.fit(disp=0, method=self.weight_fit_method)) def _fit_LTFU(self, WDT): @@ -75,7 +75,7 @@ def _fit_numerator(self, WDT): model = smf.logit(formula, DT_subset) else: model = smf.mnlogit(formula, DT_subset) - model_fit = model.fit(disp=0) + model_fit = model.fit(disp=0, method=self.weight_fit_method) fits.append(model_fit) self.numerator_model = fits @@ -116,7 +116,7 @@ def _fit_denominator(self, WDT): model = smf.logit(formula, DT_subset) else: model = smf.mnlogit(formula, DT_subset) - model_fit = model.fit(disp=0) + model_fit = model.fit(disp=0, method=self.weight_fit_method) fits.append(model_fit) self.denominator_model = fits From 90dca1419f3ee94fe16276596b940b6b06cb4d13 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Thu, 1 Jan 2026 19:02:37 +0000 Subject: [PATCH 23/27] Obtain expected category levels from fitted model --- pySEQTarget/helpers/_predict_model.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/pySEQTarget/helpers/_predict_model.py b/pySEQTarget/helpers/_predict_model.py index 5ddd731..66828f6 100644 --- a/pySEQTarget/helpers/_predict_model.py +++ b/pySEQTarget/helpers/_predict_model.py @@ -3,7 +3,27 @@ def _predict_model(self, model, newdata): newdata = newdata.to_pandas() + + # Original behavior - convert fixed_cols to category for col in self.fixed_cols: if col in newdata.columns: newdata[col] = newdata[col].astype("category") - return np.array(model.predict(newdata)) + + try: + return np.array(model.predict(newdata)) + except Exception as e: + if "mismatching levels" in str(e): + # Fix category levels from model's design_info + if hasattr(model, 'model') and hasattr(model.model, 'data') and hasattr(model.model.data, 'design_info'): + design_info = model.model.data.design_info + for factor, factor_info in design_info.factor_infos.items(): + if factor_info.type == 'categorical': + col_name = factor.name() + if col_name in newdata.columns: + expected_categories = list(factor_info.categories) + newdata[col_name] = newdata[col_name].astype(str) + newdata[col_name] = newdata[col_name].astype('category') + newdata[col_name] = newdata[col_name].cat.set_categories(expected_categories) + return np.array(model.predict(newdata)) + else: + raise From 360dd6a45d88150a98c150a9991efd08ba7cc7da Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Thu, 1 Jan 2026 19:30:09 +0000 Subject: [PATCH 24/27] Improve handling of categories for predictions --- pySEQTarget/analysis/_hazard.py | 19 +++++++++++++++++-- pySEQTarget/helpers/_fix_categories.py | 15 +++++++++++++++ pySEQTarget/helpers/_predict_model.py | 14 +++----------- 3 files changed, 35 insertions(+), 13 deletions(-) create mode 100644 pySEQTarget/helpers/_fix_categories.py diff --git a/pySEQTarget/analysis/_hazard.py b/pySEQTarget/analysis/_hazard.py index 27d32f4..c26c747 100644 --- a/pySEQTarget/analysis/_hazard.py +++ b/pySEQTarget/analysis/_hazard.py @@ -4,6 +4,8 @@ import polars as pl from lifelines import CoxPHFitter +from ..helpers._fix_categories import _fix_categories_for_predict + def _calculate_hazard(self): if self.subgroup_colname is None: @@ -64,6 +66,18 @@ def _calculate_hazard_single(self, data, idx=None, val=None): return _create_hazard_output(full_hr, lci, uci, val, self) +def _safe_predict(model, data_pd): + """Predict with category fix fallback if needed.""" + try: + return model.predict(data_pd) + except Exception as e: + if "mismatching levels" in str(e): + data_pd = _fix_categories_for_predict(model, data_pd) + return model.predict(data_pd) + else: + raise + + def _hazard_handler(self, data, idx, boot_idx, rng): exclude_cols = [ "followup", @@ -105,13 +119,14 @@ def _hazard_handler(self, data, idx, boot_idx, rng): ) tmp_pd = tmp.to_pandas() - outcome_prob = outcome_model.predict(tmp_pd) + outcome_prob = _safe_predict(outcome_model, tmp_pd) outcome_sim = rng.binomial(1, outcome_prob) tmp = tmp.with_columns([pl.Series("outcome", outcome_sim)]) if ce_model is not None: - ce_prob = ce_model.predict(tmp_pd) + ce_tmp_pd = tmp.to_pandas() + ce_prob = _safe_predict(ce_model, ce_tmp_pd) ce_sim = rng.binomial(1, ce_prob) tmp = tmp.with_columns([pl.Series("ce", ce_sim)]) diff --git a/pySEQTarget/helpers/_fix_categories.py b/pySEQTarget/helpers/_fix_categories.py new file mode 100644 index 0000000..dda7f98 --- /dev/null +++ b/pySEQTarget/helpers/_fix_categories.py @@ -0,0 +1,15 @@ +def _fix_categories_for_predict(model, newdata): + """ + Fix categorical column ordering in newdata to match what the model expects. + """ + if hasattr(model, 'model') and hasattr(model.model, 'data') and hasattr(model.model.data, 'design_info'): + design_info = model.model.data.design_info + for factor, factor_info in design_info.factor_infos.items(): + if factor_info.type == 'categorical': + col_name = factor.name() + if col_name in newdata.columns: + expected_categories = list(factor_info.categories) + newdata[col_name] = newdata[col_name].astype(str) + newdata[col_name] = newdata[col_name].astype('category') + newdata[col_name] = newdata[col_name].cat.set_categories(expected_categories) + return newdata diff --git a/pySEQTarget/helpers/_predict_model.py b/pySEQTarget/helpers/_predict_model.py index 66828f6..c254061 100644 --- a/pySEQTarget/helpers/_predict_model.py +++ b/pySEQTarget/helpers/_predict_model.py @@ -1,5 +1,7 @@ import numpy as np +from ._fix_categories import _fix_categories_for_predict + def _predict_model(self, model, newdata): newdata = newdata.to_pandas() @@ -13,17 +15,7 @@ def _predict_model(self, model, newdata): return np.array(model.predict(newdata)) except Exception as e: if "mismatching levels" in str(e): - # Fix category levels from model's design_info - if hasattr(model, 'model') and hasattr(model.model, 'data') and hasattr(model.model.data, 'design_info'): - design_info = model.model.data.design_info - for factor, factor_info in design_info.factor_infos.items(): - if factor_info.type == 'categorical': - col_name = factor.name() - if col_name in newdata.columns: - expected_categories = list(factor_info.categories) - newdata[col_name] = newdata[col_name].astype(str) - newdata[col_name] = newdata[col_name].astype('category') - newdata[col_name] = newdata[col_name].cat.set_categories(expected_categories) + newdata = _fix_categories_for_predict(model, newdata) return np.array(model.predict(newdata)) else: raise From 8d3aba0a1ae70cd37ae632fdc1f49d41632a833e Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 2 Jan 2026 07:12:00 +0000 Subject: [PATCH 25/27] Account for NaNs in predicted probs --- pySEQTarget/analysis/_hazard.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/pySEQTarget/analysis/_hazard.py b/pySEQTarget/analysis/_hazard.py index c26c747..d132aee 100644 --- a/pySEQTarget/analysis/_hazard.py +++ b/pySEQTarget/analysis/_hazard.py @@ -68,15 +68,26 @@ def _calculate_hazard_single(self, data, idx=None, val=None): def _safe_predict(model, data_pd): """Predict with category fix fallback if needed.""" + # Make a copy to avoid modifying original + data_pd = data_pd.copy() + try: - return model.predict(data_pd) + probs = model.predict(data_pd) except Exception as e: if "mismatching levels" in str(e): data_pd = _fix_categories_for_predict(model, data_pd) - return model.predict(data_pd) + probs = model.predict(data_pd) else: raise + # Ensure probabilities are valid (clip to [0, 1] and replace NaN with 0.5) + probs = np.array(probs) + if np.any(np.isnan(probs)): + warnings.warn("NaN values in predicted probabilities, replacing with 0.5") + probs = np.where(np.isnan(probs), 0.5, probs) + probs = np.clip(probs, 0, 1) + + return probs def _hazard_handler(self, data, idx, boot_idx, rng): exclude_cols = [ From 27ffd787897a4b42a3392f387fba550620eed8ec Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 2 Jan 2026 12:33:40 +0000 Subject: [PATCH 26/27] Make survival preds safe --- pySEQTarget/analysis/_survival_pred.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/pySEQTarget/analysis/_survival_pred.py b/pySEQTarget/analysis/_survival_pred.py index 79e99d2..e649801 100644 --- a/pySEQTarget/analysis/_survival_pred.py +++ b/pySEQTarget/analysis/_survival_pred.py @@ -1,5 +1,19 @@ import polars as pl +from ..helpers._fix_categories import _fix_categories_for_predict + + +def _safe_predict(model, data): + """Predict with category fix fallback if needed.""" + try: + return model.predict(data) + except Exception as e: + if "mismatching levels" in str(e): + data = _fix_categories_for_predict(model, data) + return model.predict(data) + else: + raise + def _get_outcome_predictions(self, TxDT, idx=None): data = TxDT.to_pandas() @@ -10,11 +24,11 @@ def _get_outcome_predictions(self, TxDT, idx=None): for boot_model in self.outcome_model: model_dict = boot_model[idx] if idx is not None else boot_model outcome_model = self._offloader.load_model(model_dict["outcome"]) - predictions["outcome"].append(outcome_model.predict(data)) + predictions["outcome"].append(_safe_predict(outcome_model, data.copy())) if self.compevent_colname is not None: compevent_model = self._offloader.load_model(model_dict["compevent"]) - predictions["compevent"].append(compevent_model.predict(data)) + predictions["compevent"].append(_safe_predict(compevent_model, data.copy())) return predictions From c211376baab78b8bde48554a906137ef685b3562 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 2 Jan 2026 12:39:55 +0000 Subject: [PATCH 27/27] Move _safe_predict into a helper file --- pySEQTarget/analysis/_hazard.py | 25 +----------------- pySEQTarget/analysis/_survival_pred.py | 14 +--------- pySEQTarget/helpers/_predict_model.py | 36 ++++++++++++++++++++++++++ 3 files changed, 38 insertions(+), 37 deletions(-) diff --git a/pySEQTarget/analysis/_hazard.py b/pySEQTarget/analysis/_hazard.py index d132aee..06200ba 100644 --- a/pySEQTarget/analysis/_hazard.py +++ b/pySEQTarget/analysis/_hazard.py @@ -4,7 +4,7 @@ import polars as pl from lifelines import CoxPHFitter -from ..helpers._fix_categories import _fix_categories_for_predict +from ..helpers._predict_model import _safe_predict def _calculate_hazard(self): @@ -66,29 +66,6 @@ def _calculate_hazard_single(self, data, idx=None, val=None): return _create_hazard_output(full_hr, lci, uci, val, self) -def _safe_predict(model, data_pd): - """Predict with category fix fallback if needed.""" - # Make a copy to avoid modifying original - data_pd = data_pd.copy() - - try: - probs = model.predict(data_pd) - except Exception as e: - if "mismatching levels" in str(e): - data_pd = _fix_categories_for_predict(model, data_pd) - probs = model.predict(data_pd) - else: - raise - - # Ensure probabilities are valid (clip to [0, 1] and replace NaN with 0.5) - probs = np.array(probs) - if np.any(np.isnan(probs)): - warnings.warn("NaN values in predicted probabilities, replacing with 0.5") - probs = np.where(np.isnan(probs), 0.5, probs) - probs = np.clip(probs, 0, 1) - - return probs - def _hazard_handler(self, data, idx, boot_idx, rng): exclude_cols = [ "followup", diff --git a/pySEQTarget/analysis/_survival_pred.py b/pySEQTarget/analysis/_survival_pred.py index e649801..4b24c4f 100644 --- a/pySEQTarget/analysis/_survival_pred.py +++ b/pySEQTarget/analysis/_survival_pred.py @@ -1,18 +1,6 @@ import polars as pl -from ..helpers._fix_categories import _fix_categories_for_predict - - -def _safe_predict(model, data): - """Predict with category fix fallback if needed.""" - try: - return model.predict(data) - except Exception as e: - if "mismatching levels" in str(e): - data = _fix_categories_for_predict(model, data) - return model.predict(data) - else: - raise +from ..helpers._predict_model import _safe_predict def _get_outcome_predictions(self, TxDT, idx=None): diff --git a/pySEQTarget/helpers/_predict_model.py b/pySEQTarget/helpers/_predict_model.py index c254061..ec9fadc 100644 --- a/pySEQTarget/helpers/_predict_model.py +++ b/pySEQTarget/helpers/_predict_model.py @@ -1,8 +1,44 @@ +import warnings + import numpy as np from ._fix_categories import _fix_categories_for_predict +def _safe_predict(model, data, clip_probs=True): + """ + Predict with category fix fallback if needed. + + Parameters + ---------- + model : statsmodels model + Fitted model object + data : pandas DataFrame + Data to predict on + clip_probs : bool + If True, clip probabilities to [0, 1] and replace NaN with 0.5 + """ + data = data.copy() + + try: + probs = model.predict(data) + except Exception as e: + if "mismatching levels" in str(e): + data = _fix_categories_for_predict(model, data) + probs = model.predict(data) + else: + raise + + if clip_probs: + probs = np.array(probs) + if np.any(np.isnan(probs)): + warnings.warn("NaN values in predicted probabilities, replacing with 0.5") + probs = np.where(np.isnan(probs), 0.5, probs) + probs = np.clip(probs, 0, 1) + + return probs + + def _predict_model(self, model, newdata): newdata = newdata.to_pandas()