Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import logging
import os
import shutil
import tempfile
import warnings
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
Expand All @@ -41,6 +42,7 @@
determine_dtype,
smallest_uint_dtype,
)
from bionemo.scdl.util.partition_scdl import partition_scdl
from bionemo.scdl.util.scdl_constants import FLOAT_ORDER, INT_ORDER, FileNames, Mode, NeighborSamplingStrategy


Expand Down Expand Up @@ -128,6 +130,8 @@ def __init__(
self.data_path: str = data_path
self.header: SCDLHeader = None
self.mode: Mode = mode
self._is_chunked: bool = False
self._chunks: List[Tuple[np.ndarray, np.ndarray, np.ndarray]] = []
self.paginated_load_cutoff = paginated_load_cutoff
self.load_block_row_size = load_block_row_size
self.var_feature_index_name = var_feature_index_name
Expand Down Expand Up @@ -436,10 +440,16 @@ def get_row(
List[np.ndarray]: optional, corresponding variable (column) features.
List[np.ndarray]: optional, corresponding observed (row) features.
"""
start = self.row_index[index]
end = self.row_index[index + 1]
values = self.data[start:end]
columns = self.col_index[start:end]
if self._is_chunked:
chunk_id, local_idx = self.header.chunked_info.get_chunk_for_row(index)
data, rowptr, colptr = self._chunks[chunk_id]
start, end = rowptr[local_idx], rowptr[local_idx + 1]
values, columns = data[start:end], colptr[start:end]
else:
start = self.row_index[index]
end = self.row_index[index + 1]
values = self.data[start:end]
columns = self.col_index[start:end]
ret = (values, columns)
var_features = (
self._var_feature_index.lookup(index, select_features=var_feature_names)[0]
Expand Down Expand Up @@ -685,37 +695,52 @@ def load(self, stored_path: str) -> None:
raise ValueError(f"Array name {FileNames[array_info.name].value} not found in dtypes")
self.dtypes[FileNames[array_info.name].value] = array_info.dtype.numpy_dtype_string

# Metadata is required, so we must check if it exists and fail if not.
if not os.path.exists(f"{self.data_path}/{FileNames.METADATA.value}"):
raise FileNotFoundError(
f"Error: the metadata file {self.data_path}/{FileNames.METADATA.value} does not exist."
)

with open(f"{self.data_path}/{FileNames.METADATA.value}", Mode.READ_APPEND.value) as mfi:
self.metadata = json.load(mfi)
# Load metadata if exists
metadata_path = f"{self.data_path}/{FileNames.METADATA.value}"
if os.path.exists(metadata_path):
with open(metadata_path, Mode.READ_APPEND.value) as mfi:
self.metadata = json.load(mfi)

# Load feature indices
if os.path.exists(f"{self.data_path}/{FileNames.VAR_FEATURES.value}"):
self._var_feature_index = VariableFeatureIndex.load(f"{self.data_path}/{FileNames.VAR_FEATURES.value}")
elif os.path.exists(
f"{self.data_path}/{FileNames.FEATURES.value}"
): # Backward compatibility with old features file
elif os.path.exists(f"{self.data_path}/{FileNames.FEATURES.value}"):
self._var_feature_index = VariableFeatureIndex.load(f"{self.data_path}/{FileNames.FEATURES.value}")
if os.path.exists(f"{self.data_path}/{FileNames.OBS_FEATURES.value}"):
self._obs_feature_index = ObservedFeatureIndex.load(f"{self.data_path}/{FileNames.OBS_FEATURES.value}")
# mmap the existing arrays
self.data = self._load_mmap_file_if_exists(
f"{self.data_path}/{FileNames.DATA.value}", self.dtypes[f"{FileNames.DATA.value}"]
)
self.row_index = self._load_mmap_file_if_exists(
f"{self.data_path}/{FileNames.ROWPTR.value}", dtype=self.dtypes[f"{FileNames.ROWPTR.value}"]
)
self.col_index = self._load_mmap_file_if_exists(
f"{self.data_path}/{FileNames.COLPTR.value}", dtype=self.dtypes[f"{FileNames.COLPTR.value}"]
)

# Load neighbor data
if self.load_neighbors:
self._load_neighbor_memmaps()
# Load data arrays - chunked vs monolithic
if self.header is not None and self.header.backend == Backend.CHUNKED_MEMMAP_V0:
self._is_chunked = True
self._load_chunk_memmaps()
else:
self.data = self._load_mmap_file_if_exists(
f"{self.data_path}/{FileNames.DATA.value}", self.dtypes[f"{FileNames.DATA.value}"]
)
self.row_index = self._load_mmap_file_if_exists(
f"{self.data_path}/{FileNames.ROWPTR.value}", dtype=self.dtypes[f"{FileNames.ROWPTR.value}"]
)
self.col_index = self._load_mmap_file_if_exists(
f"{self.data_path}/{FileNames.COLPTR.value}", dtype=self.dtypes[f"{FileNames.COLPTR.value}"]
)
if self.load_neighbors:
self._load_neighbor_memmaps()

def _load_chunk_memmaps(self) -> None:
"""Preload all chunk memmaps (lazy - just file handles, no RAM)."""
for chunk_id in range(self.header.chunked_info.num_chunks):
chunk_path = Path(self.data_path) / f"chunk_{chunk_id:05d}"
self._chunks.append(
(
np.memmap(chunk_path / FileNames.DATA.value, dtype=self.dtypes[FileNames.DATA.value], mode="r"),
np.memmap(
chunk_path / FileNames.ROWPTR.value, dtype=self.dtypes[FileNames.ROWPTR.value], mode="r"
),
np.memmap(
chunk_path / FileNames.COLPTR.value, dtype=self.dtypes[FileNames.COLPTR.value], mode="r"
),
)
)

def _write_metadata(self) -> None:
with open(f"{self.data_path}/{FileNames.METADATA.value}", f"{Mode.CREATE.value}") as mfi:
Expand Down Expand Up @@ -1218,6 +1243,8 @@ def number_of_rows(self) -> int:
ValueError if the length of the number of rows in the feature
index does not correspond to the number of stored rows.
"""
if self._is_chunked:
return self.header.chunked_info.total_rows
if len(self._var_feature_index) > 0 and self._var_feature_index.number_of_rows() != self.row_index.size - 1:
raise ValueError(
f"""The number of rows in the feature index {self._var_feature_index.number_of_rows()}
Expand Down Expand Up @@ -1445,3 +1472,32 @@ def concat(
mode=Mode.READ_APPEND.value,
)
self.save()

def to_chunked(
self, output_path: Optional[str] = None, chunk_size: int = 100_000, delete_original: bool = False
) -> "SingleCellMemMapDataset":
"""Convert this dataset to a chunked format for efficient remote access.

Args:
output_path: Path where the chunked dataset will be created. If None, replaces in-place.
chunk_size: Number of rows per chunk (default: 100,000).
delete_original: If True and output_path is set, delete the original after conversion.

Returns:
A new SingleCellMemMapDataset instance pointing to the chunked data.
"""
if self._is_chunked:
raise ValueError("Dataset is already chunked")

src = Path(self.data_path)
if output_path is None:
# In-place: partition to temp, then swap
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir) / "chunked"
partition_scdl(src, tmp_path, chunk_size=chunk_size)
shutil.rmtree(src)
shutil.move(str(tmp_path), str(src))
return SingleCellMemMapDataset(str(src))

partition_scdl(src, Path(output_path), chunk_size=chunk_size, delete_original=delete_original)
return SingleCellMemMapDataset(output_path)
59 changes: 58 additions & 1 deletion sub-packages/bionemo-scdl/src/bionemo/scdl/schema/header.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,38 @@ def __repr__(self) -> str:
return self.__str__()


class ChunkedInfo:
"""Chunking metadata for CHUNKED_MEMMAP backend."""

def __init__(self, chunk_size: int, num_chunks: int, total_rows: int):
"""Initialize chunked info."""
self.chunk_size = chunk_size
self.num_chunks = num_chunks
self.total_rows = total_rows

def get_chunk_for_row(self, global_idx: int) -> Tuple[int, int]:
"""Return (chunk_id, local_idx) for a global row index."""
if global_idx < 0 or global_idx >= self.total_rows:
raise IndexError(f"Row index {global_idx} out of range [0, {self.total_rows})")
return global_idx // self.chunk_size, global_idx % self.chunk_size

def serialize(self, codec: BinaryHeaderCodec) -> bytes:
"""Serialize to binary format."""
return (
codec.pack_uint32(self.chunk_size)
+ codec.pack_uint32(self.num_chunks)
+ codec.pack_uint64(self.total_rows)
)

@classmethod
def deserialize(cls, codec: BinaryHeaderCodec, data: bytes, offset: int = 0) -> Tuple["ChunkedInfo", int]:
"""Deserialize from binary data. Returns (ChunkedInfo, bytes_consumed)."""
chunk_size = codec.unpack_uint32(data[offset : offset + 4])
num_chunks = codec.unpack_uint32(data[offset + 4 : offset + 8])
total_rows = codec.unpack_uint64(data[offset + 8 : offset + 16])
return cls(chunk_size=chunk_size, num_chunks=num_chunks, total_rows=total_rows), 16


class SCDLHeader:
"""Header for a SCDL archive following the official schema specification.

Expand All @@ -423,6 +455,7 @@ def __init__(
backend: Backend = Backend.MEMMAP_V0,
arrays: Optional[List[ArrayInfo]] = None,
feature_indices: Optional[List[FeatureIndexInfo]] = None,
chunked_info: Optional["ChunkedInfo"] = None,
):
"""Initialize SCDL header.

Expand All @@ -431,12 +464,14 @@ def __init__(
backend: Storage backend type
arrays: List of arrays in the archive
feature_indices: Optional list of feature indices in the archive
chunked_info: Optional chunking metadata for CHUNKED_MEMMAP backend
"""
self.version = version or CurrentSCDLVersion()
self.endianness = Endianness.NETWORK # Always network byte order per spec
self.backend = backend
self.arrays = arrays or []
self.feature_indices = feature_indices or []
self.chunked_info = chunked_info

# Create codec with network byte order
self._codec = BinaryHeaderCodec(self.endianness)
Expand Down Expand Up @@ -525,6 +560,13 @@ def serialize(self) -> bytes:
for feature_index in self.feature_indices:
data += feature_index.serialize(self._codec)

# Chunked info (optional, for CHUNKED_MEMMAP backend)
if self.chunked_info is not None:
data += self._codec.pack_uint8(1) # has_chunked_info = true
data += self.chunked_info.serialize(self._codec)
else:
data += self._codec.pack_uint8(0) # has_chunked_info = false

return data

except Exception as e:
Expand Down Expand Up @@ -617,7 +659,22 @@ def deserialize(cls, data: bytes) -> "SCDLHeader":
feature_indices.append(feature_index)
offset += bytes_consumed

header = cls(version=version, backend=backend, arrays=arrays, feature_indices=feature_indices)
# Read chunked info (optional, for backwards compatibility)
chunked_info = None
if offset < len(data):
has_chunked_info = codec.unpack_uint8(data[offset : offset + 1])
offset += 1
if has_chunked_info:
chunked_info, bytes_consumed = ChunkedInfo.deserialize(codec, data, offset)
offset += bytes_consumed

header = cls(
version=version,
backend=backend,
arrays=arrays,
feature_indices=feature_indices,
chunked_info=chunked_info,
)
return header

except HeaderSerializationError:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Partition a monolithic SCDL dataset into chunks."""

import shutil
from pathlib import Path

import numpy as np

from bionemo.scdl.schema.header import ChunkedInfo, SCDLHeader
from bionemo.scdl.util.scdl_constants import Backend, FileNames


def partition_scdl(
input_path: Path,
output_path: Path,
chunk_size: int = 100_000,
delete_original: bool = False,
) -> SCDLHeader:
"""Partition an SCDL dataset into chunks."""
from bionemo.scdl.io.single_cell_memmap_dataset import SingleCellMemMapDataset

input_path, output_path = Path(input_path), Path(output_path)

if not input_path.exists():
raise FileNotFoundError(f"Input path does not exist: {input_path}")
if output_path.exists():
raise FileExistsError(f"Output path already exists: {output_path}")

output_path.mkdir(parents=True)

# Load source dataset
source_ds = SingleCellMemMapDataset(str(input_path))
total_rows = len(source_ds)
rowptr = source_ds.row_index
if chunk_size <= 0:
raise ValueError(f"Chunk size must be greater than 0, got {chunk_size}")
if total_rows <= 0:
raise ValueError(f"Total rows must be greater than 0, got {total_rows}")
num_chunks = max(1, (total_rows + chunk_size - 1) // chunk_size)

# Create chunks
for chunk_id in range(num_chunks):
row_start = chunk_id * chunk_size
row_end = min(row_start + chunk_size, total_rows)
chunk_dir = output_path / f"chunk_{chunk_id:05d}"
chunk_dir.mkdir()

data_start, data_end = int(rowptr[row_start]), int(rowptr[row_end])

# Write chunk files using memmap slicing
chunk_rowptr = rowptr[row_start : row_end + 1] - data_start
with open(chunk_dir / FileNames.ROWPTR.value, "wb") as f:
f.write(chunk_rowptr.astype(source_ds.dtypes[FileNames.ROWPTR.value]).tobytes())
with open(chunk_dir / FileNames.DATA.value, "wb") as f:
f.write(np.array(source_ds.data[data_start:data_end]).tobytes())
with open(chunk_dir / FileNames.COLPTR.value, "wb") as f:
f.write(np.array(source_ds.col_index[data_start:data_end]).tobytes())

# Copy features and metadata
for name in [FileNames.VAR_FEATURES.value, FileNames.OBS_FEATURES.value]:
if (input_path / name).exists():
shutil.copytree(input_path / name, output_path / name)
for name in [FileNames.VERSION.value, FileNames.METADATA.value]:
if (input_path / name).exists():
shutil.copy(input_path / name, output_path / name)

# Update header with chunked info
header = source_ds.header if source_ds.header else SCDLHeader()
header.backend = Backend.CHUNKED_MEMMAP_V0
header.chunked_info = ChunkedInfo(chunk_size=chunk_size, num_chunks=num_chunks, total_rows=total_rows)
header.save(str(output_path / FileNames.HEADER.value))

if delete_original:
del source_ds # Release memmap handles
shutil.rmtree(input_path)

return header
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ class Backend(IntEnum):
"""

MEMMAP_V0 = 1
CHUNKED_MEMMAP_V0 = 2 # Chunked memmap for large datasets with remote access support


class Mode(str, Enum):
Expand Down
19 changes: 15 additions & 4 deletions sub-packages/bionemo-scdl/tests/bionemo/scdl/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,24 @@ def _make(tmp_path):

@pytest.fixture
def make_h5ad_with_raw(make_random_csr):
"""Factory to create an h5ad with uniquely randomized data for the fields .raw.X and .X"""
"""Factory to create an h5ad with uniquely randomized data for .raw.X, .X, obs, and var."""

def _make(tmp_path):
X = make_random_csr(total_nnz=100, n_cols=50, seed=42)
X_raw = make_random_csr(total_nnz=100, n_cols=50, seed=43)
n_rows, n_cols = 100, 50
X = make_random_csr(total_nnz=n_rows, n_cols=n_cols, seed=42)
X_raw = make_random_csr(total_nnz=n_rows, n_cols=n_cols, seed=43)

obs = pd.DataFrame(
{"cell_type": [f"type_{i % 3}" for i in range(n_rows)]},
index=[f"cell_{i}" for i in range(n_rows)],
)
var = pd.DataFrame(
{"gene_name": [f"gene_{i}" for i in range(n_cols)]},
index=[f"ENSG{i:08d}" for i in range(n_cols)],
)

h = tmp_path / "var.h5ad"
ad.AnnData(X=X, var=pd.DataFrame(index=np.arange(X.shape[1])), raw={"X": X_raw}).write_h5ad(h)
ad.AnnData(X=X, obs=obs, var=var, raw={"X": X_raw}).write_h5ad(h)
return h

return _make
Loading