From 341ddfa8f2ef3ba595b59f3f6288b1562bee9387 Mon Sep 17 00:00:00 2001 From: Sean Koval Date: Fri, 13 Feb 2026 16:02:41 -0500 Subject: [PATCH 1/2] feat(openquant): add AFML ch6 ensemble methods module --- crates/openquant/src/ensemble_methods.rs | 294 +++++++++++++++++++++ crates/openquant/src/lib.rs | 1 + crates/openquant/tests/ensemble_methods.rs | 80 ++++++ docs-site/src/data/afmlDocsState.ts | 20 ++ docs-site/src/data/moduleDocs.ts | 50 ++++ 5 files changed, 445 insertions(+) create mode 100644 crates/openquant/src/ensemble_methods.rs create mode 100644 crates/openquant/tests/ensemble_methods.rs diff --git a/crates/openquant/src/ensemble_methods.rs b/crates/openquant/src/ensemble_methods.rs new file mode 100644 index 0000000..9deb070 --- /dev/null +++ b/crates/openquant/src/ensemble_methods.rs @@ -0,0 +1,294 @@ +//! Ensemble-method utilities aligned to AFML Chapter 6. +//! +//! This module provides: +//! - Bias/variance/noise diagnostics for ensemble forecasts. +//! - Bagging mechanics (bootstrap + sequential-bootstrap wrappers). +//! - Aggregation helpers (majority vote and mean probability). +//! - Dependency-aware diagnostics for when bagging is likely to underperform. +//! - A practical bagging-vs-boosting recommendation heuristic. + +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; + +use crate::sampling::seq_bootstrap; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum EnsembleMethod { + Bagging, + Boosting, +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct BiasVarianceNoise { + pub bias_sq: f64, + pub variance: f64, + pub noise: f64, + pub mse: f64, +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct BaggingBoostingDecision { + pub recommended: EnsembleMethod, + pub expected_bagging_variance: f64, + pub expected_variance_reduction: f64, +} + +pub fn bias_variance_noise( + y_true: &[f64], + per_model_predictions: &[Vec], +) -> Result { + if y_true.is_empty() { + return Err("y_true cannot be empty".to_string()); + } + if per_model_predictions.is_empty() { + return Err("per_model_predictions cannot be empty".to_string()); + } + if per_model_predictions.iter().any(|row| row.len() != y_true.len()) { + return Err("prediction length mismatch".to_string()); + } + + let n_models = per_model_predictions.len() as f64; + let n_samples = y_true.len() as f64; + + let mut bias_sq_sum = 0.0; + let mut var_sum = 0.0; + let mut mse_sum = 0.0; + + for i in 0..y_true.len() { + let mut mean_pred = 0.0; + for model in per_model_predictions { + mean_pred += model[i]; + let err = model[i] - y_true[i]; + mse_sum += err * err; + } + mean_pred /= n_models; + + let bias = mean_pred - y_true[i]; + bias_sq_sum += bias * bias; + + let mut local_var = 0.0; + for model in per_model_predictions { + let d = model[i] - mean_pred; + local_var += d * d; + } + local_var /= n_models; + var_sum += local_var; + } + + let bias_sq = bias_sq_sum / n_samples; + let variance = var_sum / n_samples; + let mse = mse_sum / (n_samples * n_models); + let noise = (mse - bias_sq - variance).max(0.0); + + Ok(BiasVarianceNoise { + bias_sq, + variance, + noise, + mse, + }) +} + +pub fn bootstrap_sample_indices( + n_samples: usize, + sample_size: usize, + seed: u64, +) -> Result, String> { + if n_samples == 0 || sample_size == 0 { + return Err("n_samples and sample_size must be > 0".to_string()); + } + let mut rng = StdRng::seed_from_u64(seed); + Ok((0..sample_size).map(|_| rng.gen_range(0..n_samples)).collect()) +} + +pub fn sequential_bootstrap_sample_indices( + ind_mat: &[Vec], + sample_size: usize, + seed: u64, +) -> Result, String> { + if sample_size == 0 { + return Err("sample_size must be > 0".to_string()); + } + if ind_mat.is_empty() { + return Err("ind_mat cannot be empty".to_string()); + } + let n_labels = ind_mat.first().map(|r| r.len()).unwrap_or(0); + if n_labels == 0 { + return Err("ind_mat must include at least one label column".to_string()); + } + + let mut rng = StdRng::seed_from_u64(seed); + let warmup: Vec = (0..sample_size).map(|_| rng.gen_range(0..n_labels)).collect(); + Ok(seq_bootstrap(ind_mat, Some(sample_size), Some(warmup))) +} + +pub fn aggregate_regression_mean(per_model_predictions: &[Vec]) -> Result, String> { + if per_model_predictions.is_empty() { + return Err("per_model_predictions cannot be empty".to_string()); + } + let n = per_model_predictions[0].len(); + if n == 0 { + return Err("prediction rows cannot be empty".to_string()); + } + if per_model_predictions.iter().any(|row| row.len() != n) { + return Err("prediction length mismatch".to_string()); + } + + let mut out = vec![0.0; n]; + for row in per_model_predictions { + for (i, v) in row.iter().enumerate() { + out[i] += *v; + } + } + let denom = per_model_predictions.len() as f64; + for v in &mut out { + *v /= denom; + } + Ok(out) +} + +pub fn aggregate_classification_vote(per_model_predictions: &[Vec]) -> Result, String> { + if per_model_predictions.is_empty() { + return Err("per_model_predictions cannot be empty".to_string()); + } + let n = per_model_predictions[0].len(); + if n == 0 { + return Err("prediction rows cannot be empty".to_string()); + } + if per_model_predictions.iter().any(|row| row.len() != n) { + return Err("prediction length mismatch".to_string()); + } + if per_model_predictions + .iter() + .flat_map(|row| row.iter()) + .any(|label| *label > 1) + { + return Err("classification vote expects binary labels in {0,1}".to_string()); + } + + let mut out = vec![0u8; n]; + for i in 0..n { + let votes = per_model_predictions.iter().map(|row| row[i] as usize).sum::(); + out[i] = if votes * 2 >= per_model_predictions.len() { + 1 + } else { + 0 + }; + } + Ok(out) +} + +pub fn aggregate_classification_probability_mean( + per_model_probabilities: &[Vec], + threshold: f64, +) -> Result<(Vec, Vec), String> { + if !(0.0..=1.0).contains(&threshold) { + return Err("threshold must be in [0,1]".to_string()); + } + let probs = aggregate_regression_mean(per_model_probabilities)?; + if probs.iter().any(|p| !(0.0..=1.0).contains(p)) { + return Err("probabilities must be in [0,1]".to_string()); + } + let labels = probs.iter().map(|p| if *p >= threshold { 1 } else { 0 }).collect(); + Ok((probs, labels)) +} + +pub fn average_pairwise_prediction_correlation(per_model_predictions: &[Vec]) -> Result { + if per_model_predictions.len() < 2 { + return Err("at least two model prediction rows are required".to_string()); + } + let n = per_model_predictions[0].len(); + if n < 2 { + return Err("prediction rows must have at least two samples".to_string()); + } + if per_model_predictions.iter().any(|row| row.len() != n) { + return Err("prediction length mismatch".to_string()); + } + + let mut corr_sum = 0.0; + let mut pairs = 0usize; + for i in 0..per_model_predictions.len() { + for j in (i + 1)..per_model_predictions.len() { + corr_sum += pearson_corr(&per_model_predictions[i], &per_model_predictions[j]); + pairs += 1; + } + } + Ok(corr_sum / pairs as f64) +} + +pub fn bagging_ensemble_variance( + single_estimator_variance: f64, + average_correlation: f64, + n_estimators: usize, +) -> Result { + if single_estimator_variance < 0.0 { + return Err("single_estimator_variance must be non-negative".to_string()); + } + if !(-1.0..=1.0).contains(&average_correlation) { + return Err("average_correlation must be in [-1,1]".to_string()); + } + if n_estimators == 0 { + return Err("n_estimators must be > 0".to_string()); + } + + let n = n_estimators as f64; + let rho = average_correlation; + Ok(single_estimator_variance * (rho + (1.0 - rho) / n)) +} + +pub fn recommend_bagging_vs_boosting( + base_estimator_accuracy: f64, + average_prediction_correlation: f64, + label_redundancy: f64, + single_estimator_variance: f64, + n_estimators: usize, +) -> Result { + if !(0.0..=1.0).contains(&base_estimator_accuracy) { + return Err("base_estimator_accuracy must be in [0,1]".to_string()); + } + if !(0.0..=1.0).contains(&label_redundancy) { + return Err("label_redundancy must be in [0,1]".to_string()); + } + let bag_var = + bagging_ensemble_variance(single_estimator_variance, average_prediction_correlation, n_estimators)?; + let expected_reduction = (single_estimator_variance - bag_var).max(0.0); + + // Heuristic criteria: + // - weak learners (accuracy near random) favor boosting for bias reduction. + // - highly correlated learners or high label redundancy reduce bagging gains. + let weak_learner = base_estimator_accuracy < 0.55; + let highly_correlated = average_prediction_correlation >= 0.75; + let redundant_labels = label_redundancy >= 0.70; + + let recommended = if weak_learner || highly_correlated || redundant_labels { + EnsembleMethod::Boosting + } else { + EnsembleMethod::Bagging + }; + + Ok(BaggingBoostingDecision { + recommended, + expected_bagging_variance: bag_var, + expected_variance_reduction: expected_reduction, + }) +} + +fn pearson_corr(x: &[f64], y: &[f64]) -> f64 { + let mx = x.iter().sum::() / x.len() as f64; + let my = y.iter().sum::() / y.len() as f64; + + let mut num = 0.0; + let mut den_x = 0.0; + let mut den_y = 0.0; + for (a, b) in x.iter().zip(y.iter()) { + let dx = *a - mx; + let dy = *b - my; + num += dx * dy; + den_x += dx * dx; + den_y += dy * dy; + } + if den_x == 0.0 || den_y == 0.0 { + 0.0 + } else { + num / (den_x.sqrt() * den_y.sqrt()) + } +} diff --git a/crates/openquant/src/lib.rs b/crates/openquant/src/lib.rs index 223b1e9..c068ae7 100644 --- a/crates/openquant/src/lib.rs +++ b/crates/openquant/src/lib.rs @@ -4,6 +4,7 @@ pub mod cla; pub mod codependence; pub mod cross_validation; pub mod data_structures; +pub mod ensemble_methods; pub mod ef3m; pub mod etf_trick; pub mod feature_importance; diff --git a/crates/openquant/tests/ensemble_methods.rs b/crates/openquant/tests/ensemble_methods.rs new file mode 100644 index 0000000..f571357 --- /dev/null +++ b/crates/openquant/tests/ensemble_methods.rs @@ -0,0 +1,80 @@ +use openquant::ensemble_methods::{ + aggregate_classification_probability_mean, aggregate_classification_vote, aggregate_regression_mean, + average_pairwise_prediction_correlation, bagging_ensemble_variance, bias_variance_noise, + bootstrap_sample_indices, recommend_bagging_vs_boosting, sequential_bootstrap_sample_indices, + EnsembleMethod, +}; + +#[test] +fn test_bias_variance_noise_decomposition() { + let y = vec![1.0, 0.0, 1.0, 0.0]; + let preds = vec![ + vec![0.9, 0.1, 0.8, 0.2], + vec![0.8, 0.2, 0.7, 0.3], + vec![1.0, 0.0, 0.9, 0.1], + ]; + + let out = bias_variance_noise(&y, &preds).unwrap(); + assert!(out.bias_sq >= 0.0); + assert!(out.variance >= 0.0); + assert!(out.noise >= 0.0); + assert!(out.mse >= 0.0); + + let lhs = out.bias_sq + out.variance + out.noise; + assert!((lhs - out.mse).abs() < 1e-10); +} + +#[test] +fn test_bootstrap_and_sequential_bootstrap_shapes() { + let b = bootstrap_sample_indices(10, 6, 7).unwrap(); + assert_eq!(b.len(), 6); + assert!(b.iter().all(|v| *v < 10)); + + let ind_mat = vec![vec![1, 0, 1, 0], vec![0, 1, 0, 1], vec![1, 1, 0, 0]]; + let sb = sequential_bootstrap_sample_indices(&ind_mat, 8, 11).unwrap(); + assert_eq!(sb.len(), 8); + assert!(sb.iter().all(|v| *v < ind_mat[0].len())); +} + +#[test] +fn test_aggregation_helpers() { + let reg = aggregate_regression_mean(&[vec![1.0, 3.0], vec![3.0, 1.0]]).unwrap(); + assert_eq!(reg, vec![2.0, 2.0]); + + let vote = aggregate_classification_vote(&[vec![1, 0, 1], vec![1, 1, 0], vec![0, 1, 1]]).unwrap(); + assert_eq!(vote, vec![1, 1, 1]); + + let (prob, labels) = + aggregate_classification_probability_mean(&[vec![0.9, 0.2], vec![0.7, 0.4], vec![0.8, 0.3]], 0.5).unwrap(); + assert!((prob[0] - 0.8).abs() < 1e-12); + assert!((prob[1] - 0.3).abs() < 1e-12); + assert_eq!(labels, vec![1, 0]); +} + +#[test] +fn test_variance_reduction_and_redundancy_failure_mode() { + let low_corr = bagging_ensemble_variance(1.0, 0.0, 10).unwrap(); + assert!((low_corr - 0.1).abs() < 1e-12); + + let high_corr = bagging_ensemble_variance(1.0, 0.95, 10).unwrap(); + assert!(high_corr > 0.9); + assert!(high_corr > low_corr); +} + +#[test] +fn test_pairwise_correlation_and_strategy_recommendation() { + let weak_preds = vec![ + vec![0.50, 0.52, 0.48, 0.50], + vec![0.51, 0.53, 0.49, 0.51], + vec![0.49, 0.51, 0.47, 0.49], + ]; + let corr = average_pairwise_prediction_correlation(&weak_preds).unwrap(); + assert!(corr > 0.95); + + let weak = recommend_bagging_vs_boosting(0.53, corr, 0.8, 1.0, 16).unwrap(); + assert_eq!(weak.recommended, EnsembleMethod::Boosting); + + let strong_diverse = recommend_bagging_vs_boosting(0.68, 0.15, 0.25, 1.0, 16).unwrap(); + assert_eq!(strong_diverse.recommended, EnsembleMethod::Bagging); + assert!(strong_diverse.expected_variance_reduction > 0.0); +} diff --git a/docs-site/src/data/afmlDocsState.ts b/docs-site/src/data/afmlDocsState.ts index eae56db..ba01d1a 100644 --- a/docs-site/src/data/afmlDocsState.ts +++ b/docs-site/src/data/afmlDocsState.ts @@ -88,6 +88,26 @@ export const afmlDocsState = { } ] }, + { + "chapter": "CHAPTER 6", + "theme": "Ensemble methods", + "status": "done", + "chunkCount": 7, + "sections": [ + { + "id": "chapter-6-ensemble_methods", + "module": "ensemble_methods", + "slug": "ensemble-methods", + "status": "done" + }, + { + "id": "chapter-6-sb_bagging", + "module": "sb_bagging", + "slug": "sb-bagging", + "status": "done" + } + ] + }, { "chapter": "CHAPTER 7", "theme": "Leakage-aware validation", diff --git a/docs-site/src/data/moduleDocs.ts b/docs-site/src/data/moduleDocs.ts index 82e3e19..bf75fac 100644 --- a/docs-site/src/data/moduleDocs.ts +++ b/docs-site/src/data/moduleDocs.ts @@ -201,6 +201,56 @@ export const moduleDocs: ModuleDoc[] = [ ], notes: ["Use as initialization for more expensive optimizers.", "Sensitive to higher-moment estimation noise."], }, + { + slug: "ensemble-methods", + module: "ensemble_methods", + subject: "Sampling, Validation and ML Diagnostics", + summary: "Bias/variance diagnostics and practical bagging-vs-boosting ensemble utilities.", + whyItExists: + "AFML Chapter 6 emphasizes that ensemble gains depend on error decomposition and forecast dependence, not just estimator count.", + keyApis: [ + "bias_variance_noise", + "bootstrap_sample_indices", + "sequential_bootstrap_sample_indices", + "aggregate_classification_vote", + "aggregate_classification_probability_mean", + "average_pairwise_prediction_correlation", + "bagging_ensemble_variance", + "recommend_bagging_vs_boosting", + ], + formulas: [ + { + label: "Error Decomposition", + latex: "\\operatorname{MSE}=\\operatorname{Bias}^2+\\operatorname{Var}+\\operatorname{Noise}", + }, + { + label: "Bagging Variance Under Average Correlation", + latex: "\\sigma^2_{bag}=\\sigma^2\\left(\\rho+\\frac{1-\\rho}{N}\\right)", + }, + { + label: "Majority Vote and Mean Probability", + latex: + "\\hat y=\\mathbf 1\\left(\\frac{1}{N}\\sum_{m=1}^N \\hat p_m \\ge \\tau\\right),\\quad \\hat p=\\frac{1}{N}\\sum_{m=1}^N \\hat p_m", + }, + ], + examples: [ + { + title: "Assess Ensemble Variance and Recommendation", + language: "rust", + code: `use openquant::ensemble_methods::{\n average_pairwise_prediction_correlation,\n bagging_ensemble_variance,\n recommend_bagging_vs_boosting,\n};\n\nlet preds = vec![\n vec![0.51, 0.49, 0.52, 0.50],\n vec![0.50, 0.48, 0.53, 0.49],\n vec![0.52, 0.50, 0.51, 0.50],\n];\n\nlet rho = average_pairwise_prediction_correlation(&preds)?;\nlet bag_var = bagging_ensemble_variance(1.0, rho, 20)?;\nlet decision = recommend_bagging_vs_boosting(0.54, rho, 0.75, 1.0, 20)?;\n\nprintln!(\"rho={rho:.3}, var={bag_var:.3}, rec={:?}\", decision.recommended);`, + }, + { + title: "Aggregate Bagged Classifier Outputs", + language: "rust", + code: `use openquant::ensemble_methods::{\n aggregate_classification_vote,\n aggregate_classification_probability_mean,\n};\n\nlet vote = aggregate_classification_vote(&[\n vec![1, 0, 1],\n vec![1, 1, 0],\n vec![0, 1, 1],\n])?;\n\nlet (mean_prob, labels) = aggregate_classification_probability_mean(&[\n vec![0.9, 0.2, 0.6],\n vec![0.8, 0.3, 0.5],\n vec![0.7, 0.4, 0.4],\n], 0.5)?;\n\nassert_eq!(vote, vec![1, 1, 1]);\nassert_eq!(labels, vec![1, 0, 1]);\nassert_eq!(mean_prob.len(), 3);`, + }, + ], + notes: [ + "If base learners are highly correlated, bagging variance reduction is minimal even with many estimators.", + "Sequential-bootstrap-style sampling is preferable under heavy label overlap and non-IID observations.", + "Boosting is usually preferable for weak learners (bias reduction); bagging is usually preferable for unstable learners (variance reduction).", + ], + }, { slug: "etf-trick", module: "etf_trick", From 95e44c9d4b40c15e33f4cef230fd37ccf165f7f2 Mon Sep 17 00:00:00 2001 From: Sean Koval Date: Fri, 13 Feb 2026 16:24:20 -0500 Subject: [PATCH 2/2] style(openquant): format ensemble methods files for CI --- crates/openquant/src/ensemble_methods.rs | 30 ++++++++-------------- crates/openquant/src/lib.rs | 2 +- crates/openquant/tests/ensemble_methods.rs | 24 ++++++++--------- 3 files changed, 24 insertions(+), 32 deletions(-) diff --git a/crates/openquant/src/ensemble_methods.rs b/crates/openquant/src/ensemble_methods.rs index 9deb070..a35b8dc 100644 --- a/crates/openquant/src/ensemble_methods.rs +++ b/crates/openquant/src/ensemble_methods.rs @@ -80,12 +80,7 @@ pub fn bias_variance_noise( let mse = mse_sum / (n_samples * n_models); let noise = (mse - bias_sq - variance).max(0.0); - Ok(BiasVarianceNoise { - bias_sq, - variance, - noise, - mse, - }) + Ok(BiasVarianceNoise { bias_sq, variance, noise, mse }) } pub fn bootstrap_sample_indices( @@ -157,22 +152,14 @@ pub fn aggregate_classification_vote(per_model_predictions: &[Vec]) -> Resul if per_model_predictions.iter().any(|row| row.len() != n) { return Err("prediction length mismatch".to_string()); } - if per_model_predictions - .iter() - .flat_map(|row| row.iter()) - .any(|label| *label > 1) - { + if per_model_predictions.iter().flat_map(|row| row.iter()).any(|label| *label > 1) { return Err("classification vote expects binary labels in {0,1}".to_string()); } let mut out = vec![0u8; n]; for i in 0..n { let votes = per_model_predictions.iter().map(|row| row[i] as usize).sum::(); - out[i] = if votes * 2 >= per_model_predictions.len() { - 1 - } else { - 0 - }; + out[i] = if votes * 2 >= per_model_predictions.len() { 1 } else { 0 }; } Ok(out) } @@ -192,7 +179,9 @@ pub fn aggregate_classification_probability_mean( Ok((probs, labels)) } -pub fn average_pairwise_prediction_correlation(per_model_predictions: &[Vec]) -> Result { +pub fn average_pairwise_prediction_correlation( + per_model_predictions: &[Vec], +) -> Result { if per_model_predictions.len() < 2 { return Err("at least two model prediction rows are required".to_string()); } @@ -248,8 +237,11 @@ pub fn recommend_bagging_vs_boosting( if !(0.0..=1.0).contains(&label_redundancy) { return Err("label_redundancy must be in [0,1]".to_string()); } - let bag_var = - bagging_ensemble_variance(single_estimator_variance, average_prediction_correlation, n_estimators)?; + let bag_var = bagging_ensemble_variance( + single_estimator_variance, + average_prediction_correlation, + n_estimators, + )?; let expected_reduction = (single_estimator_variance - bag_var).max(0.0); // Heuristic criteria: diff --git a/crates/openquant/src/lib.rs b/crates/openquant/src/lib.rs index c068ae7..8964c04 100644 --- a/crates/openquant/src/lib.rs +++ b/crates/openquant/src/lib.rs @@ -4,8 +4,8 @@ pub mod cla; pub mod codependence; pub mod cross_validation; pub mod data_structures; -pub mod ensemble_methods; pub mod ef3m; +pub mod ensemble_methods; pub mod etf_trick; pub mod feature_importance; pub mod filters; diff --git a/crates/openquant/tests/ensemble_methods.rs b/crates/openquant/tests/ensemble_methods.rs index f571357..9dceade 100644 --- a/crates/openquant/tests/ensemble_methods.rs +++ b/crates/openquant/tests/ensemble_methods.rs @@ -1,18 +1,14 @@ use openquant::ensemble_methods::{ - aggregate_classification_probability_mean, aggregate_classification_vote, aggregate_regression_mean, - average_pairwise_prediction_correlation, bagging_ensemble_variance, bias_variance_noise, - bootstrap_sample_indices, recommend_bagging_vs_boosting, sequential_bootstrap_sample_indices, - EnsembleMethod, + aggregate_classification_probability_mean, aggregate_classification_vote, + aggregate_regression_mean, average_pairwise_prediction_correlation, bagging_ensemble_variance, + bias_variance_noise, bootstrap_sample_indices, recommend_bagging_vs_boosting, + sequential_bootstrap_sample_indices, EnsembleMethod, }; #[test] fn test_bias_variance_noise_decomposition() { let y = vec![1.0, 0.0, 1.0, 0.0]; - let preds = vec![ - vec![0.9, 0.1, 0.8, 0.2], - vec![0.8, 0.2, 0.7, 0.3], - vec![1.0, 0.0, 0.9, 0.1], - ]; + let preds = vec![vec![0.9, 0.1, 0.8, 0.2], vec![0.8, 0.2, 0.7, 0.3], vec![1.0, 0.0, 0.9, 0.1]]; let out = bias_variance_noise(&y, &preds).unwrap(); assert!(out.bias_sq >= 0.0); @@ -41,11 +37,15 @@ fn test_aggregation_helpers() { let reg = aggregate_regression_mean(&[vec![1.0, 3.0], vec![3.0, 1.0]]).unwrap(); assert_eq!(reg, vec![2.0, 2.0]); - let vote = aggregate_classification_vote(&[vec![1, 0, 1], vec![1, 1, 0], vec![0, 1, 1]]).unwrap(); + let vote = + aggregate_classification_vote(&[vec![1, 0, 1], vec![1, 1, 0], vec![0, 1, 1]]).unwrap(); assert_eq!(vote, vec![1, 1, 1]); - let (prob, labels) = - aggregate_classification_probability_mean(&[vec![0.9, 0.2], vec![0.7, 0.4], vec![0.8, 0.3]], 0.5).unwrap(); + let (prob, labels) = aggregate_classification_probability_mean( + &[vec![0.9, 0.2], vec![0.7, 0.4], vec![0.8, 0.3]], + 0.5, + ) + .unwrap(); assert!((prob[0] - 0.8).abs() < 1e-12); assert!((prob[1] - 0.3).abs() < 1e-12); assert_eq!(labels, vec![1, 0]);