diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index f4e8db3c..a07f2db0 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -609,6 +609,95 @@ impl, Y: Array1 Vec { + // improvement: unwrap делаем один раз + let trees = self.trees.as_ref().unwrap(); + + // improvement: unwrap classes тоже один раз + let k = self.classes.as_ref().unwrap().len(); + + let mut probs = vec![0.0; k]; + + for tree in trees { + let tree_probs = tree.predict_proba_for_row_real(x, row); + + // improvement: убран range loop + // improvement: нет индексирования + // improvement: zip гарантирует покомпонентное сложение + for (p, tp) in probs.iter_mut().zip(tree_probs.iter()) { + *p += *tp; // важно разыменование + } + } + + // improvement: unwrap уже не нужен + let n_trees = trees.len() as f64; + + // improvement: убран needless_range_loop + for p in &mut probs { + *p /= n_trees; + } + + probs + } + + /// Predict class probabilities for the input samples. + /// + /// This method returns probability estimates for each sample in the input matrix. + /// For each sample, probabilities are computed by averaging the predictions from + /// all trees in the forest. Each tree contributes a probability distribution based + /// on the class distribution in its leaf node. + /// + /// This is the scikit-learn style `predict_proba` behavior, providing calibrated + /// probability estimates rather than just class predictions. + /// + /// # Arguments + /// + /// * `x` - The input samples as a matrix where each row is a sample and each column + /// is a feature. + /// + /// # Returns + /// + /// A `Result` containing a `Vec>` where each inner vector corresponds to + /// a sample and contains probabilities for each class. The sum of probabilities + /// for each sample equals 1.0. + /// + /// # Note + /// + /// Return type is `Vec>` for minimal API changes. The tree classifier + /// returns `DenseMatrix` for the same method. + /// + /// # Errors + /// + /// Returns an error if the forest has not been fitted (trees are None). + pub fn predict_proba(&self, x: &X) -> Result>, Failed> { + let (n, _) = x.shape(); + + let mut result = Vec::with_capacity(n); + + for i in 0..n { + result.push(self.predict_proba_for_row(x, i)); + } + + Ok(result) + } } #[cfg(test)] @@ -806,4 +895,145 @@ mod tests { assert_eq!(forest, deserialized_forest); } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn test_predict_proba_iris() { + let x = DenseMatrix::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[5.4, 3.9, 1.7, 0.4], + &[4.6, 3.4, 1.4, 0.3], + &[5.0, 3.4, 1.5, 0.2], + &[4.4, 2.9, 1.4, 0.2], + &[4.9, 3.1, 1.5, 0.1], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + &[5.7, 2.8, 4.5, 1.3], + &[6.3, 3.3, 4.7, 1.6], + &[4.9, 2.4, 3.3, 1.0], + &[6.6, 2.9, 4.6, 1.3], + &[5.2, 2.7, 3.9, 1.4], + ]) + .unwrap(); + let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; + + let classifier = RandomForestClassifier::fit( + &x, + &y, + RandomForestClassifierParameters { + criterion: SplitCriterion::Gini, + max_depth: Option::None, + min_samples_leaf: 1, + min_samples_split: 2, + n_trees: 10, + m: Option::None, + keep_samples: false, + seed: 87, + }, + ) + .unwrap(); + + let probabilities = classifier.predict_proba(&x).unwrap(); + + // Check 1: dimensions + assert_eq!(probabilities.len(), 20); + assert_eq!(probabilities[0].len(), 2); + + // Check 2: probabilities sum to 1.0 for all rows + for row in 0..20 { + let row_sum: f64 = probabilities[row].iter().sum(); + assert!( + (row_sum - 1.0).abs() < 1e-6, + "Row {} probabilities should sum to 1, got {}", + row, + row_sum + ); + } + + // Check 3: first 8 samples → higher prob for class 0 + for i in 0..8 { + assert!( + probabilities[i][0] > probabilities[i][1], + "Sample {} should have higher prob for class 0", + i + ); + } + + // Check 4: last 12 samples → higher prob for class 1 + for i in 8..20 { + assert!( + probabilities[i][1] > probabilities[i][0], + "Sample {} should have higher prob for class 1", + i + ); + } + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn test_predict_proba_iris_mixed_leaves() { + // Dataset with mixed leaves + let x = DenseMatrix::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[5.1, 3.5, 1.4, 0.2], // Same features + &[5.1, 3.5, 1.4, 0.2], // Same features + &[7.0, 3.2, 4.7, 1.4], + &[7.0, 3.2, 4.7, 1.4], // Same features + ]) + .unwrap(); + let y = vec![0, 0, 1, 1, 1]; // Mixed classes in same feature region + + let classifier = RandomForestClassifier::fit( + &x, + &y, + RandomForestClassifierParameters { + n_trees: 5, + seed: 42, + ..Default::default() + }, + ) + .unwrap(); + + let probabilities = classifier.predict_proba(&x).unwrap(); + + // Check 1: All probabilities should be valid + for row in 0..5 { + let sum: f64 = probabilities[row].iter().sum(); + assert!( + (sum - 1.0).abs() < 1e-6, + "Probabilities for row {} should sum to 1.0", + row + ); + for &p in &probabilities[row] { + assert!(p >= 0.0 && p <= 1.0, "Probability out of range"); + } + } + + // Check 2: First 3 samples must have non-zero prob for both classes, since they are mixed + for i in 0..3 { + assert!( + probabilities[i][0] > 0.0, + "Sample {} should have non-zero prob for class 0", + i + ); + assert!( + probabilities[i][1] > 0.0, + "Sample {} should have non-zero prob for class 1", + i + ); + } + } } diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index 96007677..c9dc7d0f 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -68,13 +68,10 @@ use std::collections::LinkedList; use std::default::Default; use std::fmt::Debug; use std::marker::PhantomData; - use rand::seq::SliceRandom; use rand::Rng; - #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; - use crate::api::{Predictor, SupervisedEstimator}; use crate::error::Failed; use crate::linalg::basic::arrays::MutArray; @@ -162,12 +159,21 @@ pub enum SplitCriterion { #[derive(Debug, Clone)] struct Node { output: usize, + /// number of samples that reached this node n_node_samples: usize, + /// class distribution in this node (for probability estimation) + class_distribution: Vec, + /// feature used for split split_feature: usize, + /// threshold split_value: Option, + /// impurity improvement of split split_score: Option, + /// left child index true_child: Option, + /// right child index false_child: Option, + /// impurity value of node impurity: Option, } @@ -208,6 +214,7 @@ impl PartialEq for Node { (None, None) => true, _ => false, } + && self.class_distribution == other.class_distribution } } @@ -405,16 +412,17 @@ impl Default for DecisionTreeClassifierSearchParameters { } impl Node { - fn new(output: usize, n_node_samples: usize) -> Self { + fn new(output: usize, n_node_samples: usize, class_distribution: Vec) -> Self { Node { output, n_node_samples, + class_distribution, split_feature: 0, - split_value: Option::None, - split_score: Option::None, - true_child: Option::None, - false_child: Option::None, - impurity: Option::None, + split_value: None, + split_score: None, + true_child: None, + false_child: None, + impurity: None, } } } @@ -433,7 +441,6 @@ struct NodeVisitor<'a, TX: Number + PartialOrd, X: Array2> { fn impurity(criterion: &SplitCriterion, count: &[usize], n: usize) -> f64 { let mut impurity = 0f64; - match criterion { SplitCriterion::Gini => { impurity = 1f64; @@ -444,7 +451,6 @@ fn impurity(criterion: &SplitCriterion, count: &[usize], n: usize) -> f64 { } } } - SplitCriterion::Entropy => { for count_i in count.iter() { if *count_i > 0 { @@ -462,7 +468,6 @@ fn impurity(criterion: &SplitCriterion, count: &[usize], n: usize) -> f64 { impurity = (1f64 - impurity).abs(); } } - impurity } @@ -492,14 +497,12 @@ impl<'a, TX: Number + PartialOrd, X: Array2> NodeVisitor<'a, TX, X> { pub(crate) fn which_max(x: &[usize]) -> usize { let mut m = x[0]; let mut which = 0; - for (i, x_i) in x.iter().enumerate().skip(1) { if *x_i > m { m = *x_i; which = i; } } - which } @@ -520,7 +523,6 @@ impl, Y: Array1> _phantom_y: PhantomData, } } - fn fit(x: &X, y: &Y, parameters: DecisionTreeClassifierParameters) -> Result { DecisionTreeClassifier::fit(x, y, parameters) } @@ -549,7 +551,6 @@ impl, Y: Array1> if x_nrows != y.shape() { return Err(Failed::fit("Size of x should equal size of y")); } - let samples = vec![1; x_nrows]; DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters) } @@ -570,31 +571,30 @@ impl, Y: Array1> "Incorrect number of classes: {k}. Should be >= 2." ))); } - let mut rng = get_rng_impl(parameters.seed); let mut yi: Vec = vec![0; y_ncols]; - for (i, yi_i) in yi.iter_mut().enumerate().take(y_ncols) { let yc = y.get(i); *yi_i = classes.iter().position(|c| yc == c).unwrap(); } - let mut change_nodes: Vec = Vec::new(); - + // -------------------------------- + // compute class distribution + // -------------------------------- let mut count = vec![0; k]; for i in 0..y_ncols { count[yi[i]] += samples[i]; } - - let root = Node::new(which_max(&count), y_ncols); + // majority class + let root_output = which_max(&count); + let root = Node::new(root_output, y_ncols, count.clone()); change_nodes.push(root); + // -------------------------------- let mut order: Vec> = Vec::new(); - for i in 0..num_attributes { let mut col_i: Vec = x.get_col(i).iterator(0).copied().collect(); order.push(col_i.argsort_mut()); } - let mut tree = DecisionTreeClassifier { nodes: change_nodes, parameters: Some(parameters), @@ -606,22 +606,17 @@ impl, Y: Array1> _phantom_x: PhantomData, _phantom_y: PhantomData, }; - let mut visitor = NodeVisitor::::new(0, samples, &order, x, &yi, 1); - let mut visitor_queue: LinkedList> = LinkedList::new(); - if tree.find_best_cutoff(&mut visitor, mtry, &mut rng) { visitor_queue.push_back(visitor); } - while tree.depth() < tree.parameters().max_depth.unwrap_or(u16::MAX) { match visitor_queue.pop_front() { Some(node) => tree.split(node, mtry, &mut visitor_queue, &mut rng), None => break, }; } - Ok(tree) } @@ -629,22 +624,17 @@ impl, Y: Array1> /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. pub fn predict(&self, x: &X) -> Result { let mut result = Y::zeros(x.shape().0); - let (n, _) = x.shape(); - for i in 0..n { result.set(i, self.classes()[self.predict_for_row(x, i)]); } - Ok(result) } pub(crate) fn predict_for_row(&self, x: &X, row: usize) -> usize { let mut result = 0; let mut queue: LinkedList = LinkedList::new(); - queue.push_back(0); - while !queue.is_empty() { match queue.pop_front() { Some(node_id) => { @@ -662,7 +652,6 @@ impl, Y: Array1> None => break, }; } - result } @@ -673,7 +662,6 @@ impl, Y: Array1> rng: &mut impl Rng, ) -> bool { let (n_rows, n_attr) = visitor.x.shape(); - let mut label = None; let mut is_pure = true; for i in 0..n_rows { @@ -691,7 +679,6 @@ impl, Y: Array1> } } } - let n = visitor.samples.iter().sum(); let mut count = vec![0; self.num_classes]; let mut false_count = vec![0; self.num_classes]; @@ -700,27 +687,20 @@ impl, Y: Array1> count[visitor.y[i]] += visitor.samples[i]; } } - self.nodes[visitor.node].impurity = Some(impurity(&self.parameters().criterion, &count, n)); - if is_pure { return false; } - if n <= self.parameters().min_samples_split { return false; } - let mut variables = (0..n_attr).collect::>(); - if mtry < n_attr { variables.shuffle(rng); } - for variable in variables.iter().take(mtry) { self.find_best_split(visitor, n, &count, &mut false_count, *variable); } - self.nodes()[visitor.node].split_score.is_some() } @@ -735,34 +715,26 @@ impl, Y: Array1> let mut true_count = vec![0; self.num_classes]; let mut prevx = Option::None; let mut prevy = 0; - for i in visitor.order[j].iter() { if visitor.samples[*i] > 0 { let x_ij = *visitor.x.get((*i, j)); - if prevx.is_none() || x_ij == prevx.unwrap() || visitor.y[*i] == prevy { prevx = Some(x_ij); prevy = visitor.y[*i]; true_count[visitor.y[*i]] += visitor.samples[*i]; continue; } - let tc = true_count.iter().sum(); let fc = n - tc; - - if tc < self.parameters().min_samples_leaf - || fc < self.parameters().min_samples_leaf - { + if tc < self.parameters().min_samples_leaf || fc < self.parameters().min_samples_leaf { prevx = Some(x_ij); prevy = visitor.y[*i]; true_count[visitor.y[*i]] += visitor.samples[*i]; continue; } - for l in 0..self.num_classes { false_count[l] = count[l] - true_count[l]; } - let true_label = which_max(&true_count); let false_label = which_max(false_count); let parent_impurity = self.nodes()[visitor.node].impurity.unwrap(); @@ -771,7 +743,6 @@ impl, Y: Array1> * impurity(&self.parameters().criterion, &true_count, tc) - fc as f64 / n as f64 * impurity(&self.parameters().criterion, false_count, fc); - if self.nodes()[visitor.node].split_score.is_none() || gain > self.nodes()[visitor.node].split_score.unwrap() { @@ -779,11 +750,9 @@ impl, Y: Array1> self.nodes[visitor.node].split_value = Option::Some((x_ij + prevx.unwrap()).to_f64().unwrap() / 2f64); self.nodes[visitor.node].split_score = Option::Some(gain); - visitor.true_child_output = true_label; visitor.false_child_output = false_label; } - prevx = Some(x_ij); prevy = visitor.y[*i]; true_count[visitor.y[*i]] += visitor.samples[*i]; @@ -802,7 +771,6 @@ impl, Y: Array1> let mut tc = 0; let mut fc = 0; let mut true_samples: Vec = vec![0; n]; - for (i, true_sample) in true_samples.iter_mut().enumerate().take(n) { if visitor.samples[i] > 0 { if visitor @@ -820,25 +788,38 @@ impl, Y: Array1> } } } - if tc < self.parameters().min_samples_leaf || fc < self.parameters().min_samples_leaf { self.nodes[visitor.node].split_feature = 0; self.nodes[visitor.node].split_value = Option::None; self.nodes[visitor.node].split_score = Option::None; - return false; } - let true_child_idx = self.nodes().len(); - - self.nodes.push(Node::new(visitor.true_child_output, tc)); + // Added. We are computing class distribution + let mut true_distribution = vec![0; self.num_classes]; + let mut false_distribution = vec![0; self.num_classes]; + for i in 0..n { + if true_samples[i] > 0 { + true_distribution[visitor.y[i]] += true_samples[i]; + } + if visitor.samples[i] > 0 { + false_distribution[visitor.y[i]] += visitor.samples[i]; + } + } + // Some additional checks + let true_sum: usize = true_distribution.iter().sum(); + let false_sum: usize = false_distribution.iter().sum(); + let original_total: usize = true_sum + false_sum; + debug_assert_eq!(true_sum, tc); + debug_assert_eq!(false_sum, fc); + debug_assert_eq!(true_sum + false_sum, original_total); + + self.nodes.push(Node::new(visitor.true_child_output, tc, true_distribution)); let false_child_idx = self.nodes().len(); - self.nodes.push(Node::new(visitor.false_child_output, fc)); + self.nodes.push(Node::new(visitor.false_child_output, fc, false_distribution)); self.nodes[visitor.node].true_child = Some(true_child_idx); self.nodes[visitor.node].false_child = Some(false_child_idx); - self.depth = u16::max(self.depth, visitor.level + 1); - let mut true_visitor = NodeVisitor::::new( true_child_idx, true_samples, @@ -847,11 +828,9 @@ impl, Y: Array1> visitor.y, visitor.level + 1, ); - if self.find_best_cutoff(&mut true_visitor, mtry, rng) { visitor_queue.push_back(true_visitor); } - let mut false_visitor = NodeVisitor::::new( false_child_idx, visitor.samples, @@ -860,25 +839,21 @@ impl, Y: Array1> visitor.y, visitor.level + 1, ); - if self.find_best_cutoff(&mut false_visitor, mtry, rng) { visitor_queue.push_back(false_visitor); } - true } /// Compute feature importances for the fitted tree. pub fn compute_feature_importances(&self, normalize: bool) -> Vec { let mut importances = vec![0f64; self.num_features]; - for node in self.nodes().iter() { if node.true_child.is_none() && node.false_child.is_none() { continue; } let left = &self.nodes()[node.true_child.unwrap()]; let right = &self.nodes()[node.false_child.unwrap()]; - importances[node.split_feature] += node.n_node_samples as f64 * node.impurity.unwrap() - left.n_node_samples as f64 * left.impurity.unwrap() - right.n_node_samples as f64 * right.impurity.unwrap(); @@ -897,6 +872,10 @@ impl, Y: Array1> /// Predict class probabilities for the input samples. /// + /// This method returns probabilities using the original API (one-hot encoding at leaves). + /// For proper probability estimation based on class distribution in leaves, use + /// `predict_proba_for_row_real` for individual samples. + /// /// # Arguments /// /// * `x` - The input samples as a matrix where each row is a sample and each column is a feature. @@ -905,6 +884,7 @@ impl, Y: Array1> /// /// A `Result` containing a `DenseMatrix` where each row corresponds to a sample and each column /// corresponds to a class. The values represent the probability of the sample belonging to each class. + /// Note: Current implementation returns one-hot vectors at leaves (backward compatibility). /// /// # Errors /// @@ -913,18 +893,19 @@ impl, Y: Array1> let (n_samples, _) = x.shape(); let n_classes = self.classes().len(); let mut result = DenseMatrix::::zeros(n_samples, n_classes); - for i in 0..n_samples { let probs = self.predict_proba_for_row(x, i)?; for (j, &prob) in probs.iter().enumerate() { result.set((i, j), prob); } } - Ok(result) } - /// Predict class probabilities for a single input sample. + /// Predict class probabilities for a single input sample (legacy API). + /// + /// Returns one-hot encoded probabilities: 1.0 for the majority class, 0.0 for others. + /// This maintains backward compatibility with the original smartcore API. /// /// # Arguments /// @@ -933,32 +914,68 @@ impl, Y: Array1> /// /// # Returns /// - /// A vector of probabilities, one for each class, representing the probability - /// of the input sample belonging to each class. + /// A vector of probabilities, one for each class. Only the majority class has probability 1.0. fn predict_proba_for_row(&self, x: &X, row: usize) -> Result, Failed> { let mut node = 0; - while let Some(current_node) = self.nodes().get(node) { if current_node.true_child.is_none() && current_node.false_child.is_none() { - // Leaf node reached + // Leaf node reached - legacy behavior: one-hot encoding let mut probs = vec![0.0; self.classes().len()]; probs[current_node.output] = 1.0; return Ok(probs); } - let split_feature = current_node.split_feature; let split_value = current_node.split_value.unwrap_or(f64::NAN); - if x.get((row, split_feature)).to_f64().unwrap() <= split_value { node = current_node.true_child.unwrap(); } else { node = current_node.false_child.unwrap(); } } - - // This should never happen if the tree is properly constructed Err(Failed::predict("Nodes iteration did not reach leaf")) } + + /// Predict class probabilities for a single input sample using class distribution. + /// + /// This method returns proper probability estimates based on the class distribution + /// in the leaf node. For a leaf with N samples, if class i has count[i] samples, + /// the probability for class i is count[i] / N. + /// + /// This is the scikit-learn style behavior and should be used when you need + /// calibrated probability estimates rather than just the majority class. + /// + /// # Arguments + /// + /// * `x` - The input matrix containing all samples. + /// * `row` - The index of the row in `x` for which to predict probabilities. + /// + /// # Returns + /// + /// A vector of probabilities, one for each class. The sum of probabilities equals 1.0. + /// Each probability represents the fraction of training samples of that class + /// that reached the same leaf node. + pub fn predict_proba_for_row_real(&self, x: &X, row: usize) -> Vec { + let mut node = 0; + loop { + let current = &self.nodes()[node]; + if current.true_child.is_none() && current.false_child.is_none() { + // Leaf node reached - use class distribution for proper probabilities + let total: usize = current.class_distribution.iter().sum(); + let mut probs = vec![0.0; self.num_classes]; + if total > 0 { + for (p, count) in probs.iter_mut().zip(¤t.class_distribution) { *p = *count as f64 / total as f64; } + } + return probs; + } + let split_feature = current.split_feature; + let split_value = current.split_value.unwrap(); + if x.get((row, split_feature)).to_f64().unwrap() <= split_value { + node = current.true_child.unwrap(); + } else { + node = current.false_child.unwrap(); + } + } + } } #[cfg(test)] @@ -1027,12 +1044,9 @@ mod tests { ]) .unwrap(); let y: Vec = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1]; - let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap(); let probabilities = tree.predict_proba(&x).unwrap(); - assert_eq!(probabilities.shape(), (10, 2)); - for row in 0..10 { let row_sum: f64 = probabilities.get_row(row).sum(); assert!( @@ -1040,18 +1054,94 @@ mod tests { "Row probabilities should sum to 1" ); } - // Check if the first 5 samples have higher probability for class 0 for i in 0..5 { assert!(probabilities.get((i, 0)) > probabilities.get((i, 1))); } - // Check if the last 5 samples have higher probability for class 1 for i in 5..10 { assert!(probabilities.get((i, 1)) > probabilities.get((i, 0))); } } + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn test_predict_proba_real_distribution() { + // Test that predict_proba_for_row_real returns proper class distribution + let x: DenseMatrix = DenseMatrix::from_2d_array(&[ + &[1.0, 2.0], + &[1.1, 2.1], + &[1.2, 2.2], + &[5.0, 6.0], + &[5.1, 6.1], + ]) + .unwrap(); + let y: Vec = vec![0, 0, 1, 2, 2]; // 3 classes, mixed in first leaf + + let tree = DecisionTreeClassifier::fit( + &x, + &y, + DecisionTreeClassifierParameters { + min_samples_leaf: 2, // Force some mixing + ..Default::default() + }, + ) + .unwrap(); + + // Test that probabilities sum to 1.0 for all samples + for row in 0..x.shape().0 { + let probs = tree.predict_proba_for_row_real(&x, row); + let sum: f64 = probs.iter().sum(); + assert!( + (sum - 1.0).abs() < 1e-6, + "Probabilities for row {} should sum to 1.0, got {}", + row, + sum + ); + // All probabilities should be non-negative + for &p in &probs { + assert!(p >= 0.0 && p <= 1.0, "Probability {} out of range", p); + } + } + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn test_predict_proba_real_vs_legacy() { + // Test that the new method gives different (better) results than legacy + let x: DenseMatrix = DenseMatrix::from_2d_array(&[ + &[1.0, 1.0], + &[1.0, 1.0], // Same features, different classes -> mixed leaf + &[1.0, 1.0], + &[2.0, 2.0], + ]) + .unwrap(); + let y: Vec = vec![0, 1, 1, 2]; // Leaf 1: [0,1,1], Leaf 2: [2] + + let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap(); + + // For sample 0 (in mixed leaf), legacy returns one-hot, new returns distribution + let legacy_probs = tree.predict_proba_for_row(&x, 0).unwrap(); + let real_probs = tree.predict_proba_for_row_real(&x, 0); + + // Legacy should be one-hot (only one class has prob 1.0) + let ones: usize = legacy_probs.iter().filter(|&&p| p == 1.0).count(); + let zeros: usize = legacy_probs.iter().filter(|&&p| p == 0.0).count(); + assert_eq!(ones + zeros, legacy_probs.len(), "Legacy should be one-hot"); + + // Real should have fractional probabilities for the mixed leaf + // Leaf has 3 samples: 1 of class 0, 2 of class 1 -> probs [1/3, 2/3, 0] + assert!((real_probs[0] - 1.0/3.0).abs() < 1e-6, "Class 0 prob should be 1/3"); + assert!((real_probs[1] - 2.0/3.0).abs() < 1e-6, "Class 1 prob should be 2/3"); + assert!(real_probs[2] < 1e-6, "Class 2 prob should be ~0"); + } + #[cfg_attr( all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test @@ -1083,17 +1173,14 @@ mod tests { ]) .unwrap(); let y: Vec = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; - assert_eq!( y, DecisionTreeClassifier::fit(&x, &y, Default::default()) .and_then(|t| t.predict(&x)) .unwrap() ); - println!( "{:?}", - //3, DecisionTreeClassifier::fit( &x, &y, @@ -1113,11 +1200,8 @@ mod tests { #[test] fn test_random_matrix_with_wrong_rownum() { let x_rand: DenseMatrix = DenseMatrix::::rand(21, 200); - let y: Vec = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; - let fail = DecisionTreeClassifier::fit(&x_rand, &y, Default::default()); - assert!(fail.is_err()); } @@ -1151,7 +1235,6 @@ mod tests { ]) .unwrap(); let y: Vec = vec![1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0]; - assert_eq!( y, DecisionTreeClassifier::fit(&x, &y, Default::default()) @@ -1228,12 +1311,9 @@ mod tests { ]) .unwrap(); let y = vec![1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0]; - let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap(); - let deserialized_tree: DecisionTreeClassifier, Vec> = bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap(); - assert_eq!(tree, deserialized_tree); } -} +} \ No newline at end of file