diff --git a/vortex-tensor/src/encodings/mod.rs b/vortex-tensor/src/encodings/mod.rs new file mode 100644 index 00000000000..0d9feeb17ce --- /dev/null +++ b/vortex-tensor/src/encodings/mod.rs @@ -0,0 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +pub mod norm; +// mod spherical; diff --git a/vortex-tensor/src/encodings/norm/array.rs b/vortex-tensor/src/encodings/norm/array.rs new file mode 100644 index 00000000000..55b9efd7ac2 --- /dev/null +++ b/vortex-tensor/src/encodings/norm/array.rs @@ -0,0 +1,211 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use num_traits::Float; +use vortex::array::ArrayRef; +use vortex::array::ExecutionCtx; +use vortex::array::IntoArray; +use vortex::array::ToCanonical; +use vortex::array::arrays::ExtensionArray; +use vortex::array::arrays::FixedSizeListArray; +use vortex::array::arrays::PrimitiveArray; +use vortex::array::arrays::ScalarFnArray; +use vortex::array::match_each_float_ptype; +use vortex::array::validity::Validity; +use vortex::dtype::DType; +use vortex::dtype::Nullability; +use vortex::dtype::extension::ExtDType; +use vortex::error::VortexResult; +use vortex::error::vortex_ensure; +use vortex::error::vortex_ensure_eq; +use vortex::error::vortex_err; +use vortex::extension::EmptyMetadata; +use vortex::scalar_fn::EmptyOptions; +use vortex::scalar_fn::ScalarFn; + +use crate::scalar_fns::l2_norm::L2Norm; +use crate::utils::extension_element_ptype; +use crate::utils::extension_list_size; +use crate::utils::extension_storage; +use crate::utils::extract_flat_elements; +use crate::vector::Vector; + +/// A normalized array that stores unit-normalized vectors alongside their original L2 norms. +/// +/// Each vector in the array is divided by its L2 norm, producing a unit-normalized vector. The +/// original norms are stored separately so that the original vectors can be reconstructed. +#[derive(Debug, Clone)] +pub struct NormVectorArray { + /// The backing vector array that has been unit normalized. + /// + /// The underlying elements of the vector array must be floating-point. + pub(crate) vector_array: ArrayRef, + + /// The L2 (Frobenius) norms of each vector. + /// + /// This must have the same dtype as the elements of the vector array. + pub(crate) norms: ArrayRef, +} + +impl NormVectorArray { + /// Creates a new [`NormVectorArray`] from a unit-normalized vector array and its L2 norms. + /// + /// The `vector_array` must be a [`Vector`] extension array with floating-point elements, and + /// `norms` must be a primitive array of the same float type with the same length. + pub fn try_new(vector_array: ArrayRef, norms: ArrayRef) -> VortexResult { + let ext = vector_array.dtype().as_extension_opt().ok_or_else(|| { + vortex_err!( + "vector_array dtype must be an extension type, got {}", + vector_array.dtype() + ) + })?; + + vortex_ensure!( + ext.is::(), + "vector_array must have the Vector extension type, got {}", + vector_array.dtype() + ); + + let element_ptype = extension_element_ptype(ext)?; + + let expected_norms_dtype = DType::Primitive(element_ptype, Nullability::NonNullable); + vortex_ensure_eq!( + *norms.dtype(), + expected_norms_dtype, + "norms dtype must match vector element type" + ); + + vortex_ensure_eq!( + vector_array.len(), + norms.len(), + "vector_array and norms must have the same length" + ); + + Ok(Self { + vector_array, + norms, + }) + } + + /// Encodes a [`Vector`] extension array into a [`NormVectorArray`] by computing L2 norms and + /// dividing each vector by its norm. + /// + /// The input must be a [`Vector`] extension array with floating-point elements. + pub fn compress(vector_array: ArrayRef) -> VortexResult { + let ext = vector_array.dtype().as_extension_opt().ok_or_else(|| { + vortex_err!( + "vector_array dtype must be an extension type, got {}", + vector_array.dtype() + ) + })?; + + vortex_ensure!( + ext.is::(), + "vector_array must have the Vector extension type, got {}", + vector_array.dtype() + ); + + let list_size = extension_list_size(ext)?; + let row_count = vector_array.len(); + + // Compute L2 norms using the scalar function. + let l2_norm_fn = ScalarFn::new(L2Norm, EmptyOptions).erased(); + let norms = ScalarFnArray::try_new(l2_norm_fn, vec![vector_array.clone()], row_count)? + .to_primitive() + .into_array(); + + // Divide each vector element by its corresponding norm. + let storage = extension_storage(&vector_array)?; + let flat = extract_flat_elements(&storage, list_size)?; + let norms_prim = norms.to_canonical()?.into_primitive(); + + match_each_float_ptype!(flat.ptype(), |T| { + let norms_slice = norms_prim.as_slice::(); + + let normalized_elems: PrimitiveArray = (0..row_count) + .flat_map(|i| { + let inv_norm = safe_inv_norm(norms_slice[i]); + flat.row::(i).iter().map(move |&v| v * inv_norm) + }) + .collect(); + + let fsl = FixedSizeListArray::new( + normalized_elems.into_array(), + u32::try_from(list_size)?, + Validity::NonNullable, + row_count, + ); + + let ext_dtype = + ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); + let normalized_vector = ExtensionArray::new(ext_dtype, fsl.into_array()).into_array(); + + Self::try_new(normalized_vector, norms) + }) + } + + /// Returns a reference to the backing vector array that has been unit normalized. + pub fn vector_array(&self) -> &ArrayRef { + &self.vector_array + } + + /// Returns a reference to the L2 (Frobenius) norms of each vector. + pub fn norms(&self) -> &ArrayRef { + &self.norms + } + + /// Reconstructs the original vectors by multiplying each unit-normalized vector by its L2 norm. + pub fn decompress(&self, _ctx: &mut ExecutionCtx) -> VortexResult { + let ext_dtype = self + .vector_array + .dtype() + .as_extension_opt() + .ok_or_else(|| { + vortex_err!( + "expected Vector extension dtype, got {}", + self.vector_array.dtype() + ) + })?; + + let list_size = extension_list_size(ext_dtype)?; + let row_count = self.vector_array.len(); + + let storage = extension_storage(&self.vector_array)?; + let flat = extract_flat_elements(&storage, list_size)?; + + let norms_prim = self.norms.to_canonical()?.into_primitive(); + + match_each_float_ptype!(flat.ptype(), |T| { + let norms_slice = norms_prim.as_slice::(); + + let result_elems: PrimitiveArray = (0..row_count) + .flat_map(|i| { + let norm = norms_slice[i]; + flat.row::(i).iter().map(move |&v| v * norm) + }) + .collect(); + + let fsl = FixedSizeListArray::new( + result_elems.into_array(), + u32::try_from(list_size)?, + Validity::NonNullable, + row_count, + ); + + let ext_dtype = + ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); + Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + }) + } +} + +/// Returns `1 / norm` if the norm is non-zero, or zero otherwise. +/// +/// This avoids division by zero for zero-length or all-zero vectors. +fn safe_inv_norm(norm: T) -> T { + if norm == T::zero() { + T::zero() + } else { + T::one() / norm + } +} diff --git a/vortex-tensor/src/encodings/norm/mod.rs b/vortex-tensor/src/encodings/norm/mod.rs new file mode 100644 index 00000000000..9cd20e5cdac --- /dev/null +++ b/vortex-tensor/src/encodings/norm/mod.rs @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +mod array; +pub use array::NormVectorArray; + +// pub(crate) mod compute; + +mod vtable; +pub use vtable::NormVector; + +#[cfg(test)] +mod tests; diff --git a/vortex-tensor/src/encodings/norm/tests.rs b/vortex-tensor/src/encodings/norm/tests.rs new file mode 100644 index 00000000000..ef87e18d912 --- /dev/null +++ b/vortex-tensor/src/encodings/norm/tests.rs @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex::array::IntoArray; +use vortex::array::VortexSessionExecute; +use vortex::array::arrays::Extension; +use vortex::error::VortexResult; + +use crate::encodings::norm::NormVectorArray; +use crate::utils::extension_list_size; +use crate::utils::extension_storage; +use crate::utils::extract_flat_elements; +use crate::utils::test_helpers::assert_close; +use crate::utils::test_helpers::vector_array; + +#[test] +fn encode_unit_vectors() -> VortexResult<()> { + // Already unit-length vectors: norms should be 1.0 and vectors unchanged. + let arr = vector_array( + 3, + &[ + 1.0, 0.0, 0.0, // norm = 1.0 + 0.0, 1.0, 0.0, // norm = 1.0 + ], + )?; + + let norm = NormVectorArray::compress(arr)?; + let norms = norm.norms().to_canonical()?.into_primitive(); + assert_close(norms.as_slice::(), &[1.0, 1.0]); + + let vectors = norm.vector_array(); + let ext = vectors.dtype().as_extension_opt().unwrap(); + let list_size = extension_list_size(ext)?; + let storage = extension_storage(vectors)?; + let flat = extract_flat_elements(&storage, list_size)?; + assert_close(flat.row::(0), &[1.0, 0.0, 0.0]); + assert_close(flat.row::(1), &[0.0, 1.0, 0.0]); + + Ok(()) +} + +#[test] +fn encode_non_unit_vectors() -> VortexResult<()> { + let arr = vector_array( + 2, + &[ + 3.0, 4.0, // norm = 5.0 + 0.0, 0.0, // norm = 0.0 (zero vector) + ], + )?; + + let norm = NormVectorArray::compress(arr)?; + let norms = norm.norms().to_canonical()?.into_primitive(); + assert_close(norms.as_slice::(), &[5.0, 0.0]); + + let vectors = norm.vector_array(); + let ext = vectors.dtype().as_extension_opt().unwrap(); + let list_size = extension_list_size(ext)?; + let storage = extension_storage(vectors)?; + let flat = extract_flat_elements(&storage, list_size)?; + assert_close(flat.row::(0), &[3.0 / 5.0, 4.0 / 5.0]); + assert_close(flat.row::(1), &[0.0, 0.0]); + + Ok(()) +} + +#[test] +fn execute_round_trip() -> VortexResult<()> { + let original_elements = &[ + 3.0, 4.0, // norm = 5.0 + 6.0, 8.0, // norm = 10.0 + ]; + let arr = vector_array(2, original_elements)?; + + let norm = NormVectorArray::compress(arr)?; + + // Execute to reconstruct the original vectors. + let mut ctx = vortex::array::LEGACY_SESSION.create_execution_ctx(); + let reconstructed = norm.decompress(&mut ctx)?; + + // The reconstructed array should be a Vector extension array. + assert!(reconstructed.as_opt::().is_some()); + + let ext = reconstructed.dtype().as_extension_opt().unwrap(); + let list_size = extension_list_size(ext)?; + let storage = extension_storage(&reconstructed)?; + let flat = extract_flat_elements(&storage, list_size)?; + assert_close(flat.row::(0), &[3.0, 4.0]); + assert_close(flat.row::(1), &[6.0, 8.0]); + + Ok(()) +} + +#[test] +fn execute_round_trip_zero_vector() -> VortexResult<()> { + let arr = vector_array(2, &[0.0, 0.0])?; + + let norm = NormVectorArray::compress(arr)?; + + let mut ctx = vortex::array::LEGACY_SESSION.create_execution_ctx(); + let reconstructed = norm.decompress(&mut ctx)?; + + let ext = reconstructed.dtype().as_extension_opt().unwrap(); + let list_size = extension_list_size(ext)?; + let storage = extension_storage(&reconstructed)?; + let flat = extract_flat_elements(&storage, list_size)?; + // Zero vector should remain zero after round-trip. + assert_close(flat.row::(0), &[0.0, 0.0]); + + Ok(()) +} + +#[test] +fn scalar_at_returns_original_vector() -> VortexResult<()> { + let arr = vector_array( + 2, + &[ + 3.0, 4.0, // norm = 5.0 + 6.0, 8.0, // norm = 10.0 + ], + )?; + + let encoded = NormVectorArray::compress(arr)?; + + // `scalar_at` on the NormVectorArray should match `scalar_at` on the decompressed result. + let mut ctx = vortex::array::LEGACY_SESSION.create_execution_ctx(); + let decompressed = encoded.decompress(&mut ctx)?; + + let norm_array = encoded.into_array(); + for i in 0..2 { + assert_eq!(norm_array.scalar_at(i)?, decompressed.scalar_at(i)?); + } + + Ok(()) +} diff --git a/vortex-tensor/src/encodings/norm/vtable/mod.rs b/vortex-tensor/src/encodings/norm/vtable/mod.rs new file mode 100644 index 00000000000..83facf80739 --- /dev/null +++ b/vortex-tensor/src/encodings/norm/vtable/mod.rs @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::hash::Hasher; + +use vortex::array::ArrayEq; +use vortex::array::ArrayHash; +use vortex::array::ArrayRef; +use vortex::array::EmptyMetadata; +use vortex::array::ExecutionCtx; +use vortex::array::ExecutionStep; +use vortex::array::Precision; +use vortex::array::buffer::BufferHandle; +use vortex::array::serde::ArrayChildren; +use vortex::array::stats::StatsSetRef; +use vortex::array::vtable; +use vortex::array::vtable::ArrayId; +use vortex::array::vtable::VTable; +use vortex::array::vtable::ValidityVTableFromChild; +use vortex::dtype::DType; +use vortex::dtype::Nullability; +use vortex::error::VortexResult; +use vortex::error::vortex_ensure_eq; +use vortex::error::vortex_err; +use vortex::error::vortex_panic; +use vortex::session::VortexSession; + +use crate::encodings::norm::array::NormVectorArray; +use crate::utils::extension_element_ptype; + +mod operations; +mod validity; + +vtable!(NormVector); + +#[derive(Debug)] +pub struct NormVector; + +impl VTable for NormVector { + type Array = NormVectorArray; + type Metadata = EmptyMetadata; + type OperationsVTable = Self; + type ValidityVTable = ValidityVTableFromChild; + + fn id(_array: &NormVectorArray) -> ArrayId { + ArrayId::new_ref("vortex.tensor.norm_vector") + } + + fn len(array: &NormVectorArray) -> usize { + array.vector_array().len() + } + + fn dtype(array: &NormVectorArray) -> &DType { + array.vector_array().dtype() + } + + fn stats(array: &NormVectorArray) -> StatsSetRef<'_> { + array.vector_array().statistics() + } + + fn array_hash(array: &NormVectorArray, state: &mut H, precision: Precision) { + array.vector_array().array_hash(state, precision); + array.norms().array_hash(state, precision); + } + + fn array_eq(array: &NormVectorArray, other: &NormVectorArray, precision: Precision) -> bool { + array.norms().array_eq(other.norms(), precision) + && array + .vector_array() + .array_eq(other.vector_array(), precision) + } + + fn nbuffers(_array: &NormVectorArray) -> usize { + 0 + } + + fn buffer(_array: &NormVectorArray, idx: usize) -> BufferHandle { + vortex_panic!("NormVectorArray has no buffers (index {idx})") + } + + fn buffer_name(_array: &NormVectorArray, idx: usize) -> Option { + vortex_panic!("NormVectorArray has no buffers (index {idx})") + } + + fn nchildren(_array: &NormVectorArray) -> usize { + 2 + } + + fn child(array: &NormVectorArray, idx: usize) -> ArrayRef { + match idx { + 0 => array.vector_array().clone(), + 1 => array.norms().clone(), + _ => vortex_panic!("NormVectorArray child index {idx} out of bounds"), + } + } + + fn child_name(_array: &NormVectorArray, idx: usize) -> String { + match idx { + 0 => "vector_array".to_string(), + 1 => "norms".to_string(), + _ => vortex_panic!("NormVectorArray child_name index {idx} out of bounds"), + } + } + + fn metadata(_array: &NormVectorArray) -> VortexResult { + Ok(EmptyMetadata) + } + + fn serialize(_metadata: Self::Metadata) -> VortexResult>> { + Ok(Some(vec![])) + } + + fn deserialize( + _bytes: &[u8], + _dtype: &DType, + _len: usize, + _buffers: &[BufferHandle], + _session: &VortexSession, + ) -> VortexResult { + Ok(EmptyMetadata) + } + + fn build( + dtype: &DType, + len: usize, + _metadata: &Self::Metadata, + _buffers: &[BufferHandle], + children: &dyn ArrayChildren, + ) -> VortexResult { + vortex_ensure_eq!( + children.len(), + 2, + "NormVectorArray requires exactly 2 children" + ); + + let vector_array = children.get(0, dtype, len)?; + + let ext = dtype.as_extension_opt().ok_or_else(|| { + vortex_err!("NormVectorArray dtype must be an extension type, got {dtype}") + })?; + let element_ptype = extension_element_ptype(ext)?; + let norms_dtype = DType::Primitive(element_ptype, Nullability::NonNullable); + let norms = children.get(1, &norms_dtype, len)?; + + NormVectorArray::try_new(vector_array, norms) + } + + fn with_children(array: &mut NormVectorArray, children: Vec) -> VortexResult<()> { + vortex_ensure_eq!( + children.len(), + 2, + "NormVectorArray requires exactly 2 children" + ); + + let [vector_array, norms]: [ArrayRef; 2] = children + .try_into() + .map_err(|_| vortex_err!("NormVectorArray requires exactly 2 children"))?; + + array.vector_array = vector_array; + array.norms = norms; + Ok(()) + } + + fn execute(array: &NormVectorArray, ctx: &mut ExecutionCtx) -> VortexResult { + Ok(ExecutionStep::Done(array.decompress(ctx)?)) + } +} diff --git a/vortex-tensor/src/encodings/norm/vtable/operations.rs b/vortex-tensor/src/encodings/norm/vtable/operations.rs new file mode 100644 index 00000000000..a384501f8d8 --- /dev/null +++ b/vortex-tensor/src/encodings/norm/vtable/operations.rs @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex::array::IntoArray; +use vortex::array::arrays::ConstantArray; +use vortex::array::arrays::FixedSizeList; +use vortex::array::builtins::ArrayBuiltins; +use vortex::array::vtable::OperationsVTable; +use vortex::dtype::Nullability; +use vortex::error::VortexResult; +use vortex::error::vortex_err; +use vortex::scalar::Scalar; +use vortex::scalar_fn::fns::operators::Operator; + +use crate::encodings::norm::array::NormVectorArray; +use crate::encodings::norm::vtable::NormVector; +use crate::utils::extension_list_size; +use crate::utils::extension_storage; + +impl OperationsVTable for NormVector { + fn scalar_at(array: &NormVectorArray, index: usize) -> VortexResult { + let ext = array + .vector_array() + .dtype() + .as_extension_opt() + .ok_or_else(|| { + vortex_err!( + "expected Vector extension dtype, got {}", + array.vector_array().dtype() + ) + })?; + let list_size = extension_list_size(ext)?; + + // Get the storage (FixedSizeList) and slice out the elements for this row. + let storage = extension_storage(array.vector_array())?; + let fsl = storage + .as_opt::() + .ok_or_else(|| vortex_err!("expected FixedSizeList storage"))?; + let row_elements = fsl.fixed_size_list_elements_at(index)?; + + // Multiply all elements by the norm using a ConstantArray broadcast. + let norm_scalar = array.norms().scalar_at(index)?; + let norm_broadcast = ConstantArray::new(norm_scalar, list_size).into_array(); + let scaled = row_elements.binary(norm_broadcast, Operator::Mul)?; + + // Rebuild the FSL scalar, then wrap in the extension type. + let element_dtype = ext + .storage_dtype() + .as_fixed_size_list_element_opt() + .ok_or_else(|| { + vortex_err!( + "expected FixedSizeList storage dtype, got {}", + ext.storage_dtype() + ) + })?; + + let children: Vec = (0..list_size) + .map(|i| scaled.scalar_at(i)) + .collect::>()?; + + let fsl_scalar = + Scalar::fixed_size_list(element_dtype.clone(), children, Nullability::NonNullable); + + Ok(Scalar::extension_ref(ext.clone(), fsl_scalar)) + } +} diff --git a/vortex-tensor/src/encodings/norm/vtable/validity.rs b/vortex-tensor/src/encodings/norm/vtable/validity.rs new file mode 100644 index 00000000000..8925ffc7378 --- /dev/null +++ b/vortex-tensor/src/encodings/norm/vtable/validity.rs @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex::array::ArrayRef; +use vortex::array::vtable::ValidityChild; + +use crate::encodings::norm::array::NormVectorArray; +use crate::encodings::norm::vtable::NormVector; + +impl ValidityChild for NormVector { + fn validity_child(array: &NormVectorArray) -> &ArrayRef { + array.vector_array() + } +} diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index 56e96488167..c036b9854b2 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -5,8 +5,12 @@ //! including unit vectors, spherical coordinates, and similarity measures such as cosine //! similarity. +pub mod matcher; +pub mod scalar_fns; + pub mod fixed_shape; pub mod vector; -pub mod matcher; -pub mod scalar_fns; +pub mod encodings; + +mod utils; diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index cd2f158d719..2b922649307 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -28,10 +28,10 @@ use vortex::scalar_fn::ScalarFnId; use vortex::scalar_fn::ScalarFnVTable; use crate::matcher::AnyTensor; -use crate::scalar_fns::utils::extension_element_ptype; -use crate::scalar_fns::utils::extension_list_size; -use crate::scalar_fns::utils::extension_storage; -use crate::scalar_fns::utils::extract_flat_elements; +use crate::utils::extension_element_ptype; +use crate::utils::extension_list_size; +use crate::utils::extension_storage; +use crate::utils::extract_flat_elements; /// Cosine similarity between two columns. /// @@ -196,11 +196,11 @@ mod tests { use vortex::scalar_fn::ScalarFn; use crate::scalar_fns::cosine_similarity::CosineSimilarity; - use crate::scalar_fns::utils::test_helpers::assert_close; - use crate::scalar_fns::utils::test_helpers::constant_tensor_array; - use crate::scalar_fns::utils::test_helpers::constant_vector_array; - use crate::scalar_fns::utils::test_helpers::tensor_array; - use crate::scalar_fns::utils::test_helpers::vector_array; + use crate::utils::test_helpers::assert_close; + use crate::utils::test_helpers::constant_tensor_array; + use crate::utils::test_helpers::constant_vector_array; + use crate::utils::test_helpers::tensor_array; + use crate::utils::test_helpers::vector_array; /// Evaluates cosine similarity between two tensor arrays and returns the result as `Vec`. fn eval_cosine_similarity(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult> { diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index e0a3bac4143..1535c7b7463 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -28,10 +28,10 @@ use vortex::scalar_fn::ScalarFnId; use vortex::scalar_fn::ScalarFnVTable; use crate::matcher::AnyTensor; -use crate::scalar_fns::utils::extension_element_ptype; -use crate::scalar_fns::utils::extension_list_size; -use crate::scalar_fns::utils::extension_storage; -use crate::scalar_fns::utils::extract_flat_elements; +use crate::utils::extension_element_ptype; +use crate::utils::extension_list_size; +use crate::utils::extension_storage; +use crate::utils::extract_flat_elements; /// L2 norm (Euclidean norm) of a tensor or vector column. /// @@ -163,9 +163,9 @@ mod tests { use vortex::scalar_fn::ScalarFn; use crate::scalar_fns::l2_norm::L2Norm; - use crate::scalar_fns::utils::test_helpers::assert_close; - use crate::scalar_fns::utils::test_helpers::tensor_array; - use crate::scalar_fns::utils::test_helpers::vector_array; + use crate::utils::test_helpers::assert_close; + use crate::utils::test_helpers::tensor_array; + use crate::utils::test_helpers::vector_array; /// Evaluates L2 norm on a tensor/vector array and returns the result as `Vec`. fn eval_l2_norm(input: vortex::array::ArrayRef, len: usize) -> VortexResult> { diff --git a/vortex-tensor/src/scalar_fns/mod.rs b/vortex-tensor/src/scalar_fns/mod.rs index 2597f1115c8..2f56305cd53 100644 --- a/vortex-tensor/src/scalar_fns/mod.rs +++ b/vortex-tensor/src/scalar_fns/mod.rs @@ -5,5 +5,3 @@ pub mod cosine_similarity; pub mod l2_norm; - -mod utils; diff --git a/vortex-tensor/src/scalar_fns/utils.rs b/vortex-tensor/src/utils.rs similarity index 100% rename from vortex-tensor/src/scalar_fns/utils.rs rename to vortex-tensor/src/utils.rs