Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions src/orcapod/core/nodes/function_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,18 +433,13 @@ def keys(
def pipeline_path(self) -> tuple[str, ...]:
"""Return the pipeline path for DB record scoping.

Raises:
RuntimeError: If no database is attached and this is not a
read-only deserialized node.
Returns ``()`` when no pipeline database is attached.
"""
stored = getattr(self, "_stored_pipeline_path", None)
if self._packet_function is None and stored is not None:
return stored
if self._pipeline_database is None:
raise RuntimeError(
"Cannot compute pipeline_path without an attached database. "
"Call attach_databases() first."
)
return ()
return (
self._pipeline_path_prefix
+ self._packet_function.uri
Expand Down Expand Up @@ -521,7 +516,9 @@ def execute(

obs = observer if observer is not None else NoOpObserver()

obs.on_node_start(node_label, node_hash)
pp = self.pipeline_path
tag_schema = input_stream.output_schema(columns={"system_tags": True})[0]
obs.on_node_start(node_label, node_hash, pipeline_path=pp, tag_schema=tag_schema)

# Gather entry IDs and check cache
upstream_entries = [
Expand All @@ -531,8 +528,6 @@ def execute(
entry_ids = [eid for _, _, eid in upstream_entries]
cached = self.get_cached_results(entry_ids=entry_ids)

pp = self.pipeline_path if self._pipeline_database is not None else ()

output: list[tuple[TagProtocol, PacketProtocol]] = []
for tag, packet, entry_id in upstream_entries:
obs.on_packet_start(node_label, tag, packet)
Expand All @@ -559,7 +554,7 @@ def execute(
)
obs.on_packet_crash(node_label, tag, packet, exc)
if error_policy == "fail_fast":
obs.on_node_end(node_label, node_hash)
obs.on_node_end(node_label, node_hash, pipeline_path=pp)
raise
else:
obs.on_packet_end(
Expand All @@ -568,7 +563,7 @@ def execute(
if result is not None:
output.append((tag_out, result))

obs.on_node_end(node_label, node_hash)
obs.on_node_end(node_label, node_hash, pipeline_path=pp)
return output

def _process_packet_internal(
Expand Down Expand Up @@ -1243,8 +1238,11 @@ async def async_execute(

obs = observer if observer is not None else NoOpObserver()

pp = self.pipeline_path

try:
obs.on_node_start(node_label, node_hash)
tag_schema = self._input_stream.output_schema(columns={"system_tags": True})[0]
obs.on_node_start(node_label, node_hash, pipeline_path=pp, tag_schema=tag_schema)

if self._cached_function_pod is not None:
# DB-backed async execution:
Expand Down Expand Up @@ -1322,7 +1320,7 @@ async def async_execute(
node_hash=node_hash,
)

obs.on_node_end(node_label, node_hash)
obs.on_node_end(node_label, node_hash, pipeline_path=pp)
finally:
await output.close()

Expand All @@ -1337,7 +1335,7 @@ async def _async_execute_one_packet(
node_hash: str,
) -> None:
"""Process one non-cached packet in the async execute path."""
pp = self.pipeline_path if self._pipeline_database is not None else ()
pp = self.pipeline_path

observer.on_packet_start(node_label, tag, packet)
ctx_obs = observer.contextualize(node_hash, node_label)
Expand Down
4 changes: 4 additions & 0 deletions src/orcapod/pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from .async_orchestrator import AsyncPipelineOrchestrator
from .composite_observer import CompositeObserver
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are starting to accumulate multiple types of observers.Let's add an issue to aggregate them into its own sub package in the future

from .graph import Pipeline
from .logging_observer import LoggingObserver, PacketLogger
from .serialization import LoadStatus, PIPELINE_FORMAT_VERSION
from .status_observer import StatusObserver
from .sync_orchestrator import SyncPipelineOrchestrator

__all__ = [
"AsyncPipelineOrchestrator",
"CompositeObserver",
"LoadStatus",
"LoggingObserver",
"PacketLogger",
"PIPELINE_FORMAT_VERSION",
"Pipeline",
"StatusObserver",
"SyncPipelineOrchestrator",
]
111 changes: 111 additions & 0 deletions src/orcapod/pipeline/composite_observer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Composite observer that delegates to multiple child observers.

Provides ``CompositeObserver``, a multiplexer that forwards every
``ExecutionObserverProtocol`` hook to N child observers. This allows
combining observers (e.g. ``LoggingObserver`` + ``StatusObserver``)
without modifying the orchestrator or node code.

Example::

from orcapod.pipeline.composite_observer import CompositeObserver
from orcapod.pipeline.logging_observer import LoggingObserver
from orcapod.pipeline.status_observer import StatusObserver
from orcapod.databases import InMemoryArrowDatabase

log_obs = LoggingObserver(log_database=InMemoryArrowDatabase())
status_obs = StatusObserver(status_database=InMemoryArrowDatabase())
observer = CompositeObserver(log_obs, status_obs)

pipeline.run(orchestrator=SyncPipelineOrchestrator(observer=observer))
"""

from __future__ import annotations

from typing import Any

from orcapod.pipeline.observer import NoOpLogger
from orcapod.protocols.core_protocols import PacketProtocol, TagProtocol
from orcapod.types import SchemaLike

_NOOP_LOGGER = NoOpLogger()


class CompositeObserver:
"""Observer that delegates all hooks to multiple child observers.

Args:
*observers: Child observers to delegate to. Each must satisfy
``ExecutionObserverProtocol``.
"""

def __init__(self, *observers: Any) -> None:
self._observers = observers

def contextualize(
self, node_hash: str, node_label: str
) -> CompositeObserver:
"""Return a composite of contextualized children."""
return CompositeObserver(
*(obs.contextualize(node_hash, node_label) for obs in self._observers)
)

def on_run_start(self, run_id: str) -> None:
for obs in self._observers:
obs.on_run_start(run_id)

def on_run_end(self, run_id: str) -> None:
for obs in self._observers:
obs.on_run_end(run_id)

def on_node_start(
self, node_label: str, node_hash: str, pipeline_path: tuple[str, ...] = (), tag_schema: SchemaLike | None = None
) -> None:
for obs in self._observers:
obs.on_node_start(node_label, node_hash, pipeline_path=pipeline_path, tag_schema=tag_schema)

def on_node_end(
self, node_label: str, node_hash: str, pipeline_path: tuple[str, ...] = ()
) -> None:
for obs in self._observers:
obs.on_node_end(node_label, node_hash, pipeline_path=pipeline_path)

def on_packet_start(
self, node_label: str, tag: TagProtocol, packet: PacketProtocol
) -> None:
for obs in self._observers:
obs.on_packet_start(node_label, tag, packet)

def on_packet_end(
self,
node_label: str,
tag: TagProtocol,
input_packet: PacketProtocol,
output_packet: PacketProtocol | None,
cached: bool,
) -> None:
for obs in self._observers:
obs.on_packet_end(node_label, tag, input_packet, output_packet, cached)

def on_packet_crash(
self, node_label: str, tag: TagProtocol, packet: PacketProtocol, error: Exception
) -> None:
for obs in self._observers:
obs.on_packet_crash(node_label, tag, packet, error)

def create_packet_logger(
self,
tag: TagProtocol,
packet: PacketProtocol,
pipeline_path: tuple[str, ...] = (),
) -> Any:
"""Return the first non-no-op logger from children.

Iterates through child observers and returns the logger from the
first child that provides a real (non-no-op) implementation.
Falls back to a no-op logger if all children return no-ops.
"""
for obs in self._observers:
pkt_logger = obs.create_packet_logger(tag, packet, pipeline_path=pipeline_path)
if not isinstance(pkt_logger, NoOpLogger):
return pkt_logger
return _NOOP_LOGGER
50 changes: 30 additions & 20 deletions src/orcapod/pipeline/logging_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
from uuid_utils import uuid7

from orcapod.pipeline.logging_capture import install_capture_streams
from orcapod.protocols.core_protocols import PacketProtocol, TagProtocol
from orcapod.types import SchemaLike

if TYPE_CHECKING:
import pyarrow as pa
Expand Down Expand Up @@ -173,36 +175,40 @@ def on_run_start(self, run_id: str) -> None:
def on_run_end(self, run_id: str) -> None:
self._parent.on_run_end(run_id)

def on_node_start(self, node_label: str, node_hash: str) -> None:
self._parent.on_node_start(node_label, node_hash)
def on_node_start(
self, node_label: str, node_hash: str, pipeline_path: tuple[str, ...] = (), tag_schema: SchemaLike | None = None
) -> None:
self._parent.on_node_start(node_label, node_hash, pipeline_path=pipeline_path, tag_schema=tag_schema)

def on_node_end(self, node_label: str, node_hash: str) -> None:
self._parent.on_node_end(node_label, node_hash)
def on_node_end(
self, node_label: str, node_hash: str, pipeline_path: tuple[str, ...] = ()
) -> None:
self._parent.on_node_end(node_label, node_hash, pipeline_path=pipeline_path)

def on_packet_start(
self, node_label: str, tag: Any, packet: Any
self, node_label: str, tag: TagProtocol, packet: PacketProtocol
) -> None:
self._parent.on_packet_start(node_label, tag, packet)

def on_packet_end(
self,
node_label: str,
tag: Any,
input_packet: Any,
output_packet: Any,
tag: TagProtocol,
input_packet: PacketProtocol,
output_packet: PacketProtocol | None,
cached: bool,
) -> None:
self._parent.on_packet_end(node_label, tag, input_packet, output_packet, cached)

def on_packet_crash(
self, node_label: str, tag: Any, packet: Any, error: Exception
self, node_label: str, tag: TagProtocol, packet: PacketProtocol, error: Exception
) -> None:
self._parent.on_packet_crash(node_label, tag, packet, error)

def create_packet_logger(
self,
tag: Any,
packet: Any,
tag: TagProtocol,
packet: PacketProtocol,
pipeline_path: tuple[str, ...] = (),
) -> PacketLogger:
"""Create a logger using context from this wrapper."""
Expand Down Expand Up @@ -282,34 +288,38 @@ def on_run_start(self, run_id: str) -> None:
def on_run_end(self, run_id: str) -> None:
pass

def on_node_start(self, node_label: str, node_hash: str) -> None:
def on_node_start(
self, node_label: str, node_hash: str, pipeline_path: tuple[str, ...] = (), tag_schema: SchemaLike | None = None
) -> None:
pass

def on_node_end(self, node_label: str, node_hash: str) -> None:
def on_node_end(
self, node_label: str, node_hash: str, pipeline_path: tuple[str, ...] = ()
) -> None:
pass

def on_packet_start(self, node_label: str, tag: Any, packet: Any) -> None:
def on_packet_start(self, node_label: str, tag: TagProtocol, packet: PacketProtocol) -> None:
pass

def on_packet_end(
self,
node_label: str,
tag: Any,
input_packet: Any,
output_packet: Any,
tag: TagProtocol,
input_packet: PacketProtocol,
output_packet: PacketProtocol | None,
cached: bool,
) -> None:
pass

def on_packet_crash(
self, node_label: str, tag: Any, packet: Any, error: Exception
self, node_label: str, tag: TagProtocol, packet: PacketProtocol, error: Exception
) -> None:
pass

def create_packet_logger(
self,
tag: Any,
packet: Any,
tag: TagProtocol,
packet: PacketProtocol,
pipeline_path: tuple[str, ...] = (),
) -> PacketLogger:
"""Return a :class:`PacketLogger` bound to *tag* context.
Expand Down
9 changes: 7 additions & 2 deletions src/orcapod/pipeline/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any

from orcapod.protocols.core_protocols import PacketProtocol, TagProtocol
from orcapod.types import SchemaLike
from orcapod.protocols.observability_protocols import ( # noqa: F401 (re-exported for convenience)
ExecutionObserverProtocol,
PacketExecutionLoggerProtocol,
Expand Down Expand Up @@ -58,10 +59,14 @@ def on_run_start(self, run_id: str) -> None:
def on_run_end(self, run_id: str) -> None:
pass

def on_node_start(self, node_label: str, node_hash: str) -> None:
def on_node_start(
self, node_label: str, node_hash: str, pipeline_path: tuple[str, ...] = (), tag_schema: SchemaLike | None = None
) -> None:
pass

def on_node_end(self, node_label: str, node_hash: str) -> None:
def on_node_end(
self, node_label: str, node_hash: str, pipeline_path: tuple[str, ...] = ()
) -> None:
pass

def on_packet_start(
Expand Down
Loading
Loading