diff --git a/encodings/alp/benches/alp_compress.rs b/encodings/alp/benches/alp_compress.rs index a67583fd72a..239541533f8 100644 --- a/encodings/alp/benches/alp_compress.rs +++ b/encodings/alp/benches/alp_compress.rs @@ -103,20 +103,47 @@ fn decompress_alp(bencher: Bencher, args: (usize, f64 .bench_values(|(v, mut ctx)| decompress_into_array(v, &mut ctx)); } -#[divan::bench(types = [f32, f64], args = [10_000, 100_000])] -fn compress_rd(bencher: Bencher, n: usize) { - let primitive = PrimitiveArray::new(buffer![T::from(1.23).unwrap(); n], Validity::NonNullable); - let encoder = RDEncoder::new(&[T::from(1.23).unwrap()]); +const RD_BENCH_ARGS: &[(usize, f64)] = &[ + // length, fraction_patch + (10_000, 0.0), + (10_000, 0.01), + (10_000, 0.1), + (100_000, 0.0), + (100_000, 0.01), + (100_000, 0.1), +]; + +fn make_rd_array(n: usize, fraction_patch: f64) -> PrimitiveArray { + let base_val = T::from(1.23).unwrap(); + let mut rng = StdRng::seed_from_u64(42); + let mut values = buffer![base_val; n].into_mut(); + if fraction_patch > 0.0 { + let outlier = T::from(1000.0).unwrap(); + for index in 0..values.len() { + if rng.random_bool(fraction_patch) { + values[index] = outlier; + } + } + } + PrimitiveArray::new(values.freeze(), Validity::NonNullable) +} + +#[divan::bench(types = [f32, f64], args = RD_BENCH_ARGS)] +fn compress_rd(bencher: Bencher, args: (usize, f64)) { + let (n, fraction_patch) = args; + let primitive = make_rd_array::(n, fraction_patch); + let encoder = RDEncoder::new(primitive.as_slice::()); bencher .with_inputs(|| (&primitive, &encoder)) .bench_refs(|(primitive, encoder)| encoder.encode(primitive)) } -#[divan::bench(types = [f32, f64], args = [10_000, 100_000])] -fn decompress_rd(bencher: Bencher, n: usize) { - let primitive = PrimitiveArray::new(buffer![T::from(1.23).unwrap(); n], Validity::NonNullable); - let encoder = RDEncoder::new(&[T::from(1.23).unwrap()]); +#[divan::bench(types = [f32, f64], args = RD_BENCH_ARGS)] +fn decompress_rd(bencher: Bencher, args: (usize, f64)) { + let (n, fraction_patch) = args; + let primitive = make_rd_array::(n, fraction_patch); + let encoder = RDEncoder::new(primitive.as_slice::()); let encoded = encoder.encode(&primitive); bencher diff --git a/encodings/alp/public-api.lock b/encodings/alp/public-api.lock index 7f769a0fc7a..5aeb977a646 100644 --- a/encodings/alp/public-api.lock +++ b/encodings/alp/public-api.lock @@ -552,7 +552,7 @@ pub fn f64::to_u16(bits: Self::UINT) -> u16 pub fn vortex_alp::alp_encode(parray: &vortex_array::arrays::primitive::array::PrimitiveArray, exponents: core::option::Option) -> vortex_error::VortexResult -pub fn vortex_alp::alp_rd_decode(left_parts: vortex_buffer::buffer::Buffer, left_parts_dict: &[u16], right_bit_width: u8, right_parts: vortex_buffer::buffer_mut::BufferMut<::UINT>, left_parts_patches: core::option::Option<&vortex_array::patches::Patches>, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> +pub fn vortex_alp::alp_rd_decode(left_parts: vortex_buffer::buffer_mut::BufferMut, left_parts_dict: &[u16], right_bit_width: u8, right_parts: vortex_buffer::buffer_mut::BufferMut<::UINT>, left_parts_patches: core::option::Option<&vortex_array::patches::Patches>, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> pub fn vortex_alp::decompress_into_array(array: vortex_alp::ALPArray, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult diff --git a/encodings/alp/src/alp_rd/array.rs b/encodings/alp/src/alp_rd/array.rs index df72cc189d3..7f4b49cf0b8 100644 --- a/encodings/alp/src/alp_rd/array.rs +++ b/encodings/alp/src/alp_rd/array.rs @@ -26,7 +26,6 @@ use vortex_array::patches::PatchesMetadata; use vortex_array::serde::ArrayChildren; use vortex_array::stats::ArrayStats; use vortex_array::stats::StatsSetRef; -use vortex_array::validity::Validity; use vortex_array::vtable; use vortex_array::vtable::ArrayId; use vortex_array::vtable::VTable; @@ -41,7 +40,6 @@ use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use vortex_error::vortex_err; use vortex_error::vortex_panic; -use vortex_mask::Mask; use vortex_session::VortexSession; use crate::alp_rd::kernel::PARENT_KERNELS; @@ -300,38 +298,31 @@ impl VTable for ALPRD { let left_parts = array.left_parts().clone().execute::(ctx)?; let right_parts = array.right_parts().clone().execute::(ctx)?; - // Decode the left_parts using our builtin dictionary. let left_parts_dict = array.left_parts_dictionary(); - let validity = array - .left_parts() - .validity()? - .to_array(array.len()) - .execute::(ctx)?; - let decoded_array = if array.is_f32() { PrimitiveArray::new( alp_rd_decode::( - left_parts.into_buffer::(), + left_parts.into_buffer_mut::(), left_parts_dict, array.right_bit_width, right_parts.into_buffer_mut::(), array.left_parts_patches(), ctx, )?, - Validity::from_mask(validity, array.dtype().nullability()), + array.left_parts().validity()?, ) } else { PrimitiveArray::new( alp_rd_decode::( - left_parts.into_buffer::(), + left_parts.into_buffer_mut::(), left_parts_dict, array.right_bit_width, right_parts.into_buffer_mut::(), array.left_parts_patches(), ctx, )?, - Validity::from_mask(validity, array.dtype().nullability()), + array.left_parts().validity()?, ) }; diff --git a/encodings/alp/src/alp_rd/mod.rs b/encodings/alp/src/alp_rd/mod.rs index 58188f7ab9f..d14935b89af 100644 --- a/encodings/alp/src/alp_rd/mod.rs +++ b/encodings/alp/src/alp_rd/mod.rs @@ -285,13 +285,13 @@ impl RDEncoder { } } -/// Decode a vector of ALP-RD encoded values back into their original floating point format. +/// Decode ALP-RD encoded values back into their original floating point format. /// /// # Panics /// -/// The function panics if the provided `left_parts` and `right_parts` differ in length. +/// Panics if `left_parts` and `right_parts` differ in length. pub fn alp_rd_decode( - left_parts: Buffer, + mut left_parts: BufferMut, left_parts_dict: &[u16], right_bit_width: u8, right_parts: BufferMut, @@ -302,27 +302,39 @@ pub fn alp_rd_decode( vortex_panic!("alp_rd_decode: left_parts.len != right_parts.len"); } - // Decode the left-parts dictionary - let mut values = BufferMut::::from_iter( - left_parts - .iter() - .map(|code| left_parts_dict[*code as usize]), - ); + let shift = right_bit_width as usize; - // Apply any patches if let Some(patches) = left_parts_patches { + for code in left_parts.iter_mut() { + *code = left_parts_dict[*code as usize]; + } let indices = patches.indices().clone().execute::(ctx)?; let patch_values = patches.values().clone().execute::(ctx)?; - alp_rd_apply_patches(&mut values, &indices, &patch_values, patches.offset()); - } + alp_rd_apply_patches(&mut left_parts, &indices, &patch_values, patches.offset()); + + alp_rd_combine_inplace::( + right_parts, + |right, &left| { + *right = (::from_u16(left) << shift) | *right; + }, + left_parts.as_ref(), + ) + } else { + // Pre-shift dictionary entries so the hot loop is just a lookup + OR. + let mut shifted_dict = [T::UINT::default(); MAX_DICT_SIZE as usize]; + for (i, &entry) in left_parts_dict.iter().enumerate() { + shifted_dict[i] = ::from_u16(entry) << shift; + } - // Shift the left-parts and add in the right-parts. - Ok(alp_rd_decode_core( - left_parts_dict, - right_bit_width, - right_parts, - values, - )) + alp_rd_combine_inplace::( + right_parts, + |right, &code| { + // SAFETY: codes are bounded by dict size (< left_parts_dict.len() <= MAX_DICT_SIZE). + *right = unsafe { *shifted_dict.get_unchecked(code as usize) } | *right; + }, + left_parts.as_ref(), + ) + } } /// Apply patches to the decoded left-parts values. @@ -343,23 +355,18 @@ fn alp_rd_apply_patches( }) } -/// Core decode logic shared between `alp_rd_decode` and `execute_alp_rd_decode`. -fn alp_rd_decode_core( - _left_parts_dict: &[u16], - right_bit_width: u8, - right_parts: BufferMut, - values: BufferMut, -) -> Buffer { - // Shift the left-parts and add in the right-parts. - let mut index = 0; - right_parts - .map_each_in_place(|right| { - let left = values[index]; - index += 1; - let left = ::from_u16(left); - T::from_bits((left << (right_bit_width as usize)) | right) - }) - .freeze() +/// Zip `right_parts` with `left_data`, apply `combine_fn` per element, then reinterpret the +/// buffer from `T::UINT` to `T` (same bit-width: u32↔f32, u64↔f64). +fn alp_rd_combine_inplace( + mut right_parts: BufferMut, + combine_fn: impl Fn(&mut T::UINT, &u16), + left_data: &[u16], +) -> VortexResult> { + for (right, left) in right_parts.as_mut_slice().iter_mut().zip(left_data.iter()) { + combine_fn(right, left); + } + // SAFETY: all bit patterns of T::UINT are valid T (u32↔f32 or u64↔f64). + Ok(unsafe { right_parts.transmute::() }.freeze()) } /// Find the best "cut point" for a set of floating point values such that we can