diff --git a/src/orcapod/core/nodes/function_node.py b/src/orcapod/core/nodes/function_node.py index bfcda31..7b35af8 100644 --- a/src/orcapod/core/nodes/function_node.py +++ b/src/orcapod/core/nodes/function_node.py @@ -433,18 +433,13 @@ def keys( def pipeline_path(self) -> tuple[str, ...]: """Return the pipeline path for DB record scoping. - Raises: - RuntimeError: If no database is attached and this is not a - read-only deserialized node. + Returns ``()`` when no pipeline database is attached. """ stored = getattr(self, "_stored_pipeline_path", None) if self._packet_function is None and stored is not None: return stored if self._pipeline_database is None: - raise RuntimeError( - "Cannot compute pipeline_path without an attached database. " - "Call attach_databases() first." - ) + return () return ( self._pipeline_path_prefix + self._packet_function.uri @@ -521,7 +516,9 @@ def execute( obs = observer if observer is not None else NoOpObserver() - obs.on_node_start(node_label, node_hash) + pp = self.pipeline_path + tag_schema = input_stream.output_schema(columns={"system_tags": True})[0] + obs.on_node_start(node_label, node_hash, pipeline_path=pp, tag_schema=tag_schema) # Gather entry IDs and check cache upstream_entries = [ @@ -531,8 +528,6 @@ def execute( entry_ids = [eid for _, _, eid in upstream_entries] cached = self.get_cached_results(entry_ids=entry_ids) - pp = self.pipeline_path if self._pipeline_database is not None else () - output: list[tuple[TagProtocol, PacketProtocol]] = [] for tag, packet, entry_id in upstream_entries: obs.on_packet_start(node_label, tag, packet) @@ -559,7 +554,7 @@ def execute( ) obs.on_packet_crash(node_label, tag, packet, exc) if error_policy == "fail_fast": - obs.on_node_end(node_label, node_hash) + obs.on_node_end(node_label, node_hash, pipeline_path=pp) raise else: obs.on_packet_end( @@ -568,7 +563,7 @@ def execute( if result is not None: output.append((tag_out, result)) - obs.on_node_end(node_label, node_hash) + obs.on_node_end(node_label, node_hash, pipeline_path=pp) return output def _process_packet_internal( @@ -1243,8 +1238,11 @@ async def async_execute( obs = observer if observer is not None else NoOpObserver() + pp = self.pipeline_path + try: - obs.on_node_start(node_label, node_hash) + tag_schema = self._input_stream.output_schema(columns={"system_tags": True})[0] + obs.on_node_start(node_label, node_hash, pipeline_path=pp, tag_schema=tag_schema) if self._cached_function_pod is not None: # DB-backed async execution: @@ -1322,7 +1320,7 @@ async def async_execute( node_hash=node_hash, ) - obs.on_node_end(node_label, node_hash) + obs.on_node_end(node_label, node_hash, pipeline_path=pp) finally: await output.close() @@ -1337,7 +1335,7 @@ async def _async_execute_one_packet( node_hash: str, ) -> None: """Process one non-cached packet in the async execute path.""" - pp = self.pipeline_path if self._pipeline_database is not None else () + pp = self.pipeline_path observer.on_packet_start(node_label, tag, packet) ctx_obs = observer.contextualize(node_hash, node_label) diff --git a/src/orcapod/pipeline/__init__.py b/src/orcapod/pipeline/__init__.py index 67dddfe..0ae8469 100644 --- a/src/orcapod/pipeline/__init__.py +++ b/src/orcapod/pipeline/__init__.py @@ -1,15 +1,19 @@ from .async_orchestrator import AsyncPipelineOrchestrator +from .composite_observer import CompositeObserver from .graph import Pipeline from .logging_observer import LoggingObserver, PacketLogger from .serialization import LoadStatus, PIPELINE_FORMAT_VERSION +from .status_observer import StatusObserver from .sync_orchestrator import SyncPipelineOrchestrator __all__ = [ "AsyncPipelineOrchestrator", + "CompositeObserver", "LoadStatus", "LoggingObserver", "PacketLogger", "PIPELINE_FORMAT_VERSION", "Pipeline", + "StatusObserver", "SyncPipelineOrchestrator", ] diff --git a/src/orcapod/pipeline/composite_observer.py b/src/orcapod/pipeline/composite_observer.py new file mode 100644 index 0000000..128c9ce --- /dev/null +++ b/src/orcapod/pipeline/composite_observer.py @@ -0,0 +1,111 @@ +"""Composite observer that delegates to multiple child observers. + +Provides ``CompositeObserver``, a multiplexer that forwards every +``ExecutionObserverProtocol`` hook to N child observers. This allows +combining observers (e.g. ``LoggingObserver`` + ``StatusObserver``) +without modifying the orchestrator or node code. + +Example:: + + from orcapod.pipeline.composite_observer import CompositeObserver + from orcapod.pipeline.logging_observer import LoggingObserver + from orcapod.pipeline.status_observer import StatusObserver + from orcapod.databases import InMemoryArrowDatabase + + log_obs = LoggingObserver(log_database=InMemoryArrowDatabase()) + status_obs = StatusObserver(status_database=InMemoryArrowDatabase()) + observer = CompositeObserver(log_obs, status_obs) + + pipeline.run(orchestrator=SyncPipelineOrchestrator(observer=observer)) +""" + +from __future__ import annotations + +from typing import Any + +from orcapod.pipeline.observer import NoOpLogger +from orcapod.protocols.core_protocols import PacketProtocol, TagProtocol +from orcapod.types import SchemaLike + +_NOOP_LOGGER = NoOpLogger() + + +class CompositeObserver: + """Observer that delegates all hooks to multiple child observers. + + Args: + *observers: Child observers to delegate to. Each must satisfy + ``ExecutionObserverProtocol``. + """ + + def __init__(self, *observers: Any) -> None: + self._observers = observers + + def contextualize( + self, node_hash: str, node_label: str + ) -> CompositeObserver: + """Return a composite of contextualized children.""" + return CompositeObserver( + *(obs.contextualize(node_hash, node_label) for obs in self._observers) + ) + + def on_run_start(self, run_id: str) -> None: + for obs in self._observers: + obs.on_run_start(run_id) + + def on_run_end(self, run_id: str) -> None: + for obs in self._observers: + obs.on_run_end(run_id) + + def on_node_start( + self, node_label: str, node_hash: str, pipeline_path: tuple[str, ...] = (), tag_schema: SchemaLike | None = None + ) -> None: + for obs in self._observers: + obs.on_node_start(node_label, node_hash, pipeline_path=pipeline_path, tag_schema=tag_schema) + + def on_node_end( + self, node_label: str, node_hash: str, pipeline_path: tuple[str, ...] = () + ) -> None: + for obs in self._observers: + obs.on_node_end(node_label, node_hash, pipeline_path=pipeline_path) + + def on_packet_start( + self, node_label: str, tag: TagProtocol, packet: PacketProtocol + ) -> None: + for obs in self._observers: + obs.on_packet_start(node_label, tag, packet) + + def on_packet_end( + self, + node_label: str, + tag: TagProtocol, + input_packet: PacketProtocol, + output_packet: PacketProtocol | None, + cached: bool, + ) -> None: + for obs in self._observers: + obs.on_packet_end(node_label, tag, input_packet, output_packet, cached) + + def on_packet_crash( + self, node_label: str, tag: TagProtocol, packet: PacketProtocol, error: Exception + ) -> None: + for obs in self._observers: + obs.on_packet_crash(node_label, tag, packet, error) + + def create_packet_logger( + self, + tag: TagProtocol, + packet: PacketProtocol, + pipeline_path: tuple[str, ...] = (), + ) -> Any: + """Return the first non-no-op logger from children. + + Iterates through child observers and returns the logger from the + first child that provides a real (non-no-op) implementation. + Falls back to a no-op logger if all children return no-ops. + """ + for obs in self._observers: + pkt_logger = obs.create_packet_logger(tag, packet, pipeline_path=pipeline_path) + if not isinstance(pkt_logger, NoOpLogger): + return pkt_logger + return _NOOP_LOGGER diff --git a/src/orcapod/pipeline/logging_observer.py b/src/orcapod/pipeline/logging_observer.py index 00bd49c..a2aedb3 100644 --- a/src/orcapod/pipeline/logging_observer.py +++ b/src/orcapod/pipeline/logging_observer.py @@ -55,6 +55,8 @@ from uuid_utils import uuid7 from orcapod.pipeline.logging_capture import install_capture_streams +from orcapod.protocols.core_protocols import PacketProtocol, TagProtocol +from orcapod.types import SchemaLike if TYPE_CHECKING: import pyarrow as pa @@ -173,36 +175,40 @@ def on_run_start(self, run_id: str) -> None: def on_run_end(self, run_id: str) -> None: self._parent.on_run_end(run_id) - def on_node_start(self, node_label: str, node_hash: str) -> None: - self._parent.on_node_start(node_label, node_hash) + def on_node_start( + self, node_label: str, node_hash: str, pipeline_path: tuple[str, ...] = (), tag_schema: SchemaLike | None = None + ) -> None: + self._parent.on_node_start(node_label, node_hash, pipeline_path=pipeline_path, tag_schema=tag_schema) - def on_node_end(self, node_label: str, node_hash: str) -> None: - self._parent.on_node_end(node_label, node_hash) + def on_node_end( + self, node_label: str, node_hash: str, pipeline_path: tuple[str, ...] = () + ) -> None: + self._parent.on_node_end(node_label, node_hash, pipeline_path=pipeline_path) def on_packet_start( - self, node_label: str, tag: Any, packet: Any + self, node_label: str, tag: TagProtocol, packet: PacketProtocol ) -> None: self._parent.on_packet_start(node_label, tag, packet) def on_packet_end( self, node_label: str, - tag: Any, - input_packet: Any, - output_packet: Any, + tag: TagProtocol, + input_packet: PacketProtocol, + output_packet: PacketProtocol | None, cached: bool, ) -> None: self._parent.on_packet_end(node_label, tag, input_packet, output_packet, cached) def on_packet_crash( - self, node_label: str, tag: Any, packet: Any, error: Exception + self, node_label: str, tag: TagProtocol, packet: PacketProtocol, error: Exception ) -> None: self._parent.on_packet_crash(node_label, tag, packet, error) def create_packet_logger( self, - tag: Any, - packet: Any, + tag: TagProtocol, + packet: PacketProtocol, pipeline_path: tuple[str, ...] = (), ) -> PacketLogger: """Create a logger using context from this wrapper.""" @@ -282,34 +288,38 @@ def on_run_start(self, run_id: str) -> None: def on_run_end(self, run_id: str) -> None: pass - def on_node_start(self, node_label: str, node_hash: str) -> None: + def on_node_start( + self, node_label: str, node_hash: str, pipeline_path: tuple[str, ...] = (), tag_schema: SchemaLike | None = None + ) -> None: pass - def on_node_end(self, node_label: str, node_hash: str) -> None: + def on_node_end( + self, node_label: str, node_hash: str, pipeline_path: tuple[str, ...] = () + ) -> None: pass - def on_packet_start(self, node_label: str, tag: Any, packet: Any) -> None: + def on_packet_start(self, node_label: str, tag: TagProtocol, packet: PacketProtocol) -> None: pass def on_packet_end( self, node_label: str, - tag: Any, - input_packet: Any, - output_packet: Any, + tag: TagProtocol, + input_packet: PacketProtocol, + output_packet: PacketProtocol | None, cached: bool, ) -> None: pass def on_packet_crash( - self, node_label: str, tag: Any, packet: Any, error: Exception + self, node_label: str, tag: TagProtocol, packet: PacketProtocol, error: Exception ) -> None: pass def create_packet_logger( self, - tag: Any, - packet: Any, + tag: TagProtocol, + packet: PacketProtocol, pipeline_path: tuple[str, ...] = (), ) -> PacketLogger: """Return a :class:`PacketLogger` bound to *tag* context. diff --git a/src/orcapod/pipeline/observer.py b/src/orcapod/pipeline/observer.py index 8359b8c..e9a33ed 100644 --- a/src/orcapod/pipeline/observer.py +++ b/src/orcapod/pipeline/observer.py @@ -9,6 +9,7 @@ from typing import Any from orcapod.protocols.core_protocols import PacketProtocol, TagProtocol +from orcapod.types import SchemaLike from orcapod.protocols.observability_protocols import ( # noqa: F401 (re-exported for convenience) ExecutionObserverProtocol, PacketExecutionLoggerProtocol, @@ -58,10 +59,14 @@ def on_run_start(self, run_id: str) -> None: def on_run_end(self, run_id: str) -> None: pass - def on_node_start(self, node_label: str, node_hash: str) -> None: + def on_node_start( + self, node_label: str, node_hash: str, pipeline_path: tuple[str, ...] = (), tag_schema: SchemaLike | None = None + ) -> None: pass - def on_node_end(self, node_label: str, node_hash: str) -> None: + def on_node_end( + self, node_label: str, node_hash: str, pipeline_path: tuple[str, ...] = () + ) -> None: pass def on_packet_start( diff --git a/src/orcapod/pipeline/status_observer.py b/src/orcapod/pipeline/status_observer.py new file mode 100644 index 0000000..d4d9e90 --- /dev/null +++ b/src/orcapod/pipeline/status_observer.py @@ -0,0 +1,321 @@ +"""Run status observer for orcapod pipelines. + +Provides ``StatusObserver``, a drop-in observer that records per-packet +execution state transitions (``RUNNING``, ``SUCCESS``, ``FAILED``) to any +``ArrowDatabaseProtocol`` implementation (in-memory, Delta Lake, etc.). + +Example:: + + from orcapod.pipeline.status_observer import StatusObserver + from orcapod.pipeline import SyncPipelineOrchestrator + from orcapod.databases import InMemoryArrowDatabase + + obs = StatusObserver(status_database=InMemoryArrowDatabase()) + pipeline.run(orchestrator=SyncPipelineOrchestrator(observer=obs)) + + # Inspect run status for a specific node + status = obs.get_status(pipeline_path=node.pipeline_path) # pyarrow.Table + status.to_pandas() # pandas DataFrame + +Status schema (fixed columns): + Fixed columns are prefixed with ``_status_`` to follow system column + conventions and avoid collision with user-defined tag column names. + + - ``_status_id`` (large_utf8): UUID7 unique to this status event. + - ``_status_run_id`` (large_utf8): UUID of the pipeline run (from ``on_run_start``). + - ``_status_node_label`` (large_utf8): Label of the function node. + - ``_status_node_hash`` (large_utf8): Content hash of the function node. + - ``_status_state`` (large_utf8): ``RUNNING``, ``SUCCESS``, ``FAILED``, or ``CACHED``. + - ``_status_timestamp`` (large_utf8): ISO-8601 UTC timestamp. + - ``_status_error_summary`` (large_utf8): Brief error on ``FAILED``; ``None`` otherwise. + + In addition, each tag key from the packet's tag becomes a separate + ``large_utf8`` column (queryable, not JSON-encoded). Tag columns use + bare names (no prefix), so they are always distinguishable from fixed + columns. + +Status storage: + Status events are stored at a pipeline-path-mirrored location: + ``pipeline_path[:1] + ("status",) + pipeline_path[1:]``. + Each function node gets its own status table. Use + ``get_status(pipeline_path=node.pipeline_path)`` to retrieve + node-specific status. + +Append-only: + Each state transition is a new row. Current state for a (node, tag) + combination within a run is the row with the latest ``_status_timestamp``. + If a ``RUNNING`` event has no subsequent terminal event for the same + ``run_id``, the process crashed. +""" + +from __future__ import annotations + +import logging +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any + +from uuid_utils import uuid7 + +from orcapod.pipeline.observer import NoOpLogger +from orcapod.protocols.core_protocols import PacketProtocol, TagProtocol +from orcapod.types import SchemaLike + +if TYPE_CHECKING: + import pyarrow as pa + + from orcapod.protocols.database_protocols import ArrowDatabaseProtocol + +logger = logging.getLogger(__name__) + +# Default path within the database where status rows are stored. +DEFAULT_STATUS_PATH: tuple[str, ...] = ("execution_status",) + +_NOOP_LOGGER = NoOpLogger() + + +class StatusObserver: + """Concrete observer that writes packet execution status to a database. + + Instantiate once, outside the pipeline, and pass to the orchestrator + (directly or via a ``CompositeObserver``):: + + obs = StatusObserver(status_database=InMemoryArrowDatabase()) + orch = SyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + # After the run, read back status: + status_table = obs.get_status() # pyarrow.Table + + Args: + status_database: Any ``ArrowDatabaseProtocol`` instance. + status_path: Tuple of strings identifying the default table within + the database. Defaults to ``("execution_status",)``. + """ + + def __init__( + self, + status_database: ArrowDatabaseProtocol, + status_path: tuple[str, ...] | None = None, + ) -> None: + self._db = status_database + self._status_path = status_path or DEFAULT_STATUS_PATH + self._current_run_id: str = "" + # Tracks (node_hash, pipeline_path, tag_schema) per node_label, + # populated by on_node_start. Allows packet-level hooks (which + # only receive node_label) to look up the node's identity, + # storage path, and tag schema. + self._node_context: dict[str, tuple[str, tuple[str, ...], SchemaLike]] = {} + + # -- contextualize -- + + def contextualize( + self, node_hash: str, node_label: str + ) -> _ContextualizedStatusObserver: + """Return a contextualized wrapper stamped with node identity.""" + return _ContextualizedStatusObserver(self, node_hash, node_label) + + # -- lifecycle hooks -- + + def on_run_start(self, run_id: str) -> None: + self._current_run_id = run_id + self._node_context.clear() + + def on_run_end(self, run_id: str) -> None: + self._node_context.clear() + + def on_node_start( + self, + node_label: str, + node_hash: str, + pipeline_path: tuple[str, ...] = (), + tag_schema: SchemaLike | None = None, + ) -> None: + self._node_context[node_label] = (node_hash, pipeline_path, tag_schema or {}) + + def on_node_end( + self, + node_label: str, + node_hash: str, + pipeline_path: tuple[str, ...] = (), + ) -> None: + self._node_context.pop(node_label, None) + + def on_packet_start( + self, node_label: str, tag: TagProtocol, packet: PacketProtocol + ) -> None: + self._write_event(node_label, tag, state="RUNNING") + + def on_packet_end( + self, + node_label: str, + tag: TagProtocol, + input_packet: PacketProtocol, + output_packet: PacketProtocol | None, + cached: bool, + ) -> None: + self._write_event(node_label, tag, state="CACHED" if cached else "SUCCESS") + + def on_packet_crash( + self, node_label: str, tag: TagProtocol, packet: PacketProtocol, error: Exception + ) -> None: + self._write_event( + node_label, tag, state="FAILED", error=error + ) + + def create_packet_logger( + self, + tag: TagProtocol, + packet: PacketProtocol, + pipeline_path: tuple[str, ...] = (), + ) -> NoOpLogger: + """Return a no-op logger. + + Status events are written from observer hooks, not from the + packet logger. + """ + return _NOOP_LOGGER + + # -- convenience -- + + def get_status( + self, pipeline_path: tuple[str, ...] | None = None + ) -> pa.Table | None: + """Read status rows from the database as a ``pyarrow.Table``. + + Args: + pipeline_path: If provided, reads status for a specific node + (mirrored path). If ``None``, reads from the default + status path. + + Returns: + ``None`` if no status events have been written yet. + """ + if pipeline_path is not None: + status_path = pipeline_path[:1] + ("status",) + pipeline_path[1:] + else: + status_path = self._status_path + return self._db.get_all_records(status_path) + + # -- internal -- + + def _write_event( + self, + node_label: str, + tag: TagProtocol, + state: str, + error: Exception | None = None, + ) -> None: + """Build and write a single status event row.""" + import pyarrow as pa + + node_hash, pipeline_path, tag_schema = self._node_context.get( + node_label, ("", (), {}) + ) + + # Compute mirrored status path + if pipeline_path: + status_path = pipeline_path[:1] + ("status",) + pipeline_path[1:] + else: + status_path = self._status_path + + status_id = str(uuid7()) + timestamp = datetime.now(timezone.utc).isoformat() + + columns: dict[str, pa.Array] = { + "_status_id": pa.array([status_id], type=pa.large_utf8()), + "_status_run_id": pa.array([self._current_run_id], type=pa.large_utf8()), + "_status_node_label": pa.array([node_label], type=pa.large_utf8()), + "_status_node_hash": pa.array([node_hash], type=pa.large_utf8()), + "_status_state": pa.array([state], type=pa.large_utf8()), + "_status_timestamp": pa.array([timestamp], type=pa.large_utf8()), + "_status_error_summary": pa.array( + [str(error) if error is not None else None], + type=pa.large_utf8(), + ), + } + + # Tag columns — use statically-known schema from on_node_start + for key in tag_schema: + value = tag.get(key, None) + columns[key] = pa.array( + [str(value) if value is not None else None], + type=pa.large_utf8(), + ) + + row = pa.table(columns) + try: + self._db.add_record(status_path, status_id, row, flush=True) + except Exception: + logger.exception( + "StatusObserver: failed to write status event for node=%s state=%s", + node_label, + state, + ) + + +class _ContextualizedStatusObserver: + """Lightweight wrapper holding parent observer + node identity context. + + Created by ``StatusObserver.contextualize()``. All lifecycle hooks + delegate to the parent. + """ + + def __init__( + self, + parent: StatusObserver, + node_hash: str, + node_label: str, + ) -> None: + self._parent = parent + + def contextualize( + self, node_hash: str, node_label: str + ) -> _ContextualizedStatusObserver: + """Re-contextualize (returns a new wrapper with updated identity).""" + return _ContextualizedStatusObserver(self._parent, node_hash, node_label) + + def on_run_start(self, run_id: str) -> None: + self._parent.on_run_start(run_id) + + def on_run_end(self, run_id: str) -> None: + self._parent.on_run_end(run_id) + + def on_node_start( + self, node_label: str, node_hash: str, pipeline_path: tuple[str, ...] = (), tag_schema: SchemaLike | None = None + ) -> None: + self._parent.on_node_start(node_label, node_hash, pipeline_path=pipeline_path, tag_schema=tag_schema) + + def on_node_end( + self, node_label: str, node_hash: str, pipeline_path: tuple[str, ...] = () + ) -> None: + self._parent.on_node_end(node_label, node_hash, pipeline_path=pipeline_path) + + def on_packet_start( + self, node_label: str, tag: TagProtocol, packet: PacketProtocol + ) -> None: + self._parent.on_packet_start(node_label, tag, packet) + + def on_packet_end( + self, + node_label: str, + tag: TagProtocol, + input_packet: PacketProtocol, + output_packet: PacketProtocol | None, + cached: bool, + ) -> None: + self._parent.on_packet_end( + node_label, tag, input_packet, output_packet, cached + ) + + def on_packet_crash( + self, node_label: str, tag: TagProtocol, packet: PacketProtocol, error: Exception + ) -> None: + self._parent.on_packet_crash(node_label, tag, packet, error) + + def create_packet_logger( + self, + tag: TagProtocol, + packet: PacketProtocol, + pipeline_path: tuple[str, ...] = (), + ) -> NoOpLogger: + return _NOOP_LOGGER diff --git a/src/orcapod/protocols/node_protocols.py b/src/orcapod/protocols/node_protocols.py index 6f91036..2471b87 100644 --- a/src/orcapod/protocols/node_protocols.py +++ b/src/orcapod/protocols/node_protocols.py @@ -51,6 +51,14 @@ class FunctionNodeProtocol(Protocol): node_type: str + @property + def pipeline_path(self) -> tuple[str, ...]: + """The node's pipeline path for storage scoping. + + Returns ``()`` when no pipeline database is attached. + """ + ... + def execute( self, input_stream: StreamProtocol, diff --git a/src/orcapod/protocols/observability_protocols.py b/src/orcapod/protocols/observability_protocols.py index 4982289..457d29c 100644 --- a/src/orcapod/protocols/observability_protocols.py +++ b/src/orcapod/protocols/observability_protocols.py @@ -16,6 +16,7 @@ from typing import Any, Protocol, runtime_checkable from orcapod.protocols.core_protocols import PacketProtocol, TagProtocol +from orcapod.types import SchemaLike @runtime_checkable @@ -95,12 +96,37 @@ def on_run_end(self, run_id: str) -> None: """ ... - def on_node_start(self, node_label: str, node_hash: str) -> None: - """Called before a node begins processing its packets.""" + def on_node_start( + self, + node_label: str, + node_hash: str, + pipeline_path: tuple[str, ...] = (), + tag_schema: SchemaLike | None = None, + ) -> None: + """Called before a node begins processing its packets. + + Args: + node_label: Human-readable label of the node. + node_hash: Content hash of the node. + pipeline_path: The node's pipeline path for storage scoping. + tag_schema: The tag schema (including system tags) for this + node's input stream. + """ ... - def on_node_end(self, node_label: str, node_hash: str) -> None: - """Called after a node finishes processing all packets.""" + def on_node_end( + self, + node_label: str, + node_hash: str, + pipeline_path: tuple[str, ...] = (), + ) -> None: + """Called after a node finishes processing all packets. + + Args: + node_label: Human-readable label of the node. + node_hash: Content hash of the node. + pipeline_path: The node's pipeline path for storage scoping. + """ ... def on_packet_start( diff --git a/tests/test_pipeline/test_composite_observer.py b/tests/test_pipeline/test_composite_observer.py new file mode 100644 index 0000000..af5af93 --- /dev/null +++ b/tests/test_pipeline/test_composite_observer.py @@ -0,0 +1,218 @@ +"""Tests for CompositeObserver. + +Verifies that CompositeObserver correctly delegates all hooks to +multiple child observers and that create_packet_logger returns the +first real (non-no-op) logger. +""" + +from __future__ import annotations + +import pyarrow as pa + +from orcapod.core.executors import LocalExecutor +from orcapod.core.function_pod import FunctionPod +from orcapod.core.packet_function import PythonPacketFunction +from orcapod.core.sources.arrow_table_source import ArrowTableSource +from orcapod.databases import InMemoryArrowDatabase +from orcapod.pipeline import ( + Pipeline, + SyncPipelineOrchestrator, +) +from orcapod.pipeline.composite_observer import CompositeObserver +from orcapod.pipeline.logging_observer import LoggingObserver, PacketLogger +from orcapod.pipeline.observer import NoOpLogger +from orcapod.pipeline.status_observer import StatusObserver + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_source(n: int = 3) -> ArrowTableSource: + table = pa.table({ + "id": pa.array([str(i) for i in range(n)], type=pa.large_string()), + "x": pa.array([10 * (i + 1) for i in range(n)], type=pa.int64()), + }) + return ArrowTableSource(table, tag_columns=["id"]) + + +def _get_function_node(pipeline: Pipeline): + """Return the first function node from the pipeline graph.""" + import networkx as nx + + for node in nx.topological_sort(pipeline._node_graph): + if node.node_type == "function": + return node + raise RuntimeError("No function node found") + + +# --------------------------------------------------------------------------- +# 1. Integration: logging + status together +# --------------------------------------------------------------------------- + + +class TestLoggingAndStatusTogether: + def test_both_observers_populated(self): + db = InMemoryArrowDatabase() + source = _make_source(2) + + def double(x: int) -> int: + return x * 2 + + pf = PythonPacketFunction(double, output_keys="result", executor=LocalExecutor()) + pod = FunctionPod(pf) + + pipeline = Pipeline(name="test_composite", pipeline_database=db) + with pipeline: + pod(source, label="doubler") + + log_obs = LoggingObserver(log_database=db) + status_obs = StatusObserver(status_database=db) + observer = CompositeObserver(log_obs, status_obs) + + orch = SyncPipelineOrchestrator(observer=observer) + pipeline.run(orchestrator=orch) + + fn_node = _get_function_node(pipeline) + + # Logging observer should have logs + logs = log_obs.get_logs(pipeline_path=fn_node.pipeline_path) + assert logs is not None + assert logs.num_rows == 2 + + # Status observer should have status events + status = status_obs.get_status(pipeline_path=fn_node.pipeline_path) + assert status is not None + assert status.num_rows == 4 # 2 × (RUNNING + SUCCESS) + + +# --------------------------------------------------------------------------- +# 2. create_packet_logger returns the real logger, not no-op +# --------------------------------------------------------------------------- + + +class TestCreatePacketLoggerDelegation: + def test_returns_logging_observer_logger(self): + db = InMemoryArrowDatabase() + source = _make_source(1) + + def identity(x: int) -> int: + print("hello") + return x + + pf = PythonPacketFunction(identity, output_keys="result", executor=LocalExecutor()) + pod = FunctionPod(pf) + + pipeline = Pipeline(name="test_logger_delegation", pipeline_database=db) + with pipeline: + pod(source, label="ident") + + log_obs = LoggingObserver(log_database=db) + status_obs = StatusObserver(status_database=db) + observer = CompositeObserver(log_obs, status_obs) + + orch = SyncPipelineOrchestrator(observer=observer) + pipeline.run(orchestrator=orch) + + fn_node = _get_function_node(pipeline) + + # Logs should have captured the print output, proving the real + # LoggingObserver logger was used (not the no-op) + logs = log_obs.get_logs(pipeline_path=fn_node.pipeline_path) + assert logs is not None + stdout = logs.column("_log_stdout_log").to_pylist()[0] + assert "hello" in stdout + + +# --------------------------------------------------------------------------- +# 3. Contextualize returns a composite +# --------------------------------------------------------------------------- + + +class TestContextualizeReturnsComposite: + def test_contextualized_composite_delegates(self): + db = InMemoryArrowDatabase() + source = _make_source(2) + + def double(x: int) -> int: + return x * 2 + + def triple(result: int) -> int: + return result * 3 + + pf1 = PythonPacketFunction(double, output_keys="result", executor=LocalExecutor()) + pod1 = FunctionPod(pf1) + pf2 = PythonPacketFunction(triple, output_keys="final", executor=LocalExecutor()) + pod2 = FunctionPod(pf2) + + pipeline = Pipeline(name="test_ctx_composite", pipeline_database=db) + with pipeline: + s1 = pod1(source, label="doubler") + pod2(s1, label="tripler") + + log_obs = LoggingObserver(log_database=db) + status_obs = StatusObserver(status_database=db) + observer = CompositeObserver(log_obs, status_obs) + + orch = SyncPipelineOrchestrator(observer=observer) + pipeline.run(orchestrator=orch) + + import networkx as nx + + fn_nodes = [ + n + for n in nx.topological_sort(pipeline._node_graph) + if n.node_type == "function" + ] + + # Both nodes should have both logs and status + for fn_node in fn_nodes: + logs = log_obs.get_logs(pipeline_path=fn_node.pipeline_path) + status = status_obs.get_status(pipeline_path=fn_node.pipeline_path) + assert logs is not None + assert status is not None + assert logs.num_rows == 2 + assert status.num_rows == 4 # 2 × (RUNNING + SUCCESS) + + +# --------------------------------------------------------------------------- +# 4. Mixed success/failure with composite +# --------------------------------------------------------------------------- + + +class TestCompositeWithFailures: + def test_failures_tracked_by_both_observers(self): + db = InMemoryArrowDatabase() + source = _make_source(2) + + def failing(x: int) -> int: + raise ValueError("boom") + + pf = PythonPacketFunction(failing, output_keys="result", executor=LocalExecutor()) + pod = FunctionPod(pf) + + pipeline = Pipeline(name="test_composite_fail", pipeline_database=db) + with pipeline: + pod(source, label="failing") + + log_obs = LoggingObserver(log_database=db) + status_obs = StatusObserver(status_database=db) + observer = CompositeObserver(log_obs, status_obs) + + orch = SyncPipelineOrchestrator(observer=observer) + pipeline.run(orchestrator=orch) + + fn_node = _get_function_node(pipeline) + + # Logs should show failures + logs = log_obs.get_logs(pipeline_path=fn_node.pipeline_path) + assert logs is not None + assert all(s is False for s in logs.column("_log_success").to_pylist()) + + # Status should show RUNNING + FAILED + status = status_obs.get_status(pipeline_path=fn_node.pipeline_path) + assert status is not None + states = status.column("_status_state").to_pylist() + assert states.count("RUNNING") == 2 + assert states.count("FAILED") == 2 diff --git a/tests/test_pipeline/test_node_protocols.py b/tests/test_pipeline/test_node_protocols.py index 20d210b..d4ab780 100644 --- a/tests/test_pipeline/test_node_protocols.py +++ b/tests/test_pipeline/test_node_protocols.py @@ -48,6 +48,10 @@ def test_requires_execute_and_async_execute(self): class GoodFunction: node_type = "function" + @property + def pipeline_path(self): + return () + def execute(self, input_stream, *, observer=None): return [] @@ -157,9 +161,9 @@ def test_execute_with_observer(self): events = [] class Obs: - def on_node_start(self, node_label, node_hash): + def on_node_start(self, node_label, node_hash, **kwargs): events.append(("start", node_label)) - def on_node_end(self, node_label, node_hash): + def on_node_end(self, node_label, node_hash, **kwargs): events.append(("end", node_label)) def on_packet_start(self, node_label, t, p): pass @@ -207,9 +211,9 @@ async def test_async_execute_with_observer(self): events = [] class Obs: - def on_node_start(self, node_label, node_hash): + def on_node_start(self, node_label, node_hash, **kwargs): events.append("start") - def on_node_end(self, node_label, node_hash): + def on_node_end(self, node_label, node_hash, **kwargs): events.append("end") def on_packet_start(self, node_label, t, p): pass @@ -252,9 +256,9 @@ def test_execute_with_observer(self): class Obs: def contextualize(self, node_hash, node_label): return self - def on_node_start(self, node_label, node_hash): + def on_node_start(self, node_label, node_hash, **kwargs): events.append(("node_start", node_label)) - def on_node_end(self, node_label, node_hash): + def on_node_end(self, node_label, node_hash, **kwargs): events.append(("node_end", node_label)) def on_packet_start(self, node_label, t, p): events.append(("packet_start",)) @@ -331,9 +335,9 @@ async def test_async_execute_with_observer(self): class Obs: def contextualize(self, node_hash, node_label): return self - def on_node_start(self, node_label, node_hash): + def on_node_start(self, node_label, node_hash, **kwargs): events.append("node_start") - def on_node_end(self, node_label, node_hash): + def on_node_end(self, node_label, node_hash, **kwargs): events.append("node_end") def on_packet_start(self, node_label, t, p): events.append("pkt_start") @@ -383,9 +387,9 @@ def test_execute_with_observer(self): events = [] class Obs: - def on_node_start(self, node_label, node_hash): + def on_node_start(self, node_label, node_hash, **kwargs): events.append(("node_start", node_label)) - def on_node_end(self, node_label, node_hash): + def on_node_end(self, node_label, node_hash, **kwargs): events.append(("node_end", node_label)) def on_packet_start(self, node_label, t, p): pass @@ -424,9 +428,9 @@ async def test_async_execute_with_observer(self): events = [] class Obs: - def on_node_start(self, node_label, node_hash): + def on_node_start(self, node_label, node_hash, **kwargs): events.append("start") - def on_node_end(self, node_label, node_hash): + def on_node_end(self, node_label, node_hash, **kwargs): events.append("end") def on_packet_start(self, node_label, t, p): pass diff --git a/tests/test_pipeline/test_orchestrator.py b/tests/test_pipeline/test_orchestrator.py index 587b4d1..b6c1c9c 100644 --- a/tests/test_pipeline/test_orchestrator.py +++ b/tests/test_pipeline/test_orchestrator.py @@ -496,9 +496,9 @@ def test_linear_pipeline_observer_hooks(self): class RecordingObserver: def on_run_start(self, run_id): pass def on_run_end(self, run_id): pass - def on_node_start(self, node_label, node_hash): + def on_node_start(self, node_label, node_hash, **kwargs): events.append(("node_start", node_label)) - def on_node_end(self, node_label, node_hash): + def on_node_end(self, node_label, node_hash, **kwargs): events.append(("node_end", node_label)) def on_packet_start(self, node_label, tag, packet): events.append(("packet_start", node_label)) @@ -553,9 +553,9 @@ def double_val(val: int) -> int: class RecordingObserver: def on_run_start(self, run_id): pass def on_run_end(self, run_id): pass - def on_node_start(self, node_label, node_hash): + def on_node_start(self, node_label, node_hash, **kwargs): events.append(("node_start", node_label)) - def on_node_end(self, node_label, node_hash): + def on_node_end(self, node_label, node_hash, **kwargs): events.append(("node_end", node_label)) def on_packet_start(self, node_label, tag, packet): events.append(("packet_start", node_label)) diff --git a/tests/test_pipeline/test_status_observer_integration.py b/tests/test_pipeline/test_status_observer_integration.py new file mode 100644 index 0000000..0906323 --- /dev/null +++ b/tests/test_pipeline/test_status_observer_integration.py @@ -0,0 +1,485 @@ +"""Integration tests for StatusObserver with real pipelines. + +Exercises the full status tracking pipeline: observer hooks → +StatusObserver → database, using InMemoryArrowDatabase and real +Pipeline objects. +""" + +from __future__ import annotations + +import pyarrow as pa +import pytest + +from orcapod.core.executors import LocalExecutor +from orcapod.core.function_pod import FunctionPod +from orcapod.core.packet_function import PythonPacketFunction +from orcapod.core.sources.arrow_table_source import ArrowTableSource +from orcapod.databases import InMemoryArrowDatabase +from orcapod.pipeline import ( + AsyncPipelineOrchestrator, + Pipeline, + SyncPipelineOrchestrator, +) +from orcapod.pipeline.status_observer import StatusObserver + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_source(n: int = 3) -> ArrowTableSource: + table = pa.table({ + "id": pa.array([str(i) for i in range(n)], type=pa.large_string()), + "x": pa.array([10 * (i + 1) for i in range(n)], type=pa.int64()), + }) + return ArrowTableSource(table, tag_columns=["id"]) + + +def _get_function_node(pipeline: Pipeline): + """Return the first function node from the pipeline graph.""" + import networkx as nx + + for node in nx.topological_sort(pipeline._node_graph): + if node.node_type == "function": + return node + raise RuntimeError("No function node found") + + +# --------------------------------------------------------------------------- +# 1. Sync pipeline — success status events +# --------------------------------------------------------------------------- + + +class TestSyncPipelineSuccessStatus: + def test_success_produces_running_and_success_events(self): + db = InMemoryArrowDatabase() + source = _make_source() + + def double(x: int) -> int: + return x * 2 + + pf = PythonPacketFunction(double, output_keys="result", executor=LocalExecutor()) + pod = FunctionPod(pf) + + pipeline = Pipeline(name="test_status", pipeline_database=db) + with pipeline: + pod(source, label="doubler") + + obs = StatusObserver(status_database=db) + orch = SyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + fn_node = _get_function_node(pipeline) + status = obs.get_status(pipeline_path=fn_node.pipeline_path) + + assert status is not None + # 3 packets × 2 events each (RUNNING + SUCCESS) = 6 rows + assert status.num_rows == 6 + + states = status.column("_status_state").to_pylist() + assert states.count("RUNNING") == 3 + assert states.count("SUCCESS") == 3 + + +# --------------------------------------------------------------------------- +# 2. Failing packets → FAILED status with error summary +# --------------------------------------------------------------------------- + + +class TestFailingPacketsStatus: + def test_failure_status_with_error_summary(self): + db = InMemoryArrowDatabase() + source = _make_source(2) + + def failing(x: int) -> int: + raise ValueError("boom") + + pf = PythonPacketFunction(failing, output_keys="result", executor=LocalExecutor()) + pod = FunctionPod(pf) + + pipeline = Pipeline(name="test_fail_status", pipeline_database=db) + with pipeline: + pod(source, label="failing") + + obs = StatusObserver(status_database=db) + orch = SyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + fn_node = _get_function_node(pipeline) + status = obs.get_status(pipeline_path=fn_node.pipeline_path) + + assert status is not None + # 2 packets × 2 events each (RUNNING + FAILED) + assert status.num_rows == 4 + + states = status.column("_status_state").to_pylist() + assert states.count("RUNNING") == 2 + assert states.count("FAILED") == 2 + + # Error summary should be populated for FAILED events + for i, state in enumerate(states): + error_summary = status.column("_status_error_summary").to_pylist()[i] + if state == "FAILED": + assert error_summary is not None + assert "boom" in error_summary + else: + assert error_summary is None + + +# --------------------------------------------------------------------------- +# 3. Pipeline-path-mirrored storage +# --------------------------------------------------------------------------- + + +class TestPipelinePathMirroredStorage: + def test_status_path_mirrors_pipeline_path(self): + db = InMemoryArrowDatabase() + source = _make_source(1) + + def identity(x: int) -> int: + return x + + pf = PythonPacketFunction(identity, output_keys="result", executor=LocalExecutor()) + pod = FunctionPod(pf) + + pipeline = Pipeline(name="test_mirror_status", pipeline_database=db) + with pipeline: + pod(source, label="ident") + + obs = StatusObserver(status_database=db) + orch = SyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + fn_node = _get_function_node(pipeline) + pp = fn_node.pipeline_path + expected_status_path = pp[:1] + ("status",) + pp[1:] + + # Verify the status path is correct by reading directly from the DB + raw = db.get_all_records(expected_status_path) + assert raw is not None + assert raw.num_rows == 2 # RUNNING + SUCCESS + + +# --------------------------------------------------------------------------- +# 4. Queryable tag columns +# --------------------------------------------------------------------------- + + +class TestQueryableTagColumns: + def test_tag_columns_in_status_table(self): + db = InMemoryArrowDatabase() + source = _make_source(2) + + def identity(x: int) -> int: + return x + + pf = PythonPacketFunction(identity, output_keys="result", executor=LocalExecutor()) + pod = FunctionPod(pf) + + pipeline = Pipeline(name="test_tags_status", pipeline_database=db) + with pipeline: + pod(source, label="ident") + + obs = StatusObserver(status_database=db) + orch = SyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + fn_node = _get_function_node(pipeline) + status = obs.get_status(pipeline_path=fn_node.pipeline_path) + + assert status is not None + # "id" tag column should be a separate column, not JSON + assert "id" in status.column_names + id_values = sorted(set(status.column("id").to_pylist())) + assert id_values == ["0", "1"] + + +# --------------------------------------------------------------------------- +# 5. Async orchestrator status +# --------------------------------------------------------------------------- + + +class TestAsyncOrchestratorStatus: + def test_async_pipeline_captures_status(self): + db = InMemoryArrowDatabase() + source = _make_source(2) + + def double(x: int) -> int: + return x * 2 + + pf = PythonPacketFunction(double, output_keys="result", executor=LocalExecutor()) + pod = FunctionPod(pf) + + pipeline = Pipeline(name="test_async_status", pipeline_database=db) + with pipeline: + pod(source, label="doubler") + + obs = StatusObserver(status_database=db) + orch = AsyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + fn_node = _get_function_node(pipeline) + status = obs.get_status(pipeline_path=fn_node.pipeline_path) + + assert status is not None + assert status.num_rows == 4 # 2 × (RUNNING + SUCCESS) + + states = status.column("_status_state").to_pylist() + assert states.count("RUNNING") == 2 + assert states.count("SUCCESS") == 2 + + +# --------------------------------------------------------------------------- +# 6. fail_fast error policy +# --------------------------------------------------------------------------- + + +class TestFailFastErrorPolicy: + def test_fail_fast_aborts_and_records_status(self): + db = InMemoryArrowDatabase() + source = _make_source(3) + + def failing(x: int) -> int: + raise RuntimeError("crash") + + pf = PythonPacketFunction(failing, output_keys="result", executor=LocalExecutor()) + pod = FunctionPod(pf) + + pipeline = Pipeline(name="test_failfast_status", pipeline_database=db) + with pipeline: + pod(source, label="failing") + + obs = StatusObserver(status_database=db) + orch = SyncPipelineOrchestrator(observer=obs, error_policy="fail_fast") + + with pytest.raises(RuntimeError, match="crash"): + pipeline.run(orchestrator=orch) + + fn_node = _get_function_node(pipeline) + status = obs.get_status(pipeline_path=fn_node.pipeline_path) + + # At least one RUNNING + one FAILED before abort + assert status is not None + assert status.num_rows >= 2 + states = status.column("_status_state").to_pylist() + assert "RUNNING" in states + assert "FAILED" in states + + +# --------------------------------------------------------------------------- +# 7. Mixed success/failure — correct status per packet +# --------------------------------------------------------------------------- + + +class TestMixedSuccessFailure: + def test_mixed_results_tracked_correctly(self): + db = InMemoryArrowDatabase() + source = ArrowTableSource( + pa.table({ + "id": pa.array(["a", "b", "c"], type=pa.large_string()), + "x": pa.array([10, 0, 30], type=pa.int64()), + }), + tag_columns=["id"], + ) + + def safe_div(x: int) -> float: + return 100 / x # x=0 will raise ZeroDivisionError + + pf = PythonPacketFunction(safe_div, output_keys="result", executor=LocalExecutor()) + pod = FunctionPod(pf) + + pipeline = Pipeline(name="test_mixed_status", pipeline_database=db) + with pipeline: + pod(source, label="divider") + + obs = StatusObserver(status_database=db) + orch = SyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + fn_node = _get_function_node(pipeline) + status = obs.get_status(pipeline_path=fn_node.pipeline_path) + + assert status is not None + # 3 RUNNING + 2 SUCCESS + 1 FAILED = 6 + assert status.num_rows == 6 + + states = status.column("_status_state").to_pylist() + assert states.count("RUNNING") == 3 + assert states.count("SUCCESS") == 2 + assert states.count("FAILED") == 1 + + +# --------------------------------------------------------------------------- +# 8. Multiple function nodes — each gets own status table +# --------------------------------------------------------------------------- + + +class TestMultipleFunctionNodesSeparateStatus: + def test_two_nodes_separate_status_tables(self): + db = InMemoryArrowDatabase() + source = _make_source(2) + + def double(x: int) -> int: + return x * 2 + + def triple(result: int) -> int: + return result * 3 + + pf1 = PythonPacketFunction(double, output_keys="result", executor=LocalExecutor()) + pod1 = FunctionPod(pf1) + pf2 = PythonPacketFunction(triple, output_keys="final", executor=LocalExecutor()) + pod2 = FunctionPod(pf2) + + pipeline = Pipeline(name="test_multi_status", pipeline_database=db) + with pipeline: + s1 = pod1(source, label="doubler") + pod2(s1, label="tripler") + + obs = StatusObserver(status_database=db) + orch = SyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + import networkx as nx + + fn_nodes = [ + n + for n in nx.topological_sort(pipeline._node_graph) + if n.node_type == "function" + ] + assert len(fn_nodes) == 2 + + status1 = obs.get_status(pipeline_path=fn_nodes[0].pipeline_path) + status2 = obs.get_status(pipeline_path=fn_nodes[1].pipeline_path) + + assert status1 is not None + assert status2 is not None + # Each node: 2 packets × 2 events = 4 rows + assert status1.num_rows == 4 + assert status2.num_rows == 4 + + # Verify they are at different paths + assert fn_nodes[0].pipeline_path != fn_nodes[1].pipeline_path + + +# --------------------------------------------------------------------------- +# 9. get_status(pipeline_path) retrieves node-specific status +# --------------------------------------------------------------------------- + + +class TestGetStatusNodeSpecific: + def test_get_status_filters_by_node(self): + db = InMemoryArrowDatabase() + source = _make_source(2) + + def double(x: int) -> int: + return x * 2 + + def triple(result: int) -> int: + return result * 3 + + pf1 = PythonPacketFunction(double, output_keys="result", executor=LocalExecutor()) + pod1 = FunctionPod(pf1) + pf2 = PythonPacketFunction(triple, output_keys="final", executor=LocalExecutor()) + pod2 = FunctionPod(pf2) + + pipeline = Pipeline(name="test_filter_status", pipeline_database=db) + with pipeline: + s1 = pod1(source, label="doubler") + pod2(s1, label="tripler") + + obs = StatusObserver(status_database=db) + orch = SyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + import networkx as nx + + fn_nodes = [ + n + for n in nx.topological_sort(pipeline._node_graph) + if n.node_type == "function" + ] + + # Each node's status contains only that node's label + status1 = obs.get_status(pipeline_path=fn_nodes[0].pipeline_path) + status2 = obs.get_status(pipeline_path=fn_nodes[1].pipeline_path) + + labels1 = set(status1.column("_status_node_label").to_pylist()) + labels2 = set(status2.column("_status_node_label").to_pylist()) + + assert labels1 == {"doubler"} + assert labels2 == {"tripler"} + + +# --------------------------------------------------------------------------- +# 10. Status columns have correct schema +# --------------------------------------------------------------------------- + + +class TestStatusSchema: + def test_all_expected_columns_present(self): + db = InMemoryArrowDatabase() + source = _make_source(1) + + def identity(x: int) -> int: + return x + + pf = PythonPacketFunction(identity, output_keys="result", executor=LocalExecutor()) + pod = FunctionPod(pf) + + pipeline = Pipeline(name="test_schema", pipeline_database=db) + with pipeline: + pod(source, label="ident") + + obs = StatusObserver(status_database=db) + orch = SyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + fn_node = _get_function_node(pipeline) + status = obs.get_status(pipeline_path=fn_node.pipeline_path) + + expected_system_cols = { + "_status_id", + "_status_run_id", + "_status_node_label", + "_status_node_hash", + "_status_state", + "_status_timestamp", + "_status_error_summary", + } + assert expected_system_cols.issubset(set(status.column_names)) + + # Tag column should also be present + assert "id" in status.column_names + + +# --------------------------------------------------------------------------- +# 11. run_id is tracked correctly +# --------------------------------------------------------------------------- + + +class TestRunIdTracking: + def test_run_id_populated_in_status(self): + db = InMemoryArrowDatabase() + source = _make_source(1) + + def identity(x: int) -> int: + return x + + pf = PythonPacketFunction(identity, output_keys="result", executor=LocalExecutor()) + pod = FunctionPod(pf) + + pipeline = Pipeline(name="test_runid", pipeline_database=db) + with pipeline: + pod(source, label="ident") + + obs = StatusObserver(status_database=db) + orch = SyncPipelineOrchestrator(observer=obs) + # Pass run_id via the orchestrator's run() method directly + orch.run(pipeline._node_graph, run_id="my-custom-run-id") + + fn_node = _get_function_node(pipeline) + status = obs.get_status(pipeline_path=fn_node.pipeline_path) + + run_ids = set(status.column("_status_run_id").to_pylist()) + assert run_ids == {"my-custom-run-id"} diff --git a/tests/test_pipeline/test_sync_orchestrator.py b/tests/test_pipeline/test_sync_orchestrator.py index bdd58ef..6954a28 100644 --- a/tests/test_pipeline/test_sync_orchestrator.py +++ b/tests/test_pipeline/test_sync_orchestrator.py @@ -132,9 +132,9 @@ def test_observer_hooks_fire(self): class RecordingObserver: def on_run_start(self, run_id): pass def on_run_end(self, run_id): pass - def on_node_start(self, node_label, node_hash): + def on_node_start(self, node_label, node_hash, **kwargs): events.append(("node_start", node_label)) - def on_node_end(self, node_label, node_hash): + def on_node_end(self, node_label, node_hash, **kwargs): events.append(("node_end", node_label)) def on_packet_start(self, node_label, tag, packet): events.append(("packet_start",)) @@ -213,9 +213,9 @@ def test_run_with_explicit_orchestrator(self): class RecordingObserver: def on_run_start(self, run_id): pass def on_run_end(self, run_id): pass - def on_node_start(self, node_label, node_hash): + def on_node_start(self, node_label, node_hash, **kwargs): events.append(("node_start", node_label)) - def on_node_end(self, node_label, node_hash): + def on_node_end(self, node_label, node_hash, **kwargs): events.append(("node_end", node_label)) def on_packet_start(self, node_label, tag, packet): events.append(("packet_start",)) @@ -423,9 +423,9 @@ def double_val(val: int) -> int: class RecordingObserver: def on_run_start(self, run_id): pass def on_run_end(self, run_id): pass - def on_node_start(self, node_label, node_hash): + def on_node_start(self, node_label, node_hash, **kwargs): events.append(("node_start", node_label)) - def on_node_end(self, node_label, node_hash): + def on_node_end(self, node_label, node_hash, **kwargs): events.append(("node_end", node_label)) def on_packet_start(self, node_label, tag, packet): events.append(("packet_start", node_label)) @@ -469,8 +469,8 @@ def test_function_node_cached_flag(self): class Obs1: def on_run_start(self, run_id): pass def on_run_end(self, run_id): pass - def on_node_start(self, node_label, node_hash): pass - def on_node_end(self, node_label, node_hash): pass + def on_node_start(self, node_label, node_hash, **kwargs): pass + def on_node_end(self, node_label, node_hash, **kwargs): pass def on_packet_start(self, node_label, tag, packet): pass def on_packet_end(self, node_label, tag, input_pkt, output_pkt, cached): if node_label == "doubler": @@ -491,8 +491,8 @@ def contextualize(self, node_hash, node_label): class Obs2: def on_run_start(self, run_id): pass def on_run_end(self, run_id): pass - def on_node_start(self, node_label, node_hash): pass - def on_node_end(self, node_label, node_hash): pass + def on_node_start(self, node_label, node_hash, **kwargs): pass + def on_node_end(self, node_label, node_hash, **kwargs): pass def on_packet_start(self, node_label, tag, packet): pass def on_packet_end(self, node_label, tag, input_pkt, output_pkt, cached): if node_label == "doubler": @@ -524,9 +524,9 @@ def test_diamond_dag_observer_event_order(self): class OrderObserver: def on_run_start(self, run_id): pass def on_run_end(self, run_id): pass - def on_node_start(self, node_label, node_hash): + def on_node_start(self, node_label, node_hash, **kwargs): node_order.append(("start", node_label)) - def on_node_end(self, node_label, node_hash): + def on_node_end(self, node_label, node_hash, **kwargs): node_order.append(("end", node_label)) def on_packet_start(self, node_label, tag, packet): pass def on_packet_end(self, node_label, tag, input_pkt, output_pkt, cached): pass