diff --git a/docs/conf.py b/docs/conf.py index 1023c2c..ec9b34e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -12,7 +12,7 @@ version = importlib.metadata.version("pySEQTarget") if not version: - version = "0.11.0" + version = "0.12.0" sys.path.insert(0, os.path.abspath("../")) project = "pySEQTarget" diff --git a/pySEQTarget/analysis/_risk_estimates.py b/pySEQTarget/analysis/_risk_estimates.py index 39e2382..0cf330b 100644 --- a/pySEQTarget/analysis/_risk_estimates.py +++ b/pySEQTarget/analysis/_risk_estimates.py @@ -2,6 +2,69 @@ from scipy import stats +def _compute_rd_rr(comp, has_bootstrap, z=None, group_cols=None): + """ + Compute Risk Difference and Risk Ratio from a comparison dataframe. + Consolidates the repeated calculation logic. + """ + if group_cols is None: + group_cols = [] + + if has_bootstrap: + rd_se = (pl.col("se_x").pow(2) + pl.col("se_y").pow(2)).sqrt() + rd_comp = comp.with_columns( + [ + (pl.col("risk_x") - pl.col("risk_y")).alias("Risk Difference"), + (pl.col("risk_x") - pl.col("risk_y") - z * rd_se).alias("RD 95% LCI"), + (pl.col("risk_x") - pl.col("risk_y") + z * rd_se).alias("RD 95% UCI"), + ] + ) + rd_comp = rd_comp.drop(["risk_x", "risk_y", "se_x", "se_y"]) + col_order = group_cols + [ + "A_x", + "A_y", + "Risk Difference", + "RD 95% LCI", + "RD 95% UCI", + ] + rd_comp = rd_comp.select([c for c in col_order if c in rd_comp.columns]) + + rr_log_se = ( + (pl.col("se_x") / pl.col("risk_x")).pow(2) + + (pl.col("se_y") / pl.col("risk_y")).pow(2) + ).sqrt() + rr_comp = comp.with_columns( + [ + (pl.col("risk_x") / pl.col("risk_y")).alias("Risk Ratio"), + ( + (pl.col("risk_x") / pl.col("risk_y")) * (-z * rr_log_se).exp() + ).alias("RR 95% LCI"), + ( + (pl.col("risk_x") / pl.col("risk_y")) * (z * rr_log_se).exp() + ).alias("RR 95% UCI"), + ] + ) + rr_comp = rr_comp.drop(["risk_x", "risk_y", "se_x", "se_y"]) + col_order = group_cols + ["A_x", "A_y", "Risk Ratio", "RR 95% LCI", "RR 95% UCI"] + rr_comp = rr_comp.select([c for c in col_order if c in rr_comp.columns]) + else: + rd_comp = comp.with_columns( + (pl.col("risk_x") - pl.col("risk_y")).alias("Risk Difference") + ) + rd_comp = rd_comp.drop(["risk_x", "risk_y"]) + col_order = group_cols + ["A_x", "A_y", "Risk Difference"] + rd_comp = rd_comp.select([c for c in col_order if c in rd_comp.columns]) + + rr_comp = comp.with_columns( + (pl.col("risk_x") / pl.col("risk_y")).alias("Risk Ratio") + ) + rr_comp = rr_comp.drop(["risk_x", "risk_y"]) + col_order = group_cols + ["A_x", "A_y", "Risk Ratio"] + rr_comp = rr_comp.select([c for c in col_order if c in rr_comp.columns]) + + return rd_comp, rr_comp + + def _risk_estimates(self): last_followup = self.km_data["followup"].max() risk = self.km_data.filter( @@ -9,29 +72,35 @@ def _risk_estimates(self): ) group_cols = [self.subgroup_colname] if self.subgroup_colname else [] - rd_comparisons = [] - rr_comparisons = [] + has_bootstrap = self.bootstrap_nboot > 0 - if self.bootstrap_nboot > 0: + if has_bootstrap: alpha = 1 - self.bootstrap_CI z = stats.norm.ppf(1 - alpha / 2) + else: + z = None + + # Pre-extract data for each treatment level once (avoid repeated filtering) + risk_by_level = {} + for tx in self.treatment_level: + level_data = risk.filter(pl.col(self.treatment_col) == tx) + risk_by_level[tx] = { + "pred": level_data.select(group_cols + ["pred"]), + } + if has_bootstrap: + risk_by_level[tx]["SE"] = level_data.select(group_cols + ["SE"]) + + rd_comparisons = [] + rr_comparisons = [] for tx_x in self.treatment_level: for tx_y in self.treatment_level: if tx_x == tx_y: continue - risk_x = ( - risk.filter(pl.col(self.treatment_col) == tx_x) - .select(group_cols + ["pred"]) - .rename({"pred": "risk_x"}) - ) - - risk_y = ( - risk.filter(pl.col(self.treatment_col) == tx_y) - .select(group_cols + ["pred"]) - .rename({"pred": "risk_y"}) - ) + # Use pre-extracted data instead of filtering again + risk_x = risk_by_level[tx_x]["pred"].rename({"pred": "risk_x"}) + risk_y = risk_by_level[tx_y]["pred"].rename({"pred": "risk_y"}) if group_cols: comp = risk_x.join(risk_y, on=group_cols, how="left") @@ -42,18 +111,9 @@ def _risk_estimates(self): [pl.lit(tx_x).alias("A_x"), pl.lit(tx_y).alias("A_y")] ) - if self.bootstrap_nboot > 0: - se_x = ( - risk.filter(pl.col(self.treatment_col) == tx_x) - .select(group_cols + ["SE"]) - .rename({"SE": "se_x"}) - ) - - se_y = ( - risk.filter(pl.col(self.treatment_col) == tx_y) - .select(group_cols + ["SE"]) - .rename({"SE": "se_y"}) - ) + if has_bootstrap: + se_x = risk_by_level[tx_x]["SE"].rename({"SE": "se_x"}) + se_y = risk_by_level[tx_y]["SE"].rename({"SE": "se_y"}) if group_cols: comp = comp.join(se_x, on=group_cols, how="left") @@ -62,73 +122,9 @@ def _risk_estimates(self): comp = comp.join(se_x, how="cross") comp = comp.join(se_y, how="cross") - rd_se = (pl.col("se_x").pow(2) + pl.col("se_y").pow(2)).sqrt() - rd_comp = comp.with_columns( - [ - (pl.col("risk_x") - pl.col("risk_y")).alias("Risk Difference"), - (pl.col("risk_x") - pl.col("risk_y") - z * rd_se).alias( - "RD 95% LCI" - ), - (pl.col("risk_x") - pl.col("risk_y") + z * rd_se).alias( - "RD 95% UCI" - ), - ] - ) - rd_comp = rd_comp.drop(["risk_x", "risk_y", "se_x", "se_y"]) - col_order = group_cols + [ - "A_x", - "A_y", - "Risk Difference", - "RD 95% LCI", - "RD 95% UCI", - ] - rd_comp = rd_comp.select([c for c in col_order if c in rd_comp.columns]) - rd_comparisons.append(rd_comp) - - rr_log_se = ( - (pl.col("se_x") / pl.col("risk_x")).pow(2) - + (pl.col("se_y") / pl.col("risk_y")).pow(2) - ).sqrt() - rr_comp = comp.with_columns( - [ - (pl.col("risk_x") / pl.col("risk_y")).alias("Risk Ratio"), - ( - (pl.col("risk_x") / pl.col("risk_y")) - * (-z * rr_log_se).exp() - ).alias("RR 95% LCI"), - ( - (pl.col("risk_x") / pl.col("risk_y")) - * (z * rr_log_se).exp() - ).alias("RR 95% UCI"), - ] - ) - rr_comp = rr_comp.drop(["risk_x", "risk_y", "se_x", "se_y"]) - col_order = group_cols + [ - "A_x", - "A_y", - "Risk Ratio", - "RR 95% LCI", - "RR 95% UCI", - ] - rr_comp = rr_comp.select([c for c in col_order if c in rr_comp.columns]) - rr_comparisons.append(rr_comp) - - else: - rd_comp = comp.with_columns( - (pl.col("risk_x") - pl.col("risk_y")).alias("Risk Difference") - ) - rd_comp = rd_comp.drop(["risk_x", "risk_y"]) - col_order = group_cols + ["A_x", "A_y", "Risk Difference"] - rd_comp = rd_comp.select([c for c in col_order if c in rd_comp.columns]) - rd_comparisons.append(rd_comp) - - rr_comp = comp.with_columns( - (pl.col("risk_x") / pl.col("risk_y")).alias("Risk Ratio") - ) - rr_comp = rr_comp.drop(["risk_x", "risk_y"]) - col_order = group_cols + ["A_x", "A_y", "Risk Ratio"] - rr_comp = rr_comp.select([c for c in col_order if c in rr_comp.columns]) - rr_comparisons.append(rr_comp) + rd_comp, rr_comp = _compute_rd_rr(comp, has_bootstrap, z, group_cols) + rd_comparisons.append(rd_comp) + rr_comparisons.append(rr_comp) risk_difference = pl.concat(rd_comparisons) if rd_comparisons else pl.DataFrame() risk_ratio = pl.concat(rr_comparisons) if rr_comparisons else pl.DataFrame() diff --git a/pySEQTarget/analysis/_survival_pred.py b/pySEQTarget/analysis/_survival_pred.py index 4b24c4f..4e9234f 100644 --- a/pySEQTarget/analysis/_survival_pred.py +++ b/pySEQTarget/analysis/_survival_pred.py @@ -46,24 +46,20 @@ def _calculate_risk(self, data, idx=None, val=None): lci = a / 2 uci = 1 - lci + # Pre-compute the followup range once (starts at 1, not 0) + followup_range = list(range(1, self.followup_max + 1)) + SDT = ( data.with_columns( - [ - ( - pl.col(self.id_col).cast(pl.Utf8) + pl.col("trial").cast(pl.Utf8) - ).alias("TID") - ] + [pl.concat_str([pl.col(self.id_col), pl.col("trial")]).alias("TID")] ) .group_by("TID") .first() .drop(["followup", f"followup{self.indicator_squared}"]) - .with_columns([pl.lit(list(range(self.followup_max))).alias("followup")]) + .with_columns([pl.lit(followup_range).alias("followup")]) .explode("followup") .with_columns( - [ - (pl.col("followup") + 1).alias("followup"), - (pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}"), - ] + [(pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}")] ) ).sort([self.id_col, "trial", "followup"]) diff --git a/pySEQTarget/expansion/_mapper.py b/pySEQTarget/expansion/_mapper.py index c169669..37600e5 100644 --- a/pySEQTarget/expansion/_mapper.py +++ b/pySEQTarget/expansion/_mapper.py @@ -13,17 +13,10 @@ def _mapper(data, id_col, time_col, min_followup=-math.inf, max_followup=math.in .with_columns([pl.col(id_col).cum_count().over(id_col).sub(1).alias("trial")]) .with_columns( [ - pl.struct( - [ - pl.col(time_col), - pl.col(time_col).max().over(id_col).alias("max_time"), - ] - ) - .map_elements( - lambda x: list(range(x[time_col], x["max_time"] + 1)), - return_dtype=pl.List(pl.Int64), - ) - .alias("period") + pl.int_ranges( + pl.col(time_col), + pl.col(time_col).max().over(id_col) + 1, + ).alias("period") ] ) .explode("period") diff --git a/pySEQTarget/helpers/_bootstrap.py b/pySEQTarget/helpers/_bootstrap.py index 69770b2..67dda0e 100644 --- a/pySEQTarget/helpers/_bootstrap.py +++ b/pySEQTarget/helpers/_bootstrap.py @@ -35,7 +35,13 @@ def _prepare_boot_data(self, data, boot_id): def _bootstrap_worker(obj, method_name, original_DT, i, seed, args, kwargs): - obj = copy.deepcopy(obj) + # Shallow copy the object and only deep copy mutable state that changes per-bootstrap + obj = copy.copy(obj) + # Deep copy only the mutable attributes that get modified during fitting + obj.outcome_model = [] + obj.numerator_model = copy.copy(obj.numerator_model) if hasattr(obj, 'numerator_model') and obj.numerator_model else [] + obj.denominator_model = copy.copy(obj.denominator_model) if hasattr(obj, 'denominator_model') and obj.denominator_model else [] + obj._rng = ( np.random.RandomState(seed + i) if seed is not None else np.random.RandomState() ) @@ -104,13 +110,19 @@ def wrapper(self, *args, **kwargs): self._rng = original_rng self.DT = self._offloader.load_dataframe(original_DT_ref) else: - original_DT_ref = self._offloader.save_dataframe(original_DT, "_DT") - del original_DT + # Keep original data in memory if offloading is disabled to avoid unnecessary I/O + if self._offloader.enabled: + original_DT_ref = self._offloader.save_dataframe(original_DT, "_DT") + del original_DT + else: + original_DT_ref = original_DT + for i in tqdm(range(nboot), desc="Bootstrapping..."): self._current_boot_idx = i + 1 tmp = self._offloader.load_dataframe(original_DT_ref) self.DT = _prepare_boot_data(self, tmp, i) - del tmp + if self._offloader.enabled: + del tmp self.bootstrap_nboot = 0 boot_fit = method(self, *args, **kwargs) results.append(boot_fit) diff --git a/pySEQTarget/helpers/_offloader.py b/pySEQTarget/helpers/_offloader.py index d0882d0..509a233 100644 --- a/pySEQTarget/helpers/_offloader.py +++ b/pySEQTarget/helpers/_offloader.py @@ -1,3 +1,4 @@ +from functools import lru_cache from pathlib import Path from typing import Any, Optional, Union @@ -12,6 +13,25 @@ def __init__(self, enabled: bool, dir: str, compression: int = 3): self.enabled = enabled self.dir = Path(dir) self.compression = compression + # Create a cached loader bound to this instance + self._init_cache() + + def _init_cache(self): + """Initialize the LRU cache for model loading.""" + self._cached_load = lru_cache(maxsize=32)(self._load_from_disk) + + def __getstate__(self): + """Prepare state for pickling - exclude the unpicklable cache.""" + state = self.__dict__.copy() + # Remove the cache wrapper which can't be pickled + del state['_cached_load'] + return state + + def __setstate__(self, state): + """Restore state after unpickling - recreate the cache.""" + self.__dict__.update(state) + # Recreate the cache after unpickling + self._init_cache() def save_model( self, model: Any, name: str, boot_idx: Optional[int] = None @@ -29,11 +49,20 @@ def save_model( return str(filepath) + def _load_from_disk(self, filepath: str) -> Any: + """Internal method to load a model from disk (cached).""" + return joblib.load(filepath) + def load_model(self, ref: Union[Any, str]) -> Any: + """Load a model, using cache for repeated loads of the same file.""" if not self.enabled or not isinstance(ref, str): return ref - return joblib.load(ref) + return self._cached_load(ref) + + def clear_cache(self) -> None: + """Clear the model loading cache. Call between bootstrap iterations if needed.""" + self._cached_load.cache_clear() def save_dataframe(self, df: pl.DataFrame, name: str) -> Union[pl.DataFrame, str]: if not self.enabled: diff --git a/pySEQTarget/weighting/_weight_fit.py b/pySEQTarget/weighting/_weight_fit.py index ead8814..220500a 100644 --- a/pySEQTarget/weighting/_weight_fit.py +++ b/pySEQTarget/weighting/_weight_fit.py @@ -2,6 +2,32 @@ import statsmodels.formula.api as smf +def _get_subset_for_level(self, WDT, level_idx, level, tx_lag_col, exclude_followup_zero=False): + """ + Helper to create the subset of data for a given treatment level. + Consolidates the repeated filtering logic from _fit_numerator and _fit_denominator. + """ + DT_subset = WDT + + # Filter by excused column if applicable + if self.excused and self.excused_colnames[level_idx] is not None: + DT_subset = DT_subset[DT_subset[self.excused_colnames[level_idx]] == 0] + + # Filter by treatment lag condition + if self.weight_lag_condition: + DT_subset = DT_subset[DT_subset[tx_lag_col] == level] + + # Exclude followup == 0 for denominator (not pre-expansion) + if exclude_followup_zero: + DT_subset = DT_subset[DT_subset["followup"] != 0] + + # Filter by eligibility column if applicable + if self.weight_eligible_colnames[level_idx] is not None: + DT_subset = DT_subset[DT_subset[self.weight_eligible_colnames[level_idx]] == 1] + + return DT_subset + + def _fit_pair( self, WDT, outcome_attr, formula_attr, output_attrs, eligible_colname_attr=None ): @@ -54,7 +80,7 @@ def _fit_numerator(self, WDT): formula = f"{predictor}~1" else: formula = f"{predictor}~{self.numerator}" - tx_bas = ( + tx_lag_col = ( f"{self.treatment_col}{self.indicator_baseline}" if self.excused else "tx_lag" ) fits = [] @@ -62,14 +88,7 @@ def _fit_numerator(self, WDT): # treatment_level=[1,2] or dose-response always uses mnlogit is_binary = sorted(self.treatment_level) == [0, 1] and self.method == "censoring" for i, level in enumerate(self.treatment_level): - if self.excused and self.excused_colnames[i] is not None: - DT_subset = WDT[WDT[self.excused_colnames[i]] == 0] - else: - DT_subset = WDT - if self.weight_lag_condition: - DT_subset = DT_subset[DT_subset[tx_bas] == level] - if self.weight_eligible_colnames[i] is not None: - DT_subset = DT_subset[DT_subset[self.weight_eligible_colnames[i]] == 1] + DT_subset = _get_subset_for_level(self, WDT, i, level, tx_lag_col) # Use logit for binary 0/1 censoring, mnlogit otherwise if is_binary: model = smf.logit(formula, DT_subset) @@ -99,18 +118,11 @@ def _fit_denominator(self, WDT): # Use logit for binary 0/1 treatment with censoring method only # treatment_level=[1,2] or dose-response always uses mnlogit is_binary = sorted(self.treatment_level) == [0, 1] and self.method == "censoring" + exclude_followup_zero = not self.weight_preexpansion for i, level in enumerate(self.treatment_level): - if self.excused and self.excused_colnames[i] is not None: - DT_subset = WDT[WDT[self.excused_colnames[i]] == 0] - else: - DT_subset = WDT - if self.weight_lag_condition: - DT_subset = DT_subset[DT_subset["tx_lag"] == level] - if not self.weight_preexpansion: - DT_subset = DT_subset[DT_subset["followup"] != 0] - if self.weight_eligible_colnames[i] is not None: - DT_subset = DT_subset[DT_subset[self.weight_eligible_colnames[i]] == 1] - + DT_subset = _get_subset_for_level( + self, WDT, i, level, "tx_lag", exclude_followup_zero=exclude_followup_zero + ) # Use logit for binary 0/1 censoring, mnlogit otherwise if is_binary: model = smf.logit(formula, DT_subset) diff --git a/pySEQTarget/weighting/_weight_pred.py b/pySEQTarget/weighting/_weight_pred.py index 354e6fc..551e3b4 100644 --- a/pySEQTarget/weighting/_weight_pred.py +++ b/pySEQTarget/weighting/_weight_pred.py @@ -4,6 +4,23 @@ from ..helpers import _predict_model +def _extract_class_probability(p, level_idx, is_binary): + """ + Extract the probability for a specific class from model predictions. + Handles both binary (logit) and multinomial (mnlogit) output formats. + """ + if is_binary: + # logit returns P(Y=1) directly as 1D array + # For level 0: want P(stay at 0) = 1 - P(Y=1) + # For level 1: want P(switch to 1) = P(Y=1) + return p if level_idx == 1 else (1 - p) + else: + # mnlogit returns [P(Y=0), P(Y=1), ...] as 2D array + if p.ndim == 1: + p = p.reshape(-1, 1) + return p[:, level_idx] + + def _weight_predict(self, WDT): grouping = [self.id_col] grouping += ["trial"] if not self.weight_preexpansion else [] @@ -27,78 +44,43 @@ def _weight_predict(self, WDT): ) if not self.excused: + # Pre-allocate arrays once for all treatment levels + pred_num = np.ones(WDT.height) + pred_denom = np.ones(WDT.height) + + # Pre-compute treatment switch mask once + switched_treatment = (WDT[self.treatment_col] != WDT["tx_lag"]).to_numpy() + tx_lag_array = WDT["tx_lag"].to_numpy() + for i, level in enumerate(self.treatment_level): - mask = pl.col("tx_lag") == level - lag_mask = (WDT["tx_lag"] == level).to_numpy() + lag_mask = tx_lag_array == level # Load models via offloader (handles both offloaded and in-memory models) denom_model = self._offloader.load_model(self.denominator_model[i]) num_model = self._offloader.load_model(self.numerator_model[i]) - if denom_model is not None: - pred_denom = np.ones(WDT.height) - if lag_mask.sum() > 0: - subset = WDT.filter(pl.Series(lag_mask)) - p = _predict_model(self, denom_model, subset) - - # Handle binary vs multinomial prediction output - if is_binary: - # logit returns P(Y=1) directly as 1D array - # For i=0 (level 0): want P(stay at 0) = 1 - P(Y=1) - # For i=1 (level 1): want P(switch to 1) = P(Y=1) - p_class = p if i == 1 else (1 - p) - else: - # mnlogit returns [P(Y=0), P(Y=1), ...] as 2D array - if p.ndim == 1: - p = p.reshape(-1, 1) - p_class = p[:, i] - - switched_treatment = ( - subset[self.treatment_col] != subset["tx_lag"] - ).to_numpy() - pred_denom[lag_mask] = np.where( - switched_treatment, 1.0 - p_class, p_class - ) - else: - pred_denom = np.ones(WDT.height) - - if num_model is not None: - pred_num = np.ones(WDT.height) - if lag_mask.sum() > 0: - subset = WDT.filter(pl.Series(lag_mask)) - p = _predict_model(self, num_model, subset) - - # Handle binary vs multinomial prediction output - if is_binary: - # logit returns P(Y=1) directly as 1D array - p_class = p if i == 1 else (1 - p) - else: - # mnlogit returns [P(Y=0), P(Y=1), ...] as 2D array - if p.ndim == 1: - p = p.reshape(-1, 1) - p_class = p[:, i] + if denom_model is not None and lag_mask.sum() > 0: + subset = WDT.filter(pl.Series(lag_mask)) + p = _predict_model(self, denom_model, subset) + p_class = _extract_class_probability(p, i, is_binary) + pred_denom[lag_mask] = np.where( + switched_treatment[lag_mask], 1.0 - p_class, p_class + ) - switched_treatment = ( - subset[self.treatment_col] != subset["tx_lag"] - ).to_numpy() - pred_num[lag_mask] = np.where( - switched_treatment, 1.0 - p_class, p_class - ) - else: - pred_num = np.ones(WDT.height) + if num_model is not None and lag_mask.sum() > 0: + subset = WDT.filter(pl.Series(lag_mask)) + p = _predict_model(self, num_model, subset) + p_class = _extract_class_probability(p, i, is_binary) + pred_num[lag_mask] = np.where( + switched_treatment[lag_mask], 1.0 - p_class, p_class + ) - WDT = WDT.with_columns( - [ - pl.when(mask) - .then(pl.Series(pred_num)) - .otherwise(pl.col("numerator")) - .alias("numerator"), - pl.when(mask) - .then(pl.Series(pred_denom)) - .otherwise(pl.col("denominator")) - .alias("denominator"), - ] - ) + WDT = WDT.with_columns( + [ + pl.Series("numerator", pred_num), + pl.Series("denominator", pred_denom), + ] + ) else: for i, level in enumerate(self.treatment_level): diff --git a/pyproject.toml b/pyproject.toml index 8633196..e81af2a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pySEQTarget" -version = "0.11.0" +version = "0.12.0" description = "Sequentially Nested Target Trial Emulation" readme = "README.md" license = {text = "MIT"}