From f2ac2cb71c4a24c1066ff1b59605950be144f7ed Mon Sep 17 00:00:00 2001 From: "agent-kurouto[bot]" <268466204+agent-kurouto[bot]@users.noreply.github.com> Date: Sat, 21 Mar 2026 06:47:54 +0000 Subject: [PATCH 1/8] docs: add DBConnector design spec for PLT-1078 --- .../specs/2026-03-21-db-connector-design.md | 215 ++++++++++++++++++ 1 file changed, 215 insertions(+) create mode 100644 superpowers/specs/2026-03-21-db-connector-design.md diff --git a/superpowers/specs/2026-03-21-db-connector-design.md b/superpowers/specs/2026-03-21-db-connector-design.md new file mode 100644 index 0000000..ce751da --- /dev/null +++ b/superpowers/specs/2026-03-21-db-connector-design.md @@ -0,0 +1,215 @@ +# DBConnector: Shared Database Access Abstraction + +**Issue:** PLT-1078 +**Date:** 2026-03-21 +**Status:** Approved + +--- + +## Problem + +With three `ArrowDatabaseProtocol` backends (SQLite, PostgreSQL, SpiralDB) and three corresponding `Source` implementations being built in parallel (PLT-1072–1077), each pair shares the same underlying DB technology. Without a shared abstraction, every contributor would need to implement: + +- Connection lifecycle management +- DB-native → Arrow type mapping +- Schema introspection (list tables, get primary keys, get column metadata) +- Query execution + +…twice: once for the `ArrowDatabaseProtocol` implementation, once for the `Source`. + +This compounds design friction and creates diverging implementations of the same low-level logic. + +--- + +## Design + +### Layering + +``` +┌──────────────────────────────────────────────────────────┐ +│ User-facing layer │ +│ ArrowDatabaseProtocol (read+write) RootSource (read) │ +└──────────────────┬──────────────────────────┬────────────┘ + │ │ + ┌─────────────▼──────────────────────────▼──────────┐ + │ ConnectorArrowDatabase DBTableSource │ + │ (generic ArrowDB impl) (generic Source impl) │ + └─────────────────────────┬─────────────────────────┘ + │ shared dependency + ┌───────────────▼───────────────────┐ + │ DBConnectorProtocol │ + │ (connection + type mapping + │ + │ schema introspection + queries) │ + └──────┬────────────┬────────────────┘ + │ │ │ + SQLiteConnector PostgreSQLConnector SpiralDBConnector + (PLT-1076) (PLT-1075) (PLT-1074) +``` + +**`DBConnectorProtocol`** — the minimal shared raw-access interface every DB technology must implement. + +**`ConnectorArrowDatabase(connector)`** — generic `ArrowDatabaseProtocol` implementation on top of any `DBConnectorProtocol`. Owns all record-management logic: `record_path → table name` mapping, `__record_id` column convention, in-memory pending batch management, deduplication, upsert, schema evolution, flush semantics. + +**`DBTableSource(connector, table_name)`** — generic read-only `RootSource` on top of any `DBConnectorProtocol`. Uses PK columns as default tag columns, delegates all data fetching and type mapping to the connector, feeds Arrow data into `SourceStreamBuilder`. + +Contributors implementing PLT-1074/1075/1076 implement **one class** (`DBConnector`) and get both `ArrowDatabase` and `Source` support. + +--- + +## `DBConnectorProtocol` Interface + +```python +@dataclass(frozen=True) +class ColumnInfo: + """A single column with its Arrow-mapped type. Type mapping is the connector's responsibility.""" + name: str + arrow_type: pa.DataType + nullable: bool = True + + +@runtime_checkable +class DBConnectorProtocol(Protocol): + # ── Schema introspection (used by both Source and ArrowDatabase) ────────── + def get_table_names(self) -> list[str]: ... + def get_pk_columns(self, table_name: str) -> list[str]: ... + def get_column_info(self, table_name: str) -> list[ColumnInfo]: ... + + # ── Read (used by both Source and ArrowDatabase) ────────────────────────── + def iter_batches( + self, + query: str, + params: Any = None, + batch_size: int = 1000, + ) -> Iterator[pa.RecordBatch]: ... + + # ── Write (used only by ConnectorArrowDatabase) ─────────────────────────── + def create_table_if_not_exists( + self, + table_name: str, + columns: list[ColumnInfo], + pk_column: str, + ) -> None: ... + + def upsert_records( + self, + table_name: str, + records: pa.Table, + id_column: str, + skip_existing: bool = False, + ) -> None: ... + + # ── Lifecycle ───────────────────────────────────────────────────────────── + def close(self) -> None: ... + def __enter__(self) -> "DBConnectorProtocol": ... + def __exit__(self, *args: Any) -> None: ... + + # ── Serialization ───────────────────────────────────────────────────────── + def to_config(self) -> dict[str, Any]: ... + @classmethod + def from_config(cls, config: dict[str, Any]) -> "DBConnectorProtocol": ... +``` + +**Type mapping is fully owned by the connector.** `iter_batches()` returns Arrow-typed `RecordBatch`es. `get_column_info()` returns `ColumnInfo` with `arrow_type` already mapped. `upsert_records()` accepts an Arrow table and the connector handles conversion to DB-native types internally. `ConnectorArrowDatabase` and `DBTableSource` are completely DB-type-agnostic. + +**Write methods** (`create_table_if_not_exists`, `upsert_records`) are on the protocol but are only called by `ConnectorArrowDatabase`. `DBTableSource` never calls them. This avoids the need for a separate `ReadableDBConnectorProtocol` split while keeping the interface intentionally minimal. + +--- + +## `ConnectorArrowDatabase` + +Implements `ArrowDatabaseProtocol` on top of any `DBConnectorProtocol`. Mirrors the pending-batch + flush semantics of `DeltaTableDatabase` and `InMemoryArrowDatabase`: + +- **`record_path → table_name`**: `"__".join(record_path)` with validation (max depth, safe characters). Uses `__` as separator to avoid collisions with path components. +- **`__record_id` column**: Standard column added to every table, used as the primary key in the underlying DB table. +- **Pending batch**: Records are buffered in memory as Arrow tables, keyed by `record_path`. `flush()` commits all pending batches via `connector.upsert_records()`. +- **Deduplication**: Within-batch deduplication keeps the last occurrence per `__record_id`. +- **`skip_duplicates`**: Passes through to `connector.upsert_records(skip_existing=...)`. +- **Schema evolution**: `connector.create_table_if_not_exists()` is called at flush time if the table does not yet exist. For schema changes on existing tables, evolution is delegated to the connector. + +### `record_path → table_name` mapping + +```python +def _path_to_table_name(record_path: tuple[str, ...]) -> str: + # Joins with '__' separator; sanitizes each component (replaces non-alphanumeric with '_') + return "__".join(re.sub(r"[^a-zA-Z0-9_]", "_", part) for part in record_path) +``` + +--- + +## `DBTableSource` + +Implements `RootSource` on top of any `DBConnectorProtocol`. Read-only. + +```python +DBTableSource( + connector: DBConnectorProtocol, + table_name: str, + tag_columns: Collection[str] | None = None, # None → use PK columns + record_id_column: str | None = None, + source_id: str | None = None, # None → defaults to table_name + **kwargs, # passed to RootSource (label, data_context, config) +) +``` + +Construction flow: +1. Resolve `tag_columns`: if `None`, call `connector.get_pk_columns(table_name)`. +2. Fetch full table: `list(connector.iter_batches(f'SELECT * FROM "{table_name}"'))` → `pa.Table.from_batches(...)`. +3. Feed into `SourceStreamBuilder.build(table, tag_columns=..., source_id=..., record_id_column=...)`. +4. Store result stream, tag_columns, source_id. + +`to_config()` serializes `connector.to_config()`, `table_name`, `tag_columns`, `record_id_column`, `source_id` plus identity fields from `_identity_config()`. + +`from_config()` calls a `build_db_connector_from_config(config["connector"])` registry helper (to be implemented alongside connectors in PLT-1074/1075/1076). + +--- + +## Module Layout + +``` +src/orcapod/ +├── protocols/ +│ ├── database_protocols.py # existing ArrowDatabaseProtocol (unchanged) +│ └── db_connector_protocol.py # NEW: ColumnInfo, DBConnectorProtocol +├── databases/ +│ ├── __init__.py # updated: export ConnectorArrowDatabase +│ ├── connector_arrow_database.py # NEW: ConnectorArrowDatabase +│ ├── delta_lake_databases.py # existing (unchanged) +│ ├── in_memory_databases.py # existing (unchanged) +│ └── noop_database.py # existing (unchanged) +└── core/sources/ + ├── __init__.py # updated: export DBTableSource + └── db_table_source.py # NEW: DBTableSource +``` + +**Not in this spike** (belong to PLT-1074/1075/1076): +- `databases/sqlite_connector.py` — `SQLiteConnector` +- `databases/postgresql_connector.py` — `PostgreSQLConnector` +- `databases/spiraldb_connector.py` — `SpiralDBConnector` + +--- + +## Tests + +``` +tests/ +├── test_databases/ +│ └── test_connector_arrow_database.py # NEW: protocol conformance + behaviour via mock connector +└── test_core/sources/ + └── test_db_table_source.py # NEW: DBTableSource via mock connector +``` + +Both test files use a `MockDBConnector` (defined in the test file) that holds data in-memory, enabling tests with zero external dependencies. + +--- + +## Design Decisions Log + +| Question | Decision | +|---|---| +| Protocol vs ABC? | `Protocol` — consistent with existing codebase, enables parallel dev without import coupling | +| Separate `ReadableDBConnectorProtocol`? | No — single `DBConnectorProtocol`; `Source` simply doesn't call write methods. Can split later if needed. | +| Generic Source vs per-DB subclasses? | Single `DBTableSource(connector, table_name)` — only the connector varies | +| Type mapping ownership? | Connector — `iter_batches` and `get_column_info` always return Arrow types | +| `record_path → table_name`? | `"__".join(sanitized_parts)` — double underscore separator avoids collisions | +| Upsert abstraction? | `upsert_records(table, id_column, skip_existing)` on connector — hides SQLite/PostgreSQL/SpiralDB dialect differences (INSERT OR IGNORE vs ON CONFLICT DO NOTHING, etc.) | +| Pending-batch location? | In `ConnectorArrowDatabase` (Python-side) — mirrors existing `DeltaTableDatabase`/`InMemoryArrowDatabase` pattern | From ff28391cab1a420d6b5d962f8a95cbb3a30c76e9 Mon Sep 17 00:00:00 2001 From: "agent-kurouto[bot]" <268466204+agent-kurouto[bot]@users.noreply.github.com> Date: Sat, 21 Mar 2026 06:50:20 +0000 Subject: [PATCH 2/8] docs: address spec review feedback on DBConnector design --- .../specs/2026-03-21-db-connector-design.md | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/superpowers/specs/2026-03-21-db-connector-design.md b/superpowers/specs/2026-03-21-db-connector-design.md index ce751da..71f3ee8 100644 --- a/superpowers/specs/2026-03-21-db-connector-design.md +++ b/superpowers/specs/2026-03-21-db-connector-design.md @@ -144,22 +144,24 @@ Implements `RootSource` on top of any `DBConnectorProtocol`. Read-only. DBTableSource( connector: DBConnectorProtocol, table_name: str, - tag_columns: Collection[str] | None = None, # None → use PK columns + tag_columns: Collection[str] | None = None, # None → use PK columns + system_tag_columns: Collection[str] = (), # consistent with DeltaTableSource record_id_column: str | None = None, - source_id: str | None = None, # None → defaults to table_name + source_id: str | None = None, # None → defaults to table_name **kwargs, # passed to RootSource (label, data_context, config) ) ``` Construction flow: -1. Resolve `tag_columns`: if `None`, call `connector.get_pk_columns(table_name)`. -2. Fetch full table: `list(connector.iter_batches(f'SELECT * FROM "{table_name}"'))` → `pa.Table.from_batches(...)`. -3. Feed into `SourceStreamBuilder.build(table, tag_columns=..., source_id=..., record_id_column=...)`. -4. Store result stream, tag_columns, source_id. +1. Resolve `tag_columns`: if `None`, call `connector.get_pk_columns(table_name)`. Raise `ValueError` if the result is empty (table has no primary key and no explicit tag columns were provided). +2. Validate the table exists: if `table_name not in connector.get_table_names()`, raise `ValueError(f"Table {table_name!r} not found in database.")`. This distinguishes "not found" from "empty". +3. Fetch full table: `list(connector.iter_batches(f'SELECT * FROM "{table_name}"'))` → `pa.Table.from_batches(...)`. If the result is empty, raise `ValueError(f"Table {table_name!r} is empty.")` — consistent with `ArrowTableSource`'s behaviour (via `SourceStreamBuilder`) which also rejects empty tables. The `table_name` is always double-quoted in the query string (`f'SELECT * FROM "{table_name}"'`); connectors must support ANSI-standard double-quoted identifiers. +4. Feed into `SourceStreamBuilder.build(table, tag_columns=..., source_id=..., record_id_column=...)`. +5. Store result stream, tag_columns, source_id. -`to_config()` serializes `connector.to_config()`, `table_name`, `tag_columns`, `record_id_column`, `source_id` plus identity fields from `_identity_config()`. +`to_config()` serializes `connector.to_config()`, `table_name`, `tag_columns`, `system_tag_columns`, `record_id_column`, `source_id` plus identity fields from `_identity_config()`. -`from_config()` calls a `build_db_connector_from_config(config["connector"])` registry helper (to be implemented alongside connectors in PLT-1074/1075/1076). +`from_config()` raises `NotImplementedError` until connector implementations land in PLT-1074/1075/1076. A `build_db_connector_from_config(config)` registry helper will be added as part of those issues; implementing the registry is **out of scope** for this spike. The config shape uses a `"connector_type"` discriminator key (e.g., `"sqlite"`, `"postgresql"`, `"spiraldb"`) so the registry can dispatch to the correct `from_config` classmethod. --- @@ -198,7 +200,7 @@ tests/ └── test_db_table_source.py # NEW: DBTableSource via mock connector ``` -Both test files use a `MockDBConnector` (defined in the test file) that holds data in-memory, enabling tests with zero external dependencies. +Both test files use a `MockDBConnector` defined in `tests/conftest.py` or inline in the test module. The mock holds data as a `dict[str, pa.Table]` keyed by table name; `iter_batches` slices rows into batches, `get_pk_columns` returns a pre-configured list, `create_table_if_not_exists` is a no-op, and `upsert_records` applies insert-or-replace semantics in memory. This shared mock shape ensures the two test suites use compatible fixtures. --- @@ -213,3 +215,4 @@ Both test files use a `MockDBConnector` (defined in the test file) that holds da | `record_path → table_name`? | `"__".join(sanitized_parts)` — double underscore separator avoids collisions | | Upsert abstraction? | `upsert_records(table, id_column, skip_existing)` on connector — hides SQLite/PostgreSQL/SpiralDB dialect differences (INSERT OR IGNORE vs ON CONFLICT DO NOTHING, etc.) | | Pending-batch location? | In `ConnectorArrowDatabase` (Python-side) — mirrors existing `DeltaTableDatabase`/`InMemoryArrowDatabase` pattern | +| Schema evolution on existing tables? | **Out of scope for this spike.** `DBConnectorProtocol` has no `alter_table` method. `ConnectorArrowDatabase` raises `ValueError` if a flush encounters a schema mismatch with an existing table. Schema evolution can be added in a follow-up. | From 805762f973f664bf707b9cb21012ee572697346c Mon Sep 17 00:00:00 2001 From: "agent-kurouto[bot]" <268466204+agent-kurouto[bot]@users.noreply.github.com> Date: Sat, 21 Mar 2026 07:08:38 +0000 Subject: [PATCH 3/8] docs: add DBConnector implementation plan for PLT-1078 --- .../plans/2026-03-21-db-connector-plan.md | 1479 +++++++++++++++++ 1 file changed, 1479 insertions(+) create mode 100644 superpowers/plans/2026-03-21-db-connector-plan.md diff --git a/superpowers/plans/2026-03-21-db-connector-plan.md b/superpowers/plans/2026-03-21-db-connector-plan.md new file mode 100644 index 0000000..8eb1456 --- /dev/null +++ b/superpowers/plans/2026-03-21-db-connector-plan.md @@ -0,0 +1,1479 @@ +# DBConnector Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Introduce `DBConnectorProtocol`, `ConnectorArrowDatabase`, and `DBTableSource` so that a single DB technology class (SQLiteConnector, PostgreSQLConnector, SpiralDBConnector) powers both the `ArrowDatabaseProtocol` and `Source` layers. + +**Architecture:** A new `DBConnectorProtocol` captures the minimal raw-access interface (schema introspection + Arrow-typed reads + writes + lifecycle). `ConnectorArrowDatabase` wraps any connector and implements `ArrowDatabaseProtocol` with pending-batch/flush semantics. `DBTableSource` wraps any connector as a read-only `RootSource`, defaulting to PK columns as tag columns. + +**Tech Stack:** Python 3.11+, PyArrow, `typing.Protocol`, existing `RootSource` / `SourceStreamBuilder` / `ArrowDatabaseProtocol` abstractions. + +--- + +## File Map + +| Action | Path | Responsibility | +|--------|------|---------------| +| Create | `src/orcapod/protocols/db_connector_protocol.py` | `ColumnInfo` dataclass + `DBConnectorProtocol` | +| Modify | `src/orcapod/protocols/database_protocols.py` | Re-export `ColumnInfo`, `DBConnectorProtocol` | +| Create | `src/orcapod/databases/connector_arrow_database.py` | `ConnectorArrowDatabase` (generic `ArrowDatabaseProtocol` impl) | +| Modify | `src/orcapod/databases/__init__.py` | Export `ConnectorArrowDatabase`; update backend comment | +| Create | `src/orcapod/core/sources/db_table_source.py` | `DBTableSource` (generic read-only `RootSource`) | +| Modify | `src/orcapod/core/sources/__init__.py` | Export `DBTableSource` | +| Create | `tests/test_databases/test_connector_arrow_database.py` | Protocol conformance + behaviour tests | +| Create | `tests/test_core/sources/test_db_table_source.py` | `DBTableSource` tests | + +--- + +## Task 1: `ColumnInfo` and `DBConnectorProtocol` + +**Files:** +- Create: `src/orcapod/protocols/db_connector_protocol.py` +- Modify: `src/orcapod/protocols/database_protocols.py` + +- [ ] **Step 1: Write the failing import test** + +```python +# tests/test_databases/test_connector_arrow_database.py (stub for now) +def test_import_db_connector_protocol(): + from orcapod.protocols.db_connector_protocol import ColumnInfo, DBConnectorProtocol + assert ColumnInfo is not None + assert DBConnectorProtocol is not None +``` + +- [ ] **Step 2: Run test to confirm it fails** + +```bash +cd /tmp/kurouto-jobs/b8d04a9f-a949-4b75-9ab4-332a63bc70e3/orcapod-python +uv run pytest tests/test_databases/test_connector_arrow_database.py::test_import_db_connector_protocol -v +``` +Expected: `ModuleNotFoundError` + +- [ ] **Step 3: Create `src/orcapod/protocols/db_connector_protocol.py`** + +```python +"""DBConnectorProtocol — minimal shared interface for external relational DB backends. + +Each DB technology (SQLite, PostgreSQL, SpiralDB) implements this once. +Both ``ConnectorArrowDatabase`` (read+write) and ``DBTableSource`` (read-only) +depend on it, eliminating duplicated connection management and type-mapping logic. +""" +from __future__ import annotations + +import re +from collections.abc import Iterator +from dataclasses import dataclass +from typing import Any, Protocol, TYPE_CHECKING, runtime_checkable + +if TYPE_CHECKING: + import pyarrow as pa + + +@dataclass(frozen=True) +class ColumnInfo: + """Metadata for a single database column with its Arrow-mapped type. + + Type mapping (DB-native → Arrow) is the connector's responsibility. + Consumers of ``DBConnectorProtocol`` always see Arrow types. + + Args: + name: Column name. + arrow_type: Arrow data type (already mapped from the DB-native type). + nullable: Whether the column accepts NULL values. + """ + + name: str + arrow_type: "pa.DataType" + nullable: bool = True + + +@runtime_checkable +class DBConnectorProtocol(Protocol): + """Minimal interface for an external relational database backend. + + Implementations encapsulate: + - Connection lifecycle + - DB-native ↔ Arrow type mapping + - Schema introspection + - Query execution (reads) and record management (writes) + + Read methods are used by both ``ConnectorArrowDatabase`` and ``DBTableSource``. + Write methods (``create_table_if_not_exists``, ``upsert_records``) are used + only by ``ConnectorArrowDatabase``. + + All query results are returned as Arrow types; connectors handle all + DB-native type conversion internally. + + Planned implementations: ``SQLiteConnector`` (PLT-1076), + ``PostgreSQLConnector`` (PLT-1075), ``SpiralDBConnector`` (PLT-1074). + """ + + # ── Schema introspection ────────────────────────────────────────────────── + + def get_table_names(self) -> list[str]: + """Return all available table names in this database.""" + ... + + def get_pk_columns(self, table_name: str) -> list[str]: + """Return primary-key column names for a table, in key-sequence order. + + Returns an empty list if the table has no primary key. + """ + ... + + def get_column_info(self, table_name: str) -> list[ColumnInfo]: + """Return column metadata for a table, with types mapped to Arrow.""" + ... + + # ── Read ────────────────────────────────────────────────────────────────── + + def iter_batches( + self, + query: str, + params: Any = None, + batch_size: int = 1000, + ) -> Iterator["pa.RecordBatch"]: + """Execute a query and yield results as Arrow RecordBatches. + + Args: + query: SQL query string. Table names should be double-quoted + (``SELECT * FROM "my_table"``); all connectors must support + ANSI-standard double-quoted identifiers. + params: Optional query parameters (connector-specific format). + batch_size: Maximum rows per yielded batch. + """ + ... + + # ── Write ───────────────────────────────────────────────────────────────── + + def create_table_if_not_exists( + self, + table_name: str, + columns: list[ColumnInfo], + pk_column: str, + ) -> None: + """Create a table with the given columns if it does not already exist. + + Args: + table_name: Table to create. + columns: Column definitions with Arrow-mapped types. + pk_column: Name of the column to use as the primary key. + """ + ... + + def upsert_records( + self, + table_name: str, + records: "pa.Table", + id_column: str, + skip_existing: bool = False, + ) -> None: + """Write records to a table using upsert semantics. + + Args: + table_name: Target table (must already exist). + records: Arrow table of records to write. + id_column: Column used as the unique row identifier. + skip_existing: If ``True``, skip records whose ``id_column`` value + already exists in the table (INSERT OR IGNORE). + If ``False``, overwrite existing records (INSERT OR REPLACE). + """ + ... + + # ── Lifecycle ───────────────────────────────────────────────────────────── + + def close(self) -> None: + """Release the database connection and any associated resources.""" + ... + + def __enter__(self) -> "DBConnectorProtocol": + ... + + def __exit__(self, *args: Any) -> None: + ... + + # ── Serialization ───────────────────────────────────────────────────────── + + def to_config(self) -> dict[str, Any]: + """Serialize connection configuration to a JSON-compatible dict. + + The returned dict must include a ``"connector_type"`` key + (e.g., ``"sqlite"``, ``"postgresql"``, ``"spiraldb"``) so that + a registry helper can dispatch to the correct ``from_config`` + classmethod when deserializing. + """ + ... + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "DBConnectorProtocol": + """Reconstruct a connector instance from a config dict.""" + ... +``` + +- [ ] **Step 4: Add re-export to `src/orcapod/protocols/database_protocols.py`** + +Append to the end of the existing file: + +```python +from orcapod.protocols.db_connector_protocol import ColumnInfo, DBConnectorProtocol + +__all__ = [ + "ArrowDatabaseProtocol", + "ArrowDatabaseWithMetadataProtocol", + "ColumnInfo", + "DBConnectorProtocol", + "MetadataCapableProtocol", +] +``` + +- [ ] **Step 5: Run test to confirm it passes** + +```bash +uv run pytest tests/test_databases/test_connector_arrow_database.py::test_import_db_connector_protocol -v +``` +Expected: PASS + +- [ ] **Step 6: Commit** + +```bash +git add src/orcapod/protocols/db_connector_protocol.py src/orcapod/protocols/database_protocols.py tests/test_databases/test_connector_arrow_database.py +git commit -m "feat(protocols): add ColumnInfo and DBConnectorProtocol for PLT-1078" +``` + +--- + +## Task 2: `ConnectorArrowDatabase` + +**Files:** +- Create: `src/orcapod/databases/connector_arrow_database.py` +- Modify: `src/orcapod/databases/__init__.py` + +- [ ] **Step 1: Write the failing protocol-conformance test** + +Replace the stub test file content with: + +```python +"""Tests for ConnectorArrowDatabase — protocol conformance and behaviour via MockDBConnector.""" +from __future__ import annotations + +from collections.abc import Iterator +from typing import Any + +import pyarrow as pa +import pyarrow.compute as pc +import pytest + +from orcapod.protocols.db_connector_protocol import ColumnInfo, DBConnectorProtocol +from orcapod.protocols.database_protocols import ArrowDatabaseProtocol + + +# --------------------------------------------------------------------------- +# MockDBConnector — in-memory implementation of DBConnectorProtocol for tests +# --------------------------------------------------------------------------- + + +class MockDBConnector: + """Minimal in-memory DBConnectorProtocol implementation for testing.""" + + def __init__( + self, + tables: dict[str, pa.Table] | None = None, + pk_columns: dict[str, list[str]] | None = None, + ): + self._tables: dict[str, pa.Table] = dict(tables or {}) + self._pk_columns: dict[str, list[str]] = dict(pk_columns or {}) + + def get_table_names(self) -> list[str]: + return list(self._tables.keys()) + + def get_pk_columns(self, table_name: str) -> list[str]: + return list(self._pk_columns.get(table_name, [])) + + def get_column_info(self, table_name: str) -> list[ColumnInfo]: + schema = self._tables[table_name].schema + return [ + ColumnInfo(name=f.name, arrow_type=f.type, nullable=f.nullable) + for f in schema + ] + + def iter_batches( + self, query: str, params: Any = None, batch_size: int = 1000 + ) -> Iterator[pa.RecordBatch]: + import re + match = re.search(r'FROM\s+"?(\w+)"?', query, re.IGNORECASE) + if not match: + return + table_name = match.group(1) + table = self._tables.get(table_name) + if table is None or table.num_rows == 0: + return + for batch in table.to_batches(max_chunksize=batch_size): + yield batch + + def create_table_if_not_exists( + self, table_name: str, columns: list[ColumnInfo], pk_column: str + ) -> None: + if table_name not in self._tables: + self._tables[table_name] = pa.table( + {c.name: pa.array([], type=c.arrow_type) for c in columns} + ) + self._pk_columns.setdefault(table_name, [pk_column]) + + def upsert_records( + self, table_name: str, records: pa.Table, id_column: str, skip_existing: bool = False + ) -> None: + existing = self._tables.get(table_name) + if existing is None or existing.num_rows == 0: + self._tables[table_name] = records + return + new_ids = set(records[id_column].to_pylist()) + if skip_existing: + existing_ids = set(existing[id_column].to_pylist()) + mask = pc.invert( + pc.is_in(records[id_column], pa.array(list(new_ids & existing_ids))) + ) + to_add = records.filter(mask) + if to_add.num_rows > 0: + self._tables[table_name] = pa.concat_tables([existing, to_add]) + else: + mask = pc.invert(pc.is_in(existing[id_column], pa.array(list(new_ids)))) + kept = existing.filter(mask) + self._tables[table_name] = pa.concat_tables([kept, records]) + + def close(self) -> None: + pass + + def __enter__(self) -> "MockDBConnector": + return self + + def __exit__(self, *args: Any) -> None: + self.close() + + def to_config(self) -> dict[str, Any]: + return {"connector_type": "mock"} + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "MockDBConnector": + return cls() + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def connector(): + return MockDBConnector() + + +@pytest.fixture +def db(connector): + from orcapod.databases import ConnectorArrowDatabase + return ConnectorArrowDatabase(connector) + + +def make_table(**columns: list) -> pa.Table: + return pa.table({k: pa.array(v) for k, v in columns.items()}) + + +# --------------------------------------------------------------------------- +# Protocol conformance +# --------------------------------------------------------------------------- + + +class TestProtocolConformance: + def test_satisfies_arrow_database_protocol(self, db): + assert isinstance(db, ArrowDatabaseProtocol) + + def test_mock_satisfies_db_connector_protocol(self, connector): + assert isinstance(connector, DBConnectorProtocol) +``` + +- [ ] **Step 2: Run test to confirm it fails** + +```bash +uv run pytest tests/test_databases/test_connector_arrow_database.py::TestProtocolConformance -v +``` +Expected: `ImportError: cannot import name 'ConnectorArrowDatabase'` + +- [ ] **Step 3: Create `src/orcapod/databases/connector_arrow_database.py`** + +```python +"""ConnectorArrowDatabase — generic ArrowDatabaseProtocol backed by any DBConnectorProtocol. + +Implements the full ArrowDatabaseProtocol on top of any DBConnectorProtocol, +owning all record-management logic: record_path → table name mapping, +``__record_id`` column convention, in-memory pending-batch management, +deduplication, upsert, and flush. + +Connector implementations (SQLiteConnector, PostgreSQLConnector, SpiralDBConnector) +need only satisfy DBConnectorProtocol; they do not implement ArrowDatabaseProtocol. +""" +from __future__ import annotations + +import re +from collections import defaultdict +from collections.abc import Collection, Mapping +from typing import TYPE_CHECKING, Any, cast + +from orcapod.protocols.db_connector_protocol import ColumnInfo, DBConnectorProtocol +from orcapod.utils.lazy_module import LazyModule + +if TYPE_CHECKING: + import pyarrow as pa + import pyarrow.compute as pc +else: + pa = LazyModule("pyarrow") + pc = LazyModule("pyarrow.compute") + + +def _arrow_schema_to_column_infos(schema: "pa.Schema") -> list[ColumnInfo]: + """Convert a PyArrow schema to a list of ColumnInfo.""" + return [ + ColumnInfo(name=field.name, arrow_type=field.type, nullable=field.nullable) + for field in schema + ] + + +class ConnectorArrowDatabase: + """Generic ``ArrowDatabaseProtocol`` implementation backed by a ``DBConnectorProtocol``. + + Records are buffered in memory (pending batch) and written to the connector + on ``flush()``. The ``record_path`` tuple is mapped to a sanitized SQL table + name using ``"__".join(sanitized_parts)``. + + Args: + connector: A ``DBConnectorProtocol`` implementation providing the + underlying DB access (connection, type mapping, queries, writes). + max_hierarchy_depth: Maximum allowed length for ``record_path`` tuples. + + Example:: + + connector = SQLiteConnector(":memory:") # PLT-1076 + db = ConnectorArrowDatabase(connector) + db.add_record(("results", "my_fn"), record_id="abc", record=table) + db.flush() + """ + + RECORD_ID_COLUMN = "__record_id" + _ROW_INDEX_COLUMN = "__row_index" + + def __init__( + self, + connector: DBConnectorProtocol, + max_hierarchy_depth: int = 10, + ) -> None: + self._connector = connector + self.max_hierarchy_depth = max_hierarchy_depth + self._pending_batches: dict[str, pa.Table] = {} + self._pending_record_ids: dict[str, set[str]] = defaultdict(set) + + # ── Path helpers ────────────────────────────────────────────────────────── + + def _get_record_key(self, record_path: tuple[str, ...]) -> str: + return "/".join(record_path) + + def _path_to_table_name(self, record_path: tuple[str, ...]) -> str: + """Map a record_path to a safe SQL table name. + + Each component is sanitized (non-alphanumeric chars → ``_``), then + joined with ``__`` as separator. A ``t_`` prefix is added if the result + starts with a digit to ensure a valid SQL identifier. + """ + parts = [re.sub(r"[^a-zA-Z0-9_]", "_", part) for part in record_path] + name = "__".join(parts) + if name and name[0].isdigit(): + name = "t_" + name + return name + + def _validate_record_path(self, record_path: tuple[str, ...]) -> None: + if not record_path: + raise ValueError("record_path cannot be empty") + if len(record_path) > self.max_hierarchy_depth: + raise ValueError( + f"record_path depth {len(record_path)} exceeds maximum " + f"{self.max_hierarchy_depth}" + ) + for i, component in enumerate(record_path): + if not component or not isinstance(component, str): + raise ValueError( + f"record_path component {i} is invalid: {repr(component)}" + ) + + # ── Record-ID column helpers ────────────────────────────────────────────── + + def _ensure_record_id_column( + self, arrow_data: "pa.Table", record_id: str + ) -> "pa.Table": + if self.RECORD_ID_COLUMN not in arrow_data.column_names: + key_array = pa.array( + [record_id] * len(arrow_data), type=pa.large_string() + ) + arrow_data = arrow_data.add_column(0, self.RECORD_ID_COLUMN, key_array) + return arrow_data + + def _remove_record_id_column(self, arrow_data: "pa.Table") -> "pa.Table": + if self.RECORD_ID_COLUMN in arrow_data.column_names: + arrow_data = arrow_data.drop([self.RECORD_ID_COLUMN]) + return arrow_data + + def _handle_record_id_column( + self, arrow_data: "pa.Table", record_id_column: str | None + ) -> "pa.Table": + if not record_id_column: + return self._remove_record_id_column(arrow_data) + if self.RECORD_ID_COLUMN in arrow_data.column_names: + new_names = [ + record_id_column if name == self.RECORD_ID_COLUMN else name + for name in arrow_data.schema.names + ] + return arrow_data.rename_columns(new_names) + raise ValueError( + f"Record ID column '{self.RECORD_ID_COLUMN}' not found in table." + ) + + # ── Deduplication ───────────────────────────────────────────────────────── + + def _deduplicate_within_table(self, table: "pa.Table") -> "pa.Table": + """Keep the last occurrence of each record ID within a single table.""" + if table.num_rows <= 1: + return table + indices = pa.array(range(table.num_rows)) + table_with_idx = table.add_column(0, self._ROW_INDEX_COLUMN, indices) + grouped = table_with_idx.group_by([self.RECORD_ID_COLUMN]).aggregate( + [(self._ROW_INDEX_COLUMN, "max")] + ) + max_indices = grouped[f"{self._ROW_INDEX_COLUMN}_max"].to_pylist() + mask = pc.is_in(indices, pa.array(max_indices)) + return table.filter(mask) + + # ── Committed data access ───────────────────────────────────────────────── + + def _get_committed_table( + self, record_path: tuple[str, ...] + ) -> "pa.Table | None": + """Fetch all committed records for a path from the connector.""" + table_name = self._path_to_table_name(record_path) + if table_name not in self._connector.get_table_names(): + return None + batches = list( + self._connector.iter_batches(f'SELECT * FROM "{table_name}"') + ) + if not batches: + return None + return pa.Table.from_batches(batches) + + # ── Write methods ───────────────────────────────────────────────────────── + + def add_record( + self, + record_path: tuple[str, ...], + record_id: str, + record: "pa.Table", + skip_duplicates: bool = False, + flush: bool = False, + ) -> None: + """Add a single record identified by ``record_id``.""" + data_with_id = self._ensure_record_id_column(record, record_id) + self.add_records( + record_path=record_path, + records=data_with_id, + record_id_column=self.RECORD_ID_COLUMN, + skip_duplicates=skip_duplicates, + flush=flush, + ) + + def add_records( + self, + record_path: tuple[str, ...], + records: "pa.Table", + record_id_column: str | None = None, + skip_duplicates: bool = False, + flush: bool = False, + ) -> None: + """Add multiple records to the pending batch.""" + self._validate_record_path(record_path) + if records.num_rows == 0: + return + + if record_id_column is None: + record_id_column = records.column_names[0] + if record_id_column not in records.column_names: + raise ValueError( + f"record_id_column '{record_id_column}' not found in table columns: " + f"{records.column_names}" + ) + + # Normalise to internal column name + if record_id_column != self.RECORD_ID_COLUMN: + rename_map = {record_id_column: self.RECORD_ID_COLUMN} + records = records.rename_columns( + [rename_map.get(c, c) for c in records.column_names] + ) + + records = self._deduplicate_within_table(records) + record_key = self._get_record_key(record_path) + input_ids = set(cast(list[str], records[self.RECORD_ID_COLUMN].to_pylist())) + + if skip_duplicates: + committed = self._get_committed_table(record_path) + committed_ids: set[str] = set() + if committed is not None: + committed_ids = set( + cast(list[str], committed[self.RECORD_ID_COLUMN].to_pylist()) + ) + all_existing = (input_ids & self._pending_record_ids[record_key]) | ( + input_ids & committed_ids + ) + if all_existing: + mask = pc.invert( + pc.is_in( + records[self.RECORD_ID_COLUMN], pa.array(list(all_existing)) + ) + ) + records = records.filter(mask) + if records.num_rows == 0: + return + else: + conflicts = input_ids & self._pending_record_ids[record_key] + if conflicts: + raise ValueError( + f"Records with IDs {conflicts} already exist in the pending batch. " + "Use skip_duplicates=True to skip them." + ) + + # Buffer in pending batch + existing_pending = self._pending_batches.get(record_key) + if existing_pending is None: + self._pending_batches[record_key] = records + else: + self._pending_batches[record_key] = pa.concat_tables( + [existing_pending, records] + ) + self._pending_record_ids[record_key].update( + cast(list[str], records[self.RECORD_ID_COLUMN].to_pylist()) + ) + + if flush: + self.flush() + + # ── Flush ───────────────────────────────────────────────────────────────── + + def flush(self) -> None: + """Commit all pending batches to the connector via upsert.""" + for record_key in list(self._pending_batches.keys()): + record_path = tuple(record_key.split("/")) + table_name = self._path_to_table_name(record_path) + pending = self._pending_batches.pop(record_key) + self._pending_record_ids.pop(record_key, None) + + columns = _arrow_schema_to_column_infos(pending.schema) + self._connector.create_table_if_not_exists( + table_name, columns, pk_column=self.RECORD_ID_COLUMN + ) + self._connector.upsert_records( + table_name, + pending, + id_column=self.RECORD_ID_COLUMN, + skip_existing=False, + ) + + # ── Read methods ────────────────────────────────────────────────────────── + + def get_record_by_id( + self, + record_path: tuple[str, ...], + record_id: str, + record_id_column: str | None = None, + flush: bool = False, + ) -> "pa.Table | None": + if flush: + self.flush() + record_key = self._get_record_key(record_path) + + # Check pending first + if record_id in self._pending_record_ids.get(record_key, set()): + pending = self._pending_batches[record_key] + filtered = pending.filter(pc.field(self.RECORD_ID_COLUMN) == record_id) + if filtered.num_rows > 0: + return self._handle_record_id_column(filtered, record_id_column) + + # Check committed + committed = self._get_committed_table(record_path) + if committed is None: + return None + filtered = committed.filter(pc.field(self.RECORD_ID_COLUMN) == record_id) + if filtered.num_rows == 0: + return None + return self._handle_record_id_column(filtered, record_id_column) + + def get_all_records( + self, + record_path: tuple[str, ...], + record_id_column: str | None = None, + ) -> "pa.Table | None": + record_key = self._get_record_key(record_path) + parts: list[pa.Table] = [] + + committed = self._get_committed_table(record_path) + if committed is not None and committed.num_rows > 0: + parts.append(committed) + pending = self._pending_batches.get(record_key) + if pending is not None and pending.num_rows > 0: + parts.append(pending) + + if not parts: + return None + table = parts[0] if len(parts) == 1 else pa.concat_tables(parts) + return self._handle_record_id_column(table, record_id_column) + + def get_records_by_ids( + self, + record_path: tuple[str, ...], + record_ids: Collection[str], + record_id_column: str | None = None, + flush: bool = False, + ) -> "pa.Table | None": + if flush: + self.flush() + ids_list = list(record_ids) + if not ids_list: + return None + all_records = self.get_all_records( + record_path, record_id_column=self.RECORD_ID_COLUMN + ) + if all_records is None: + return None + filtered = all_records.filter( + pc.is_in(all_records[self.RECORD_ID_COLUMN], pa.array(ids_list)) + ) + if filtered.num_rows == 0: + return None + return self._handle_record_id_column(filtered, record_id_column) + + def get_records_with_column_value( + self, + record_path: tuple[str, ...], + column_values: Collection[tuple[str, Any]] | Mapping[str, Any], + record_id_column: str | None = None, + flush: bool = False, + ) -> "pa.Table | None": + if flush: + self.flush() + all_records = self.get_all_records( + record_path, record_id_column=self.RECORD_ID_COLUMN + ) + if all_records is None: + return None + + if isinstance(column_values, Mapping): + pairs = list(column_values.items()) + else: + pairs = cast(list[tuple[str, Any]], list(column_values)) + + expr = None + for col, val in pairs: + e = pc.field(col) == val + expr = e if expr is None else expr & e # type: ignore[assignment] + + filtered = all_records.filter(expr) + if filtered.num_rows == 0: + return None + return self._handle_record_id_column(filtered, record_id_column) + + # ── Config ──────────────────────────────────────────────────────────────── + + def to_config(self) -> dict[str, Any]: + """Serialize configuration to a JSON-compatible dict.""" + return { + "type": "connector_arrow_database", + "connector": self._connector.to_config(), + "max_hierarchy_depth": self.max_hierarchy_depth, + } + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "ConnectorArrowDatabase": + """Reconstruct a ConnectorArrowDatabase from config. + + Raises: + NotImplementedError: Always — requires a connector registry that + maps ``connector_type`` keys to ``from_config`` classmethods. + Implement alongside connector classes in PLT-1074/1075/1076. + """ + raise NotImplementedError( + "ConnectorArrowDatabase.from_config requires a registered connector " + "factory (connector_type → class). Implement in PLT-1074/1075/1076." + ) +``` + +- [ ] **Step 4: Update `src/orcapod/databases/__init__.py`** + +```python +from .connector_arrow_database import ConnectorArrowDatabase +from .delta_lake_databases import DeltaTableDatabase +from .in_memory_databases import InMemoryArrowDatabase +from .noop_database import NoOpArrowDatabase + +__all__ = [ + "ConnectorArrowDatabase", + "DeltaTableDatabase", + "InMemoryArrowDatabase", + "NoOpArrowDatabase", +] + +# Relational DB connector implementations (satisfy DBConnectorProtocol from +# orcapod.protocols.db_connector_protocol) and can be passed to +# ConnectorArrowDatabase or DBTableSource: +# +# SQLiteConnector -- PLT-1076 (stdlib sqlite3, zero extra deps) +# PostgreSQLConnector -- PLT-1075 (psycopg3) +# SpiralDBConnector -- PLT-1074 +# +# ArrowDatabaseProtocol backends (existing, not connector-based): +# +# DeltaTableDatabase -- Delta Lake (deltalake package) +# InMemoryArrowDatabase -- pure in-memory, for tests +# NoOpArrowDatabase -- no-op, for dry-runs / benchmarks +``` + +- [ ] **Step 5: Run conformance test to verify it passes** + +```bash +uv run pytest tests/test_databases/test_connector_arrow_database.py::TestProtocolConformance -v +``` +Expected: PASS + +- [ ] **Step 6: Add behaviour tests to the test file** + +Append to `tests/test_databases/test_connector_arrow_database.py`: + +```python +# --------------------------------------------------------------------------- +# Behaviour tests +# --------------------------------------------------------------------------- + + +class TestAddAndGetRecord: + def test_add_record_and_get_by_id(self, db): + record = make_table(value=[42]) + db.add_record(("test", "path"), record_id="r1", record=record, flush=True) + result = db.get_record_by_id(("test", "path"), "r1") + assert result is not None + assert result["value"][0].as_py() == 42 + + def test_get_record_not_found_returns_none(self, db): + assert db.get_record_by_id(("missing",), "nope") is None + + def test_add_record_pending_visible_before_flush(self, db): + record = make_table(value=[1]) + db.add_record(("p",), record_id="x", record=record) + result = db.get_record_by_id(("p",), "x") + assert result is not None + + def test_get_all_records_returns_pending_and_committed(self, db): + db.add_record(("t",), "a", make_table(v=[1]), flush=True) + db.add_record(("t",), "b", make_table(v=[2])) + all_r = db.get_all_records(("t",)) + assert all_r is not None + assert all_r.num_rows == 2 + + def test_skip_duplicates_true_ignores_existing(self, db): + db.add_record(("t",), "a", make_table(v=[1]), flush=True) + db.add_record(("t",), "a", make_table(v=[99]), skip_duplicates=True, flush=True) + result = db.get_record_by_id(("t",), "a") + assert result["v"][0].as_py() == 1 # original value preserved + + def test_skip_duplicates_false_raises_on_pending_conflict(self, db): + db.add_record(("t",), "a", make_table(v=[1])) + with pytest.raises(ValueError, match="already exist in the pending batch"): + db.add_record(("t",), "a", make_table(v=[2])) + + def test_flush_writes_to_connector(self, connector, db): + db.add_record(("fn",), "h1", make_table(x=[10])) + assert db.get_all_records(("fn",)) is not None # in pending + db.flush() + # after flush the connector should have the table + table_name = db._path_to_table_name(("fn",)) + assert table_name in connector.get_table_names() + + def test_empty_record_path_raises(self, db): + with pytest.raises(ValueError, match="cannot be empty"): + db.add_record((), "x", make_table(v=[1])) + + def test_get_records_by_ids(self, db): + db.add_record(("t",), "a", make_table(v=[1]), flush=True) + db.add_record(("t",), "b", make_table(v=[2]), flush=True) + result = db.get_records_by_ids(("t",), ["a"]) + assert result is not None + assert result.num_rows == 1 + + def test_get_records_with_column_value(self, db): + db.add_records( + ("t",), + pa.table({"__record_id": pa.array(["a", "b"]), "kind": pa.array(["x", "y"])}), + record_id_column="__record_id", + flush=True, + ) + result = db.get_records_with_column_value(("t",), {"kind": "x"}) + assert result is not None + assert result.num_rows == 1 + + +class TestPathToTableName: + def test_simple_path(self, db): + assert db._path_to_table_name(("results", "my_fn")) == "results__my_fn" + + def test_special_chars_sanitized(self, db): + name = db._path_to_table_name(("a:b", "c/d")) + assert "__" in name + assert ":" not in name + assert "/" not in name + + def test_digit_prefix_gets_t_prefix(self, db): + name = db._path_to_table_name(("1abc",)) + assert name.startswith("t_") +``` + +- [ ] **Step 7: Run all tests** + +```bash +uv run pytest tests/test_databases/test_connector_arrow_database.py -v +``` +Expected: All PASS + +- [ ] **Step 8: Commit** + +```bash +git add src/orcapod/databases/connector_arrow_database.py src/orcapod/databases/__init__.py tests/test_databases/test_connector_arrow_database.py +git commit -m "feat(databases): add ConnectorArrowDatabase for PLT-1078" +``` + +--- + +## Task 3: `DBTableSource` + +**Files:** +- Create: `src/orcapod/core/sources/db_table_source.py` +- Modify: `src/orcapod/core/sources/__init__.py` + +- [ ] **Step 1: Write the failing import test** + +```python +# tests/test_core/sources/test_db_table_source.py +def test_import_db_table_source(): + from orcapod.core.sources import DBTableSource + assert DBTableSource is not None +``` + +- [ ] **Step 2: Run to confirm failure** + +```bash +uv run pytest tests/test_core/sources/test_db_table_source.py::test_import_db_table_source -v +``` +Expected: `ImportError` + +- [ ] **Step 3: Create `src/orcapod/core/sources/db_table_source.py`** + +```python +"""DBTableSource — a read-only RootSource backed by any DBConnectorProtocol. + +Uses the table's primary-key columns as tag columns by default. +Type mapping (DB-native → Arrow) is fully delegated to the connector. + +Example:: + + connector = SQLiteConnector(":memory:") # PLT-1076 + source = DBTableSource(connector, "measurements") # PKs → tags + source = DBTableSource(connector, "events", tag_columns=["session_id"]) +""" +from __future__ import annotations + +from collections.abc import Collection +from typing import TYPE_CHECKING, Any + +from orcapod.core.sources.base import RootSource +from orcapod.core.sources.stream_builder import SourceStreamBuilder +from orcapod.utils.lazy_module import LazyModule + +if TYPE_CHECKING: + import pyarrow as pa + + from orcapod.protocols.db_connector_protocol import DBConnectorProtocol +else: + pa = LazyModule("pyarrow") + + +class DBTableSource(RootSource): + """A read-only Source backed by a table in any DBConnectorProtocol database. + + At construction time the source: + 1. Resolves tag columns (defaults to the table's primary-key columns). + 2. Validates the table exists in the connector. + 3. Fetches all rows as Arrow batches and assembles a PyArrow table. + 4. Enriches via ``SourceStreamBuilder`` (source-info, schema-hash, system tags). + + Args: + connector: A ``DBConnectorProtocol`` providing DB access. + table_name: Name of the table to expose as a source. + tag_columns: Columns to use as tag columns. If ``None`` (default), + the table's primary-key columns are used. Raises ``ValueError`` + if the table has no primary key and no explicit columns are given. + system_tag_columns: Additional system-level tag columns (passed through + to ``SourceStreamBuilder``; mirrors ``DeltaTableSource`` API). + record_id_column: Column for stable per-row record IDs in provenance + strings. If ``None``, row indices are used. + source_id: Canonical source name for the registry and provenance tokens. + Defaults to ``table_name``. + **kwargs: Forwarded to ``RootSource`` (``label``, ``data_context``, + ``config``). + + Raises: + ValueError: If the table is not found, has no PK columns and none are + provided, or is empty. + """ + + def __init__( + self, + connector: "DBConnectorProtocol", + table_name: str, + tag_columns: Collection[str] | None = None, + system_tag_columns: Collection[str] = (), + record_id_column: str | None = None, + source_id: str | None = None, + **kwargs: Any, + ) -> None: + if source_id is None: + source_id = table_name + super().__init__(source_id=source_id, **kwargs) + + self._connector = connector + self._table_name = table_name + self._record_id_column = record_id_column + + # Validate the table exists first so the error is always "not found" + # rather than a misleading "no primary key" error for missing tables. + if table_name not in connector.get_table_names(): + raise ValueError(f"Table {table_name!r} not found in database.") + + # Resolve tag columns — default to PK columns + if tag_columns is None: + resolved_tag_columns: list[str] = connector.get_pk_columns(table_name) + if not resolved_tag_columns: + raise ValueError( + f"Table {table_name!r} has no primary key columns. " + "Provide explicit tag_columns." + ) + else: + resolved_tag_columns = list(tag_columns) + + # Fetch the full table as Arrow + batches = list(connector.iter_batches(f'SELECT * FROM "{table_name}"')) + if not batches: + raise ValueError(f"Table {table_name!r} is empty.") + table: pa.Table = pa.Table.from_batches(batches) + + # Enrich via SourceStreamBuilder (same pipeline as all other RootSources) + builder = SourceStreamBuilder(self.data_context, self.orcapod_config) + result = builder.build( + table, + tag_columns=resolved_tag_columns, + source_id=self._source_id, + record_id_column=record_id_column, + system_tag_columns=system_tag_columns, + ) + + self._stream = result.stream + self._tag_columns = result.tag_columns + self._system_tag_columns = result.system_tag_columns + if self._source_id is None: + self._source_id = result.source_id + + def to_config(self) -> dict[str, Any]: + """Serialize source configuration to a JSON-compatible dict.""" + return { + "source_type": "db_table", + "connector": self._connector.to_config(), + "table_name": self._table_name, + "tag_columns": list(self._tag_columns), + "system_tag_columns": list(self._system_tag_columns), + "record_id_column": self._record_id_column, + "source_id": self.source_id, + **self._identity_config(), + } + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "DBTableSource": + """Not yet implemented — requires a connector factory registry. + + Raises: + NotImplementedError: Always, until connector implementations add + a registry helper (``build_db_connector_from_config``) in + PLT-1074/1075/1076. + """ + raise NotImplementedError( + "DBTableSource.from_config requires a registered connector factory. " + "Implement build_db_connector_from_config in PLT-1074/1075/1076." + ) +``` + +- [ ] **Step 4: Update `src/orcapod/core/sources/__init__.py`** + +Add `DBTableSource` to the imports and `__all__`: + +```python +from .db_table_source import DBTableSource +``` + +Add `"DBTableSource"` to the `__all__` list. + +- [ ] **Step 5: Run import test to confirm it passes** + +```bash +uv run pytest tests/test_core/sources/test_db_table_source.py::test_import_db_table_source -v +``` +Expected: PASS + +- [ ] **Step 6: Add full test suite to `tests/test_core/sources/test_db_table_source.py`** + +Replace the entire file with: + +```python +"""Tests for DBTableSource using MockDBConnector (no external DB required).""" +from __future__ import annotations + +from collections.abc import Iterator +from typing import Any + +import pyarrow as pa +import pyarrow.compute as pc +import pytest + +from orcapod.core.sources import DBTableSource +from orcapod.protocols.core_protocols import SourceProtocol, StreamProtocol +from orcapod.protocols.db_connector_protocol import ColumnInfo, DBConnectorProtocol +from orcapod.protocols.hashing_protocols import PipelineElementProtocol + + +# --------------------------------------------------------------------------- +# MockDBConnector (same interface as in test_connector_arrow_database.py) +# --------------------------------------------------------------------------- + + +class MockDBConnector: + """Minimal in-memory DBConnectorProtocol for testing DBTableSource.""" + + def __init__( + self, + tables: dict[str, pa.Table] | None = None, + pk_columns: dict[str, list[str]] | None = None, + ): + self._tables: dict[str, pa.Table] = dict(tables or {}) + self._pk_columns: dict[str, list[str]] = dict(pk_columns or {}) + + def get_table_names(self) -> list[str]: + return list(self._tables.keys()) + + def get_pk_columns(self, table_name: str) -> list[str]: + return list(self._pk_columns.get(table_name, [])) + + def get_column_info(self, table_name: str) -> list[ColumnInfo]: + schema = self._tables[table_name].schema + return [ColumnInfo(name=f.name, arrow_type=f.type) for f in schema] + + def iter_batches( + self, query: str, params: Any = None, batch_size: int = 1000 + ) -> Iterator[pa.RecordBatch]: + import re + match = re.search(r'FROM\s+"?(\w+)"?', query, re.IGNORECASE) + if not match: + return + table_name = match.group(1) + table = self._tables.get(table_name) + if table is None or table.num_rows == 0: + return + for batch in table.to_batches(max_chunksize=batch_size): + yield batch + + def create_table_if_not_exists(self, *args: Any, **kwargs: Any) -> None: + pass + + def upsert_records(self, *args: Any, **kwargs: Any) -> None: + pass + + def close(self) -> None: + pass + + def __enter__(self) -> "MockDBConnector": + return self + + def __exit__(self, *args: Any) -> None: + pass + + def to_config(self) -> dict[str, Any]: + return {"connector_type": "mock"} + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "MockDBConnector": + return cls() + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def measurements_table() -> pa.Table: + return pa.table( + { + "session_id": pa.array(["s1", "s2", "s3"], type=pa.large_string()), + "trial": pa.array([1, 2, 3], type=pa.int64()), + "response": pa.array([0.1, 0.2, 0.3], type=pa.float64()), + } + ) + + +@pytest.fixture +def connector(measurements_table) -> MockDBConnector: + return MockDBConnector( + tables={"measurements": measurements_table}, + pk_columns={"measurements": ["session_id"]}, + ) + + +@pytest.fixture +def source(connector) -> DBTableSource: + return DBTableSource(connector, "measurements") + + +# --------------------------------------------------------------------------- +# Protocol conformance +# --------------------------------------------------------------------------- + + +class TestProtocolConformance: + def test_is_source_protocol(self, source): + assert isinstance(source, SourceProtocol) + + def test_is_stream_protocol(self, source): + assert isinstance(source, StreamProtocol) + + def test_is_pipeline_element_protocol(self, source): + assert isinstance(source, PipelineElementProtocol) + + def test_connector_satisfies_db_connector_protocol(self, connector): + assert isinstance(connector, DBConnectorProtocol) + + +# --------------------------------------------------------------------------- +# Construction behaviour +# --------------------------------------------------------------------------- + + +class TestConstruction: + def test_pk_columns_used_as_default_tag_columns(self, source): + tag_schema, _ = source.output_schema() + assert "session_id" in tag_schema + + def test_explicit_tag_columns_override_pk(self, connector): + src = DBTableSource(connector, "measurements", tag_columns=["trial"]) + tag_schema, _ = src.output_schema() + assert "trial" in tag_schema + assert "session_id" not in tag_schema + + def test_default_source_id_is_table_name(self, source): + assert source.source_id == "measurements" + + def test_explicit_source_id_is_used(self, connector): + src = DBTableSource(connector, "measurements", source_id="my_source") + assert src.source_id == "my_source" + + def test_missing_table_raises_value_error(self, connector): + with pytest.raises(ValueError, match="not found in database"): + DBTableSource(connector, "nonexistent") + + def test_no_pk_and_no_tag_columns_raises(self, measurements_table): + connector = MockDBConnector( + tables={"t": measurements_table}, + pk_columns={}, # no PKs + ) + with pytest.raises(ValueError, match="no primary key"): + DBTableSource(connector, "t") + + def test_empty_table_raises_value_error(self, connector): + connector._tables["empty"] = pa.table( + {"id": pa.array([], type=pa.large_string())} + ) + connector._pk_columns["empty"] = ["id"] + with pytest.raises(ValueError, match="is empty"): + DBTableSource(connector, "empty") + + +# --------------------------------------------------------------------------- +# Stream behaviour +# --------------------------------------------------------------------------- + + +class TestStreamBehaviour: + def test_no_upstream_producer(self, source): + assert source.producer is None + + def test_empty_upstreams(self, source): + assert source.upstreams == () + + def test_iter_packets_yields_correct_count(self, source, measurements_table): + packets = list(source.iter_packets()) + assert len(packets) == measurements_table.num_rows + + def test_output_schema_has_correct_columns(self, source): + tag_schema, packet_schema = source.output_schema() + assert "session_id" in tag_schema + assert "trial" in packet_schema + assert "response" in packet_schema + + def test_as_table_returns_arrow_table(self, source, measurements_table): + t = source.as_table() + assert t.num_rows == measurements_table.num_rows + + def test_pipeline_hash_is_deterministic(self, connector): + src1 = DBTableSource(connector, "measurements") + src2 = DBTableSource(connector, "measurements") + assert src1.pipeline_hash() == src2.pipeline_hash() + + def test_content_hash_is_deterministic(self, connector): + src1 = DBTableSource(connector, "measurements") + src2 = DBTableSource(connector, "measurements") + assert src1.content_hash() == src2.content_hash() + + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + + +class TestConfig: + def test_to_config_has_required_keys(self, source): + config = source.to_config() + assert config["source_type"] == "db_table" + assert config["table_name"] == "measurements" + assert "tag_columns" in config + assert "connector" in config + assert config["connector"]["connector_type"] == "mock" + + def test_from_config_raises_not_implemented(self, source): + config = source.to_config() + with pytest.raises(NotImplementedError): + DBTableSource.from_config(config) +``` + +- [ ] **Step 7: Run all tests** + +```bash +uv run pytest tests/test_core/sources/test_db_table_source.py -v +``` +Expected: All PASS + +- [ ] **Step 8: Commit** + +```bash +git add src/orcapod/core/sources/db_table_source.py src/orcapod/core/sources/__init__.py tests/test_core/sources/test_db_table_source.py +git commit -m "feat(sources): add DBTableSource for PLT-1078" +``` + +--- + +## Task 4: Run the Full Test Suite + +- [ ] **Step 1: Run all existing tests to verify no regressions** + +```bash +uv run pytest tests/ -v --tb=short +``` +Expected: All previously-passing tests still PASS; new tests also PASS. + +- [ ] **Step 2: If any failures, fix them before proceeding** + +- [ ] **Step 3: Commit any fixes** + +```bash +git add -A +git commit -m "fix: address regressions from PLT-1078 interface changes" +``` + +--- + +## Task 5: Post Design Note to Linear + +- [ ] **Step 1: Post the design decisions as a comment on PLT-1078** + +Post a comment summarising: three-layer architecture (DBConnectorProtocol → ConnectorArrowDatabase / DBTableSource), all six design decisions from the decisions log, the new file map, and pointers to PLT-1074/1075/1076 for next steps. + +- [ ] **Step 2: Update PLT-1078 status to "In Review"** (or the equivalent completed status) + +--- + +## Task 6: Open PR + +- [ ] **Step 1: Push the branch** + +```bash +gh-app-token-generator nauticalab | gh auth login --with-token +git push -u origin eywalker/plt-1078-design-spike-clean-up-databasesource-interface-to-streamline +``` + +- [ ] **Step 2: Open PR targeting `dev`** + +```bash +gh pr create \ + --title "feat: DBConnectorProtocol + ConnectorArrowDatabase + DBTableSource (PLT-1078)" \ + --base dev \ + --body "$(cat <<'EOF' +## Summary + +Closes PLT-1078 — design spike: clean up Database/Source interface. + +Introduces a three-layer abstraction so each DB technology (SQLite, PostgreSQL, SpiralDB) only needs to implement **one class** (`DBConnectorProtocol`) to power both the `ArrowDatabaseProtocol` layer (read+write memoization) and the `Source` layer (read-only pipeline ingestion). + +### New files +- `src/orcapod/protocols/db_connector_protocol.py` — `ColumnInfo`, `DBConnectorProtocol` +- `src/orcapod/databases/connector_arrow_database.py` — `ConnectorArrowDatabase` +- `src/orcapod/core/sources/db_table_source.py` — `DBTableSource` +- `tests/test_databases/test_connector_arrow_database.py` +- `tests/test_core/sources/test_db_table_source.py` + +### Updated files +- `src/orcapod/protocols/database_protocols.py` — re-exports +- `src/orcapod/databases/__init__.py` — exports + backend comments +- `src/orcapod/core/sources/__init__.py` — exports + +### Key decisions +| Question | Decision | +|---|---| +| Protocol vs ABC? | `Protocol` (structural subtyping, no import coupling) | +| Generic source vs per-DB subclasses? | Single `DBTableSource(connector, table_name)` | +| Type mapping ownership? | Connector — callers always see Arrow types | +| Upsert abstraction? | `upsert_records(skip_existing)` hides SQL dialect differences | +| Pending-batch location? | `ConnectorArrowDatabase` (Python-side, mirrors existing impls) | +| Schema evolution? | Out of scope for this spike; `ValueError` on mismatch | + +### Unblocks +- PLT-1074 (SpiralDBConnector) +- PLT-1075 (PostgreSQLConnector) +- PLT-1076 (SQLiteConnector) +- PLT-1072, PLT-1073, PLT-1077 (DB-backed Sources — just `DBTableSource(connector, table_name)`) + +## Test plan +- [x] `ConnectorArrowDatabase` satisfies `ArrowDatabaseProtocol` (isinstance check) +- [x] `MockDBConnector` satisfies `DBConnectorProtocol` (isinstance check) +- [x] Full add/get/flush/skip-duplicates behaviour via mock connector +- [x] `DBTableSource` satisfies `SourceProtocol`, `StreamProtocol`, `PipelineElementProtocol` +- [x] PK columns used as default tag columns; explicit override works +- [x] Missing table / empty table / no-PK errors raised correctly +- [x] `to_config` round-trip; `from_config` raises `NotImplementedError` +- [x] No regressions in existing test suite + +🤖 Generated with [Claude Code](https://claude.com/claude-code) +EOF +)" +``` From ebfbc711c2cabafe5f4580cecaf22a77ad238a00 Mon Sep 17 00:00:00 2001 From: "agent-kurouto[bot]" <268466204+agent-kurouto[bot]@users.noreply.github.com> Date: Sat, 21 Mar 2026 08:05:30 +0000 Subject: [PATCH 4/8] feat(plt-1078): add DBConnectorProtocol, ConnectorArrowDatabase, DBTableSource --- src/orcapod/core/sources/__init__.py | 2 + src/orcapod/core/sources/db_table_source.py | 139 ++++ src/orcapod/databases/__init__.py | 21 +- .../databases/connector_arrow_database.py | 406 +++++++++++ src/orcapod/protocols/database_protocols.py | 13 + .../protocols/db_connector_protocol.py | 160 +++++ .../test_core/sources/test_db_table_source.py | 441 ++++++++++++ .../test_connector_arrow_database.py | 678 ++++++++++++++++++ 8 files changed, 1852 insertions(+), 8 deletions(-) create mode 100644 src/orcapod/core/sources/db_table_source.py create mode 100644 src/orcapod/databases/connector_arrow_database.py create mode 100644 src/orcapod/protocols/db_connector_protocol.py create mode 100644 tests/test_core/sources/test_db_table_source.py create mode 100644 tests/test_databases/test_connector_arrow_database.py diff --git a/src/orcapod/core/sources/__init__.py b/src/orcapod/core/sources/__init__.py index 3f3a795..d0337fb 100644 --- a/src/orcapod/core/sources/__init__.py +++ b/src/orcapod/core/sources/__init__.py @@ -3,6 +3,7 @@ from .cached_source import CachedSource from .csv_source import CSVSource from .data_frame_source import DataFrameSource +from .db_table_source import DBTableSource from .delta_table_source import DeltaTableSource from .derived_source import DerivedSource from .dict_source import DictSource @@ -16,6 +17,7 @@ "CachedSource", "CSVSource", "DataFrameSource", + "DBTableSource", "DeltaTableSource", "DerivedSource", "DictSource", diff --git a/src/orcapod/core/sources/db_table_source.py b/src/orcapod/core/sources/db_table_source.py new file mode 100644 index 0000000..873f231 --- /dev/null +++ b/src/orcapod/core/sources/db_table_source.py @@ -0,0 +1,139 @@ +"""DBTableSource — a read-only RootSource backed by any DBConnectorProtocol. + +Uses the table's primary-key columns as tag columns by default. +Type mapping (DB-native → Arrow) is fully delegated to the connector. + +Example:: + + connector = SQLiteConnector(":memory:") # PLT-1076 + source = DBTableSource(connector, "measurements") # PKs → tags + source = DBTableSource(connector, "events", tag_columns=["session_id"]) +""" +from __future__ import annotations + +from collections.abc import Collection +from typing import TYPE_CHECKING, Any + +from orcapod.core.sources.base import RootSource +from orcapod.core.sources.stream_builder import SourceStreamBuilder +from orcapod.utils.lazy_module import LazyModule + +if TYPE_CHECKING: + import pyarrow as pa + + from orcapod.protocols.db_connector_protocol import DBConnectorProtocol +else: + pa = LazyModule("pyarrow") + + +class DBTableSource(RootSource): + """A read-only Source backed by a table in any DBConnectorProtocol database. + + At construction time the source: + 1. Validates the table exists in the connector. + 2. Resolves tag columns (defaults to the table's primary-key columns). + 3. Fetches all rows as Arrow batches and assembles a PyArrow table. + 4. Enriches via ``SourceStreamBuilder`` (source-info, schema-hash, system tags). + + Args: + connector: A ``DBConnectorProtocol`` providing DB access. + table_name: Name of the table to expose as a source. + tag_columns: Columns to use as tag columns. If ``None`` (default), + the table's primary-key columns are used. Raises ``ValueError`` + if the table has no primary key and no explicit columns are given. + system_tag_columns: Additional system-level tag columns (passed through + to ``SourceStreamBuilder``; mirrors ``DeltaTableSource`` API). + record_id_column: Column for stable per-row record IDs in provenance + strings. If ``None``, row indices are used. + source_id: Canonical source name for the registry and provenance tokens. + Defaults to ``table_name``. + **kwargs: Forwarded to ``RootSource`` (``label``, ``data_context``, + ``config``). + + Raises: + ValueError: If the table is not found, has no PK columns and none are + provided, or is empty. + """ + + def __init__( + self, + connector: "DBConnectorProtocol", + table_name: str, + tag_columns: Collection[str] | None = None, + system_tag_columns: Collection[str] = (), + record_id_column: str | None = None, + source_id: str | None = None, + **kwargs: Any, + ) -> None: + if source_id is None: + source_id = table_name + super().__init__(source_id=source_id, **kwargs) + + self._connector = connector + self._table_name = table_name + self._record_id_column = record_id_column + + # Step 1: Validate the table exists first so the error is always + # "not found" rather than a misleading "no primary key" for missing tables. + if table_name not in connector.get_table_names(): + raise ValueError(f"Table {table_name!r} not found in database.") + + # Step 2: Resolve tag columns — default to PK columns + if tag_columns is None: + resolved_tag_columns: list[str] = connector.get_pk_columns(table_name) + if not resolved_tag_columns: + raise ValueError( + f"Table {table_name!r} has no primary key columns. " + "Provide explicit tag_columns." + ) + else: + resolved_tag_columns = list(tag_columns) + + # Step 3: Fetch the full table as Arrow + batches = list(connector.iter_batches(f'SELECT * FROM "{table_name}"')) + if not batches: + raise ValueError(f"Table {table_name!r} is empty.") + table: pa.Table = pa.Table.from_batches(batches) + + # Step 4: Enrich via SourceStreamBuilder (same pipeline as all other RootSources) + builder = SourceStreamBuilder(self.data_context, self.orcapod_config) + result = builder.build( + table, + tag_columns=resolved_tag_columns, + source_id=self._source_id, + record_id_column=record_id_column, + system_tag_columns=system_tag_columns, + ) + + self._stream = result.stream + self._tag_columns = result.tag_columns + self._system_tag_columns = result.system_tag_columns + if self._source_id is None: + self._source_id = result.source_id + + def to_config(self) -> dict[str, Any]: + """Serialize source configuration to a JSON-compatible dict.""" + return { + "source_type": "db_table", + "connector": self._connector.to_config(), + "table_name": self._table_name, + "tag_columns": list(self._tag_columns), + "system_tag_columns": list(self._system_tag_columns), + "record_id_column": self._record_id_column, + "source_id": self.source_id, + **self._identity_config(), + } + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "DBTableSource": + """Not yet implemented — requires a connector factory registry. + + Raises: + NotImplementedError: Always, until connector implementations add + a registry helper (``build_db_connector_from_config``) in + PLT-1074/1075/1076. + """ + raise NotImplementedError( + "DBTableSource.from_config requires a registered connector factory. " + "Implement build_db_connector_from_config in PLT-1074/1075/1076." + ) diff --git a/src/orcapod/databases/__init__.py b/src/orcapod/databases/__init__.py index e8556e8..fe59b3d 100644 --- a/src/orcapod/databases/__init__.py +++ b/src/orcapod/databases/__init__.py @@ -1,21 +1,26 @@ +from .connector_arrow_database import ConnectorArrowDatabase from .delta_lake_databases import DeltaTableDatabase from .in_memory_databases import InMemoryArrowDatabase from .noop_database import NoOpArrowDatabase __all__ = [ + "ConnectorArrowDatabase", "DeltaTableDatabase", "InMemoryArrowDatabase", "NoOpArrowDatabase", ] -# Future ArrowDatabaseProtocol backends to implement: +# Relational DB connector implementations satisfy DBConnectorProtocol +# (orcapod.protocols.db_connector_protocol) and can be passed to either +# ConnectorArrowDatabase (read+write ArrowDatabaseProtocol) or +# DBTableSource (read-only Source): # -# ParquetArrowDatabase -- stores each record_path as a partitioned Parquet -# directory; simpler, no Delta Lake dependency, -# suitable for write-once / read-heavy workloads. +# SQLiteConnector -- PLT-1076 (stdlib sqlite3, zero extra deps) +# PostgreSQLConnector -- PLT-1075 (psycopg3) +# SpiralDBConnector -- PLT-1074 # -# IcebergArrowDatabase -- Apache Iceberg backend for cloud-native / -# object-store deployments. +# ArrowDatabaseProtocol backends (existing, not connector-based): # -# All backends must satisfy the ArrowDatabaseProtocol protocol defined in -# orcapod.protocols.database_protocols. +# DeltaTableDatabase -- Delta Lake (deltalake package) +# InMemoryArrowDatabase -- pure in-memory, for tests +# NoOpArrowDatabase -- no-op, for dry-runs / benchmarks diff --git a/src/orcapod/databases/connector_arrow_database.py b/src/orcapod/databases/connector_arrow_database.py new file mode 100644 index 0000000..7d999cf --- /dev/null +++ b/src/orcapod/databases/connector_arrow_database.py @@ -0,0 +1,406 @@ +"""ConnectorArrowDatabase — generic ArrowDatabaseProtocol backed by any DBConnectorProtocol. + +Implements the full ArrowDatabaseProtocol on top of any DBConnectorProtocol, +owning all record-management logic: record_path → table name mapping, +``__record_id`` column convention, in-memory pending-batch management, +deduplication, upsert, and flush. + +Connector implementations (SQLiteConnector, PostgreSQLConnector, SpiralDBConnector) +need only satisfy DBConnectorProtocol; they do not implement ArrowDatabaseProtocol. + +Example:: + + connector = SQLiteConnector(":memory:") # PLT-1076 + db = ConnectorArrowDatabase(connector) + db.add_record(("results", "my_fn"), record_id="abc", record=table) + db.flush() +""" +from __future__ import annotations + +import re +from collections import defaultdict +from collections.abc import Collection, Mapping +from typing import TYPE_CHECKING, Any, cast + +from orcapod.protocols.db_connector_protocol import ColumnInfo, DBConnectorProtocol +from orcapod.utils.lazy_module import LazyModule + +if TYPE_CHECKING: + import pyarrow as pa + import pyarrow.compute as pc +else: + pa = LazyModule("pyarrow") + pc = LazyModule("pyarrow.compute") + + +def _arrow_schema_to_column_infos(schema: "pa.Schema") -> list[ColumnInfo]: + """Convert a PyArrow schema to a list of ColumnInfo.""" + return [ + ColumnInfo(name=field.name, arrow_type=field.type, nullable=field.nullable) + for field in schema + ] + + +class ConnectorArrowDatabase: + """Generic ``ArrowDatabaseProtocol`` implementation backed by a ``DBConnectorProtocol``. + + Records are buffered in memory (pending batch) and written to the connector + on ``flush()``. The ``record_path`` tuple is mapped to a sanitized SQL table + name using ``"__".join(sanitized_parts)``. + + Args: + connector: A ``DBConnectorProtocol`` implementation providing the + underlying DB access (connection, type mapping, queries, writes). + max_hierarchy_depth: Maximum allowed length for ``record_path`` tuples. + Defaults to 10, matching ``InMemoryArrowDatabase``. + """ + + RECORD_ID_COLUMN = "__record_id" + _ROW_INDEX_COLUMN = "__row_index" + + def __init__( + self, + connector: DBConnectorProtocol, + max_hierarchy_depth: int = 10, + ) -> None: + self._connector = connector + self.max_hierarchy_depth = max_hierarchy_depth + self._pending_batches: dict[str, pa.Table] = {} + self._pending_record_ids: dict[str, set[str]] = defaultdict(set) + + # ── Path helpers ────────────────────────────────────────────────────────── + + def _get_record_key(self, record_path: tuple[str, ...]) -> str: + return "/".join(record_path) + + def _path_to_table_name(self, record_path: tuple[str, ...]) -> str: + """Map a record_path to a safe SQL table name. + + Each component is sanitized (non-alphanumeric chars → ``_``), then + joined with ``__`` as separator. A ``t_`` prefix is added if the result + starts with a digit to ensure a valid SQL identifier. + """ + parts = [re.sub(r"[^a-zA-Z0-9_]", "_", part) for part in record_path] + name = "__".join(parts) + if name and name[0].isdigit(): + name = "t_" + name + return name + + def _validate_record_path(self, record_path: tuple[str, ...]) -> None: + if not record_path: + raise ValueError("record_path cannot be empty") + if len(record_path) > self.max_hierarchy_depth: + raise ValueError( + f"record_path depth {len(record_path)} exceeds maximum " + f"{self.max_hierarchy_depth}" + ) + for i, component in enumerate(record_path): + if not component or not isinstance(component, str): + raise ValueError( + f"record_path component {i} is invalid: {repr(component)}" + ) + + # ── Record-ID column helpers ────────────────────────────────────────────── + + def _ensure_record_id_column( + self, arrow_data: "pa.Table", record_id: str + ) -> "pa.Table": + if self.RECORD_ID_COLUMN not in arrow_data.column_names: + key_array = pa.array( + [record_id] * len(arrow_data), type=pa.large_string() + ) + arrow_data = arrow_data.add_column(0, self.RECORD_ID_COLUMN, key_array) + return arrow_data + + def _remove_record_id_column(self, arrow_data: "pa.Table") -> "pa.Table": + if self.RECORD_ID_COLUMN in arrow_data.column_names: + arrow_data = arrow_data.drop([self.RECORD_ID_COLUMN]) + return arrow_data + + def _handle_record_id_column( + self, arrow_data: "pa.Table", record_id_column: str | None + ) -> "pa.Table": + if not record_id_column: + return self._remove_record_id_column(arrow_data) + if self.RECORD_ID_COLUMN in arrow_data.column_names: + new_names = [ + record_id_column if name == self.RECORD_ID_COLUMN else name + for name in arrow_data.schema.names + ] + return arrow_data.rename_columns(new_names) + raise ValueError( + f"Record ID column '{self.RECORD_ID_COLUMN}' not found in table." + ) + + # ── Deduplication ───────────────────────────────────────────────────────── + + def _deduplicate_within_table(self, table: "pa.Table") -> "pa.Table": + """Keep the last occurrence of each record ID within a single table.""" + if table.num_rows <= 1: + return table + indices = pa.array(range(table.num_rows)) + table_with_idx = table.add_column(0, self._ROW_INDEX_COLUMN, indices) + grouped = table_with_idx.group_by([self.RECORD_ID_COLUMN]).aggregate( + [(self._ROW_INDEX_COLUMN, "max")] + ) + max_indices = grouped[f"{self._ROW_INDEX_COLUMN}_max"].to_pylist() + mask = pc.is_in(indices, pa.array(max_indices)) + return table.filter(mask) + + # ── Committed data access ───────────────────────────────────────────────── + + def _get_committed_table( + self, record_path: tuple[str, ...] + ) -> "pa.Table | None": + """Fetch all committed records for a path from the connector.""" + table_name = self._path_to_table_name(record_path) + if table_name not in self._connector.get_table_names(): + return None + batches = list( + self._connector.iter_batches(f'SELECT * FROM "{table_name}"') + ) + if not batches: + return None + return pa.Table.from_batches(batches) + + # ── Write methods ───────────────────────────────────────────────────────── + + def add_record( + self, + record_path: tuple[str, ...], + record_id: str, + record: "pa.Table", + skip_duplicates: bool = False, + flush: bool = False, + ) -> None: + """Add a single record identified by ``record_id``.""" + data_with_id = self._ensure_record_id_column(record, record_id) + self.add_records( + record_path=record_path, + records=data_with_id, + record_id_column=self.RECORD_ID_COLUMN, + skip_duplicates=skip_duplicates, + flush=flush, + ) + + def add_records( + self, + record_path: tuple[str, ...], + records: "pa.Table", + record_id_column: str | None = None, + skip_duplicates: bool = False, + flush: bool = False, + ) -> None: + """Add multiple records to the pending batch.""" + self._validate_record_path(record_path) + if records.num_rows == 0: + return + + if record_id_column is None: + record_id_column = records.column_names[0] + if record_id_column not in records.column_names: + raise ValueError( + f"record_id_column '{record_id_column}' not found in table columns: " + f"{records.column_names}" + ) + + # Normalise to internal column name + if record_id_column != self.RECORD_ID_COLUMN: + rename_map = {record_id_column: self.RECORD_ID_COLUMN} + records = records.rename_columns( + [rename_map.get(c, c) for c in records.column_names] + ) + + records = self._deduplicate_within_table(records) + record_key = self._get_record_key(record_path) + input_ids = set(cast(list[str], records[self.RECORD_ID_COLUMN].to_pylist())) + + if skip_duplicates: + committed = self._get_committed_table(record_path) + committed_ids: set[str] = set() + if committed is not None: + committed_ids = set( + cast(list[str], committed[self.RECORD_ID_COLUMN].to_pylist()) + ) + all_existing = (input_ids & self._pending_record_ids[record_key]) | ( + input_ids & committed_ids + ) + if all_existing: + mask = pc.invert( + pc.is_in( + records[self.RECORD_ID_COLUMN], pa.array(list(all_existing)) + ) + ) + records = records.filter(mask) + if records.num_rows == 0: + return + else: + conflicts = input_ids & self._pending_record_ids[record_key] + if conflicts: + raise ValueError( + f"Records with IDs {conflicts} already exist in the pending batch. " + "Use skip_duplicates=True to skip them." + ) + + # Buffer in pending batch + existing_pending = self._pending_batches.get(record_key) + if existing_pending is None: + self._pending_batches[record_key] = records + else: + self._pending_batches[record_key] = pa.concat_tables( + [existing_pending, records] + ) + self._pending_record_ids[record_key].update( + cast(list[str], records[self.RECORD_ID_COLUMN].to_pylist()) + ) + + if flush: + self.flush() + + # ── Flush ───────────────────────────────────────────────────────────────── + + def flush(self) -> None: + """Commit all pending batches to the connector via upsert.""" + for record_key in list(self._pending_batches.keys()): + record_path = tuple(record_key.split("/")) + table_name = self._path_to_table_name(record_path) + pending = self._pending_batches.pop(record_key) + self._pending_record_ids.pop(record_key, None) + + columns = _arrow_schema_to_column_infos(pending.schema) + self._connector.create_table_if_not_exists( + table_name, columns, pk_column=self.RECORD_ID_COLUMN + ) + self._connector.upsert_records( + table_name, + pending, + id_column=self.RECORD_ID_COLUMN, + skip_existing=False, + ) + + # ── Read methods ────────────────────────────────────────────────────────── + + def get_record_by_id( + self, + record_path: tuple[str, ...], + record_id: str, + record_id_column: str | None = None, + flush: bool = False, + ) -> "pa.Table | None": + if flush: + self.flush() + record_key = self._get_record_key(record_path) + + # Check pending first + if record_id in self._pending_record_ids.get(record_key, set()): + pending = self._pending_batches[record_key] + filtered = pending.filter(pc.field(self.RECORD_ID_COLUMN) == record_id) + if filtered.num_rows > 0: + return self._handle_record_id_column(filtered, record_id_column) + + # Check committed + committed = self._get_committed_table(record_path) + if committed is None: + return None + filtered = committed.filter(pc.field(self.RECORD_ID_COLUMN) == record_id) + if filtered.num_rows == 0: + return None + return self._handle_record_id_column(filtered, record_id_column) + + def get_all_records( + self, + record_path: tuple[str, ...], + record_id_column: str | None = None, + ) -> "pa.Table | None": + record_key = self._get_record_key(record_path) + parts: list[pa.Table] = [] + + committed = self._get_committed_table(record_path) + if committed is not None and committed.num_rows > 0: + parts.append(committed) + pending = self._pending_batches.get(record_key) + if pending is not None and pending.num_rows > 0: + parts.append(pending) + + if not parts: + return None + table = parts[0] if len(parts) == 1 else pa.concat_tables(parts) + return self._handle_record_id_column(table, record_id_column) + + def get_records_by_ids( + self, + record_path: tuple[str, ...], + record_ids: Collection[str], + record_id_column: str | None = None, + flush: bool = False, + ) -> "pa.Table | None": + if flush: + self.flush() + ids_list = list(record_ids) + if not ids_list: + return None + all_records = self.get_all_records( + record_path, record_id_column=self.RECORD_ID_COLUMN + ) + if all_records is None: + return None + filtered = all_records.filter( + pc.is_in(all_records[self.RECORD_ID_COLUMN], pa.array(ids_list)) + ) + if filtered.num_rows == 0: + return None + return self._handle_record_id_column(filtered, record_id_column) + + def get_records_with_column_value( + self, + record_path: tuple[str, ...], + column_values: Collection[tuple[str, Any]] | Mapping[str, Any], + record_id_column: str | None = None, + flush: bool = False, + ) -> "pa.Table | None": + if flush: + self.flush() + all_records = self.get_all_records( + record_path, record_id_column=self.RECORD_ID_COLUMN + ) + if all_records is None: + return None + + if isinstance(column_values, Mapping): + pairs = list(column_values.items()) + else: + pairs = cast(list[tuple[str, Any]], list(column_values)) + + expr = None + for col, val in pairs: + e = pc.field(col) == val + expr = e if expr is None else expr & e # type: ignore[assignment] + + filtered = all_records.filter(expr) + if filtered.num_rows == 0: + return None + return self._handle_record_id_column(filtered, record_id_column) + + # ── Config ──────────────────────────────────────────────────────────────── + + def to_config(self) -> dict[str, Any]: + """Serialize configuration to a JSON-compatible dict.""" + return { + "type": "connector_arrow_database", + "connector": self._connector.to_config(), + "max_hierarchy_depth": self.max_hierarchy_depth, + } + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "ConnectorArrowDatabase": + """Reconstruct a ConnectorArrowDatabase from config. + + Raises: + NotImplementedError: Always — requires a connector registry that + maps ``connector_type`` keys to ``from_config`` classmethods. + Implement alongside connector classes in PLT-1074/1075/1076. + """ + raise NotImplementedError( + "ConnectorArrowDatabase.from_config requires a registered connector " + "factory (connector_type → class). Implement in PLT-1074/1075/1076." + ) diff --git a/src/orcapod/protocols/database_protocols.py b/src/orcapod/protocols/database_protocols.py index 9af7608..4033ad5 100644 --- a/src/orcapod/protocols/database_protocols.py +++ b/src/orcapod/protocols/database_protocols.py @@ -102,3 +102,16 @@ class ArrowDatabaseWithMetadataProtocol( """A protocol that combines ArrowDatabaseProtocol with metadata capabilities.""" pass + + +# Re-export connector abstractions so callers can import everything DB-related +# from one place: ``from orcapod.protocols.database_protocols import ...`` +from orcapod.protocols.db_connector_protocol import ColumnInfo, DBConnectorProtocol # noqa: E402 + +__all__ = [ + "ArrowDatabaseProtocol", + "ArrowDatabaseWithMetadataProtocol", + "ColumnInfo", + "DBConnectorProtocol", + "MetadataCapableProtocol", +] diff --git a/src/orcapod/protocols/db_connector_protocol.py b/src/orcapod/protocols/db_connector_protocol.py new file mode 100644 index 0000000..7b0dc2a --- /dev/null +++ b/src/orcapod/protocols/db_connector_protocol.py @@ -0,0 +1,160 @@ +"""DBConnectorProtocol — minimal shared interface for external relational DB backends. + +Each DB technology (SQLite, PostgreSQL, SpiralDB) implements this once. +Both ``ConnectorArrowDatabase`` (read+write) and ``DBTableSource`` (read-only) +depend on it, eliminating duplicated connection management and type-mapping logic. + +Planned implementations: + SQLiteConnector -- PLT-1076 (stdlib sqlite3, zero extra deps) + PostgreSQLConnector -- PLT-1075 (psycopg3) + SpiralDBConnector -- PLT-1074 +""" +from __future__ import annotations + +from collections.abc import Iterator +from dataclasses import dataclass +from typing import Any, Protocol, TYPE_CHECKING, runtime_checkable + +if TYPE_CHECKING: + import pyarrow as pa + + +@dataclass(frozen=True) +class ColumnInfo: + """Metadata for a single database column with its Arrow-mapped type. + + Type mapping (DB-native → Arrow) is the connector's responsibility. + Consumers of ``DBConnectorProtocol`` always see Arrow types. + + Args: + name: Column name. + arrow_type: Arrow data type (already mapped from the DB-native type). + nullable: Whether the column accepts NULL values. + """ + + name: str + arrow_type: "pa.DataType" + nullable: bool = True + + +@runtime_checkable +class DBConnectorProtocol(Protocol): + """Minimal interface for an external relational database backend. + + Implementations encapsulate: + - Connection lifecycle + - DB-native ↔ Arrow type mapping + - Schema introspection + - Query execution (reads) and record management (writes) + + Read methods are used by both ``ConnectorArrowDatabase`` and ``DBTableSource``. + Write methods (``create_table_if_not_exists``, ``upsert_records``) are used + only by ``ConnectorArrowDatabase``. + + All query results are returned as Arrow types; connectors handle all + DB-native type conversion internally. + + Planned implementations: ``SQLiteConnector`` (PLT-1076), + ``PostgreSQLConnector`` (PLT-1075), ``SpiralDBConnector`` (PLT-1074). + """ + + # ── Schema introspection ────────────────────────────────────────────────── + + def get_table_names(self) -> list[str]: + """Return all available table names in this database.""" + ... + + def get_pk_columns(self, table_name: str) -> list[str]: + """Return primary-key column names for a table, in key-sequence order. + + Returns an empty list if the table has no primary key. + """ + ... + + def get_column_info(self, table_name: str) -> list[ColumnInfo]: + """Return column metadata for a table, with types mapped to Arrow.""" + ... + + # ── Read ────────────────────────────────────────────────────────────────── + + def iter_batches( + self, + query: str, + params: Any = None, + batch_size: int = 1000, + ) -> Iterator["pa.RecordBatch"]: + """Execute a query and yield results as Arrow RecordBatches. + + Args: + query: SQL query string. Table names should be double-quoted + (``SELECT * FROM "my_table"``); all connectors must support + ANSI-standard double-quoted identifiers. + params: Optional query parameters (connector-specific format). + batch_size: Maximum rows per yielded batch. + """ + ... + + # ── Write ───────────────────────────────────────────────────────────────── + + def create_table_if_not_exists( + self, + table_name: str, + columns: list[ColumnInfo], + pk_column: str, + ) -> None: + """Create a table with the given columns if it does not already exist. + + Args: + table_name: Table to create. + columns: Column definitions with Arrow-mapped types. + pk_column: Name of the column to use as the primary key. + """ + ... + + def upsert_records( + self, + table_name: str, + records: "pa.Table", + id_column: str, + skip_existing: bool = False, + ) -> None: + """Write records to a table using upsert semantics. + + Args: + table_name: Target table (must already exist). + records: Arrow table of records to write. + id_column: Column used as the unique row identifier. + skip_existing: If ``True``, skip records whose ``id_column`` value + already exists in the table (INSERT OR IGNORE). + If ``False``, overwrite existing records (INSERT OR REPLACE). + """ + ... + + # ── Lifecycle ───────────────────────────────────────────────────────────── + + def close(self) -> None: + """Release the database connection and any associated resources.""" + ... + + def __enter__(self) -> "DBConnectorProtocol": + ... + + def __exit__(self, *args: Any) -> None: + ... + + # ── Serialization ───────────────────────────────────────────────────────── + + def to_config(self) -> dict[str, Any]: + """Serialize connection configuration to a JSON-compatible dict. + + The returned dict must include a ``"connector_type"`` key + (e.g., ``"sqlite"``, ``"postgresql"``, ``"spiraldb"``) so that + a registry helper can dispatch to the correct ``from_config`` + classmethod when deserializing. + """ + ... + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "DBConnectorProtocol": + """Reconstruct a connector instance from a config dict.""" + ... diff --git a/tests/test_core/sources/test_db_table_source.py b/tests/test_core/sources/test_db_table_source.py new file mode 100644 index 0000000..0816dc9 --- /dev/null +++ b/tests/test_core/sources/test_db_table_source.py @@ -0,0 +1,441 @@ +""" +Comprehensive tests for DBTableSource using MockDBConnector (no external DB required). + +Test sections: + 1. Import / export sanity + 2. MockDBConnector satisfies DBConnectorProtocol + 3. DBTableSource protocol conformance (SourceProtocol, StreamProtocol, PipelineElementProtocol) + 4. Construction — default tag columns (PK), explicit tag columns, source_id + 5. Construction error cases — missing table, no PK columns, empty table + 6. Stream behaviour — iter_packets count, output_schema, as_table, producer/upstreams + 7. Deterministic hashing (pipeline_hash, content_hash) + 8. Config — to_config shape, from_config raises NotImplementedError +""" +from __future__ import annotations + +import re +from collections.abc import Iterator +from typing import Any + +import pyarrow as pa +import pytest + +from orcapod.core.sources import DBTableSource +from orcapod.protocols.core_protocols import SourceProtocol, StreamProtocol +from orcapod.protocols.db_connector_protocol import ColumnInfo, DBConnectorProtocol +from orcapod.protocols.hashing_protocols import PipelineElementProtocol + + +# --------------------------------------------------------------------------- +# MockDBConnector — minimal in-memory DBConnectorProtocol for these tests +# --------------------------------------------------------------------------- + + +class MockDBConnector: + """Read-only in-memory connector for DBTableSource tests. + + Write methods (create_table_if_not_exists, upsert_records) are no-ops + because DBTableSource never calls them. + """ + + def __init__( + self, + tables: dict[str, pa.Table] | None = None, + pk_columns: dict[str, list[str]] | None = None, + ): + self._tables: dict[str, pa.Table] = dict(tables or {}) + self._pk_columns: dict[str, list[str]] = dict(pk_columns or {}) + + def get_table_names(self) -> list[str]: + return list(self._tables.keys()) + + def get_pk_columns(self, table_name: str) -> list[str]: + return list(self._pk_columns.get(table_name, [])) + + def get_column_info(self, table_name: str) -> list[ColumnInfo]: + schema = self._tables[table_name].schema + return [ColumnInfo(name=f.name, arrow_type=f.type) for f in schema] + + def iter_batches( + self, query: str, params: Any = None, batch_size: int = 1000 + ) -> Iterator[pa.RecordBatch]: + match = re.search(r'FROM\s+"?(\w+)"?', query, re.IGNORECASE) + if not match: + return + table_name = match.group(1) + table = self._tables.get(table_name) + if table is None or table.num_rows == 0: + return + for batch in table.to_batches(max_chunksize=batch_size): + yield batch + + def create_table_if_not_exists(self, *args: Any, **kwargs: Any) -> None: + pass # not used by DBTableSource + + def upsert_records(self, *args: Any, **kwargs: Any) -> None: + pass # not used by DBTableSource + + def close(self) -> None: + pass + + def __enter__(self) -> "MockDBConnector": + return self + + def __exit__(self, *args: Any) -> None: + pass + + def to_config(self) -> dict[str, Any]: + return {"connector_type": "mock"} + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "MockDBConnector": + return cls() + + +# --------------------------------------------------------------------------- +# Standard fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def measurements_table() -> pa.Table: + return pa.table( + { + "session_id": pa.array(["s1", "s2", "s3"], type=pa.large_string()), + "trial": pa.array([1, 2, 3], type=pa.int64()), + "response": pa.array([0.1, 0.2, 0.3], type=pa.float64()), + } + ) + + +@pytest.fixture +def connector(measurements_table) -> MockDBConnector: + return MockDBConnector( + tables={"measurements": measurements_table}, + pk_columns={"measurements": ["session_id"]}, + ) + + +@pytest.fixture +def source(connector) -> DBTableSource: + return DBTableSource(connector, "measurements") + + +# =========================================================================== +# 1. Import / export sanity +# =========================================================================== + + +def test_import_db_table_source_from_core_sources(): + from orcapod.core.sources import DBTableSource as _DBTableSource + assert _DBTableSource is not None + + +def test_db_table_source_is_in_all(): + import orcapod.core.sources as sources_module + assert "DBTableSource" in sources_module.__all__ + + +# =========================================================================== +# 2. MockDBConnector satisfies DBConnectorProtocol +# =========================================================================== + + +class TestMockConnectorProtocol: + def test_satisfies_db_connector_protocol(self, connector): + assert isinstance(connector, DBConnectorProtocol) + + def test_get_table_names_returns_list(self, connector): + names = connector.get_table_names() + assert isinstance(names, list) + assert "measurements" in names + + def test_get_pk_columns_returns_list(self, connector): + pks = connector.get_pk_columns("measurements") + assert pks == ["session_id"] + + def test_iter_batches_yields_record_batches(self, connector): + batches = list(connector.iter_batches('SELECT * FROM "measurements"')) + assert len(batches) > 0 + assert all(isinstance(b, pa.RecordBatch) for b in batches) + + def test_iter_batches_total_rows_match(self, connector, measurements_table): + batches = list(connector.iter_batches('SELECT * FROM "measurements"')) + total = sum(b.num_rows for b in batches) + assert total == measurements_table.num_rows + + def test_iter_batches_missing_table_yields_nothing(self, connector): + batches = list(connector.iter_batches('SELECT * FROM "no_such_table"')) + assert batches == [] + + +# =========================================================================== +# 3. DBTableSource protocol conformance +# =========================================================================== + + +class TestProtocolConformance: + def test_is_source_protocol(self, source): + assert isinstance(source, SourceProtocol) + + def test_is_stream_protocol(self, source): + assert isinstance(source, StreamProtocol) + + def test_is_pipeline_element_protocol(self, source): + assert isinstance(source, PipelineElementProtocol) + + def test_has_iter_packets(self, source): + assert callable(source.iter_packets) + + def test_has_output_schema(self, source): + assert callable(source.output_schema) + + def test_has_as_table(self, source): + assert callable(source.as_table) + + def test_has_to_config(self, source): + assert callable(source.to_config) + + def test_has_from_config(self, source): + assert callable(source.from_config) + + +# =========================================================================== +# 4. Construction — tag columns and source_id +# =========================================================================== + + +class TestConstruction: + def test_pk_columns_used_as_default_tag_columns(self, source): + tag_schema, _ = source.output_schema() + assert "session_id" in tag_schema + + def test_pk_tag_column_not_in_packet_schema(self, source): + tag_schema, packet_schema = source.output_schema() + assert "session_id" in tag_schema + assert "session_id" not in packet_schema + + def test_non_pk_columns_in_packet_schema(self, source): + _, packet_schema = source.output_schema() + assert "trial" in packet_schema + assert "response" in packet_schema + + def test_explicit_tag_columns_override_pk(self, connector): + src = DBTableSource(connector, "measurements", tag_columns=["trial"]) + tag_schema, packet_schema = src.output_schema() + assert "trial" in tag_schema + assert "session_id" not in tag_schema + + def test_multiple_explicit_tag_columns(self, connector): + src = DBTableSource( + connector, "measurements", tag_columns=["session_id", "trial"] + ) + tag_schema, _ = src.output_schema() + assert "session_id" in tag_schema + assert "trial" in tag_schema + + def test_default_source_id_is_table_name(self, source): + assert source.source_id == "measurements" + + def test_explicit_source_id_overrides_default(self, connector): + src = DBTableSource(connector, "measurements", source_id="my_meas") + assert src.source_id == "my_meas" + + def test_table_with_multiple_pk_columns(self, measurements_table): + connector = MockDBConnector( + tables={"t": measurements_table}, + pk_columns={"t": ["session_id", "trial"]}, + ) + src = DBTableSource(connector, "t") + tag_schema, _ = src.output_schema() + assert "session_id" in tag_schema + assert "trial" in tag_schema + + +# =========================================================================== +# 5. Construction error cases +# =========================================================================== + + +class TestConstructionErrors: + def test_missing_table_raises_value_error(self, connector): + with pytest.raises(ValueError, match="not found in database"): + DBTableSource(connector, "nonexistent") + + def test_missing_table_error_not_confused_with_no_pk(self, measurements_table): + """A missing table should raise 'not found', not 'no primary key'.""" + connector = MockDBConnector( + tables={"t": measurements_table}, + pk_columns={}, # no PKs registered + ) + with pytest.raises(ValueError, match="not found in database"): + DBTableSource(connector, "completely_missing") + + def test_no_pk_and_no_explicit_tags_raises_value_error(self, measurements_table): + connector = MockDBConnector( + tables={"t": measurements_table}, + pk_columns={}, # table exists but has no PK + ) + with pytest.raises(ValueError, match="no primary key"): + DBTableSource(connector, "t") + + def test_empty_table_raises_value_error(self, connector): + connector._tables["empty"] = pa.table( + {"id": pa.array([], type=pa.large_string())} + ) + connector._pk_columns["empty"] = ["id"] + with pytest.raises(ValueError, match="is empty"): + DBTableSource(connector, "empty") + + def test_empty_table_error_distinguishable_from_missing_table(self, connector): + """The two error messages must be distinct.""" + connector._tables["empty"] = pa.table({"id": pa.array([], type=pa.large_string())}) + connector._pk_columns["empty"] = ["id"] + empty_err: ValueError | None = None + missing_err: ValueError | None = None + try: + DBTableSource(connector, "empty") + except ValueError as exc: + empty_err = exc + try: + DBTableSource(connector, "nonexistent") + except ValueError as exc: + missing_err = exc + assert empty_err is not None + assert missing_err is not None + # They must be different messages + assert "not found" not in str(empty_err) + assert "is empty" not in str(missing_err) + + +# =========================================================================== +# 6. Stream behaviour +# =========================================================================== + + +class TestStreamBehaviour: + def test_producer_is_none(self, source): + """Root sources have no upstream producer.""" + assert source.producer is None + + def test_upstreams_is_empty(self, source): + assert source.upstreams == () + + def test_iter_packets_yields_one_packet_per_row(self, source, measurements_table): + packets = list(source.iter_packets()) + assert len(packets) == measurements_table.num_rows + + def test_iter_packets_each_has_tag_and_packet(self, source): + # Tag and Packet are named types (not plain dict) but support + # dict-like access and containment checks. + for tags, packet in source.iter_packets(): + assert "session_id" in tags + assert "trial" in packet or "response" in packet + + def test_output_schema_returns_two_schemas(self, source): + result = source.output_schema() + assert len(result) == 2 + + def test_output_schema_tag_schema_is_dict_like(self, source): + tag_schema, _ = source.output_schema() + assert "session_id" in tag_schema + + def test_output_schema_packet_schema_has_payload_columns(self, source): + _, packet_schema = source.output_schema() + assert "trial" in packet_schema + assert "response" in packet_schema + + def test_as_table_returns_pyarrow_table(self, source): + t = source.as_table() + assert isinstance(t, pa.Table) + + def test_as_table_row_count_matches_source_data(self, source, measurements_table): + t = source.as_table() + assert t.num_rows == measurements_table.num_rows + + def test_source_with_explicit_tags_yields_correct_keys(self, connector): + src = DBTableSource(connector, "measurements", tag_columns=["session_id"]) + for tags, _ in src.iter_packets(): + assert "session_id" in tags + + +# =========================================================================== +# 7. Deterministic hashing +# =========================================================================== + + +class TestDeterministicHashing: + def test_pipeline_hash_is_deterministic(self, connector): + src1 = DBTableSource(connector, "measurements") + src2 = DBTableSource(connector, "measurements") + assert src1.pipeline_hash() == src2.pipeline_hash() + + def test_content_hash_is_deterministic(self, connector): + src1 = DBTableSource(connector, "measurements") + src2 = DBTableSource(connector, "measurements") + assert src1.content_hash() == src2.content_hash() + + def test_pipeline_hash_is_schema_only_not_source_id(self, connector): + # pipeline_identity_structure() is (tag_schema, packet_schema) by design — + # source_id is intentionally excluded so sources with identical schemas + # share the same pipeline hash and therefore the same pipeline DB table. + src1 = DBTableSource(connector, "measurements", source_id="a") + src2 = DBTableSource(connector, "measurements", source_id="b") + assert src1.pipeline_hash() == src2.pipeline_hash() + + def test_different_tag_columns_yields_different_pipeline_hash(self, connector): + src1 = DBTableSource(connector, "measurements", tag_columns=["session_id"]) + src2 = DBTableSource(connector, "measurements", tag_columns=["trial"]) + assert src1.pipeline_hash() != src2.pipeline_hash() + + +# =========================================================================== +# 8. Config +# =========================================================================== + + +class TestConfig: + def test_to_config_has_source_type(self, source): + config = source.to_config() + assert config.get("source_type") == "db_table" + + def test_to_config_has_table_name(self, source): + config = source.to_config() + assert config["table_name"] == "measurements" + + def test_to_config_has_tag_columns(self, source): + config = source.to_config() + assert "tag_columns" in config + assert "session_id" in config["tag_columns"] + + def test_to_config_has_connector(self, source): + config = source.to_config() + assert "connector" in config + assert config["connector"]["connector_type"] == "mock" + + def test_to_config_has_source_id(self, source): + config = source.to_config() + assert config["source_id"] == "measurements" + + def test_to_config_has_identity_fields(self, source): + config = source.to_config() + # identity_config() adds content_hash, pipeline_hash, tag_schema, packet_schema + assert "content_hash" in config + assert "pipeline_hash" in config + + def test_from_config_raises_not_implemented(self, source): + config = source.to_config() + with pytest.raises(NotImplementedError): + DBTableSource.from_config(config) + + def test_to_config_explicit_source_id_preserved(self, connector): + src = DBTableSource(connector, "measurements", source_id="custom_id") + config = src.to_config() + assert config["source_id"] == "custom_id" + + def test_to_config_system_tag_columns_preserved(self, connector): + src = DBTableSource( + connector, "measurements", system_tag_columns=["session_id"] + ) + config = src.to_config() + assert "system_tag_columns" in config diff --git a/tests/test_databases/test_connector_arrow_database.py b/tests/test_databases/test_connector_arrow_database.py new file mode 100644 index 0000000..407a0e9 --- /dev/null +++ b/tests/test_databases/test_connector_arrow_database.py @@ -0,0 +1,678 @@ +""" +Comprehensive tests for ColumnInfo, DBConnectorProtocol, and ConnectorArrowDatabase. + +Structure mirrors test_in_memory_database.py so the two databases have +symmetric coverage. All tests use a MockDBConnector — no external DB required. + +Test sections: + 1. ColumnInfo dataclass + 2. DBConnectorProtocol structural conformance (via MockDBConnector) + 3. ConnectorArrowDatabase — protocol conformance + 4. Empty-table cases + 5. add_record / get_record_by_id round-trip + 6. add_records / get_all_records + 7. Duplicate handling + 8. get_records_by_ids + 9. get_records_with_column_value +10. Hierarchical record_path + _path_to_table_name +11. Flush behaviour (pending cleared, connector receives data) +12. Config (to_config shape, from_config raises NotImplementedError) +13. Context-manager lifecycle (connector.close is called) +""" +from __future__ import annotations + +import re +from collections.abc import Iterator +from typing import Any + +import pyarrow as pa +import pyarrow.compute as pc +import pytest + +# --------------------------------------------------------------------------- +# Imports under test — all of these will fail until the modules are created +# --------------------------------------------------------------------------- + +from orcapod.protocols.db_connector_protocol import ColumnInfo, DBConnectorProtocol +from orcapod.protocols.database_protocols import ArrowDatabaseProtocol +from orcapod.databases import ConnectorArrowDatabase + + +# --------------------------------------------------------------------------- +# MockDBConnector — in-memory DBConnectorProtocol for tests +# --------------------------------------------------------------------------- + + +class MockDBConnector: + """Pure in-memory implementation of DBConnectorProtocol for testing. + + Holds data as ``dict[str, pa.Table]``. ``iter_batches`` parses the + table name from a bare ``SELECT * FROM "table_name"`` query. + ``upsert_records`` implements insert-or-replace / insert-or-ignore + semantics in memory, mirroring what a real connector would do in SQL. + """ + + def __init__( + self, + tables: dict[str, pa.Table] | None = None, + pk_columns: dict[str, list[str]] | None = None, + ): + self._tables: dict[str, pa.Table] = dict(tables or {}) + self._pk_columns: dict[str, list[str]] = dict(pk_columns or {}) + self.close_called = False + + # ── Schema introspection ────────────────────────────────────────────────── + + def get_table_names(self) -> list[str]: + return list(self._tables.keys()) + + def get_pk_columns(self, table_name: str) -> list[str]: + return list(self._pk_columns.get(table_name, [])) + + def get_column_info(self, table_name: str) -> list[ColumnInfo]: + schema = self._tables[table_name].schema + return [ + ColumnInfo(name=f.name, arrow_type=f.type, nullable=f.nullable) + for f in schema + ] + + # ── Read ────────────────────────────────────────────────────────────────── + + def iter_batches( + self, query: str, params: Any = None, batch_size: int = 1000 + ) -> Iterator[pa.RecordBatch]: + match = re.search(r'FROM\s+"?(\w+)"?', query, re.IGNORECASE) + if not match: + return + table_name = match.group(1) + table = self._tables.get(table_name) + if table is None or table.num_rows == 0: + return + for batch in table.to_batches(max_chunksize=batch_size): + yield batch + + # ── Write ───────────────────────────────────────────────────────────────── + + def create_table_if_not_exists( + self, table_name: str, columns: list[ColumnInfo], pk_column: str + ) -> None: + if table_name not in self._tables: + self._tables[table_name] = pa.table( + {c.name: pa.array([], type=c.arrow_type) for c in columns} + ) + self._pk_columns.setdefault(table_name, [pk_column]) + + def upsert_records( + self, + table_name: str, + records: pa.Table, + id_column: str, + skip_existing: bool = False, + ) -> None: + existing = self._tables.get(table_name) + if existing is None or existing.num_rows == 0: + self._tables[table_name] = records + return + new_ids = set(records[id_column].to_pylist()) + if skip_existing: + existing_ids = set(existing[id_column].to_pylist()) + already_there = new_ids & existing_ids + mask = pc.invert( + pc.is_in(records[id_column], pa.array(list(already_there))) + ) + to_add = records.filter(mask) + if to_add.num_rows > 0: + self._tables[table_name] = pa.concat_tables([existing, to_add]) + else: + # INSERT OR REPLACE: remove old rows with matching id, then append + mask = pc.invert(pc.is_in(existing[id_column], pa.array(list(new_ids)))) + kept = existing.filter(mask) + self._tables[table_name] = pa.concat_tables([kept, records]) + + # ── Lifecycle ───────────────────────────────────────────────────────────── + + def close(self) -> None: + self.close_called = True + + def __enter__(self) -> "MockDBConnector": + return self + + def __exit__(self, *args: Any) -> None: + self.close() + + # ── Serialization ───────────────────────────────────────────────────────── + + def to_config(self) -> dict[str, Any]: + return {"connector_type": "mock"} + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "MockDBConnector": + return cls() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_table(**columns: list) -> pa.Table: + """Build a small PyArrow table from keyword column lists.""" + return pa.table({k: pa.array(v) for k, v in columns.items()}) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def connector(): + return MockDBConnector() + + +@pytest.fixture +def db(connector): + return ConnectorArrowDatabase(connector) + + +# =========================================================================== +# 1. ColumnInfo dataclass +# =========================================================================== + + +class TestColumnInfo: + def test_basic_construction(self): + col = ColumnInfo(name="my_col", arrow_type=pa.int64()) + assert col.name == "my_col" + assert col.arrow_type == pa.int64() + assert col.nullable is True # default + + def test_nullable_false(self): + col = ColumnInfo(name="pk", arrow_type=pa.large_string(), nullable=False) + assert col.nullable is False + + def test_frozen_immutable(self): + col = ColumnInfo(name="x", arrow_type=pa.float32()) + with pytest.raises((AttributeError, TypeError)): + col.name = "y" # type: ignore[misc] + + def test_equality(self): + a = ColumnInfo(name="v", arrow_type=pa.int32()) + b = ColumnInfo(name="v", arrow_type=pa.int32()) + assert a == b + + def test_inequality_different_type(self): + a = ColumnInfo(name="v", arrow_type=pa.int32()) + b = ColumnInfo(name="v", arrow_type=pa.int64()) + assert a != b + + +# =========================================================================== +# 2. DBConnectorProtocol structural conformance +# =========================================================================== + + +class TestDBConnectorProtocolConformance: + """MockDBConnector must satisfy DBConnectorProtocol structural checks.""" + + def test_mock_satisfies_protocol(self, connector): + assert isinstance(connector, DBConnectorProtocol) + + def test_has_get_table_names(self, connector): + assert callable(connector.get_table_names) + + def test_has_get_pk_columns(self, connector): + assert callable(connector.get_pk_columns) + + def test_has_get_column_info(self, connector): + assert callable(connector.get_column_info) + + def test_has_iter_batches(self, connector): + assert callable(connector.iter_batches) + + def test_has_create_table_if_not_exists(self, connector): + assert callable(connector.create_table_if_not_exists) + + def test_has_upsert_records(self, connector): + assert callable(connector.upsert_records) + + def test_has_close(self, connector): + assert callable(connector.close) + + def test_has_to_config(self, connector): + assert callable(connector.to_config) + + def test_to_config_has_connector_type_key(self, connector): + config = connector.to_config() + assert "connector_type" in config + + def test_context_manager_calls_close(self, connector): + assert not connector.close_called + with connector: + pass + assert connector.close_called + + +# =========================================================================== +# 3. ConnectorArrowDatabase — protocol conformance +# =========================================================================== + + +class TestConnectorArrowDatabaseConformance: + def test_satisfies_arrow_database_protocol(self, db): + assert isinstance(db, ArrowDatabaseProtocol) + + def test_has_add_record(self, db): + assert callable(db.add_record) + + def test_has_add_records(self, db): + assert callable(db.add_records) + + def test_has_get_record_by_id(self, db): + assert callable(db.get_record_by_id) + + def test_has_get_all_records(self, db): + assert callable(db.get_all_records) + + def test_has_get_records_by_ids(self, db): + assert callable(db.get_records_by_ids) + + def test_has_get_records_with_column_value(self, db): + assert callable(db.get_records_with_column_value) + + def test_has_flush(self, db): + assert callable(db.flush) + + def test_has_to_config(self, db): + assert callable(db.to_config) + + def test_has_from_config(self, db): + assert callable(db.from_config) + + +# =========================================================================== +# 4. Empty-table cases +# =========================================================================== + + +class TestEmptyTable: + PATH = ("source", "v1") + + def test_get_record_by_id_returns_none_when_empty(self, db): + assert db.get_record_by_id(self.PATH, "id-1", flush=True) is None + + def test_get_all_records_returns_none_when_empty(self, db): + assert db.get_all_records(self.PATH) is None + + def test_get_records_by_ids_returns_none_when_empty(self, db): + assert db.get_records_by_ids(self.PATH, ["id-1"], flush=True) is None + + def test_get_records_with_column_value_returns_none_when_empty(self, db): + assert ( + db.get_records_with_column_value(self.PATH, {"value": 1}, flush=True) + is None + ) + + +# =========================================================================== +# 5. add_record / get_record_by_id round-trip +# =========================================================================== + + +class TestAddRecordRoundTrip: + PATH = ("source", "v1") + + def test_added_record_retrievable_from_pending(self, db): + record = make_table(value=[42]) + db.add_record(self.PATH, "id-1", record) + result = db.get_record_by_id(self.PATH, "id-1") + assert result is not None + assert result.column("value").to_pylist() == [42] + + def test_added_record_retrievable_after_flush(self, db): + record = make_table(value=[99]) + db.add_record(self.PATH, "id-2", record) + db.flush() + result = db.get_record_by_id(self.PATH, "id-2", flush=True) + assert result is not None + assert result.column("value").to_pylist() == [99] + + def test_record_id_column_not_in_result_by_default(self, db): + record = make_table(value=[1]) + db.add_record(self.PATH, "id-3", record) + result = db.get_record_by_id(self.PATH, "id-3") + assert result is not None + assert ConnectorArrowDatabase.RECORD_ID_COLUMN not in result.column_names + + def test_record_id_column_exposed_when_requested(self, db): + record = make_table(value=[1]) + db.add_record(self.PATH, "id-4", record) + db.flush() + result = db.get_record_by_id( + self.PATH, "id-4", record_id_column="my_id", flush=True + ) + assert result is not None + assert "my_id" in result.column_names + assert result.column("my_id").to_pylist() == ["id-4"] + + def test_unknown_record_returns_none(self, db): + record = make_table(value=[1]) + db.add_record(self.PATH, "id-5", record) + db.flush() + assert db.get_record_by_id(self.PATH, "nonexistent", flush=True) is None + + def test_multi_row_record_deduplicates_to_last_row(self, db): + # add_record stamps ALL rows with the same __record_id value, so + # within-batch deduplication (keep-last) leaves a single row. + # This mirrors InMemoryArrowDatabase behaviour by design. + record = make_table(x=[1, 2, 3]) + db.add_record(self.PATH, "multi-row", record) + db.flush() + result = db.get_record_by_id(self.PATH, "multi-row", flush=True) + assert result is not None + assert result.num_rows == 1 + assert result.column("x").to_pylist() == [3] # last row kept + + +# =========================================================================== +# 6. add_records / get_all_records +# =========================================================================== + + +class TestAddRecordsRoundTrip: + PATH = ("multi", "v1") + + def test_add_records_bulk_and_retrieve_all(self, db): + records = make_table(__record_id=["a", "b", "c"], value=[10, 20, 30]) + db.add_records(self.PATH, records, record_id_column="__record_id") + db.flush() + result = db.get_all_records(self.PATH) + assert result is not None + assert result.num_rows == 3 + + def test_get_all_records_includes_pending(self, db): + records = make_table(__record_id=["x", "y"], value=[1, 2]) + db.add_records(self.PATH, records, record_id_column="__record_id") + result = db.get_all_records(self.PATH) + assert result is not None + assert result.num_rows == 2 + + def test_first_column_used_as_record_id_by_default(self, db): + records = make_table(id=["r1", "r2"], score=[5, 6]) + db.add_records(self.PATH, records) + db.flush() + result = db.get_all_records(self.PATH) + assert result is not None + assert result.num_rows == 2 + + def test_pending_and_committed_combined(self, db): + """Records added before and after flush should both appear in get_all_records.""" + db.add_records( + self.PATH, + make_table(__record_id=["a"], v=[1]), + record_id_column="__record_id", + flush=True, + ) + db.add_records( + self.PATH, + make_table(__record_id=["b"], v=[2]), + record_id_column="__record_id", + ) + result = db.get_all_records(self.PATH) + assert result is not None + assert result.num_rows == 2 + + def test_empty_records_table_is_ignored(self, db): + empty = make_table(__record_id=[], v=[]) + db.add_records(self.PATH, empty, record_id_column="__record_id") + assert db.get_all_records(self.PATH) is None + + +# =========================================================================== +# 7. Duplicate handling +# =========================================================================== + + +class TestDuplicateHandling: + PATH = ("dup", "v1") + + def test_skip_duplicates_true_does_not_raise(self, db): + db.add_record(self.PATH, "dup-id", make_table(value=[1])) + db.flush() + # same id again with skip_duplicates=True — should silently skip + db.add_record(self.PATH, "dup-id", make_table(value=[2]), skip_duplicates=True) + + def test_skip_duplicates_preserves_original_value(self, db): + db.add_record(self.PATH, "dup-id", make_table(value=[1]), flush=True) + db.add_record( + self.PATH, "dup-id", make_table(value=[99]), skip_duplicates=True, flush=True + ) + result = db.get_record_by_id(self.PATH, "dup-id", flush=True) + assert result is not None + assert result.column("value").to_pylist() == [1] # original preserved + + def test_skip_duplicates_false_raises_on_pending_duplicate(self, db): + db.add_record(self.PATH, "dup-id2", make_table(value=[1])) + with pytest.raises(ValueError): + db.add_records( + self.PATH, + make_table(__record_id=["dup-id2"], value=[99]), + record_id_column="__record_id", + skip_duplicates=False, + ) + + def test_within_batch_deduplication_keeps_last(self, db): + records = make_table(__record_id=["same", "same"], value=[1, 2]) + db.add_records(self.PATH, records, record_id_column="__record_id") + db.flush() + result = db.get_all_records(self.PATH) + assert result is not None + assert result.num_rows == 1 + assert result.column("value").to_pylist() == [2] + + +# =========================================================================== +# 8. get_records_by_ids +# =========================================================================== + + +class TestGetRecordsByIds: + PATH = ("byids", "v1") + + @pytest.fixture(autouse=True) + def populate(self, db): + records = make_table(__record_id=["a", "b", "c"], value=[10, 20, 30]) + db.add_records(self.PATH, records, record_id_column="__record_id") + db.flush() + + def test_retrieves_subset(self, db): + result = db.get_records_by_ids(self.PATH, ["a", "c"], flush=True) + assert result is not None + assert result.num_rows == 2 + + def test_returns_none_for_missing_ids(self, db): + result = db.get_records_by_ids(self.PATH, ["z"], flush=True) + assert result is None + + def test_empty_id_list_returns_none(self, db): + assert db.get_records_by_ids(self.PATH, [], flush=True) is None + + def test_retrieves_single_id(self, db): + result = db.get_records_by_ids(self.PATH, ["b"], flush=True) + assert result is not None + assert result.num_rows == 1 + + +# =========================================================================== +# 9. get_records_with_column_value +# =========================================================================== + + +class TestGetRecordsWithColumnValue: + PATH = ("colval", "v1") + + @pytest.fixture(autouse=True) + def populate(self, db): + records = make_table( + __record_id=["p", "q", "r"], category=["A", "B", "A"] + ) + db.add_records(self.PATH, records, record_id_column="__record_id") + db.flush() + + def test_filters_by_column_value(self, db): + result = db.get_records_with_column_value( + self.PATH, {"category": "A"}, flush=True + ) + assert result is not None + assert result.num_rows == 2 + + def test_no_match_returns_none(self, db): + result = db.get_records_with_column_value( + self.PATH, {"category": "Z"}, flush=True + ) + assert result is None + + def test_accepts_mapping_and_collection_of_tuples(self, db): + result_mapping = db.get_records_with_column_value( + self.PATH, {"category": "B"}, flush=True + ) + result_tuples = db.get_records_with_column_value( + self.PATH, [("category", "B")], flush=True + ) + assert result_mapping is not None and result_tuples is not None + assert result_mapping.num_rows == result_tuples.num_rows + + +# =========================================================================== +# 10. Hierarchical record_path + _path_to_table_name +# =========================================================================== + + +class TestHierarchicalPath: + def test_deep_path_stores_and_retrieves(self, db): + path = ("org", "project", "dataset", "v1") + db.add_record(path, "deep-id", make_table(x=[7])) + db.flush() + result = db.get_record_by_id(path, "deep-id", flush=True) + assert result is not None + assert result.column("x").to_pylist() == [7] + + def test_different_paths_are_independent(self, db): + path_a = ("ns", "a") + path_b = ("ns", "b") + db.add_record(path_a, "id-1", make_table(v=[1])) + db.add_record(path_b, "id-1", make_table(v=[2])) + db.flush() + result_a = db.get_record_by_id(path_a, "id-1", flush=True) + result_b = db.get_record_by_id(path_b, "id-1", flush=True) + assert result_a.column("v").to_pylist() == [1] + assert result_b.column("v").to_pylist() == [2] + + def test_invalid_empty_path_raises(self, db): + with pytest.raises(ValueError): + db.add_record((), "id-1", make_table(v=[1])) + + def test_path_exceeding_max_depth_raises(self, db): + path = tuple(f"part{i}" for i in range(db.max_hierarchy_depth + 1)) + with pytest.raises(ValueError, match="exceeds maximum"): + db.add_record(path, "id-1", make_table(v=[1])) + + +class TestPathToTableName: + def test_simple_single_component(self, db): + assert db._path_to_table_name(("results",)) == "results" + + def test_two_components_joined_with_double_underscore(self, db): + assert db._path_to_table_name(("results", "my_fn")) == "results__my_fn" + + def test_special_chars_replaced_with_underscore(self, db): + name = db._path_to_table_name(("a:b", "c/d")) + assert "__" in name + assert ":" not in name + assert "/" not in name + + def test_digit_prefix_gets_t_prefix(self, db): + name = db._path_to_table_name(("1abc",)) + assert name.startswith("t_") + + def test_digit_prefix_only_when_first_char_is_digit(self, db): + name = db._path_to_table_name(("abc1",)) + assert not name.startswith("t_") + + def test_same_path_always_yields_same_table_name(self, db): + path = ("foo", "bar", "baz") + assert db._path_to_table_name(path) == db._path_to_table_name(path) + + +# =========================================================================== +# 11. Flush behaviour +# =========================================================================== + + +class TestFlushBehaviour: + PATH = ("flush", "v1") + + def test_flush_writes_pending_to_connector(self, db, connector): + db.add_record(self.PATH, "f1", make_table(v=[1])) + db.add_record(self.PATH, "f2", make_table(v=[2])) + # pending key exists before flush + record_key = db._get_record_key(self.PATH) + assert record_key in db._pending_batches + db.flush() + # pending cleared after flush + assert record_key not in db._pending_batches + # connector now has the table + table_name = db._path_to_table_name(self.PATH) + assert table_name in connector.get_table_names() + + def test_flush_inline_via_flush_kwarg(self, db, connector): + db.add_record(self.PATH, "x", make_table(v=[5]), flush=True) + table_name = db._path_to_table_name(self.PATH) + assert table_name in connector.get_table_names() + + def test_multiple_flushes_accumulate_records(self, db): + db.add_record(self.PATH, "m1", make_table(v=[10])) + db.flush() + db.add_record(self.PATH, "m2", make_table(v=[20])) + db.flush() + result = db.get_all_records(self.PATH) + assert result is not None + assert result.num_rows == 2 + + def test_second_flush_on_existing_table_upserts(self, db): + """Flushing the same path twice should not duplicate rows.""" + db.add_record(self.PATH, "u1", make_table(v=[1]), flush=True) + db.add_record(self.PATH, "u1", make_table(v=[99]), skip_duplicates=True, flush=True) + result = db.get_all_records(self.PATH) + assert result is not None + # skip_existing=True means original is preserved, row count stays 1 + assert result.num_rows == 1 + + def test_noop_flush_does_not_raise(self, db): + """Flushing with nothing pending should be a no-op.""" + db.flush() # no error + + +# =========================================================================== +# 12. Config +# =========================================================================== + + +class TestConfig: + def test_to_config_has_type_key(self, db): + config = db.to_config() + assert config.get("type") == "connector_arrow_database" + + def test_to_config_includes_connector_config(self, db): + config = db.to_config() + assert "connector" in config + assert config["connector"]["connector_type"] == "mock" + + def test_to_config_includes_max_hierarchy_depth(self, db): + config = db.to_config() + assert "max_hierarchy_depth" in config + + def test_from_config_raises_not_implemented(self, db): + config = db.to_config() + with pytest.raises(NotImplementedError): + ConnectorArrowDatabase.from_config(config) From ecf08c127a9550e5689895dd154bfcba0fb8d963 Mon Sep 17 00:00:00 2001 From: "agent-kurouto[bot]" <268466204+agent-kurouto[bot]@users.noreply.github.com> Date: Sat, 21 Mar 2026 09:20:28 +0000 Subject: [PATCH 5/8] fix: address Copilot review feedback on ConnectorArrowDatabase --- .../databases/connector_arrow_database.py | 56 +++++++++++++++---- .../test_connector_arrow_database.py | 36 ++++++++++++ 2 files changed, 80 insertions(+), 12 deletions(-) diff --git a/src/orcapod/databases/connector_arrow_database.py b/src/orcapod/databases/connector_arrow_database.py index 7d999cf..33e4e3d 100644 --- a/src/orcapod/databases/connector_arrow_database.py +++ b/src/orcapod/databases/connector_arrow_database.py @@ -67,6 +67,10 @@ def __init__( self.max_hierarchy_depth = max_hierarchy_depth self._pending_batches: dict[str, pa.Table] = {} self._pending_record_ids: dict[str, set[str]] = defaultdict(set) + # Per-batch flag: True when the batch was added with skip_duplicates=True, + # so flush() can pass skip_existing=True to the connector and let it use + # native INSERT-OR-IGNORE semantics rather than Python-side prefiltering. + self._pending_skip_existing: dict[str, bool] = {} # ── Path helpers ────────────────────────────────────────────────────────── @@ -99,6 +103,14 @@ def _validate_record_path(self, record_path: tuple[str, ...]) -> None: raise ValueError( f"record_path component {i} is invalid: {repr(component)}" ) + # "/" is the separator used by _get_record_key; "\0" is a common + # string-boundary sentinel. Both would corrupt key round-tripping + # in flush() where record_path is reconstructed via split("/"). + if "/" in component or "\0" in component: + raise ValueError( + f"record_path component {repr(component)} contains an " + "invalid character ('/' or '\\0')" + ) # ── Record-ID column helpers ────────────────────────────────────────────── @@ -216,24 +228,23 @@ def add_records( input_ids = set(cast(list[str], records[self.RECORD_ID_COLUMN].to_pylist())) if skip_duplicates: - committed = self._get_committed_table(record_path) - committed_ids: set[str] = set() - if committed is not None: - committed_ids = set( - cast(list[str], committed[self.RECORD_ID_COLUMN].to_pylist()) - ) - all_existing = (input_ids & self._pending_record_ids[record_key]) | ( - input_ids & committed_ids - ) - if all_existing: + # Only filter records that conflict with the in-flight pending batch. + # Committed duplicates are handled at flush time via + # upsert_records(skip_existing=True), which lets the connector use + # native INSERT-OR-IGNORE semantics — no full-table read needed here. + pending_conflicts = input_ids & self._pending_record_ids[record_key] + if pending_conflicts: mask = pc.invert( pc.is_in( - records[self.RECORD_ID_COLUMN], pa.array(list(all_existing)) + records[self.RECORD_ID_COLUMN], + pa.array(list(pending_conflicts)), ) ) records = records.filter(mask) if records.num_rows == 0: return + # Mark this pending slot so flush() uses skip_existing=True. + self._pending_skip_existing[record_key] = True else: conflicts = input_ids & self._pending_record_ids[record_key] if conflicts: @@ -266,8 +277,29 @@ def flush(self) -> None: table_name = self._path_to_table_name(record_path) pending = self._pending_batches.pop(record_key) self._pending_record_ids.pop(record_key, None) + skip_existing = self._pending_skip_existing.pop(record_key, False) columns = _arrow_schema_to_column_infos(pending.schema) + + # Schema validation: if the table already exists, confirm the column + # names and Arrow types match before writing. Schema evolution is + # intentionally out of scope; a clear ValueError is preferable to a + # cryptic DB-level error or a silent partial write. + existing_table_names = self._connector.get_table_names() + if table_name in existing_table_names: + existing_cols = { + c.name: c.arrow_type + for c in self._connector.get_column_info(table_name) + } + pending_cols = {c.name: c.arrow_type for c in columns} + if existing_cols != pending_cols: + raise ValueError( + f"Schema mismatch for table {table_name!r}: " + f"existing columns {sorted(existing_cols)} differ from " + f"pending columns {sorted(pending_cols)}. " + "Schema evolution is not supported." + ) + self._connector.create_table_if_not_exists( table_name, columns, pk_column=self.RECORD_ID_COLUMN ) @@ -275,7 +307,7 @@ def flush(self) -> None: table_name, pending, id_column=self.RECORD_ID_COLUMN, - skip_existing=False, + skip_existing=skip_existing, ) # ── Read methods ────────────────────────────────────────────────────────── diff --git a/tests/test_databases/test_connector_arrow_database.py b/tests/test_databases/test_connector_arrow_database.py index 407a0e9..50ceec9 100644 --- a/tests/test_databases/test_connector_arrow_database.py +++ b/tests/test_databases/test_connector_arrow_database.py @@ -577,6 +577,16 @@ def test_path_exceeding_max_depth_raises(self, db): with pytest.raises(ValueError, match="exceeds maximum"): db.add_record(path, "id-1", make_table(v=[1])) + def test_path_component_with_slash_raises(self, db): + # "/" is the _get_record_key separator; allowing it would corrupt + # flush()'s record_path reconstruction via split("/"). + with pytest.raises(ValueError, match="invalid character"): + db.add_record(("bad/path",), "id-1", make_table(v=[1])) + + def test_path_component_with_null_byte_raises(self, db): + with pytest.raises(ValueError, match="invalid character"): + db.add_record(("bad\x00path",), "id-1", make_table(v=[1])) + class TestPathToTableName: def test_simple_single_component(self, db): @@ -652,6 +662,32 @@ def test_noop_flush_does_not_raise(self, db): """Flushing with nothing pending should be a no-op.""" db.flush() # no error + def test_flush_with_skip_duplicates_passes_skip_existing_to_connector( + self, db, connector + ): + """skip_duplicates=True must translate to skip_existing=True at flush, + so connectors can use native INSERT-OR-IGNORE without a Python-side + full-table read.""" + db.add_record(("t",), "a", make_table(v=[1]), flush=True) + # Second add with skip_duplicates=True should not overwrite v=1 + db.add_record(("t",), "a", make_table(v=[99]), skip_duplicates=True, flush=True) + result = db.get_record_by_id(("t",), "a", flush=True) + assert result is not None + assert result["v"][0].as_py() == 1 # original preserved via skip_existing=True + + def test_flush_schema_mismatch_raises_value_error(self, db): + """flush() must raise ValueError when the pending schema differs from + the table already in the connector — before any data is written.""" + db.add_record(("t",), "a", make_table(v=[1]), flush=True) + # Now try to flush a batch with a different column name + db.add_records( + ("t",), + pa.table({"__record_id": pa.array(["b"]), "x": pa.array([2])}), + record_id_column="__record_id", + ) + with pytest.raises(ValueError, match="Schema mismatch"): + db.flush() + # =========================================================================== # 12. Config From c23eaa4f5019b78f80e09a6f9f6d10508cbb6ac7 Mon Sep 17 00:00:00 2001 From: "agent-kurouto[bot]" <268466204+agent-kurouto[bot]@users.noreply.github.com> Date: Sat, 21 Mar 2026 10:19:01 +0000 Subject: [PATCH 6/8] refactor: address eywalker review comments on code style - Move ColumnInfo to orcapod.types (consistent with other concrete type definitions); db_connector_protocol re-exports it for backward compat - Add from __future__ import annotations to database_protocols.py and drop all quoted type hints throughout new files - Move re-export import to top of database_protocols.py, eliminating the noqa: E402 suppression - Replace **kwargs in DBTableSource.__init__ with explicit label, data_context, and config parameters matching RootSource signature - Update test imports to use ColumnInfo from orcapod.types --- src/orcapod/core/sources/db_table_source.py | 22 +++++++++---- .../databases/connector_arrow_database.py | 30 +++++++++--------- src/orcapod/protocols/database_protocols.py | 26 ++++++++-------- .../protocols/db_connector_protocol.py | 31 ++++++------------- src/orcapod/types.py | 20 ++++++++++++ .../test_core/sources/test_db_table_source.py | 3 +- .../test_connector_arrow_database.py | 3 +- 7 files changed, 77 insertions(+), 58 deletions(-) diff --git a/src/orcapod/core/sources/db_table_source.py b/src/orcapod/core/sources/db_table_source.py index 873f231..9f0459f 100644 --- a/src/orcapod/core/sources/db_table_source.py +++ b/src/orcapod/core/sources/db_table_source.py @@ -21,6 +21,8 @@ if TYPE_CHECKING: import pyarrow as pa + from orcapod import contexts + from orcapod.config import Config from orcapod.protocols.db_connector_protocol import DBConnectorProtocol else: pa = LazyModule("pyarrow") @@ -47,8 +49,9 @@ class DBTableSource(RootSource): strings. If ``None``, row indices are used. source_id: Canonical source name for the registry and provenance tokens. Defaults to ``table_name``. - **kwargs: Forwarded to ``RootSource`` (``label``, ``data_context``, - ``config``). + label: Human-readable label for this source node. + data_context: Data context governing type conversion and hashing. + config: Orcapod configuration (controls hash character counts, etc.). Raises: ValueError: If the table is not found, has no PK columns and none are @@ -57,17 +60,24 @@ class DBTableSource(RootSource): def __init__( self, - connector: "DBConnectorProtocol", + connector: DBConnectorProtocol, table_name: str, tag_columns: Collection[str] | None = None, system_tag_columns: Collection[str] = (), record_id_column: str | None = None, source_id: str | None = None, - **kwargs: Any, + label: str | None = None, + data_context: str | contexts.DataContext | None = None, + config: Config | None = None, ) -> None: if source_id is None: source_id = table_name - super().__init__(source_id=source_id, **kwargs) + super().__init__( + source_id=source_id, + label=label, + data_context=data_context, + config=config, + ) self._connector = connector self._table_name = table_name @@ -125,7 +135,7 @@ def to_config(self) -> dict[str, Any]: } @classmethod - def from_config(cls, config: dict[str, Any]) -> "DBTableSource": + def from_config(cls, config: dict[str, Any]) -> DBTableSource: """Not yet implemented — requires a connector factory registry. Raises: diff --git a/src/orcapod/databases/connector_arrow_database.py b/src/orcapod/databases/connector_arrow_database.py index 33e4e3d..0a04492 100644 --- a/src/orcapod/databases/connector_arrow_database.py +++ b/src/orcapod/databases/connector_arrow_database.py @@ -33,7 +33,7 @@ pc = LazyModule("pyarrow.compute") -def _arrow_schema_to_column_infos(schema: "pa.Schema") -> list[ColumnInfo]: +def _arrow_schema_to_column_infos(schema: pa.Schema) -> list[ColumnInfo]: """Convert a PyArrow schema to a list of ColumnInfo.""" return [ ColumnInfo(name=field.name, arrow_type=field.type, nullable=field.nullable) @@ -115,8 +115,8 @@ def _validate_record_path(self, record_path: tuple[str, ...]) -> None: # ── Record-ID column helpers ────────────────────────────────────────────── def _ensure_record_id_column( - self, arrow_data: "pa.Table", record_id: str - ) -> "pa.Table": + self, arrow_data: pa.Table, record_id: str + ) -> pa.Table: if self.RECORD_ID_COLUMN not in arrow_data.column_names: key_array = pa.array( [record_id] * len(arrow_data), type=pa.large_string() @@ -124,14 +124,14 @@ def _ensure_record_id_column( arrow_data = arrow_data.add_column(0, self.RECORD_ID_COLUMN, key_array) return arrow_data - def _remove_record_id_column(self, arrow_data: "pa.Table") -> "pa.Table": + def _remove_record_id_column(self, arrow_data: pa.Table) -> pa.Table: if self.RECORD_ID_COLUMN in arrow_data.column_names: arrow_data = arrow_data.drop([self.RECORD_ID_COLUMN]) return arrow_data def _handle_record_id_column( - self, arrow_data: "pa.Table", record_id_column: str | None - ) -> "pa.Table": + self, arrow_data: pa.Table, record_id_column: str | None + ) -> pa.Table: if not record_id_column: return self._remove_record_id_column(arrow_data) if self.RECORD_ID_COLUMN in arrow_data.column_names: @@ -146,7 +146,7 @@ def _handle_record_id_column( # ── Deduplication ───────────────────────────────────────────────────────── - def _deduplicate_within_table(self, table: "pa.Table") -> "pa.Table": + def _deduplicate_within_table(self, table: pa.Table) -> pa.Table: """Keep the last occurrence of each record ID within a single table.""" if table.num_rows <= 1: return table @@ -163,7 +163,7 @@ def _deduplicate_within_table(self, table: "pa.Table") -> "pa.Table": def _get_committed_table( self, record_path: tuple[str, ...] - ) -> "pa.Table | None": + ) -> pa.Table | None: """Fetch all committed records for a path from the connector.""" table_name = self._path_to_table_name(record_path) if table_name not in self._connector.get_table_names(): @@ -181,7 +181,7 @@ def add_record( self, record_path: tuple[str, ...], record_id: str, - record: "pa.Table", + record: pa.Table, skip_duplicates: bool = False, flush: bool = False, ) -> None: @@ -198,7 +198,7 @@ def add_record( def add_records( self, record_path: tuple[str, ...], - records: "pa.Table", + records: pa.Table, record_id_column: str | None = None, skip_duplicates: bool = False, flush: bool = False, @@ -318,7 +318,7 @@ def get_record_by_id( record_id: str, record_id_column: str | None = None, flush: bool = False, - ) -> "pa.Table | None": + ) -> pa.Table | None: if flush: self.flush() record_key = self._get_record_key(record_path) @@ -343,7 +343,7 @@ def get_all_records( self, record_path: tuple[str, ...], record_id_column: str | None = None, - ) -> "pa.Table | None": + ) -> pa.Table | None: record_key = self._get_record_key(record_path) parts: list[pa.Table] = [] @@ -365,7 +365,7 @@ def get_records_by_ids( record_ids: Collection[str], record_id_column: str | None = None, flush: bool = False, - ) -> "pa.Table | None": + ) -> pa.Table | None: if flush: self.flush() ids_list = list(record_ids) @@ -389,7 +389,7 @@ def get_records_with_column_value( column_values: Collection[tuple[str, Any]] | Mapping[str, Any], record_id_column: str | None = None, flush: bool = False, - ) -> "pa.Table | None": + ) -> pa.Table | None: if flush: self.flush() all_records = self.get_all_records( @@ -424,7 +424,7 @@ def to_config(self) -> dict[str, Any]: } @classmethod - def from_config(cls, config: dict[str, Any]) -> "ConnectorArrowDatabase": + def from_config(cls, config: dict[str, Any]) -> ConnectorArrowDatabase: """Reconstruct a ConnectorArrowDatabase from config. Raises: diff --git a/src/orcapod/protocols/database_protocols.py b/src/orcapod/protocols/database_protocols.py index 4033ad5..72e3dc7 100644 --- a/src/orcapod/protocols/database_protocols.py +++ b/src/orcapod/protocols/database_protocols.py @@ -1,5 +1,9 @@ -from typing import Any, Protocol, TYPE_CHECKING, runtime_checkable +from __future__ import annotations + from collections.abc import Collection, Mapping +from typing import Any, Protocol, TYPE_CHECKING, runtime_checkable + +from orcapod.protocols.db_connector_protocol import ColumnInfo, DBConnectorProtocol if TYPE_CHECKING: import pyarrow as pa @@ -11,7 +15,7 @@ def add_record( self, record_path: tuple[str, ...], record_id: str, - record: "pa.Table", + record: pa.Table, skip_duplicates: bool = False, flush: bool = False, ) -> None: ... @@ -19,7 +23,7 @@ def add_record( def add_records( self, record_path: tuple[str, ...], - records: "pa.Table", + records: pa.Table, record_id_column: str | None = None, skip_duplicates: bool = False, flush: bool = False, @@ -31,13 +35,13 @@ def get_record_by_id( record_id: str, record_id_column: str | None = None, flush: bool = False, - ) -> "pa.Table | None": ... + ) -> pa.Table | None: ... def get_all_records( self, record_path: tuple[str, ...], record_id_column: str | None = None, - ) -> "pa.Table | None": + ) -> pa.Table | None: """Retrieve all records for a given path as a stream.""" ... @@ -47,7 +51,7 @@ def get_records_by_ids( record_ids: Collection[str], record_id_column: str | None = None, flush: bool = False, - ) -> "pa.Table | None": ... + ) -> pa.Table | None: ... def get_records_with_column_value( self, @@ -55,13 +59,13 @@ def get_records_with_column_value( column_values: Collection[tuple[str, Any]] | Mapping[str, Any], record_id_column: str | None = None, flush: bool = False, - ) -> "pa.Table | None": ... + ) -> pa.Table | None: ... def flush(self) -> None: """Flush any buffered writes to the underlying storage.""" ... - def to_config(self) -> "dict[str, Any]": + def to_config(self) -> dict[str, Any]: """Serialize database configuration to a JSON-compatible dict. The returned dict must include a ``"type"`` key identifying the @@ -70,7 +74,7 @@ def to_config(self) -> "dict[str, Any]": ... @classmethod - def from_config(cls, config: "dict[str, Any]") -> "ArrowDatabaseProtocol": + def from_config(cls, config: dict[str, Any]) -> ArrowDatabaseProtocol: """Reconstruct a database instance from a config dict.""" ... @@ -104,10 +108,6 @@ class ArrowDatabaseWithMetadataProtocol( pass -# Re-export connector abstractions so callers can import everything DB-related -# from one place: ``from orcapod.protocols.database_protocols import ...`` -from orcapod.protocols.db_connector_protocol import ColumnInfo, DBConnectorProtocol # noqa: E402 - __all__ = [ "ArrowDatabaseProtocol", "ArrowDatabaseWithMetadataProtocol", diff --git a/src/orcapod/protocols/db_connector_protocol.py b/src/orcapod/protocols/db_connector_protocol.py index 7b0dc2a..e4ec128 100644 --- a/src/orcapod/protocols/db_connector_protocol.py +++ b/src/orcapod/protocols/db_connector_protocol.py @@ -12,29 +12,16 @@ from __future__ import annotations from collections.abc import Iterator -from dataclasses import dataclass from typing import Any, Protocol, TYPE_CHECKING, runtime_checkable +from orcapod.types import ColumnInfo + if TYPE_CHECKING: import pyarrow as pa - -@dataclass(frozen=True) -class ColumnInfo: - """Metadata for a single database column with its Arrow-mapped type. - - Type mapping (DB-native → Arrow) is the connector's responsibility. - Consumers of ``DBConnectorProtocol`` always see Arrow types. - - Args: - name: Column name. - arrow_type: Arrow data type (already mapped from the DB-native type). - nullable: Whether the column accepts NULL values. - """ - - name: str - arrow_type: "pa.DataType" - nullable: bool = True +# Re-export so existing ``from orcapod.protocols.db_connector_protocol import ColumnInfo`` +# imports continue to work while the canonical definition lives in orcapod.types. +__all__ = ["ColumnInfo", "DBConnectorProtocol"] @runtime_checkable @@ -82,7 +69,7 @@ def iter_batches( query: str, params: Any = None, batch_size: int = 1000, - ) -> Iterator["pa.RecordBatch"]: + ) -> Iterator[pa.RecordBatch]: """Execute a query and yield results as Arrow RecordBatches. Args: @@ -114,7 +101,7 @@ def create_table_if_not_exists( def upsert_records( self, table_name: str, - records: "pa.Table", + records: pa.Table, id_column: str, skip_existing: bool = False, ) -> None: @@ -136,7 +123,7 @@ def close(self) -> None: """Release the database connection and any associated resources.""" ... - def __enter__(self) -> "DBConnectorProtocol": + def __enter__(self) -> DBConnectorProtocol: ... def __exit__(self, *args: Any) -> None: @@ -155,6 +142,6 @@ def to_config(self) -> dict[str, Any]: ... @classmethod - def from_config(cls, config: dict[str, Any]) -> "DBConnectorProtocol": + def from_config(cls, config: dict[str, Any]) -> DBConnectorProtocol: """Reconstruct a connector instance from a config dict.""" ... diff --git a/src/orcapod/types.py b/src/orcapod/types.py index 3f8938d..d108273 100644 --- a/src/orcapod/types.py +++ b/src/orcapod/types.py @@ -432,6 +432,26 @@ def handle_config( return column_config +@dataclass(frozen=True) +class ColumnInfo: + """Metadata for a single relational database column with its Arrow-mapped type. + + ``ColumnInfo`` is produced by ``DBConnectorProtocol.get_column_info()`` and + consumed by ``ConnectorArrowDatabase`` and ``DBTableSource``. Type mapping + (DB-native → Arrow) is always the connector's responsibility; callers always + receive Arrow types. + + Args: + name: Column name. + arrow_type: Arrow data type (already mapped from the DB-native type). + nullable: Whether the column accepts NULL values. Defaults to ``True``. + """ + + name: str + arrow_type: pa.DataType + nullable: bool = True + + @dataclass(frozen=True, slots=True) class ContentHash: """Content-addressable hash pairing a hashing method with a raw digest. diff --git a/tests/test_core/sources/test_db_table_source.py b/tests/test_core/sources/test_db_table_source.py index 0816dc9..fabd434 100644 --- a/tests/test_core/sources/test_db_table_source.py +++ b/tests/test_core/sources/test_db_table_source.py @@ -22,7 +22,8 @@ from orcapod.core.sources import DBTableSource from orcapod.protocols.core_protocols import SourceProtocol, StreamProtocol -from orcapod.protocols.db_connector_protocol import ColumnInfo, DBConnectorProtocol +from orcapod.protocols.db_connector_protocol import DBConnectorProtocol +from orcapod.types import ColumnInfo from orcapod.protocols.hashing_protocols import PipelineElementProtocol diff --git a/tests/test_databases/test_connector_arrow_database.py b/tests/test_databases/test_connector_arrow_database.py index 50ceec9..4b8bd69 100644 --- a/tests/test_databases/test_connector_arrow_database.py +++ b/tests/test_databases/test_connector_arrow_database.py @@ -33,7 +33,8 @@ # Imports under test — all of these will fail until the modules are created # --------------------------------------------------------------------------- -from orcapod.protocols.db_connector_protocol import ColumnInfo, DBConnectorProtocol +from orcapod.protocols.db_connector_protocol import DBConnectorProtocol +from orcapod.types import ColumnInfo from orcapod.protocols.database_protocols import ArrowDatabaseProtocol from orcapod.databases import ConnectorArrowDatabase From d3195e90ee7a7a61b392d2b99e3ea0eb8ef2107c Mon Sep 17 00:00:00 2001 From: "agent-kurouto[bot]" <268466204+agent-kurouto[bot]@users.noreply.github.com> Date: Sat, 21 Mar 2026 10:39:03 +0000 Subject: [PATCH 7/8] refactor: make pyarrow a lazy import in orcapod/types.py Replace the unconditional `import pyarrow as pa` with the LazyModule pattern already used across the codebase (`LazyModule("pyarrow")` at runtime, real import under `TYPE_CHECKING`). Since `from __future__ import annotations` is in place, the `pa.DataType` annotation on `ColumnInfo` is a no-op string at import time. The two module-level dicts (`_PYTHON_TO_ARROW`, `_ARROW_TO_PYTHON`) called `pa.*()` at module load time, which would defeat laziness. They are replaced by `@functools.cache` functions (`_python_to_arrow()` / `_arrow_to_python()`) that build the mappings on first access. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/orcapod/types.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/src/orcapod/types.py b/src/orcapod/types.py index d108273..96f4e39 100644 --- a/src/orcapod/types.py +++ b/src/orcapod/types.py @@ -11,6 +11,7 @@ from __future__ import annotations +import functools import logging import os import uuid @@ -21,23 +22,33 @@ from typing import TYPE_CHECKING, Any, Self, TypeAlias if TYPE_CHECKING: + import pyarrow as pa + from orcapod.protocols.core_protocols import PacketFunctionExecutorProtocol +else: + from orcapod.utils.lazy_module import LazyModule -import pyarrow as pa + pa = LazyModule("pyarrow") logger = logging.getLogger(__name__) -# Mapping from Python types to Arrow types. -_PYTHON_TO_ARROW: dict[type, pa.DataType] = { - int: pa.int64(), - float: pa.float64(), - str: pa.string(), - bool: pa.bool_(), - bytes: pa.binary(), -} - -# Reverse mapping from Arrow types back to Python types. -_ARROW_TO_PYTHON: dict[pa.DataType, type] = {v: k for k, v in _PYTHON_TO_ARROW.items()} + +@functools.cache +def _python_to_arrow() -> dict[type, pa.DataType]: + """Lazily-built Python-type → Arrow-type mapping (populated on first call).""" + return { + int: pa.int64(), + float: pa.float64(), + str: pa.string(), + bool: pa.bool_(), + bytes: pa.binary(), + } + + +@functools.cache +def _arrow_to_python() -> dict[pa.DataType, type]: + """Lazily-built Arrow-type → Python-type mapping (populated on first call).""" + return {v: k for k, v in _python_to_arrow().items()} # TODO: revisit and consider a way to incorporate older Union type DataType: TypeAlias = type | UnionType # | type[Union] From 9b1a9e966d521bc3788e8aa21d52e9ec9d5a9a89 Mon Sep 17 00:00:00 2001 From: "agent-kurouto[bot]" <268466204+agent-kurouto[bot]@users.noreply.github.com> Date: Sat, 21 Mar 2026 10:46:54 +0000 Subject: [PATCH 8/8] refactor: remove dead _PYTHON_TO_ARROW / _ARROW_TO_PYTHON mappings These module-level dicts were unused (no references anywhere in the codebase) and were the only reason pyarrow needed to be called at module load time. Removing them simplifies the lazy-import change. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/orcapod/types.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/src/orcapod/types.py b/src/orcapod/types.py index 96f4e39..97ea2c1 100644 --- a/src/orcapod/types.py +++ b/src/orcapod/types.py @@ -11,7 +11,6 @@ from __future__ import annotations -import functools import logging import os import uuid @@ -32,24 +31,6 @@ logger = logging.getLogger(__name__) - -@functools.cache -def _python_to_arrow() -> dict[type, pa.DataType]: - """Lazily-built Python-type → Arrow-type mapping (populated on first call).""" - return { - int: pa.int64(), - float: pa.float64(), - str: pa.string(), - bool: pa.bool_(), - bytes: pa.binary(), - } - - -@functools.cache -def _arrow_to_python() -> dict[pa.DataType, type]: - """Lazily-built Arrow-type → Python-type mapping (populated on first call).""" - return {v: k for k, v in _python_to_arrow().items()} - # TODO: revisit and consider a way to incorporate older Union type DataType: TypeAlias = type | UnionType # | type[Union] """A Python type or union of types used to describe the data type of a single