From 9fdfd930f6ca90c449e79da4e3696846bf14cab7 Mon Sep 17 00:00:00 2001 From: Andrey Shevchenko Date: Sun, 8 Mar 2026 21:21:17 +0300 Subject: [PATCH 1/7] added Jaccard distance --- src/metrics/distance/jaccard.rs | 101 ++++++++++++++++++++++++++++++++ src/metrics/distance/mod.rs | 7 +++ 2 files changed, 108 insertions(+) create mode 100644 src/metrics/distance/jaccard.rs diff --git a/src/metrics/distance/jaccard.rs b/src/metrics/distance/jaccard.rs new file mode 100644 index 00000000..589f52cf --- /dev/null +++ b/src/metrics/distance/jaccard.rs @@ -0,0 +1,101 @@ +//! # Jaccard Distance +//! +//! Jaccard Distance measures dissimilarity between two integer-valued vectors of the same length. +//! Given two vectors \\( x \in ℝ^n \\), \\( y \in ℝ^n \\) the Jaccard distance between \\( x \\) and \\( y \\) is defined as +//! +//! \\[ d(x, y) = 1 - \frac{|x \cap y|}{|x \cup y|} \\] +//! +//! where \\(|x \cap y|\\) is the number of positions where both vectors are non-zero, +//! and \\(|x \cup y|\\) is the number of positions where at least one of the vectors is non-zero. +//! +//! Example: +//! +//! ``` +//! use smartcore::metrics::distance::Distance; +//! use smartcore::metrics::distance::jaccard::Jaccard; +//! +//! let a = vec![1, 0, 1, 1]; +//! let b = vec![1, 1, 0, 1]; +//! +//! let j: f64 = Jaccard::new().distance(&a, &b); +//! +//! ``` +//! +//! +//! + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; +use std::marker::PhantomData; + +use super::Distance; +use crate::linalg::basic::arrays::ArrayView1; +use crate::numbers::basenum::Number; + +/// Jaccard distance between two integer-valued vectors +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct Jaccard { + _t: PhantomData, +} + +impl Jaccard { + /// instatiate the initial structure + pub fn new() -> Jaccard { + Jaccard { _t: PhantomData } + } +} + +impl Default for Jaccard { + fn default() -> Self { + Self::new() + } +} + +impl> Distance for Jaccard { + fn distance(&self, x: &A, y: &A) -> f64 { + if x.shape() != y.shape() { + panic!("Input vector sizes are different"); + } + + let (intersection, union): (usize, usize) = x + .iterator(0) + .zip(y.iterator(0)) + .map(|(a, b)| { + let a_nz = *a != T::zero(); + let b_nz = *b != T::zero(); + + match (a_nz, b_nz) { + (true, true) => (1, 1), + (true, false) | (false, true) => (0, 1), + (false, false) => (0, 0), + } + }) + .fold((0, 0), |acc, v| (acc.0 + v.0, acc.1 + v.1)); + + if union == 0 { + 0.0 + } else { + 1.0 - intersection as f64 / union as f64 + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn jaccard_distance() { + let a = vec![1, 0, 1, 1]; + let b = vec![1, 1, 0, 1]; + + let j: f64 = Jaccard::new().distance(&a, &b); + + assert!((j - 0.5).abs() < 1e-8); + } +} diff --git a/src/metrics/distance/mod.rs b/src/metrics/distance/mod.rs index 6fdbaa46..f720013e 100644 --- a/src/metrics/distance/mod.rs +++ b/src/metrics/distance/mod.rs @@ -19,6 +19,8 @@ pub mod cosine; pub mod euclidian; /// Hamming Distance between two strings is the number of positions at which the corresponding symbols are different. pub mod hamming; +/// Jaccard distance between two integer-valued vectors. +pub mod jaccard; /// The Mahalanobis distance is the distance between two points in multivariate space. pub mod mahalanobis; /// Also known as rectilinear distance, city block distance, taxicab metric. @@ -67,6 +69,11 @@ impl Distances { hamming::Hamming::new() } + /// Jaccard distance, see [`Jaccard`](jaccard/index.html) + pub fn jaccard() -> jaccard::Jaccard { + jaccard::Jaccard::new() + } + /// Mahalanobis distance, see [`Mahalanobis`](mahalanobis/index.html) pub fn mahalanobis, C: Array2 + LUDecomposable>( data: &M, From ece4f28446de178c3d1e9528d0429346259bf179 Mon Sep 17 00:00:00 2001 From: Andrey Shevchenko Date: Tue, 10 Mar 2026 11:48:09 +0300 Subject: [PATCH 2/7] two encounters of a bad pattern is_none() + unwrap(). FIXED. --- src/ensemble/base_forest_regressor.rs | 34 ++++++++++++-------- src/ensemble/random_forest_classifier.rs | 41 ++++++++++++++---------- 2 files changed, 44 insertions(+), 31 deletions(-) diff --git a/src/ensemble/base_forest_regressor.rs b/src/ensemble/base_forest_regressor.rs index dc504446..4209034c 100644 --- a/src/ensemble/base_forest_regressor.rs +++ b/src/ensemble/base_forest_regressor.rs @@ -161,25 +161,31 @@ impl, Y: Array1 /// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training. pub fn predict_oob(&self, x: &X) -> Result { let (n, _) = x.shape(); - if self.samples.is_none() { - Err(Failed::because( - FailedError::PredictFailed, - "Need samples=true for OOB predictions.", - )) - } else if self.samples.as_ref().unwrap()[0].len() != n { - Err(Failed::because( + + let samples = match &self.samples { + Some(s) => s, + None => { + return Err(Failed::because( + FailedError::PredictFailed, + "Need samples=true for OOB predictions.", + )) + } + }; + + if samples[0].len() != n { + return Err(Failed::because( FailedError::PredictFailed, "Prediction matrix must match matrix used in training for OOB predictions.", - )) - } else { - let mut result = Y::zeros(n); + )); + } - for i in 0..n { - result.set(i, self.predict_for_row_oob(x, i)); - } + let mut result = Y::zeros(n); - Ok(result) + for i in 0..n { + result.set(i, self.predict_for_row_oob(x, i)); } + + Ok(result) } fn predict_for_row_oob(&self, x: &X, row: usize) -> TY { diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index dabb2480..f4e8db3c 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -539,27 +539,34 @@ impl, Y: Array1 Result { let (n, _) = x.shape(); - if self.samples.is_none() { - Err(Failed::because( - FailedError::PredictFailed, - "Need samples=true for OOB predictions.", - )) - } else if self.samples.as_ref().unwrap()[0].len() != n { - Err(Failed::because( + + let samples = match &self.samples { + Some(s) => s, + None => { + return Err(Failed::because( + FailedError::PredictFailed, + "Need samples=true for OOB predictions.", + )); + } + }; + + if samples[0].len() != n { + return Err(Failed::because( FailedError::PredictFailed, "Prediction matrix must match matrix used in training for OOB predictions.", - )) - } else { - let mut result = Y::zeros(n); + )); + } - for i in 0..n { - result.set( - i, - self.classes.as_ref().unwrap()[self.predict_for_row_oob(x, i)], - ); - } - Ok(result) + let mut result = Y::zeros(n); + + for i in 0..n { + result.set( + i, + self.classes.as_ref().unwrap()[self.predict_for_row_oob(x, i)], + ); } + + Ok(result) } fn predict_for_row_oob(&self, x: &X, row: usize) -> usize { From ff65a0838548e2efed868ff8d3c68081868f5441 Mon Sep 17 00:00:00 2001 From: Andrey Shevchenko Date: Tue, 10 Mar 2026 17:19:20 +0300 Subject: [PATCH 3/7] added 3 tests, incl. symmetry test. Now 4 in total. --- src/metrics/distance/jaccard.rs | 35 +++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/src/metrics/distance/jaccard.rs b/src/metrics/distance/jaccard.rs index 589f52cf..4834e2e3 100644 --- a/src/metrics/distance/jaccard.rs +++ b/src/metrics/distance/jaccard.rs @@ -89,6 +89,7 @@ mod tests { all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test )] + #[test] fn jaccard_distance() { let a = vec![1, 0, 1, 1]; @@ -98,4 +99,38 @@ mod tests { assert!((j - 0.5).abs() < 1e-8); } + + #[test] + fn jaccard_identical_vectors() { + let a = vec![1, 0, 1, 0]; + let b = vec![1, 0, 1, 0]; + + let j: f64 = Jaccard::new().distance(&a, &b); + + assert!((j - 0.0).abs() < 1e-8); + } + + #[test] + fn jaccard_both_zero_vectors() { + let a = vec![0, 0, 0]; + let b = vec![0, 0, 0]; + + let j: f64 = Jaccard::new().distance(&a, &b); + + assert!((j - 0.0).abs() < 1e-8); + } + + #[test] + fn jaccard_symmetry() { + let a = vec![1, 0, 1, 1]; + let b = vec![0, 1, 1, 0]; + + let j = Jaccard::new(); + + let d1 = j.distance(&a, &b); + let d2 = j.distance(&b, &a); + + assert!((d1 - d2).abs() < 1e-12); + } } + From 8898ae2c768b8f2a7ff4580bcaf961c53372c3fd Mon Sep 17 00:00:00 2001 From: Andrey Shevchenko Date: Fri, 13 Mar 2026 17:08:00 +0300 Subject: [PATCH 4/7] it compiles --- src/ensemble/random_forest_classifier.rs | 36 ++++++ src/tree/decision_tree_classifier.rs | 135 +++++++++++++++++++---- 2 files changed, 151 insertions(+), 20 deletions(-) diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index f4e8db3c..ae8c86f3 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -609,6 +609,42 @@ impl, Y: Array1 Vec { + + let k = self.classes.as_ref().unwrap().len(); + let mut probs = vec![0.0; k]; + + for tree in self.trees.as_ref().unwrap().iter() { + + let tree_probs = tree.predict_proba_for_row_real(x, row); + + for i in 0..k { + probs[i] += tree_probs[i]; + } + } + + let n_trees = self.trees.as_ref().unwrap().len(); + + for i in 0..k { + probs[i] /= n_trees as f64; + } + + probs + } + + pub fn predict_proba(&self, x: &X) -> Result>, Failed> { + + let (n, _) = x.shape(); + + let mut result = Vec::with_capacity(n); + + for i in 0..n { + result.push(self.predict_proba_for_row(x, i)); + } + + Ok(result) + } } #[cfg(test)] diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index 96007677..e53c8c3d 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -162,12 +162,29 @@ pub enum SplitCriterion { #[derive(Debug, Clone)] struct Node { output: usize, + + /// number of samples that reached this node n_node_samples: usize, + + /// class distribution in this node + class_distribution: Vec, + + /// feature used for split split_feature: usize, + + /// threshold split_value: Option, + + /// impurity improvement of split split_score: Option, + + /// left child index true_child: Option, + + /// right child index false_child: Option, + + /// impurity value of node impurity: Option, } @@ -405,16 +422,17 @@ impl Default for DecisionTreeClassifierSearchParameters { } impl Node { - fn new(output: usize, n_node_samples: usize) -> Self { + fn new(output: usize, n_node_samples: usize, class_distribution: Vec) -> Self { Node { output, n_node_samples, + class_distribution, // added split_feature: 0, - split_value: Option::None, - split_score: Option::None, - true_child: Option::None, - false_child: Option::None, - impurity: Option::None, + split_value: None, + split_score: None, + true_child: None, + false_child: None, + impurity: None, } } } @@ -554,40 +572,62 @@ impl, Y: Array1> DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters) } + pub(crate) fn fit_weak_learner( x: &X, y: &Y, - samples: Vec, + bootstrap_sample_counts: Vec, // Renamed from just "samples" for semantic clarity. It isn't "samples" mtry: usize, parameters: DecisionTreeClassifierParameters, ) -> Result, Failed> { + let y_ncols = y.shape(); let (_, num_attributes) = x.shape(); + let classes = y.unique(); - let k = classes.len(); - if k < 2 { + let num_classes = classes.len(); + + if num_classes < 2 { return Err(Failed::fit(&format!( - "Incorrect number of classes: {k}. Should be >= 2." + "Incorrect number of classes: {num_classes}. Should be >= 2." ))); } let mut rng = get_rng_impl(parameters.seed); - let mut yi: Vec = vec![0; y_ncols]; - for (i, yi_i) in yi.iter_mut().enumerate().take(y_ncols) { + // bootstrap_classes[i] = class index of sample i + let mut bootstrap_classes: Vec = vec![0; y_ncols]; + + for (i, class_index) in bootstrap_classes.iter_mut().enumerate().take(y_ncols) { let yc = y.get(i); - *yi_i = classes.iter().position(|c| yc == c).unwrap(); + *class_index = classes.iter().position(|c| yc == c).unwrap(); } let mut change_nodes: Vec = Vec::new(); - let mut count = vec![0; k]; + // -------------------------------- + // compute class distribution + // -------------------------------- + + let mut class_distribution = vec![0; num_classes]; + for i in 0..y_ncols { - count[yi[i]] += samples[i]; + class_distribution[bootstrap_classes[i]] += bootstrap_sample_counts[i]; } - let root = Node::new(which_max(&count), y_ncols); + // majority class + let root_output = which_max(&class_distribution); + + let root = Node::new( + root_output, + y_ncols, + class_distribution.clone(), + ); + change_nodes.push(root); + + // -------------------------------- + let mut order: Vec> = Vec::new(); for i in 0..num_attributes { @@ -598,7 +638,7 @@ impl, Y: Array1> let mut tree = DecisionTreeClassifier { nodes: change_nodes, parameters: Some(parameters), - num_classes: k, + num_classes, classes, depth: 0u16, num_features: num_attributes, @@ -607,7 +647,14 @@ impl, Y: Array1> _phantom_y: PhantomData, }; - let mut visitor = NodeVisitor::::new(0, samples, &order, x, &yi, 1); + let mut visitor = NodeVisitor::::new( + 0, + bootstrap_sample_counts, + &order, + x, + &bootstrap_classes, + 1, + ); let mut visitor_queue: LinkedList> = LinkedList::new(); @@ -625,6 +672,7 @@ impl, Y: Array1> Ok(tree) } + /// Predict class value for `x`. /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. pub fn predict(&self, x: &X) -> Result { @@ -831,9 +879,32 @@ impl, Y: Array1> let true_child_idx = self.nodes().len(); - self.nodes.push(Node::new(visitor.true_child_output, tc)); + // Added. We are computing class distribution + let mut true_distribution = vec![0; self.num_classes]; + let mut false_distribution = vec![0; self.num_classes]; + + for i in 0..n { + + if true_samples[i] > 0 { + true_distribution[visitor.y[i]] += true_samples[i]; + } + + if visitor.samples[i] > 0 { + false_distribution[visitor.y[i]] += visitor.samples[i]; + } + } + + // Some additional checks + let true_sum: usize = true_distribution.iter().sum(); + let false_sum: usize = false_distribution.iter().sum(); + debug_assert_eq!(true_sum, tc); + debug_assert_eq!(false_sum, fc); + // debug_assert_eq!(tc + fc, visitor.samples.iter().sum::()); // TODO + + self.nodes.push(Node::new(visitor.true_child_output, tc, true_distribution)); let false_child_idx = self.nodes().len(); - self.nodes.push(Node::new(visitor.false_child_output, fc)); + self.nodes.push(Node::new(visitor.false_child_output, fc, false_distribution)); + self.nodes[visitor.node].true_child = Some(true_child_idx); self.nodes[visitor.node].false_child = Some(false_child_idx); @@ -959,6 +1030,30 @@ impl, Y: Array1> // This should never happen if the tree is properly constructed Err(Failed::predict("Nodes iteration did not reach leaf")) } + + pub fn predict_proba_for_row_real(&self, x: &X, row: usize) -> Vec { + let mut node = 0; + loop { + let current = &self.nodes()[node]; + if current.true_child.is_none() && current.false_child.is_none() { + let total: usize = current.class_distribution.iter().sum(); + let mut probs = vec![0.0; self.num_classes]; + for i in 0..self.num_classes { + probs[i] = current.class_distribution[i] as f64 / total as f64; + } + + return probs; + } + + let split_feature = current.split_feature; + let split_value = current.split_value.unwrap(); + if x.get((row, split_feature)).to_f64().unwrap() <= split_value { + node = current.true_child.unwrap(); + } else { + node = current.false_child.unwrap(); + } + } + } } #[cfg(test)] From 9a50b4f7912b29e101ea63844fdf5fe8144fd259 Mon Sep 17 00:00:00 2001 From: Andrey Shevchenko Date: Mon, 16 Mar 2026 07:46:14 +0300 Subject: [PATCH 5/7] feat: implement proper predict_proba for Random Forest and Decision Tree --- src/ensemble/random_forest_classifier.rs | 187 +++++++++++++++ src/tree/decision_tree_classifier.rs | 279 +++++++++++------------ 2 files changed, 320 insertions(+), 146 deletions(-) diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index ae8c86f3..f381a61c 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -610,6 +610,23 @@ impl, Y: Array1 Vec { let k = self.classes.as_ref().unwrap().len(); @@ -633,6 +650,35 @@ impl, Y: Array1>` where each inner vector corresponds to + /// a sample and contains probabilities for each class. The sum of probabilities + /// for each sample equals 1.0. + /// + /// # Note + /// + /// Return type is `Vec>` for minimal API changes. The tree classifier + /// returns `DenseMatrix` for the same method. + /// + /// # Errors + /// + /// Returns an error if the forest has not been fitted (trees are None). pub fn predict_proba(&self, x: &X) -> Result>, Failed> { let (n, _) = x.shape(); @@ -842,4 +888,145 @@ mod tests { assert_eq!(forest, deserialized_forest); } + + // Test for predict_proba + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn test_predict_proba_forest() { + let x = DenseMatrix::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + ]) + .unwrap(); + let y = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1]; + + let classifier = RandomForestClassifier::fit( + &x, + &y, + RandomForestClassifierParameters { + criterion: SplitCriterion::Gini, + max_depth: Option::None, + min_samples_leaf: 1, + min_samples_split: 2, + n_trees: 10, + m: Option::None, + keep_samples: false, + seed: 87, + }, + ) + .unwrap(); + + let probabilities = classifier.predict_proba(&x).unwrap(); + assert_eq!(probabilities.len(), 10); + assert_eq!(probabilities[0].len(), 2); + + // Check that probabilities sum to 1.0 for each sample + for row in 0..10 { + let row_sum: f64 = probabilities[row].iter().sum(); + assert!( + (row_sum - 1.0).abs() < 1e-6, + "Row probabilities should sum to 1, got {}", + row_sum + ); + } + + // Check if the first 5 samples have higher probability for class 0 + for i in 0..5 { + assert!( + probabilities[i][0] > probabilities[i][1], + "Sample {} should have higher prob for class 0", + i + ); + } + + // Check if the last 5 samples have higher probability for class 1 + for i in 5..10 { + assert!( + probabilities[i][1] > probabilities[i][0], + "Sample {} should have higher prob for class 1", + i + ); + } + } + + // Test for predict_proba with mixed classes in leaves + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn test_predict_proba_mixed_leaves() { + // Create a simple dataset where some leaves will have mixed classes + let x: DenseMatrix = DenseMatrix::from_2d_array(&[ + &[1.0, 1.0], + &[1.0, 1.0], + &[1.0, 1.0], + &[5.0, 5.0], + &[5.0, 5.0], + ]) + .unwrap(); + let y: Vec = vec![0, 0, 1, 2, 2]; // 3 classes, mixed in first group + + let classifier = RandomForestClassifier::fit( + &x, + &y, + RandomForestClassifierParameters { + n_trees: 5, + seed: 42, + ..Default::default() + }, + ) + .unwrap(); + + let probabilities = classifier.predict_proba(&x).unwrap(); + + // All probabilities should be non-negative and sum to 1.0 + for row in 0..5 { + let sum: f64 = probabilities[row].iter().sum(); + assert!( + (sum - 1.0).abs() < 1e-6, + "Probabilities for row {} should sum to 1.0, got {}", + row, + sum + ); + for &p in &probabilities[row] { + assert!(p >= 0.0 && p <= 1.0, "Probability {} out of range", p); + } + } + + // First 3 samples should have non-zero probability for both class 0 and 1 + // (since they're in the same region with mixed classes) + for i in 0..3 { + assert!( + probabilities[i][0] > 0.0, + "Sample {} should have non-zero prob for class 0", + i + ); + assert!( + probabilities[i][1] > 0.0, + "Sample {} should have non-zero prob for class 1", + i + ); + } + + // Last 2 samples should have high probability for class 2 + for i in 3..5 { + assert!( + probabilities[i][2] > 0.5, + "Sample {} should have high prob for class 2", + i + ); + } + } } diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index e53c8c3d..c91680ee 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -68,13 +68,10 @@ use std::collections::LinkedList; use std::default::Default; use std::fmt::Debug; use std::marker::PhantomData; - use rand::seq::SliceRandom; use rand::Rng; - #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; - use crate::api::{Predictor, SupervisedEstimator}; use crate::error::Failed; use crate::linalg::basic::arrays::MutArray; @@ -162,28 +159,20 @@ pub enum SplitCriterion { #[derive(Debug, Clone)] struct Node { output: usize, - /// number of samples that reached this node n_node_samples: usize, - - /// class distribution in this node + /// class distribution in this node (for probability estimation) class_distribution: Vec, - /// feature used for split split_feature: usize, - /// threshold split_value: Option, - /// impurity improvement of split split_score: Option, - /// left child index true_child: Option, - /// right child index false_child: Option, - /// impurity value of node impurity: Option, } @@ -225,6 +214,7 @@ impl PartialEq for Node { (None, None) => true, _ => false, } + && self.class_distribution == other.class_distribution } } @@ -426,7 +416,7 @@ impl Node { Node { output, n_node_samples, - class_distribution, // added + class_distribution, split_feature: 0, split_value: None, split_score: None, @@ -451,7 +441,6 @@ struct NodeVisitor<'a, TX: Number + PartialOrd, X: Array2> { fn impurity(criterion: &SplitCriterion, count: &[usize], n: usize) -> f64 { let mut impurity = 0f64; - match criterion { SplitCriterion::Gini => { impurity = 1f64; @@ -462,7 +451,6 @@ fn impurity(criterion: &SplitCriterion, count: &[usize], n: usize) -> f64 { } } } - SplitCriterion::Entropy => { for count_i in count.iter() { if *count_i > 0 { @@ -480,7 +468,6 @@ fn impurity(criterion: &SplitCriterion, count: &[usize], n: usize) -> f64 { impurity = (1f64 - impurity).abs(); } } - impurity } @@ -510,14 +497,12 @@ impl<'a, TX: Number + PartialOrd, X: Array2> NodeVisitor<'a, TX, X> { pub(crate) fn which_max(x: &[usize]) -> usize { let mut m = x[0]; let mut which = 0; - for (i, x_i) in x.iter().enumerate().skip(1) { if *x_i > m { m = *x_i; which = i; } } - which } @@ -538,7 +523,6 @@ impl, Y: Array1> _phantom_y: PhantomData, } } - fn fit(x: &X, y: &Y, parameters: DecisionTreeClassifierParameters) -> Result { DecisionTreeClassifier::fit(x, y, parameters) } @@ -567,78 +551,54 @@ impl, Y: Array1> if x_nrows != y.shape() { return Err(Failed::fit("Size of x should equal size of y")); } - let samples = vec![1; x_nrows]; DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters) } - pub(crate) fn fit_weak_learner( x: &X, y: &Y, - bootstrap_sample_counts: Vec, // Renamed from just "samples" for semantic clarity. It isn't "samples" + samples: Vec, mtry: usize, parameters: DecisionTreeClassifierParameters, ) -> Result, Failed> { - let y_ncols = y.shape(); let (_, num_attributes) = x.shape(); - let classes = y.unique(); - let num_classes = classes.len(); - - if num_classes < 2 { + let k = classes.len(); + if k < 2 { return Err(Failed::fit(&format!( - "Incorrect number of classes: {num_classes}. Should be >= 2." + "Incorrect number of classes: {k}. Should be >= 2." ))); } - let mut rng = get_rng_impl(parameters.seed); - - // bootstrap_classes[i] = class index of sample i - let mut bootstrap_classes: Vec = vec![0; y_ncols]; - - for (i, class_index) in bootstrap_classes.iter_mut().enumerate().take(y_ncols) { + let mut yi: Vec = vec![0; y_ncols]; + for (i, yi_i) in yi.iter_mut().enumerate().take(y_ncols) { let yc = y.get(i); - *class_index = classes.iter().position(|c| yc == c).unwrap(); + *yi_i = classes.iter().position(|c| yc == c).unwrap(); } - let mut change_nodes: Vec = Vec::new(); - // -------------------------------- // compute class distribution // -------------------------------- - - let mut class_distribution = vec![0; num_classes]; - + let mut count = vec![0; k]; for i in 0..y_ncols { - class_distribution[bootstrap_classes[i]] += bootstrap_sample_counts[i]; + count[yi[i]] += samples[i]; } - // majority class - let root_output = which_max(&class_distribution); - - let root = Node::new( - root_output, - y_ncols, - class_distribution.clone(), - ); - + let root_output = which_max(&count); + let root = Node::new(root_output, y_ncols, count.clone()); change_nodes.push(root); - // -------------------------------- - let mut order: Vec> = Vec::new(); - for i in 0..num_attributes { let mut col_i: Vec = x.get_col(i).iterator(0).copied().collect(); order.push(col_i.argsort_mut()); } - let mut tree = DecisionTreeClassifier { nodes: change_nodes, parameters: Some(parameters), - num_classes, + num_classes: k, classes, depth: 0u16, num_features: num_attributes, @@ -646,53 +606,35 @@ impl, Y: Array1> _phantom_x: PhantomData, _phantom_y: PhantomData, }; - - let mut visitor = NodeVisitor::::new( - 0, - bootstrap_sample_counts, - &order, - x, - &bootstrap_classes, - 1, - ); - + let mut visitor = NodeVisitor::::new(0, samples, &order, x, &yi, 1); let mut visitor_queue: LinkedList> = LinkedList::new(); - if tree.find_best_cutoff(&mut visitor, mtry, &mut rng) { visitor_queue.push_back(visitor); } - while tree.depth() < tree.parameters().max_depth.unwrap_or(u16::MAX) { match visitor_queue.pop_front() { Some(node) => tree.split(node, mtry, &mut visitor_queue, &mut rng), None => break, }; } - Ok(tree) } - /// Predict class value for `x`. /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. pub fn predict(&self, x: &X) -> Result { let mut result = Y::zeros(x.shape().0); - let (n, _) = x.shape(); - for i in 0..n { result.set(i, self.classes()[self.predict_for_row(x, i)]); } - Ok(result) } pub(crate) fn predict_for_row(&self, x: &X, row: usize) -> usize { let mut result = 0; let mut queue: LinkedList = LinkedList::new(); - queue.push_back(0); - while !queue.is_empty() { match queue.pop_front() { Some(node_id) => { @@ -710,7 +652,6 @@ impl, Y: Array1> None => break, }; } - result } @@ -721,7 +662,6 @@ impl, Y: Array1> rng: &mut impl Rng, ) -> bool { let (n_rows, n_attr) = visitor.x.shape(); - let mut label = None; let mut is_pure = true; for i in 0..n_rows { @@ -739,7 +679,6 @@ impl, Y: Array1> } } } - let n = visitor.samples.iter().sum(); let mut count = vec![0; self.num_classes]; let mut false_count = vec![0; self.num_classes]; @@ -748,27 +687,20 @@ impl, Y: Array1> count[visitor.y[i]] += visitor.samples[i]; } } - self.nodes[visitor.node].impurity = Some(impurity(&self.parameters().criterion, &count, n)); - if is_pure { return false; } - if n <= self.parameters().min_samples_split { return false; } - let mut variables = (0..n_attr).collect::>(); - if mtry < n_attr { variables.shuffle(rng); } - for variable in variables.iter().take(mtry) { self.find_best_split(visitor, n, &count, &mut false_count, *variable); } - self.nodes()[visitor.node].split_score.is_some() } @@ -783,34 +715,26 @@ impl, Y: Array1> let mut true_count = vec![0; self.num_classes]; let mut prevx = Option::None; let mut prevy = 0; - for i in visitor.order[j].iter() { if visitor.samples[*i] > 0 { let x_ij = *visitor.x.get((*i, j)); - if prevx.is_none() || x_ij == prevx.unwrap() || visitor.y[*i] == prevy { prevx = Some(x_ij); prevy = visitor.y[*i]; true_count[visitor.y[*i]] += visitor.samples[*i]; continue; } - let tc = true_count.iter().sum(); let fc = n - tc; - - if tc < self.parameters().min_samples_leaf - || fc < self.parameters().min_samples_leaf - { + if tc < self.parameters().min_samples_leaf || fc < self.parameters().min_samples_leaf { prevx = Some(x_ij); prevy = visitor.y[*i]; true_count[visitor.y[*i]] += visitor.samples[*i]; continue; } - for l in 0..self.num_classes { false_count[l] = count[l] - true_count[l]; } - let true_label = which_max(&true_count); let false_label = which_max(false_count); let parent_impurity = self.nodes()[visitor.node].impurity.unwrap(); @@ -819,7 +743,6 @@ impl, Y: Array1> * impurity(&self.parameters().criterion, &true_count, tc) - fc as f64 / n as f64 * impurity(&self.parameters().criterion, false_count, fc); - if self.nodes()[visitor.node].split_score.is_none() || gain > self.nodes()[visitor.node].split_score.unwrap() { @@ -827,11 +750,9 @@ impl, Y: Array1> self.nodes[visitor.node].split_value = Option::Some((x_ij + prevx.unwrap()).to_f64().unwrap() / 2f64); self.nodes[visitor.node].split_score = Option::Some(gain); - visitor.true_child_output = true_label; visitor.false_child_output = false_label; } - prevx = Some(x_ij); prevy = visitor.y[*i]; true_count[visitor.y[*i]] += visitor.samples[*i]; @@ -850,7 +771,6 @@ impl, Y: Array1> let mut tc = 0; let mut fc = 0; let mut true_samples: Vec = vec![0; n]; - for (i, true_sample) in true_samples.iter_mut().enumerate().take(n) { if visitor.samples[i] > 0 { if visitor @@ -868,48 +788,38 @@ impl, Y: Array1> } } } - if tc < self.parameters().min_samples_leaf || fc < self.parameters().min_samples_leaf { self.nodes[visitor.node].split_feature = 0; self.nodes[visitor.node].split_value = Option::None; self.nodes[visitor.node].split_score = Option::None; - return false; } - let true_child_idx = self.nodes().len(); - // Added. We are computing class distribution let mut true_distribution = vec![0; self.num_classes]; let mut false_distribution = vec![0; self.num_classes]; - for i in 0..n { - if true_samples[i] > 0 { true_distribution[visitor.y[i]] += true_samples[i]; } - if visitor.samples[i] > 0 { false_distribution[visitor.y[i]] += visitor.samples[i]; } } - // Some additional checks let true_sum: usize = true_distribution.iter().sum(); let false_sum: usize = false_distribution.iter().sum(); + let original_total: usize = true_sum + false_sum; debug_assert_eq!(true_sum, tc); debug_assert_eq!(false_sum, fc); - // debug_assert_eq!(tc + fc, visitor.samples.iter().sum::()); // TODO + debug_assert_eq!(true_sum + false_sum, original_total); self.nodes.push(Node::new(visitor.true_child_output, tc, true_distribution)); let false_child_idx = self.nodes().len(); self.nodes.push(Node::new(visitor.false_child_output, fc, false_distribution)); - self.nodes[visitor.node].true_child = Some(true_child_idx); self.nodes[visitor.node].false_child = Some(false_child_idx); - self.depth = u16::max(self.depth, visitor.level + 1); - let mut true_visitor = NodeVisitor::::new( true_child_idx, true_samples, @@ -918,11 +828,9 @@ impl, Y: Array1> visitor.y, visitor.level + 1, ); - if self.find_best_cutoff(&mut true_visitor, mtry, rng) { visitor_queue.push_back(true_visitor); } - let mut false_visitor = NodeVisitor::::new( false_child_idx, visitor.samples, @@ -931,25 +839,21 @@ impl, Y: Array1> visitor.y, visitor.level + 1, ); - if self.find_best_cutoff(&mut false_visitor, mtry, rng) { visitor_queue.push_back(false_visitor); } - true } /// Compute feature importances for the fitted tree. pub fn compute_feature_importances(&self, normalize: bool) -> Vec { let mut importances = vec![0f64; self.num_features]; - for node in self.nodes().iter() { if node.true_child.is_none() && node.false_child.is_none() { continue; } let left = &self.nodes()[node.true_child.unwrap()]; let right = &self.nodes()[node.false_child.unwrap()]; - importances[node.split_feature] += node.n_node_samples as f64 * node.impurity.unwrap() - left.n_node_samples as f64 * left.impurity.unwrap() - right.n_node_samples as f64 * right.impurity.unwrap(); @@ -968,6 +872,10 @@ impl, Y: Array1> /// Predict class probabilities for the input samples. /// + /// This method returns probabilities using the original API (one-hot encoding at leaves). + /// For proper probability estimation based on class distribution in leaves, use + /// `predict_proba_for_row_real` for individual samples. + /// /// # Arguments /// /// * `x` - The input samples as a matrix where each row is a sample and each column is a feature. @@ -976,6 +884,7 @@ impl, Y: Array1> /// /// A `Result` containing a `DenseMatrix` where each row corresponds to a sample and each column /// corresponds to a class. The values represent the probability of the sample belonging to each class. + /// Note: Current implementation returns one-hot vectors at leaves (backward compatibility). /// /// # Errors /// @@ -984,18 +893,19 @@ impl, Y: Array1> let (n_samples, _) = x.shape(); let n_classes = self.classes().len(); let mut result = DenseMatrix::::zeros(n_samples, n_classes); - for i in 0..n_samples { let probs = self.predict_proba_for_row(x, i)?; for (j, &prob) in probs.iter().enumerate() { result.set((i, j), prob); } } - Ok(result) } - /// Predict class probabilities for a single input sample. + /// Predict class probabilities for a single input sample (legacy API). + /// + /// Returns one-hot encoded probabilities: 1.0 for the majority class, 0.0 for others. + /// This maintains backward compatibility with the original smartcore API. /// /// # Arguments /// @@ -1004,47 +914,61 @@ impl, Y: Array1> /// /// # Returns /// - /// A vector of probabilities, one for each class, representing the probability - /// of the input sample belonging to each class. + /// A vector of probabilities, one for each class. Only the majority class has probability 1.0. fn predict_proba_for_row(&self, x: &X, row: usize) -> Result, Failed> { let mut node = 0; - while let Some(current_node) = self.nodes().get(node) { if current_node.true_child.is_none() && current_node.false_child.is_none() { - // Leaf node reached + // Leaf node reached - legacy behavior: one-hot encoding let mut probs = vec![0.0; self.classes().len()]; probs[current_node.output] = 1.0; return Ok(probs); } - let split_feature = current_node.split_feature; let split_value = current_node.split_value.unwrap_or(f64::NAN); - if x.get((row, split_feature)).to_f64().unwrap() <= split_value { node = current_node.true_child.unwrap(); } else { node = current_node.false_child.unwrap(); } } - - // This should never happen if the tree is properly constructed Err(Failed::predict("Nodes iteration did not reach leaf")) } + /// Predict class probabilities for a single input sample using class distribution. + /// + /// This method returns proper probability estimates based on the class distribution + /// in the leaf node. For a leaf with N samples, if class i has count[i] samples, + /// the probability for class i is count[i] / N. + /// + /// This is the scikit-learn style behavior and should be used when you need + /// calibrated probability estimates rather than just the majority class. + /// + /// # Arguments + /// + /// * `x` - The input matrix containing all samples. + /// * `row` - The index of the row in `x` for which to predict probabilities. + /// + /// # Returns + /// + /// A vector of probabilities, one for each class. The sum of probabilities equals 1.0. + /// Each probability represents the fraction of training samples of that class + /// that reached the same leaf node. pub fn predict_proba_for_row_real(&self, x: &X, row: usize) -> Vec { let mut node = 0; loop { let current = &self.nodes()[node]; if current.true_child.is_none() && current.false_child.is_none() { + // Leaf node reached - use class distribution for proper probabilities let total: usize = current.class_distribution.iter().sum(); let mut probs = vec![0.0; self.num_classes]; - for i in 0..self.num_classes { - probs[i] = current.class_distribution[i] as f64 / total as f64; + if total > 0 { + for i in 0..self.num_classes { + probs[i] = current.class_distribution[i] as f64 / total as f64; + } } - return probs; } - let split_feature = current.split_feature; let split_value = current.split_value.unwrap(); if x.get((row, split_feature)).to_f64().unwrap() <= split_value { @@ -1122,12 +1046,9 @@ mod tests { ]) .unwrap(); let y: Vec = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1]; - let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap(); let probabilities = tree.predict_proba(&x).unwrap(); - assert_eq!(probabilities.shape(), (10, 2)); - for row in 0..10 { let row_sum: f64 = probabilities.get_row(row).sum(); assert!( @@ -1135,18 +1056,94 @@ mod tests { "Row probabilities should sum to 1" ); } - // Check if the first 5 samples have higher probability for class 0 for i in 0..5 { assert!(probabilities.get((i, 0)) > probabilities.get((i, 1))); } - // Check if the last 5 samples have higher probability for class 1 for i in 5..10 { assert!(probabilities.get((i, 1)) > probabilities.get((i, 0))); } } + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn test_predict_proba_real_distribution() { + // Test that predict_proba_for_row_real returns proper class distribution + let x: DenseMatrix = DenseMatrix::from_2d_array(&[ + &[1.0, 2.0], + &[1.1, 2.1], + &[1.2, 2.2], + &[5.0, 6.0], + &[5.1, 6.1], + ]) + .unwrap(); + let y: Vec = vec![0, 0, 1, 2, 2]; // 3 classes, mixed in first leaf + + let tree = DecisionTreeClassifier::fit( + &x, + &y, + DecisionTreeClassifierParameters { + min_samples_leaf: 2, // Force some mixing + ..Default::default() + }, + ) + .unwrap(); + + // Test that probabilities sum to 1.0 for all samples + for row in 0..x.shape().0 { + let probs = tree.predict_proba_for_row_real(&x, row); + let sum: f64 = probs.iter().sum(); + assert!( + (sum - 1.0).abs() < 1e-6, + "Probabilities for row {} should sum to 1.0, got {}", + row, + sum + ); + // All probabilities should be non-negative + for &p in &probs { + assert!(p >= 0.0 && p <= 1.0, "Probability {} out of range", p); + } + } + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn test_predict_proba_real_vs_legacy() { + // Test that the new method gives different (better) results than legacy + let x: DenseMatrix = DenseMatrix::from_2d_array(&[ + &[1.0, 1.0], + &[1.0, 1.0], // Same features, different classes -> mixed leaf + &[1.0, 1.0], + &[2.0, 2.0], + ]) + .unwrap(); + let y: Vec = vec![0, 1, 1, 2]; // Leaf 1: [0,1,1], Leaf 2: [2] + + let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap(); + + // For sample 0 (in mixed leaf), legacy returns one-hot, new returns distribution + let legacy_probs = tree.predict_proba_for_row(&x, 0).unwrap(); + let real_probs = tree.predict_proba_for_row_real(&x, 0); + + // Legacy should be one-hot (only one class has prob 1.0) + let ones: usize = legacy_probs.iter().filter(|&&p| p == 1.0).count(); + let zeros: usize = legacy_probs.iter().filter(|&&p| p == 0.0).count(); + assert_eq!(ones + zeros, legacy_probs.len(), "Legacy should be one-hot"); + + // Real should have fractional probabilities for the mixed leaf + // Leaf has 3 samples: 1 of class 0, 2 of class 1 -> probs [1/3, 2/3, 0] + assert!((real_probs[0] - 1.0/3.0).abs() < 1e-6, "Class 0 prob should be 1/3"); + assert!((real_probs[1] - 2.0/3.0).abs() < 1e-6, "Class 1 prob should be 2/3"); + assert!(real_probs[2] < 1e-6, "Class 2 prob should be ~0"); + } + #[cfg_attr( all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test @@ -1178,17 +1175,14 @@ mod tests { ]) .unwrap(); let y: Vec = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; - assert_eq!( y, DecisionTreeClassifier::fit(&x, &y, Default::default()) .and_then(|t| t.predict(&x)) .unwrap() ); - println!( "{:?}", - //3, DecisionTreeClassifier::fit( &x, &y, @@ -1208,11 +1202,8 @@ mod tests { #[test] fn test_random_matrix_with_wrong_rownum() { let x_rand: DenseMatrix = DenseMatrix::::rand(21, 200); - let y: Vec = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; - let fail = DecisionTreeClassifier::fit(&x_rand, &y, Default::default()); - assert!(fail.is_err()); } @@ -1246,7 +1237,6 @@ mod tests { ]) .unwrap(); let y: Vec = vec![1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0]; - assert_eq!( y, DecisionTreeClassifier::fit(&x, &y, Default::default()) @@ -1323,12 +1313,9 @@ mod tests { ]) .unwrap(); let y = vec![1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0]; - let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap(); - let deserialized_tree: DecisionTreeClassifier, Vec> = bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap(); - assert_eq!(tree, deserialized_tree); } -} +} \ No newline at end of file From 470de49d641f43d1251c13b856acc0939b1a742b Mon Sep 17 00:00:00 2001 From: Andrey Shevchenko Date: Mon, 16 Mar 2026 14:50:08 +0300 Subject: [PATCH 6/7] clippy --- src/ensemble/random_forest_classifier.rs | 24 ++++++++++++++++-------- src/tree/decision_tree_classifier.rs | 4 +--- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index f381a61c..229cda5c 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -628,23 +628,31 @@ impl, Y: Array1 Vec { + // improvement: unwrap делаем один раз + let trees = self.trees.as_ref().unwrap(); + // improvement: unwrap classes тоже один раз let k = self.classes.as_ref().unwrap().len(); - let mut probs = vec![0.0; k]; - for tree in self.trees.as_ref().unwrap().iter() { + let mut probs = vec![0.0; k]; + for tree in trees { let tree_probs = tree.predict_proba_for_row_real(x, row); - for i in 0..k { - probs[i] += tree_probs[i]; + // improvement: убран range loop + // improvement: нет индексирования + // improvement: zip гарантирует покомпонентное сложение + for (p, tp) in probs.iter_mut().zip(tree_probs.iter()) { + *p += *tp; // важно разыменование } } - let n_trees = self.trees.as_ref().unwrap().len(); + // improvement: unwrap уже не нужен + let n_trees = trees.len() as f64; - for i in 0..k { - probs[i] /= n_trees as f64; + // improvement: убран needless_range_loop + for p in &mut probs { + *p /= n_trees; } probs @@ -663,7 +671,7 @@ impl, Y: Array1, Y: Array1> let total: usize = current.class_distribution.iter().sum(); let mut probs = vec![0.0; self.num_classes]; if total > 0 { - for i in 0..self.num_classes { - probs[i] = current.class_distribution[i] as f64 / total as f64; - } + for (p, count) in probs.iter_mut().zip(¤t.class_distribution) { *p = *count as f64 / total as f64; } } return probs; } From 8f7b17ab9135571bcdae6b77faaadd47c5c85d5f Mon Sep 17 00:00:00 2001 From: Andrey Shevchenko Date: Thu, 19 Mar 2026 17:23:27 +0300 Subject: [PATCH 7/7] RF predict_proba functionality is now covered by 2 tests. The first test uses Iris dataset, and consists of 4 checks. The 2nd test consists of 2 checks. --- src/ensemble/random_forest_classifier.rs | 77 ++++++++++++------------ 1 file changed, 38 insertions(+), 39 deletions(-) diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index 229cda5c..a07f2db0 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -688,7 +688,6 @@ impl, Y: Array1 Result>, Failed> { - let (n, _) = x.shape(); let mut result = Vec::with_capacity(n); @@ -896,28 +895,37 @@ mod tests { assert_eq!(forest, deserialized_forest); } - - // Test for predict_proba + #[cfg_attr( all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test )] #[test] - fn test_predict_proba_forest() { + fn test_predict_proba_iris() { let x = DenseMatrix::from_2d_array(&[ &[5.1, 3.5, 1.4, 0.2], &[4.9, 3.0, 1.4, 0.2], &[4.7, 3.2, 1.3, 0.2], &[4.6, 3.1, 1.5, 0.2], &[5.0, 3.6, 1.4, 0.2], + &[5.4, 3.9, 1.7, 0.4], + &[4.6, 3.4, 1.4, 0.3], + &[5.0, 3.4, 1.5, 0.2], + &[4.4, 2.9, 1.4, 0.2], + &[4.9, 3.1, 1.5, 0.1], &[7.0, 3.2, 4.7, 1.4], &[6.4, 3.2, 4.5, 1.5], &[6.9, 3.1, 4.9, 1.5], &[5.5, 2.3, 4.0, 1.3], &[6.5, 2.8, 4.6, 1.5], + &[5.7, 2.8, 4.5, 1.3], + &[6.3, 3.3, 4.7, 1.6], + &[4.9, 2.4, 3.3, 1.0], + &[6.6, 2.9, 4.6, 1.3], + &[5.2, 2.7, 3.9, 1.4], ]) .unwrap(); - let y = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1]; + let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; let classifier = RandomForestClassifier::fit( &x, @@ -936,21 +944,24 @@ mod tests { .unwrap(); let probabilities = classifier.predict_proba(&x).unwrap(); - assert_eq!(probabilities.len(), 10); + + // Check 1: dimensions + assert_eq!(probabilities.len(), 20); assert_eq!(probabilities[0].len(), 2); - // Check that probabilities sum to 1.0 for each sample - for row in 0..10 { + // Check 2: probabilities sum to 1.0 for all rows + for row in 0..20 { let row_sum: f64 = probabilities[row].iter().sum(); assert!( (row_sum - 1.0).abs() < 1e-6, - "Row probabilities should sum to 1, got {}", + "Row {} probabilities should sum to 1, got {}", + row, row_sum ); } - // Check if the first 5 samples have higher probability for class 0 - for i in 0..5 { + // Check 3: first 8 samples → higher prob for class 0 + for i in 0..8 { assert!( probabilities[i][0] > probabilities[i][1], "Sample {} should have higher prob for class 0", @@ -958,8 +969,8 @@ mod tests { ); } - // Check if the last 5 samples have higher probability for class 1 - for i in 5..10 { + // Check 4: last 12 samples → higher prob for class 1 + for i in 8..20 { assert!( probabilities[i][1] > probabilities[i][0], "Sample {} should have higher prob for class 1", @@ -968,23 +979,22 @@ mod tests { } } - // Test for predict_proba with mixed classes in leaves #[cfg_attr( all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test )] #[test] - fn test_predict_proba_mixed_leaves() { - // Create a simple dataset where some leaves will have mixed classes - let x: DenseMatrix = DenseMatrix::from_2d_array(&[ - &[1.0, 1.0], - &[1.0, 1.0], - &[1.0, 1.0], - &[5.0, 5.0], - &[5.0, 5.0], + fn test_predict_proba_iris_mixed_leaves() { + // Dataset with mixed leaves + let x = DenseMatrix::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[5.1, 3.5, 1.4, 0.2], // Same features + &[5.1, 3.5, 1.4, 0.2], // Same features + &[7.0, 3.2, 4.7, 1.4], + &[7.0, 3.2, 4.7, 1.4], // Same features ]) .unwrap(); - let y: Vec = vec![0, 0, 1, 2, 2]; // 3 classes, mixed in first group + let y = vec![0, 0, 1, 1, 1]; // Mixed classes in same feature region let classifier = RandomForestClassifier::fit( &x, @@ -999,22 +1009,20 @@ mod tests { let probabilities = classifier.predict_proba(&x).unwrap(); - // All probabilities should be non-negative and sum to 1.0 + // Check 1: All probabilities should be valid for row in 0..5 { let sum: f64 = probabilities[row].iter().sum(); assert!( (sum - 1.0).abs() < 1e-6, - "Probabilities for row {} should sum to 1.0, got {}", - row, - sum + "Probabilities for row {} should sum to 1.0", + row ); for &p in &probabilities[row] { - assert!(p >= 0.0 && p <= 1.0, "Probability {} out of range", p); + assert!(p >= 0.0 && p <= 1.0, "Probability out of range"); } } - // First 3 samples should have non-zero probability for both class 0 and 1 - // (since they're in the same region with mixed classes) + // Check 2: First 3 samples must have non-zero prob for both classes, since they are mixed for i in 0..3 { assert!( probabilities[i][0] > 0.0, @@ -1027,14 +1035,5 @@ mod tests { i ); } - - // Last 2 samples should have high probability for class 2 - for i in 3..5 { - assert!( - probabilities[i][2] > 0.5, - "Sample {} should have high prob for class 2", - i - ); - } } }