From 17e0e1de08383afe9edf050ff289037fbef6f5ea Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 20 Mar 2026 16:38:28 -0400 Subject: [PATCH 1/4] boilerplate with `NormVector` encoding Signed-off-by: Connor Tsui --- vortex-tensor/src/encodings/mod.rs | 5 + vortex-tensor/src/encodings/norm/array.rs | 33 +++++ vortex-tensor/src/encodings/norm/mod.rs | 13 ++ .../src/encodings/norm/vtable/mod.rs | 121 ++++++++++++++++++ .../src/encodings/norm/vtable/operations.rs | 15 +++ .../src/encodings/norm/vtable/validity.rs | 14 ++ vortex-tensor/src/lib.rs | 6 +- 7 files changed, 205 insertions(+), 2 deletions(-) create mode 100644 vortex-tensor/src/encodings/mod.rs create mode 100644 vortex-tensor/src/encodings/norm/array.rs create mode 100644 vortex-tensor/src/encodings/norm/mod.rs create mode 100644 vortex-tensor/src/encodings/norm/vtable/mod.rs create mode 100644 vortex-tensor/src/encodings/norm/vtable/operations.rs create mode 100644 vortex-tensor/src/encodings/norm/vtable/validity.rs diff --git a/vortex-tensor/src/encodings/mod.rs b/vortex-tensor/src/encodings/mod.rs new file mode 100644 index 00000000000..0d9feeb17ce --- /dev/null +++ b/vortex-tensor/src/encodings/mod.rs @@ -0,0 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +pub mod norm; +// mod spherical; diff --git a/vortex-tensor/src/encodings/norm/array.rs b/vortex-tensor/src/encodings/norm/array.rs new file mode 100644 index 00000000000..9e6e19dd3a8 --- /dev/null +++ b/vortex-tensor/src/encodings/norm/array.rs @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex::array::ArrayRef; + +/// A normalized array that stores unit-normalized vectors alongside their original L2 norms. +/// +/// Each vector in the array is divided by its L2 norm, producing a unit-normalized vector. The +/// original norms are stored separately so that the original vectors can be reconstructed. +#[derive(Debug, Clone)] +pub struct NormVectorArray { + /// The backing vector array that has been unit normalized. + /// + /// The underlying elements of the vector array must be floating-point. + vector_array: ArrayRef, + + /// The L2 (Frobenius) norms of each vector. + /// + /// This must have the same dtype as the elements of the vector array. + norms: ArrayRef, +} + +impl NormVectorArray { + /// Returns a reference to the backing vector array that has been unit normalized. + pub fn vector_array(&self) -> &ArrayRef { + &self.vector_array + } + + /// Returns a reference to the L2 (Frobenius) norms of each vector. + pub fn norms(&self) -> &ArrayRef { + &self.norms + } +} diff --git a/vortex-tensor/src/encodings/norm/mod.rs b/vortex-tensor/src/encodings/norm/mod.rs new file mode 100644 index 00000000000..252f867de9c --- /dev/null +++ b/vortex-tensor/src/encodings/norm/mod.rs @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +mod array; +pub use array::NormVectorArray; + +// pub(crate) mod compute; + +mod vtable; +pub use vtable::NormVector; + +// #[cfg(test)] +// mod tests; diff --git a/vortex-tensor/src/encodings/norm/vtable/mod.rs b/vortex-tensor/src/encodings/norm/vtable/mod.rs new file mode 100644 index 00000000000..55f7cd026dd --- /dev/null +++ b/vortex-tensor/src/encodings/norm/vtable/mod.rs @@ -0,0 +1,121 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::hash::Hasher; + +use vortex::array::ArrayRef; +use vortex::array::EmptyMetadata; +use vortex::array::ExecutionCtx; +use vortex::array::ExecutionStep; +use vortex::array::Precision; +use vortex::array::buffer::BufferHandle; +use vortex::array::serde::ArrayChildren; +use vortex::array::stats::StatsSetRef; +use vortex::array::vtable; +use vortex::array::vtable::ArrayId; +use vortex::array::vtable::VTable; +use vortex::array::vtable::ValidityVTableFromChild; +use vortex::dtype::DType; +use vortex::error::VortexResult; +use vortex::session::VortexSession; + +use crate::encodings::norm::array::NormVectorArray; + +mod operations; +mod validity; + +vtable!(NormVector); + +#[derive(Debug)] +pub struct NormVector; + +impl VTable for NormVector { + type Array = NormVectorArray; + type Metadata = EmptyMetadata; + type OperationsVTable = Self; + type ValidityVTable = ValidityVTableFromChild; + + fn id(_array: &NormVectorArray) -> ArrayId { + ArrayId::new_ref("vortex.tensor.norm_vector") + } + + fn len(array: &NormVectorArray) -> usize { + array.vector_array().len() + } + + fn dtype(array: &NormVectorArray) -> &DType { + array.vector_array().dtype() + } + + fn stats(array: &NormVectorArray) -> StatsSetRef<'_> { + array.vector_array().statistics() + } + + fn array_hash(array: &NormVectorArray, state: &mut H, precision: Precision) { + todo!() + } + + fn array_eq(array: &NormVectorArray, other: &NormVectorArray, precision: Precision) -> bool { + todo!() + } + + fn nbuffers(array: &NormVectorArray) -> usize { + todo!() + } + + fn buffer(array: &NormVectorArray, idx: usize) -> BufferHandle { + todo!() + } + + fn buffer_name(array: &NormVectorArray, idx: usize) -> Option { + todo!() + } + + fn nchildren(array: &NormVectorArray) -> usize { + todo!() + } + + fn child(array: &NormVectorArray, idx: usize) -> ArrayRef { + todo!() + } + + fn child_name(array: &NormVectorArray, idx: usize) -> String { + todo!() + } + + fn metadata(array: &NormVectorArray) -> VortexResult { + todo!() + } + + fn serialize(metadata: Self::Metadata) -> VortexResult>> { + todo!() + } + + fn deserialize( + bytes: &[u8], + _dtype: &DType, + _len: usize, + _buffers: &[BufferHandle], + _session: &VortexSession, + ) -> VortexResult { + todo!() + } + + fn build( + dtype: &DType, + len: usize, + metadata: &Self::Metadata, + buffers: &[BufferHandle], + children: &dyn ArrayChildren, + ) -> VortexResult { + todo!() + } + + fn with_children(array: &mut NormVectorArray, children: Vec) -> VortexResult<()> { + todo!() + } + + fn execute(array: &NormVectorArray, ctx: &mut ExecutionCtx) -> VortexResult { + todo!() + } +} diff --git a/vortex-tensor/src/encodings/norm/vtable/operations.rs b/vortex-tensor/src/encodings/norm/vtable/operations.rs new file mode 100644 index 00000000000..5c743f1d01d --- /dev/null +++ b/vortex-tensor/src/encodings/norm/vtable/operations.rs @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex::array::vtable::OperationsVTable; +use vortex::error::VortexResult; +use vortex::scalar::Scalar; + +use crate::encodings::norm::array::NormVectorArray; +use crate::encodings::norm::vtable::NormVector; + +impl OperationsVTable for NormVector { + fn scalar_at(array: &NormVectorArray, index: usize) -> VortexResult { + todo!() + } +} diff --git a/vortex-tensor/src/encodings/norm/vtable/validity.rs b/vortex-tensor/src/encodings/norm/vtable/validity.rs new file mode 100644 index 00000000000..8925ffc7378 --- /dev/null +++ b/vortex-tensor/src/encodings/norm/vtable/validity.rs @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex::array::ArrayRef; +use vortex::array::vtable::ValidityChild; + +use crate::encodings::norm::array::NormVectorArray; +use crate::encodings::norm::vtable::NormVector; + +impl ValidityChild for NormVector { + fn validity_child(array: &NormVectorArray) -> &ArrayRef { + array.vector_array() + } +} diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index 56e96488167..7aca9f54ab3 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -5,8 +5,10 @@ //! including unit vectors, spherical coordinates, and similarity measures such as cosine //! similarity. +pub mod matcher; +pub mod scalar_fns; + pub mod fixed_shape; pub mod vector; -pub mod matcher; -pub mod scalar_fns; +pub mod encodings; From 39e0a0ee9849afbf68fa0ba05f37d94cde9e5e33 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 20 Mar 2026 17:13:02 -0400 Subject: [PATCH 2/4] add most implementation Signed-off-by: Connor Tsui --- vortex-tensor/src/encodings/norm/array.rs | 58 +++++++++++- .../src/encodings/norm/vtable/mod.rs | 94 ++++++++++++++----- .../src/encodings/norm/vtable/operations.rs | 2 +- vortex-tensor/src/lib.rs | 2 + .../src/scalar_fns/cosine_similarity.rs | 18 ++-- vortex-tensor/src/scalar_fns/l2_norm.rs | 14 +-- vortex-tensor/src/scalar_fns/mod.rs | 2 - vortex-tensor/src/{scalar_fns => }/utils.rs | 0 8 files changed, 145 insertions(+), 45 deletions(-) rename vortex-tensor/src/{scalar_fns => }/utils.rs (100%) diff --git a/vortex-tensor/src/encodings/norm/array.rs b/vortex-tensor/src/encodings/norm/array.rs index 9e6e19dd3a8..275d8b0f885 100644 --- a/vortex-tensor/src/encodings/norm/array.rs +++ b/vortex-tensor/src/encodings/norm/array.rs @@ -2,6 +2,16 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex::array::ArrayRef; +use vortex::array::ExecutionCtx; +use vortex::dtype::DType; +use vortex::dtype::Nullability; +use vortex::error::VortexResult; +use vortex::error::vortex_ensure; +use vortex::error::vortex_ensure_eq; +use vortex::error::vortex_err; + +use crate::utils::extension_element_ptype; +use crate::vector::Vector; /// A normalized array that stores unit-normalized vectors alongside their original L2 norms. /// @@ -12,15 +22,54 @@ pub struct NormVectorArray { /// The backing vector array that has been unit normalized. /// /// The underlying elements of the vector array must be floating-point. - vector_array: ArrayRef, + pub(crate) vector_array: ArrayRef, /// The L2 (Frobenius) norms of each vector. /// /// This must have the same dtype as the elements of the vector array. - norms: ArrayRef, + pub(crate) norms: ArrayRef, } impl NormVectorArray { + /// Creates a new [`NormVectorArray`] from a unit-normalized vector array and its L2 norms. + /// + /// The `vector_array` must be a [`Vector`] extension array with floating-point elements, and + /// `norms` must be a primitive array of the same float type with the same length. + pub fn try_new(vector_array: ArrayRef, norms: ArrayRef) -> VortexResult { + let ext = vector_array.dtype().as_extension_opt().ok_or_else(|| { + vortex_err!( + "vector_array dtype must be an extension type, got {}", + vector_array.dtype() + ) + })?; + + vortex_ensure!( + ext.is::(), + "vector_array must have the Vector extension type, got {}", + vector_array.dtype() + ); + + let element_ptype = extension_element_ptype(ext)?; + + let expected_norms_dtype = DType::Primitive(element_ptype, Nullability::NonNullable); + vortex_ensure_eq!( + *norms.dtype(), + expected_norms_dtype, + "norms dtype must match vector element type" + ); + + vortex_ensure_eq!( + vector_array.len(), + norms.len(), + "vector_array and norms must have the same length" + ); + + Ok(Self { + vector_array, + norms, + }) + } + /// Returns a reference to the backing vector array that has been unit normalized. pub fn vector_array(&self) -> &ArrayRef { &self.vector_array @@ -30,4 +79,9 @@ impl NormVectorArray { pub fn norms(&self) -> &ArrayRef { &self.norms } + + // TODO docs + pub(super) fn execute_into_vector(&self, ctx: &mut ExecutionCtx) -> VortexResult { + todo!() + } } diff --git a/vortex-tensor/src/encodings/norm/vtable/mod.rs b/vortex-tensor/src/encodings/norm/vtable/mod.rs index 55f7cd026dd..783b225ad75 100644 --- a/vortex-tensor/src/encodings/norm/vtable/mod.rs +++ b/vortex-tensor/src/encodings/norm/vtable/mod.rs @@ -3,6 +3,8 @@ use std::hash::Hasher; +use vortex::array::ArrayEq; +use vortex::array::ArrayHash; use vortex::array::ArrayRef; use vortex::array::EmptyMetadata; use vortex::array::ExecutionCtx; @@ -16,10 +18,15 @@ use vortex::array::vtable::ArrayId; use vortex::array::vtable::VTable; use vortex::array::vtable::ValidityVTableFromChild; use vortex::dtype::DType; +use vortex::dtype::Nullability; use vortex::error::VortexResult; +use vortex::error::vortex_ensure_eq; +use vortex::error::vortex_err; +use vortex::error::vortex_panic; use vortex::session::VortexSession; use crate::encodings::norm::array::NormVectorArray; +use crate::utils::extension_element_ptype; mod operations; mod validity; @@ -52,70 +59,109 @@ impl VTable for NormVector { } fn array_hash(array: &NormVectorArray, state: &mut H, precision: Precision) { - todo!() + array.vector_array().array_hash(state, precision); + array.norms().array_hash(state, precision); } fn array_eq(array: &NormVectorArray, other: &NormVectorArray, precision: Precision) -> bool { - todo!() + array.norms().array_eq(other.norms(), precision) + && array + .vector_array() + .array_eq(other.vector_array(), precision) } - fn nbuffers(array: &NormVectorArray) -> usize { - todo!() + fn nbuffers(_array: &NormVectorArray) -> usize { + 0 } - fn buffer(array: &NormVectorArray, idx: usize) -> BufferHandle { - todo!() + fn buffer(_array: &NormVectorArray, idx: usize) -> BufferHandle { + vortex_panic!("NormVectorArray has no buffers (index {idx})") } - fn buffer_name(array: &NormVectorArray, idx: usize) -> Option { - todo!() + fn buffer_name(_array: &NormVectorArray, idx: usize) -> Option { + vortex_panic!("NormVectorArray has no buffers (index {idx})") } - fn nchildren(array: &NormVectorArray) -> usize { - todo!() + fn nchildren(_array: &NormVectorArray) -> usize { + 2 } fn child(array: &NormVectorArray, idx: usize) -> ArrayRef { - todo!() + match idx { + 0 => array.vector_array().clone(), + 1 => array.norms().clone(), + _ => vortex_panic!("NormVectorArray child index {idx} out of bounds"), + } } - fn child_name(array: &NormVectorArray, idx: usize) -> String { - todo!() + fn child_name(_array: &NormVectorArray, idx: usize) -> String { + match idx { + 0 => "vector_array".to_string(), + 1 => "norms".to_string(), + _ => vortex_panic!("NormVectorArray child_name index {idx} out of bounds"), + } } - fn metadata(array: &NormVectorArray) -> VortexResult { - todo!() + fn metadata(_array: &NormVectorArray) -> VortexResult { + Ok(EmptyMetadata) } - fn serialize(metadata: Self::Metadata) -> VortexResult>> { - todo!() + fn serialize(_metadata: Self::Metadata) -> VortexResult>> { + Ok(Some(vec![])) } fn deserialize( - bytes: &[u8], + _bytes: &[u8], _dtype: &DType, _len: usize, _buffers: &[BufferHandle], _session: &VortexSession, ) -> VortexResult { - todo!() + Ok(EmptyMetadata) } fn build( dtype: &DType, len: usize, - metadata: &Self::Metadata, - buffers: &[BufferHandle], + _metadata: &Self::Metadata, + _buffers: &[BufferHandle], children: &dyn ArrayChildren, ) -> VortexResult { - todo!() + vortex_ensure_eq!( + children.len(), + 2, + "NormVectorArray requires exactly 2 children" + ); + + let vector_array = children.get(0, dtype, len)?; + + let ext = dtype.as_extension_opt().ok_or_else(|| { + vortex_err!("NormVectorArray dtype must be an extension type, got {dtype}") + })?; + let element_ptype = extension_element_ptype(ext)?; + let norms_dtype = DType::Primitive(element_ptype, Nullability::NonNullable); + let norms = children.get(1, &norms_dtype, len)?; + + NormVectorArray::try_new(vector_array, norms) } fn with_children(array: &mut NormVectorArray, children: Vec) -> VortexResult<()> { - todo!() + vortex_ensure_eq!( + children.len(), + 2, + "NormVectorArray requires exactly 2 children" + ); + + let [vector_array, norms]: [ArrayRef; 2] = children + .try_into() + .map_err(|_| vortex_err!("NormVectorArray requires exactly 2 children"))?; + + array.vector_array = vector_array; + array.norms = norms; + Ok(()) } fn execute(array: &NormVectorArray, ctx: &mut ExecutionCtx) -> VortexResult { - todo!() + Ok(ExecutionStep::Done(array.execute_into_vector(ctx)?)) } } diff --git a/vortex-tensor/src/encodings/norm/vtable/operations.rs b/vortex-tensor/src/encodings/norm/vtable/operations.rs index 5c743f1d01d..0ecec996d9e 100644 --- a/vortex-tensor/src/encodings/norm/vtable/operations.rs +++ b/vortex-tensor/src/encodings/norm/vtable/operations.rs @@ -10,6 +10,6 @@ use crate::encodings::norm::vtable::NormVector; impl OperationsVTable for NormVector { fn scalar_at(array: &NormVectorArray, index: usize) -> VortexResult { - todo!() + array.vector_array().scalar_at(index) } } diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index 7aca9f54ab3..c036b9854b2 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -12,3 +12,5 @@ pub mod fixed_shape; pub mod vector; pub mod encodings; + +mod utils; diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index cd2f158d719..2b922649307 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -28,10 +28,10 @@ use vortex::scalar_fn::ScalarFnId; use vortex::scalar_fn::ScalarFnVTable; use crate::matcher::AnyTensor; -use crate::scalar_fns::utils::extension_element_ptype; -use crate::scalar_fns::utils::extension_list_size; -use crate::scalar_fns::utils::extension_storage; -use crate::scalar_fns::utils::extract_flat_elements; +use crate::utils::extension_element_ptype; +use crate::utils::extension_list_size; +use crate::utils::extension_storage; +use crate::utils::extract_flat_elements; /// Cosine similarity between two columns. /// @@ -196,11 +196,11 @@ mod tests { use vortex::scalar_fn::ScalarFn; use crate::scalar_fns::cosine_similarity::CosineSimilarity; - use crate::scalar_fns::utils::test_helpers::assert_close; - use crate::scalar_fns::utils::test_helpers::constant_tensor_array; - use crate::scalar_fns::utils::test_helpers::constant_vector_array; - use crate::scalar_fns::utils::test_helpers::tensor_array; - use crate::scalar_fns::utils::test_helpers::vector_array; + use crate::utils::test_helpers::assert_close; + use crate::utils::test_helpers::constant_tensor_array; + use crate::utils::test_helpers::constant_vector_array; + use crate::utils::test_helpers::tensor_array; + use crate::utils::test_helpers::vector_array; /// Evaluates cosine similarity between two tensor arrays and returns the result as `Vec`. fn eval_cosine_similarity(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult> { diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index e0a3bac4143..1535c7b7463 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -28,10 +28,10 @@ use vortex::scalar_fn::ScalarFnId; use vortex::scalar_fn::ScalarFnVTable; use crate::matcher::AnyTensor; -use crate::scalar_fns::utils::extension_element_ptype; -use crate::scalar_fns::utils::extension_list_size; -use crate::scalar_fns::utils::extension_storage; -use crate::scalar_fns::utils::extract_flat_elements; +use crate::utils::extension_element_ptype; +use crate::utils::extension_list_size; +use crate::utils::extension_storage; +use crate::utils::extract_flat_elements; /// L2 norm (Euclidean norm) of a tensor or vector column. /// @@ -163,9 +163,9 @@ mod tests { use vortex::scalar_fn::ScalarFn; use crate::scalar_fns::l2_norm::L2Norm; - use crate::scalar_fns::utils::test_helpers::assert_close; - use crate::scalar_fns::utils::test_helpers::tensor_array; - use crate::scalar_fns::utils::test_helpers::vector_array; + use crate::utils::test_helpers::assert_close; + use crate::utils::test_helpers::tensor_array; + use crate::utils::test_helpers::vector_array; /// Evaluates L2 norm on a tensor/vector array and returns the result as `Vec`. fn eval_l2_norm(input: vortex::array::ArrayRef, len: usize) -> VortexResult> { diff --git a/vortex-tensor/src/scalar_fns/mod.rs b/vortex-tensor/src/scalar_fns/mod.rs index 2597f1115c8..2f56305cd53 100644 --- a/vortex-tensor/src/scalar_fns/mod.rs +++ b/vortex-tensor/src/scalar_fns/mod.rs @@ -5,5 +5,3 @@ pub mod cosine_similarity; pub mod l2_norm; - -mod utils; diff --git a/vortex-tensor/src/scalar_fns/utils.rs b/vortex-tensor/src/utils.rs similarity index 100% rename from vortex-tensor/src/scalar_fns/utils.rs rename to vortex-tensor/src/utils.rs From cf65eaeb57777416e71019b4d5bd707140744237 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 20 Mar 2026 17:48:17 -0400 Subject: [PATCH 3/4] implement compress and decompress Signed-off-by: Connor Tsui --- vortex-tensor/src/encodings/norm/array.rs | 130 +++++++++++++++++- vortex-tensor/src/encodings/norm/mod.rs | 4 +- vortex-tensor/src/encodings/norm/tests.rs | 110 +++++++++++++++ .../src/encodings/norm/vtable/mod.rs | 2 +- 4 files changed, 240 insertions(+), 6 deletions(-) create mode 100644 vortex-tensor/src/encodings/norm/tests.rs diff --git a/vortex-tensor/src/encodings/norm/array.rs b/vortex-tensor/src/encodings/norm/array.rs index 275d8b0f885..55b9efd7ac2 100644 --- a/vortex-tensor/src/encodings/norm/array.rs +++ b/vortex-tensor/src/encodings/norm/array.rs @@ -1,16 +1,33 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use num_traits::Float; use vortex::array::ArrayRef; use vortex::array::ExecutionCtx; +use vortex::array::IntoArray; +use vortex::array::ToCanonical; +use vortex::array::arrays::ExtensionArray; +use vortex::array::arrays::FixedSizeListArray; +use vortex::array::arrays::PrimitiveArray; +use vortex::array::arrays::ScalarFnArray; +use vortex::array::match_each_float_ptype; +use vortex::array::validity::Validity; use vortex::dtype::DType; use vortex::dtype::Nullability; +use vortex::dtype::extension::ExtDType; use vortex::error::VortexResult; use vortex::error::vortex_ensure; use vortex::error::vortex_ensure_eq; use vortex::error::vortex_err; +use vortex::extension::EmptyMetadata; +use vortex::scalar_fn::EmptyOptions; +use vortex::scalar_fn::ScalarFn; +use crate::scalar_fns::l2_norm::L2Norm; use crate::utils::extension_element_ptype; +use crate::utils::extension_list_size; +use crate::utils::extension_storage; +use crate::utils::extract_flat_elements; use crate::vector::Vector; /// A normalized array that stores unit-normalized vectors alongside their original L2 norms. @@ -70,6 +87,63 @@ impl NormVectorArray { }) } + /// Encodes a [`Vector`] extension array into a [`NormVectorArray`] by computing L2 norms and + /// dividing each vector by its norm. + /// + /// The input must be a [`Vector`] extension array with floating-point elements. + pub fn compress(vector_array: ArrayRef) -> VortexResult { + let ext = vector_array.dtype().as_extension_opt().ok_or_else(|| { + vortex_err!( + "vector_array dtype must be an extension type, got {}", + vector_array.dtype() + ) + })?; + + vortex_ensure!( + ext.is::(), + "vector_array must have the Vector extension type, got {}", + vector_array.dtype() + ); + + let list_size = extension_list_size(ext)?; + let row_count = vector_array.len(); + + // Compute L2 norms using the scalar function. + let l2_norm_fn = ScalarFn::new(L2Norm, EmptyOptions).erased(); + let norms = ScalarFnArray::try_new(l2_norm_fn, vec![vector_array.clone()], row_count)? + .to_primitive() + .into_array(); + + // Divide each vector element by its corresponding norm. + let storage = extension_storage(&vector_array)?; + let flat = extract_flat_elements(&storage, list_size)?; + let norms_prim = norms.to_canonical()?.into_primitive(); + + match_each_float_ptype!(flat.ptype(), |T| { + let norms_slice = norms_prim.as_slice::(); + + let normalized_elems: PrimitiveArray = (0..row_count) + .flat_map(|i| { + let inv_norm = safe_inv_norm(norms_slice[i]); + flat.row::(i).iter().map(move |&v| v * inv_norm) + }) + .collect(); + + let fsl = FixedSizeListArray::new( + normalized_elems.into_array(), + u32::try_from(list_size)?, + Validity::NonNullable, + row_count, + ); + + let ext_dtype = + ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); + let normalized_vector = ExtensionArray::new(ext_dtype, fsl.into_array()).into_array(); + + Self::try_new(normalized_vector, norms) + }) + } + /// Returns a reference to the backing vector array that has been unit normalized. pub fn vector_array(&self) -> &ArrayRef { &self.vector_array @@ -80,8 +154,58 @@ impl NormVectorArray { &self.norms } - // TODO docs - pub(super) fn execute_into_vector(&self, ctx: &mut ExecutionCtx) -> VortexResult { - todo!() + /// Reconstructs the original vectors by multiplying each unit-normalized vector by its L2 norm. + pub fn decompress(&self, _ctx: &mut ExecutionCtx) -> VortexResult { + let ext_dtype = self + .vector_array + .dtype() + .as_extension_opt() + .ok_or_else(|| { + vortex_err!( + "expected Vector extension dtype, got {}", + self.vector_array.dtype() + ) + })?; + + let list_size = extension_list_size(ext_dtype)?; + let row_count = self.vector_array.len(); + + let storage = extension_storage(&self.vector_array)?; + let flat = extract_flat_elements(&storage, list_size)?; + + let norms_prim = self.norms.to_canonical()?.into_primitive(); + + match_each_float_ptype!(flat.ptype(), |T| { + let norms_slice = norms_prim.as_slice::(); + + let result_elems: PrimitiveArray = (0..row_count) + .flat_map(|i| { + let norm = norms_slice[i]; + flat.row::(i).iter().map(move |&v| v * norm) + }) + .collect(); + + let fsl = FixedSizeListArray::new( + result_elems.into_array(), + u32::try_from(list_size)?, + Validity::NonNullable, + row_count, + ); + + let ext_dtype = + ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); + Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + }) + } +} + +/// Returns `1 / norm` if the norm is non-zero, or zero otherwise. +/// +/// This avoids division by zero for zero-length or all-zero vectors. +fn safe_inv_norm(norm: T) -> T { + if norm == T::zero() { + T::zero() + } else { + T::one() / norm } } diff --git a/vortex-tensor/src/encodings/norm/mod.rs b/vortex-tensor/src/encodings/norm/mod.rs index 252f867de9c..9cd20e5cdac 100644 --- a/vortex-tensor/src/encodings/norm/mod.rs +++ b/vortex-tensor/src/encodings/norm/mod.rs @@ -9,5 +9,5 @@ pub use array::NormVectorArray; mod vtable; pub use vtable::NormVector; -// #[cfg(test)] -// mod tests; +#[cfg(test)] +mod tests; diff --git a/vortex-tensor/src/encodings/norm/tests.rs b/vortex-tensor/src/encodings/norm/tests.rs new file mode 100644 index 00000000000..652e916577b --- /dev/null +++ b/vortex-tensor/src/encodings/norm/tests.rs @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex::array::VortexSessionExecute; +use vortex::array::arrays::Extension; +use vortex::error::VortexResult; + +use crate::encodings::norm::NormVectorArray; +use crate::utils::extension_list_size; +use crate::utils::extension_storage; +use crate::utils::extract_flat_elements; +use crate::utils::test_helpers::assert_close; +use crate::utils::test_helpers::vector_array; + +#[test] +fn encode_unit_vectors() -> VortexResult<()> { + // Already unit-length vectors: norms should be 1.0 and vectors unchanged. + let arr = vector_array( + 3, + &[ + 1.0, 0.0, 0.0, // norm = 1.0 + 0.0, 1.0, 0.0, // norm = 1.0 + ], + )?; + + let norm = NormVectorArray::compress(arr)?; + let norms = norm.norms().to_canonical()?.into_primitive(); + assert_close(norms.as_slice::(), &[1.0, 1.0]); + + let vectors = norm.vector_array(); + let ext = vectors.dtype().as_extension_opt().unwrap(); + let list_size = extension_list_size(ext)?; + let storage = extension_storage(vectors)?; + let flat = extract_flat_elements(&storage, list_size)?; + assert_close(flat.row::(0), &[1.0, 0.0, 0.0]); + assert_close(flat.row::(1), &[0.0, 1.0, 0.0]); + + Ok(()) +} + +#[test] +fn encode_non_unit_vectors() -> VortexResult<()> { + let arr = vector_array( + 2, + &[ + 3.0, 4.0, // norm = 5.0 + 0.0, 0.0, // norm = 0.0 (zero vector) + ], + )?; + + let norm = NormVectorArray::compress(arr)?; + let norms = norm.norms().to_canonical()?.into_primitive(); + assert_close(norms.as_slice::(), &[5.0, 0.0]); + + let vectors = norm.vector_array(); + let ext = vectors.dtype().as_extension_opt().unwrap(); + let list_size = extension_list_size(ext)?; + let storage = extension_storage(vectors)?; + let flat = extract_flat_elements(&storage, list_size)?; + assert_close(flat.row::(0), &[3.0 / 5.0, 4.0 / 5.0]); + assert_close(flat.row::(1), &[0.0, 0.0]); + + Ok(()) +} + +#[test] +fn execute_round_trip() -> VortexResult<()> { + let original_elements = &[ + 3.0, 4.0, // norm = 5.0 + 6.0, 8.0, // norm = 10.0 + ]; + let arr = vector_array(2, original_elements)?; + + let norm = NormVectorArray::compress(arr)?; + + // Execute to reconstruct the original vectors. + let mut ctx = vortex::array::LEGACY_SESSION.create_execution_ctx(); + let reconstructed = norm.decompress(&mut ctx)?; + + // The reconstructed array should be a Vector extension array. + assert!(reconstructed.as_opt::().is_some()); + + let ext = reconstructed.dtype().as_extension_opt().unwrap(); + let list_size = extension_list_size(ext)?; + let storage = extension_storage(&reconstructed)?; + let flat = extract_flat_elements(&storage, list_size)?; + assert_close(flat.row::(0), &[3.0, 4.0]); + assert_close(flat.row::(1), &[6.0, 8.0]); + + Ok(()) +} + +#[test] +fn execute_round_trip_zero_vector() -> VortexResult<()> { + let arr = vector_array(2, &[0.0, 0.0])?; + + let norm = NormVectorArray::compress(arr)?; + + let mut ctx = vortex::array::LEGACY_SESSION.create_execution_ctx(); + let reconstructed = norm.decompress(&mut ctx)?; + + let ext = reconstructed.dtype().as_extension_opt().unwrap(); + let list_size = extension_list_size(ext)?; + let storage = extension_storage(&reconstructed)?; + let flat = extract_flat_elements(&storage, list_size)?; + // Zero vector should remain zero after round-trip. + assert_close(flat.row::(0), &[0.0, 0.0]); + + Ok(()) +} diff --git a/vortex-tensor/src/encodings/norm/vtable/mod.rs b/vortex-tensor/src/encodings/norm/vtable/mod.rs index 783b225ad75..83facf80739 100644 --- a/vortex-tensor/src/encodings/norm/vtable/mod.rs +++ b/vortex-tensor/src/encodings/norm/vtable/mod.rs @@ -162,6 +162,6 @@ impl VTable for NormVector { } fn execute(array: &NormVectorArray, ctx: &mut ExecutionCtx) -> VortexResult { - Ok(ExecutionStep::Done(array.execute_into_vector(ctx)?)) + Ok(ExecutionStep::Done(array.decompress(ctx)?)) } } From 13dd7d4bc70338eff194a114fe10f8a3c243d37b Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 20 Mar 2026 18:03:27 -0400 Subject: [PATCH 4/4] fix scalars Signed-off-by: Connor Tsui --- vortex-tensor/src/encodings/norm/tests.rs | 25 +++++++++ .../src/encodings/norm/vtable/operations.rs | 53 ++++++++++++++++++- 2 files changed, 77 insertions(+), 1 deletion(-) diff --git a/vortex-tensor/src/encodings/norm/tests.rs b/vortex-tensor/src/encodings/norm/tests.rs index 652e916577b..ef87e18d912 100644 --- a/vortex-tensor/src/encodings/norm/tests.rs +++ b/vortex-tensor/src/encodings/norm/tests.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use vortex::array::IntoArray; use vortex::array::VortexSessionExecute; use vortex::array::arrays::Extension; use vortex::error::VortexResult; @@ -108,3 +109,27 @@ fn execute_round_trip_zero_vector() -> VortexResult<()> { Ok(()) } + +#[test] +fn scalar_at_returns_original_vector() -> VortexResult<()> { + let arr = vector_array( + 2, + &[ + 3.0, 4.0, // norm = 5.0 + 6.0, 8.0, // norm = 10.0 + ], + )?; + + let encoded = NormVectorArray::compress(arr)?; + + // `scalar_at` on the NormVectorArray should match `scalar_at` on the decompressed result. + let mut ctx = vortex::array::LEGACY_SESSION.create_execution_ctx(); + let decompressed = encoded.decompress(&mut ctx)?; + + let norm_array = encoded.into_array(); + for i in 0..2 { + assert_eq!(norm_array.scalar_at(i)?, decompressed.scalar_at(i)?); + } + + Ok(()) +} diff --git a/vortex-tensor/src/encodings/norm/vtable/operations.rs b/vortex-tensor/src/encodings/norm/vtable/operations.rs index 0ecec996d9e..a384501f8d8 100644 --- a/vortex-tensor/src/encodings/norm/vtable/operations.rs +++ b/vortex-tensor/src/encodings/norm/vtable/operations.rs @@ -1,15 +1,66 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use vortex::array::IntoArray; +use vortex::array::arrays::ConstantArray; +use vortex::array::arrays::FixedSizeList; +use vortex::array::builtins::ArrayBuiltins; use vortex::array::vtable::OperationsVTable; +use vortex::dtype::Nullability; use vortex::error::VortexResult; +use vortex::error::vortex_err; use vortex::scalar::Scalar; +use vortex::scalar_fn::fns::operators::Operator; use crate::encodings::norm::array::NormVectorArray; use crate::encodings::norm::vtable::NormVector; +use crate::utils::extension_list_size; +use crate::utils::extension_storage; impl OperationsVTable for NormVector { fn scalar_at(array: &NormVectorArray, index: usize) -> VortexResult { - array.vector_array().scalar_at(index) + let ext = array + .vector_array() + .dtype() + .as_extension_opt() + .ok_or_else(|| { + vortex_err!( + "expected Vector extension dtype, got {}", + array.vector_array().dtype() + ) + })?; + let list_size = extension_list_size(ext)?; + + // Get the storage (FixedSizeList) and slice out the elements for this row. + let storage = extension_storage(array.vector_array())?; + let fsl = storage + .as_opt::() + .ok_or_else(|| vortex_err!("expected FixedSizeList storage"))?; + let row_elements = fsl.fixed_size_list_elements_at(index)?; + + // Multiply all elements by the norm using a ConstantArray broadcast. + let norm_scalar = array.norms().scalar_at(index)?; + let norm_broadcast = ConstantArray::new(norm_scalar, list_size).into_array(); + let scaled = row_elements.binary(norm_broadcast, Operator::Mul)?; + + // Rebuild the FSL scalar, then wrap in the extension type. + let element_dtype = ext + .storage_dtype() + .as_fixed_size_list_element_opt() + .ok_or_else(|| { + vortex_err!( + "expected FixedSizeList storage dtype, got {}", + ext.storage_dtype() + ) + })?; + + let children: Vec = (0..list_size) + .map(|i| scaled.scalar_at(i)) + .collect::>()?; + + let fsl_scalar = + Scalar::fixed_size_list(element_dtype.clone(), children, Nullability::NonNullable); + + Ok(Scalar::extension_ref(ext.clone(), fsl_scalar)) } }