diff --git a/mdbook/src/SUMMARY.md b/mdbook/src/SUMMARY.md index 45ad4fb99..5b48e0181 100644 --- a/mdbook/src/SUMMARY.md +++ b/mdbook/src/SUMMARY.md @@ -36,3 +36,4 @@ - [Internals](./chapter_5/chapter_5.md) - [Communication](./chapter_5/chapter_5_1.md) - [Progress Tracking](./chapter_5/chapter_5_2.md) + - [Operator Fusion](./chapter_5/chapter_5_4.md) diff --git a/mdbook/src/chapter_5/chapter_5_4.md b/mdbook/src/chapter_5/chapter_5_4.md new file mode 100644 index 000000000..3d225e6ee --- /dev/null +++ b/mdbook/src/chapter_5/chapter_5_4.md @@ -0,0 +1,123 @@ +# Operator fusion + +When building dataflows, users often compose many small operators: a `map` followed by a `filter`, a `flat_map`, another `map`, and finally a `probe`. +Each operator is a separate node in the progress tracking graph, with its own `SharedProgress` handle, pointstamp accounting, and scheduling overhead. +For long pipelines, this overhead dominates actual computation. + +Operator fusion detects groups of operators that can be scheduled as a single unit, hiding intermediate nodes from the reachability tracker. +This section explains how fusion works and why it preserves correctness. + +## Which operators fuse + +Fusion applies to operators connected by pipeline (thread-local) channels where the group's internal progress tracking can be collapsed without losing information. +An operator is *fusible* if: + +* It does not observe frontiers (`notify == false`). + Frontier-observing operators buffer data until they receive a notification that a timestamp is complete. + Fusing them would require propagating frontiers within the group, which the scheduler does not do. +* All (input, output) pairs in its internal summary are the identity. + Non-identity summaries (like the feedback operator's `Product(0, 1)`) require per-member timestamp transformation that the group's aggregate reporting does not support. +* It has an operator implementation (not already tombstoned). + +An edge between two fusible operators is *fusible* if the target uses pipeline pact on the corresponding input port. +Exchange or broadcast pacts route data through inter-worker channels that the group scheduler cannot intercept. + +Operators connected by fusible edges are grouped using union-find. +Groups with fewer members than a configurable threshold (`fuse_chain_length`, default 2) are left alone. +There is no restriction on fan-in or fan-out: diamonds, concatenations, and branches all fuse. + +## How a fused group presents to the subgraph + +A fused group replaces its member operators with a single `GroupScheduler` installed at the representative slot (the lowest-indexed member). +All other members become tombstones. + +The group exposes: + +* **Group inputs**: member input ports that receive edges from outside the group. +* **Group outputs**: member output ports that send edges outside the group, or that have no outgoing edges (their capabilities still need tracking). + +The subgraph's `edge_stash` is rewritten: internal edges are removed, incoming edges are retargeted to the representative's group input ports, and outgoing edges are sourced from the representative's group output ports. + +## Scheduling + +Members are executed in topological order, computed by Kahn's algorithm over internal edges. +This guarantees that data pushed by a producer through a pipeline channel is available to its consumer when the consumer runs. + +The physical pipeline channels between members are established during operator construction and are unaffected by fusion. +Only the progress tracking layer changes. + +### Activation forwarding + +Pipeline channels activate the original target operator when data arrives. +After fusion, the target may be a tombstone. +Each tombstone records a `forward_to` field pointing to the group representative. +The subgraph's scheduling loop checks this field and redirects the activation. + +## Why the fused group reports correct progress + +The key insight is that because all members have identity summaries, a capability at any member's output port at timestamp `t` implies the same timestamp `t` at every reachable group output. +The timestamp does not change along any internal path. + +### Consumeds and produceds + +The group reports consumeds only for group input ports and produceds only for group output ports. +Intermediate consumeds and produceds (data passing between members through internal pipeline channels) would cancel in the reachability tracker: a member producing `(t, +d)` and the next member consuming `(t, -d)` net to zero. +Since the internal edges are removed from the tracker, these intermediate changes are simply not reported. + +### Internal capabilities + +Each member reports internal capability changes through its `SharedProgress.internals`. +In the unfused case, the reachability tracker sees each member's capabilities at their respective source locations and computes implications through the graph. + +The group scheduler aggregates each member's internal changes to the group outputs via a *capability map*. +This map is computed by a single reverse-topological pass over the group's internal DAG: + +1. Seed: member output ports that are group outputs map directly to themselves. +2. Reverse pass: for each member from last to first in topological order, for each output port, follow internal edges forward to downstream members. + Use the downstream member's summary to find which of its output ports are reachable from the targeted input port. + Union the reachability sets. + +This produces `capability_map[member][output_port] -> Vec`. + +When the group scheduler runs, it reads each member's `SharedProgress.internals` and reports them at every group output reached via the capability map. +Because all summaries are identity, this is equivalent to what the reachability tracker would compute by composing identity summaries along internal paths. + +### Initial capability accounting + +During `initialize()`, each member reports `+peers` capabilities at `T::minimum()` on its output ports. +The group transfers ALL members' initial capabilities to the group's `SharedProgress`, mapped through the capability map. +Members' initial internals are then cleared to prevent double-counting. + +This is necessary because each member independently drops its initial capability during execution, producing `(-peers)` changes that flow through the capability map. +If only one member's `+peers` were reported, the tracker would go negative. + +## Composed summary + +The group's `internal_summary` describes which group outputs are reachable from which group inputs. +For each group input, the scheduler finds which member output ports are reachable (via the member's own summary), then follows the capability map to group outputs. +If a path exists, the summary entry is the identity; otherwise no entry exists. + +This composed summary is used by the reachability tracker to determine implications from the group's sources to downstream operators. + +## What does not fuse + +Several classes of operators are excluded: + +* **Frontier-observing operators** (`notify == true`): `inspect`, `unary_frontier`, and any operator that requests notifications. + These need intra-group frontier propagation, which the group scheduler does not implement. +* **Operators with non-identity summaries**: the `Feedback` operator increments a loop counter coordinate. + Fusing it would require the group to transform timestamps along internal paths. +* **Exchange-pact operators**: data moves between workers through channels outside the group scheduler's control. +* **Operators in iteration scopes**: the nested timestamp structure typically involves non-identity summaries at scope boundaries. + +In practice, the operators that fuse are the "glue" operators: `map`, `flat_map`, `filter`, `Enter`, `Leave`, `Concatenate`, and similar single-purpose transformations. +In differential dataflow's BFS, fusion merges groups like `[Enter, Concatenate, Negate, AsCollection, Concatenate, ResultsIn]` into single scheduling units. + +## Configuration + +Fusion is controlled by `WorkerConfig::fuse_chain_length(n)`: + +* `n >= 2` (default): fuse groups of at least `n` members. +* `n == 0` or `n == 1`: disable fusion entirely. + +From the command line, pass `--fuse-chain-length N` to any timely program that uses `execute_from_args`. diff --git a/timely/examples/event_driven.rs b/timely/examples/event_driven.rs index 7a392d54e..a28f0447c 100644 --- a/timely/examples/event_driven.rs +++ b/timely/examples/event_driven.rs @@ -8,12 +8,25 @@ fn main() { let timer = std::time::Instant::now(); - let mut args = std::env::args(); - args.next(); + // Collect positional arguments, skipping flags consumed by timely (-w, -n, -p, -h). + let positional: Vec = { + let mut pos = Vec::new(); + let mut args = std::env::args(); + args.next(); // skip binary name + while let Some(arg) = args.next() { + if arg.starts_with('-') { + args.next(); // skip flag value + } else { + pos.push(arg); + } + } + pos + }; - let dataflows = args.next().unwrap().parse::().unwrap(); - let length = args.next().unwrap().parse::().unwrap(); - let record = args.next() == Some("record".to_string()); + let dataflows = positional[0].parse::().unwrap(); + let length = positional[1].parse::().unwrap(); + let record = positional.get(2).map(|s| s.as_str()) == Some("record"); + let rounds: usize = positional.get(3).map(|s| s.parse().unwrap()).unwrap_or(usize::MAX); let mut inputs = Vec::new(); let mut probes = Vec::new(); @@ -37,7 +50,7 @@ fn main() { println!("{:?}\tdataflows built ({} x {})", timer.elapsed(), dataflows, length); - for round in 0 .. { + for round in 0 .. rounds { let dataflow = round % dataflows; if record { inputs[dataflow].send(()); diff --git a/timely/examples/event_driven_diamond.rs b/timely/examples/event_driven_diamond.rs new file mode 100644 index 000000000..4045a6f17 --- /dev/null +++ b/timely/examples/event_driven_diamond.rs @@ -0,0 +1,67 @@ +use timely::dataflow::operators::{Input, Concat, Probe}; +use timely::dataflow::operators::vec::{Map, Filter}; + +fn main() { + timely::execute_from_args(std::env::args(), |worker| { + + let timer = std::time::Instant::now(); + + // Collect positional arguments, skipping flags consumed by timely (-w, -n, -p, -h). + let positional: Vec = { + let mut pos = Vec::new(); + let mut args = std::env::args(); + args.next(); // skip binary name + while let Some(arg) = args.next() { + if arg.starts_with('-') { + args.next(); // skip flag value + } else { + pos.push(arg); + } + } + pos + }; + + let dataflows = positional[0].parse::().unwrap(); + let diamonds = positional[1].parse::().unwrap(); + let record = positional.get(2).map(|s| s.as_str()) == Some("record"); + let rounds: usize = positional.get(3).map(|s| s.parse().unwrap()).unwrap_or(usize::MAX); + + let mut inputs = Vec::new(); + let mut probes = Vec::new(); + + // Each dataflow builds a chain of diamond patterns: + // input -> map (left) + map (right) -> concat -> ... -> probe + // Each diamond has 3 operators (map, map, concat). + // The clone/branch doesn't create an operator — it reuses the stream's Tee. + for _dataflow in 0..dataflows { + worker.dataflow(|scope| { + let (input, mut stream) = scope.new_input(); + for _diamond in 0..diamonds { + let left = stream.clone().map(|x: ()| x); + let right = stream.filter(|_| false).map(|x: ()| x); + stream = left.concat(right).container::>(); + } + let (probe, _stream) = stream.probe(); + inputs.push(input); + probes.push(probe); + }); + } + + println!("{:?}\tdataflows built ({} x {} diamonds)", timer.elapsed(), dataflows, diamonds); + + for round in 0..rounds { + let dataflow = round % dataflows; + if record { + inputs[dataflow].send(()); + } + inputs[dataflow].advance_to(round); + let mut steps = 0; + while probes[dataflow].less_than(&round) { + worker.step(); + steps += 1; + } + println!("{:?}\tround {} complete in {} steps", timer.elapsed(), round, steps); + } + + }).unwrap(); +} diff --git a/timely/src/dataflow/channels/pact.rs b/timely/src/dataflow/channels/pact.rs index 296e87a8f..5499ba956 100644 --- a/timely/src/dataflow/channels/pact.rs +++ b/timely/src/dataflow/channels/pact.rs @@ -25,6 +25,8 @@ pub trait ParallelizationContract { type Puller: Pull>+'static; /// Allocates a matched pair of push and pull endpoints implementing the pact. fn connect(self, allocator: &mut A, identifier: usize, address: Rc<[usize]>, logging: Option) -> (Self::Pusher, Self::Puller); + /// Indicates whether this pact uses a thread-local channel (no inter-worker exchange). + fn is_pipeline(&self) -> bool { false } } /// A direct connection @@ -34,6 +36,7 @@ pub struct Pipeline; impl ParallelizationContract for Pipeline { type Pusher = LogPusher>>; type Puller = LogPuller>>; + fn is_pipeline(&self) -> bool { true } fn connect(self, allocator: &mut A, identifier: usize, address: Rc<[usize]>, logging: Option) -> (Self::Pusher, Self::Puller) { let (pusher, puller) = allocator.pipeline::>(identifier, address); (LogPusher::new(pusher, allocator.index(), allocator.index(), identifier, logging.clone()), diff --git a/timely/src/dataflow/operators/generic/builder_raw.rs b/timely/src/dataflow/operators/generic/builder_raw.rs index 7839241bc..52a228f2e 100644 --- a/timely/src/dataflow/operators/generic/builder_raw.rs +++ b/timely/src/dataflow/operators/generic/builder_raw.rs @@ -27,6 +27,7 @@ pub struct OperatorShape { peers: usize, // The total number of workers in the computation. Needed to initialize pointstamp counts with the correct magnitude. inputs: usize, // The number of input ports. outputs: usize, // The number of output ports. + pipeline: bool, // Whether all inputs use Pipeline pact (thread-local channels). } /// Core data for the structure of an operator, minus scope and logic. @@ -38,6 +39,7 @@ impl OperatorShape { peers, inputs: 0, outputs: 0, + pipeline: true, } } @@ -110,6 +112,7 @@ impl OperatorBuilder { { let channel_id = self.scope.new_identifier(); let logging = self.scope.logging(); + if !pact.is_pipeline() { self.shape.pipeline = false; } let (sender, receiver) = pact.connect(&mut self.scope, channel_id, Rc::clone(&self.address), logging); let target = Target::new(self.index, self.shape.inputs); stream.connect_to(target, sender, channel_id); @@ -224,4 +227,5 @@ where } fn notify_me(&self) -> &[FrontierInterest] { &self.shape.notify } + fn pipeline(&self) -> bool { self.shape.pipeline } } diff --git a/timely/src/dataflow/scopes/child.rs b/timely/src/dataflow/scopes/child.rs index ee3cbb12d..50c84aeaa 100644 --- a/timely/src/dataflow/scopes/child.rs +++ b/timely/src/dataflow/scopes/child.rs @@ -148,7 +148,17 @@ where }; func(&mut builder) }; - let subscope = subscope.into_inner().build(self); + let mut subscope = subscope.into_inner(); + + // Register the fusion pass if enabled. + let fuse_chain_length = self.parent.config().fuse_chain_length; + if fuse_chain_length >= 2 { + subscope.add_graph_pass(Box::new( + crate::progress::fusion::FusionPass::new(fuse_chain_length) + )); + } + + let subscope = subscope.build(self); self.add_operator_with_indices(Box::new(subscope), index, identifier); diff --git a/timely/src/progress/fusion.rs b/timely/src/progress/fusion.rs new file mode 100644 index 000000000..5f2ecffff --- /dev/null +++ b/timely/src/progress/fusion.rs @@ -0,0 +1,640 @@ +//! Pipeline group fusion: detects and fuses groups of pipeline-connected operators. +//! +//! This module implements operator fusion as a `GraphPass`. Groups of operators +//! connected by pipeline (thread-local) channels with identity summaries are +//! fused into a single `GroupScheduler` operator. The group is scheduled as a +//! unit in topological order, hiding intermediate progress from the reachability +//! tracker. + +use std::rc::Rc; +use std::cell::RefCell; +use std::collections::HashMap; + +use crate::scheduling::Schedule; +use crate::progress::{Source, Target, Timestamp}; +use crate::progress::operate::SharedProgress; +use crate::progress::operate::{FrontierInterest, Connectivity, PortConnectivity}; + +use super::subgraph::PerOperatorState; +use super::graph_pass::GraphPass; + +/// A graph pass that fuses groups of pipeline-connected operators. +/// +/// Operators are eligible for fusion when they have identity internal summaries, +/// do not require notifications (`notify = false`), and are connected via +/// pipeline (thread-local) channels. Groups of at least `min_length` eligible +/// operators are detected using union-find and fused into a single operator. +pub struct FusionPass { + /// Minimum group size for fusion. + min_length: usize, +} + +impl FusionPass { + /// Creates a new fusion pass with the given minimum group size. + pub fn new(min_length: usize) -> Self { + FusionPass { min_length } + } +} + +impl GraphPass for FusionPass { + fn apply(&self, children: &mut Vec>, edges: &mut Vec<(Source, Target)>) { + let groups = detect_groups(children, edges, self.min_length); + for group in groups { + fuse_group::(children, edges, &group); + } + } +} + +/// Returns true if an operator has identity internal summaries on all (input, output) pairs. +/// That is, every connected (input, output) pair has summary `Antichain::from_elem(Default::default())`. +fn has_identity_summary(child: &PerOperatorState) -> bool { + for input_pc in child.internal_summary.iter() { + for (_port, ac) in input_pc.iter_ports() { + if ac.len() != 1 || ac.elements()[0] != Default::default() { + return false; + } + } + } + // Must have at least one connection (empty summary means no paths). + child.internal_summary.iter().any(|pc| pc.iter_ports().next().is_some()) +} + +/// Returns true if an operator is eligible for group fusion. +fn is_fusible(child: &PerOperatorState) -> bool { + child.operator.is_some() + && child.notify.iter().all(|n| *n == FrontierInterest::Never) + && has_identity_summary(child) +} + +/// Detects fusible groups of operators connected by pipeline edges. +/// +/// Uses union-find to group operators connected by fusible edges into components. +/// An edge is fusible when both endpoints are fusible operators and the target +/// uses pipeline (thread-local) channels. No fan-in/fan-out or 1-in/1-out restriction. +/// +/// Returns groups of at least `min_length` operators, identified by child index. +fn detect_groups( + children: &[PerOperatorState], + edge_stash: &[(Source, Target)], + min_length: usize, +) -> Vec> { + // Mark fusible operators. + let fusible: Vec = children.iter().enumerate().map(|(i, child)| { + i != 0 && is_fusible(child) + }).collect(); + + // Union-Find structure. + let n = children.len(); + let mut parent: Vec = (0..n).collect(); + let mut rank: Vec = vec![0; n]; + + fn find(parent: &mut [usize], x: usize) -> usize { + if parent[x] != x { + parent[x] = find(parent, parent[x]); + } + parent[x] + } + + fn union(parent: &mut [usize], rank: &mut [usize], a: usize, b: usize) { + let ra = find(parent, a); + let rb = find(parent, b); + if ra == rb { return; } + if rank[ra] < rank[rb] { + parent[ra] = rb; + } else if rank[ra] > rank[rb] { + parent[rb] = ra; + } else { + parent[rb] = ra; + rank[ra] += 1; + } + } + + // For each edge, if both endpoints are fusible and target uses pipeline pact, union them. + for (source, target) in edge_stash.iter() { + let src = source.node; + let tgt = target.node; + if src == 0 || tgt == 0 { continue; } + if !fusible[src] || !fusible[tgt] { continue; } + if !children[tgt].pipeline { continue; } + union(&mut parent, &mut rank, src, tgt); + } + + // Collect components. + let mut components: HashMap> = HashMap::new(); + for i in 1..n { + if fusible[i] { + let root = find(&mut parent, i); + components.entry(root).or_default().push(i); + } + } + + // Filter by minimum size and sort members for determinism. + components.into_values() + .filter(|group| group.len() >= min_length) + .map(|mut group| { group.sort(); group }) + .collect() +} + +/// Topological sort of group members using Kahn's algorithm on internal edges. +fn topological_sort( + members: &[usize], + edge_stash: &[(Source, Target)], +) -> Vec { + let member_set: std::collections::HashSet = members.iter().cloned().collect(); + let member_to_pos: HashMap = members.iter().enumerate().map(|(i, &m)| (m, i)).collect(); + let n = members.len(); + + let mut in_degree = vec![0usize; n]; + let mut adj: Vec> = vec![Vec::new(); n]; + + for (source, target) in edge_stash.iter() { + if member_set.contains(&source.node) && member_set.contains(&target.node) { + let from = member_to_pos[&source.node]; + let to = member_to_pos[&target.node]; + // Avoid counting duplicate edges for the same (from, to) pair multiple times + // for in-degree. We track adjacency; Kahn's handles it correctly. + adj[from].push(to); + in_degree[to] += 1; + } + } + + let mut queue: std::collections::VecDeque = std::collections::VecDeque::new(); + for i in 0..n { + if in_degree[i] == 0 { + queue.push_back(i); + } + } + + let mut order = Vec::with_capacity(n); + while let Some(pos) = queue.pop_front() { + order.push(members[pos]); + for &next in &adj[pos] { + in_degree[next] -= 1; + if in_degree[next] == 0 { + queue.push_back(next); + } + } + } + + assert_eq!(order.len(), n, "group contains a cycle, which should be impossible with identity summaries"); + order +} + +/// A member of a fused group, holding the original operator and its progress handle. +struct GroupMember { + operator: Box, + shared_progress: Rc>>, +} + +/// Schedules a DAG of pipeline-connected operators as a single unit. +/// +/// The group presents as a single operator to the subgraph with `input_map.len()` inputs +/// and `output_map.len()` outputs. Members are scheduled in topological order. +/// Intermediate progress is hidden from the reachability tracker. +struct GroupScheduler { + name: String, + path: Vec, + /// Progress visible to the subgraph. + group_progress: Rc>>, + /// Operators in topological order, with their individual SharedProgress handles. + members: Vec>, + /// Group input i -> (member index in members vec, member input port) + input_map: Vec<(usize, usize)>, + /// Group output j -> (member index in members vec, member output port) + output_map: Vec<(usize, usize)>, + /// capability_map[member_idx][output_port] -> list of group output indices + capability_map: Vec>>, +} + +impl Schedule for GroupScheduler { + fn name(&self) -> &str { &self.name } + fn path(&self) -> &[usize] { &self.path } + + fn schedule(&mut self) -> bool { + let n = self.members.len(); + assert!(n > 0); + + // Step 1: Forward group's input frontier changes to the appropriate members. + { + let mut group_sp = self.group_progress.borrow_mut(); + for (i, &(member_idx, member_port)) in self.input_map.iter().enumerate() { + let mut member_sp = self.members[member_idx].shared_progress.borrow_mut(); + for (time, diff) in group_sp.frontiers[i].iter() { + member_sp.frontiers[member_port].update(time.clone(), *diff); + } + } + } + + // Step 2: Schedule each member in topological order. + let mut any_incomplete = false; + for i in 0..n { + let incomplete = self.members[i].operator.schedule(); + any_incomplete = any_incomplete || incomplete; + } + + // Step 3: Aggregate progress. + { + let mut group_sp = self.group_progress.borrow_mut(); + + // consumeds: for each group input, take from the corresponding member. + for (i, &(member_idx, member_port)) in self.input_map.iter().enumerate() { + let mut member_sp = self.members[member_idx].shared_progress.borrow_mut(); + for (time, diff) in member_sp.consumeds[member_port].iter() { + group_sp.consumeds[i].update(time.clone(), *diff); + } + } + + // produceds: for each group output, take from the corresponding member. + for (j, &(member_idx, member_port)) in self.output_map.iter().enumerate() { + let mut member_sp = self.members[member_idx].shared_progress.borrow_mut(); + for (time, diff) in member_sp.produceds[member_port].iter() { + group_sp.produceds[j].update(time.clone(), *diff); + } + } + + // internals: for each member's output port, report at mapped group outputs. + for (m, member) in self.members.iter().enumerate() { + let mut member_sp = member.shared_progress.borrow_mut(); + for (port, internal) in member_sp.internals.iter_mut().enumerate() { + for (time, diff) in internal.iter() { + for &group_out in &self.capability_map[m][port] { + group_sp.internals[group_out].update(time.clone(), *diff); + } + } + } + } + } + + // Step 4: Clear all members' SharedProgress to prevent accumulation. + for member in self.members.iter() { + let mut sp = member.shared_progress.borrow_mut(); + for batch in sp.frontiers.iter_mut() { batch.clear(); } + for batch in sp.consumeds.iter_mut() { batch.clear(); } + for batch in sp.internals.iter_mut() { batch.clear(); } + for batch in sp.produceds.iter_mut() { batch.clear(); } + } + + // Clear the group's own frontiers (consumed by step 1). + { + let mut group_sp = self.group_progress.borrow_mut(); + for batch in group_sp.frontiers.iter_mut() { batch.clear(); } + } + + any_incomplete + } +} + +/// Computes reachability for all (member, output_port) pairs in a single reverse-topological pass. +/// +/// Returns `capability_map[topo_pos][output_port] -> sorted Vec`. +/// Since all summaries are identity, timestamps don't change along any path. +fn compute_all_reachability( + topo_order: &[usize], + children: &[PerOperatorState], + internal_edges: &HashMap<(usize, usize), Vec<(usize, usize)>>, + member_summaries: &HashMap>, + output_port_to_group_output: &HashMap<(usize, usize), Vec>, +) -> Vec>> { + let n = topo_order.len(); + + // Build reverse lookup: node -> topo_pos. + let node_to_topo: HashMap = topo_order.iter().enumerate() + .map(|(i, &node)| (node, i)) + .collect(); + + // reachable[(node, output_port)] -> set of group output indices + // Use a flat Vec indexed by (topo_pos, port) for fast access. + // First, compute a port offset table. + let mut port_offset = Vec::with_capacity(n); + let mut total_ports = 0usize; + for &node in topo_order.iter() { + port_offset.push(total_ports); + total_ports += children[node].outputs; + } + + // Each entry is a sorted Vec of reachable group outputs. + let mut reachable: Vec> = vec![Vec::new(); total_ports]; + + // Seed: output ports that are directly group outputs. + for (&(node, port), group_outs) in output_port_to_group_output.iter() { + if let Some(&topo_pos) = node_to_topo.get(&node) { + let idx = port_offset[topo_pos] + port; + reachable[idx] = group_outs.clone(); + reachable[idx].sort(); + reachable[idx].dedup(); + } + } + + // Reverse topological pass: propagate reachability backward through edges. + for rev_pos in (0..n).rev() { + let node = topo_order[rev_pos]; + let num_outputs = children[node].outputs; + + // For each output port of this node, follow internal edges forward + // and union the reachability of the downstream ports. + for port in 0..num_outputs { + if let Some(targets) = internal_edges.get(&(node, port)) { + for &(next_node, next_input_port) in targets { + // Use next node's summary to find which output ports are reachable from this input. + if let Some(connections) = member_summaries.get(&next_node) { + if let Some(&next_topo) = node_to_topo.get(&next_node) { + for &(inp, outp) in connections.iter() { + if inp == next_input_port { + // Merge reachable[next_topo][outp] into reachable[rev_pos][port]. + let src_idx = port_offset[next_topo] + outp; + let dst_idx = port_offset[rev_pos] + port; + if src_idx != dst_idx { + // Clone to avoid double borrow. + let to_add = reachable[src_idx].clone(); + let dst = &mut reachable[dst_idx]; + dst.extend_from_slice(&to_add); + } + } + } + } + } + } + } + // Deduplicate after merging all edges for this port. + let idx = port_offset[rev_pos] + port; + reachable[idx].sort(); + reachable[idx].dedup(); + } + } + + // Reshape into capability_map[topo_pos][output_port]. + let mut capability_map = Vec::with_capacity(n); + for (topo_pos, &node) in topo_order.iter().enumerate() { + let num_outputs = children[node].outputs; + let mut port_map = Vec::with_capacity(num_outputs); + for port in 0..num_outputs { + let idx = port_offset[topo_pos] + port; + port_map.push(std::mem::take(&mut reachable[idx])); + } + capability_map.push(port_map); + } + + capability_map +} + +/// Fuses a detected group into a single operator within `children`, rewriting `edge_stash`. +/// +/// The representative (lowest index in group) retains its slot and becomes the fused operator. +/// All other group members become tombstones: their operator is removed, inputs/outputs set +/// to zero, and `forward_to` is set to the representative for activation forwarding. +fn fuse_group( + children: &mut [PerOperatorState], + edge_stash: &mut Vec<(Source, Target)>, + group: &[usize], +) { + assert!(group.len() >= 2); + let group_set: std::collections::HashSet = group.iter().cloned().collect(); + let representative = *group.iter().min().unwrap(); + + // Step 1: Compute topological order. + let topo_order = topological_sort(group, edge_stash); + let node_to_topo: HashMap = topo_order.iter().enumerate().map(|(i, &n)| (n, i)).collect(); + + // Step 2: Compute input_map and output_map by scanning edges. + // Group inputs: (member_node, input_port) pairs where at least one incoming edge originates outside the group. + // Group outputs: (member_node, output_port) pairs where at least one outgoing edge targets outside the group, + // OR the port has no outgoing edges at all within the edge_stash. + let mut group_input_set: std::collections::HashSet<(usize, usize)> = std::collections::HashSet::new(); + let mut group_output_set: std::collections::HashSet<(usize, usize)> = std::collections::HashSet::new(); + let mut has_outgoing: std::collections::HashSet<(usize, usize)> = std::collections::HashSet::new(); + + // Collect all output ports of group members. + let mut all_output_ports: std::collections::HashSet<(usize, usize)> = std::collections::HashSet::new(); + for &node in group.iter() { + for port in 0..children[node].outputs { + all_output_ports.insert((node, port)); + } + } + + // Build internal edges map: (src_node, src_port) -> [(tgt_node, tgt_port)] + let mut internal_edges: HashMap<(usize, usize), Vec<(usize, usize)>> = HashMap::new(); + + for (source, target) in edge_stash.iter() { + let src_in = group_set.contains(&source.node); + let tgt_in = group_set.contains(&target.node); + + if src_in { + has_outgoing.insert((source.node, source.port)); + } + + if src_in && tgt_in { + // Internal edge + internal_edges.entry((source.node, source.port)) + .or_default() + .push((target.node, target.port)); + } else if !src_in && tgt_in { + // Incoming edge from outside + group_input_set.insert((target.node, target.port)); + } else if src_in && !tgt_in { + // Outgoing edge to outside + group_output_set.insert((source.node, source.port)); + } + } + + // Output ports with no outgoing edges at all are also group outputs. + for &(node, port) in &all_output_ports { + if !has_outgoing.contains(&(node, port)) { + group_output_set.insert((node, port)); + } + } + + // Sort and assign indices for determinism. + let mut input_map: Vec<(usize, usize)> = group_input_set.into_iter() + .map(|(node, port)| (node_to_topo[&node], port)) + .collect(); + input_map.sort(); + // Convert back: input_map elements are (topo_position, port) + + let mut output_map: Vec<(usize, usize)> = group_output_set.into_iter() + .map(|(node, port)| (node_to_topo[&node], port)) + .collect(); + output_map.sort(); + + // Build reverse lookups. + // (node, input_port) -> group input index + let input_port_to_group_input: HashMap<(usize, usize), usize> = input_map.iter().enumerate() + .map(|(i, &(topo_pos, port))| ((topo_order[topo_pos], port), i)) + .collect(); + + // (node, output_port) -> group output indices + let mut output_port_to_group_output: HashMap<(usize, usize), Vec> = HashMap::new(); + for (j, &(topo_pos, port)) in output_map.iter().enumerate() { + output_port_to_group_output.entry((topo_order[topo_pos], port)) + .or_default() + .push(j); + } + + // Step 3: Compute member summaries (input_port, output_port) connections for each node. + let mut member_summaries: HashMap> = HashMap::new(); + for &node in group.iter() { + let mut connections = Vec::new(); + for (inp_idx, pc) in children[node].internal_summary.iter().enumerate() { + for (out_port, _ac) in pc.iter_ports() { + connections.push((inp_idx, out_port)); + } + } + member_summaries.insert(node, connections); + } + + // Step 4: Compute capability_map via reachability (single reverse-topological pass). + // capability_map[topo_pos][output_port] -> list of group output indices + let capability_map = compute_all_reachability( + &topo_order, children, &internal_edges, &member_summaries, &output_port_to_group_output, + ); + + // Step 5: Compute composed summary for the group. + // For each (group_input_i, group_output_j): if there's a reachability path, set identity summary. + let num_inputs = input_map.len(); + let num_outputs = output_map.len(); + + let mut composed_summary: Connectivity = Vec::with_capacity(num_inputs); + for &(topo_pos, port) in input_map.iter() { + let node = topo_order[topo_pos]; + let mut pc = PortConnectivity::default(); + + // Find which group outputs are reachable from this input. + // Use the node's summary to find output ports reachable from this input port, + // then use capability_map for those output ports. + if let Some(connections) = member_summaries.get(&node) { + for &(inp, outp) in connections.iter() { + if inp == port { + for &group_out in &capability_map[topo_pos][outp] { + pc.insert(group_out, Default::default()); + } + } + } + } + composed_summary.push(pc); + } + + // Step 6: Extract members in topological order. + let mut members = Vec::with_capacity(topo_order.len()); + let mut group_name_parts = Vec::new(); + let mut representative_path = Vec::new(); + + for &node in topo_order.iter() { + let child = &mut children[node]; + let operator = child.operator.take().expect("group member must have an operator"); + let shared_progress = Rc::clone(&child.shared_progress); + + if node == representative { + representative_path = operator.path().to_vec(); + } + group_name_parts.push(child.name.clone()); + + members.push(GroupMember { + operator, + shared_progress, + }); + } + + let group_name = format!("Group[{}]", group_name_parts.join(", ")); + + // Step 7: Create the group's SharedProgress. + let group_progress = Rc::new(RefCell::new(SharedProgress::new(num_inputs, num_outputs))); + + // Transfer initial internal capabilities from ALL members to group_progress, + // mapped through capability_map. + { + let mut group_sp = group_progress.borrow_mut(); + for (topo_pos, member) in members.iter().enumerate() { + let mut member_sp = member.shared_progress.borrow_mut(); + for (port, internal) in member_sp.internals.iter_mut().enumerate() { + for (time, diff) in internal.iter() { + for &group_out in &capability_map[topo_pos][port] { + group_sp.internals[group_out].update(time.clone(), *diff); + } + } + } + } + } + + // Clear all members' internals to prevent double-counting during initialize(). + for member in members.iter() { + let mut sp = member.shared_progress.borrow_mut(); + for batch in sp.internals.iter_mut() { batch.clear(); } + } + + let group_scheduler: Box = Box::new(GroupScheduler { + name: group_name.clone(), + path: representative_path, + group_progress: Rc::clone(&group_progress), + members, + input_map: input_map.clone(), + output_map: output_map.clone(), + capability_map, + }); + + // Step 8: Install the fused operator at the representative slot. + // Edges are left empty here; the build method populates them from edge_stash. + let head = &mut children[representative]; + head.name = group_name; + head.operator = Some(group_scheduler); + head.shared_progress = group_progress; + head.internal_summary = composed_summary; + head.notify = vec![FrontierInterest::Never; num_inputs]; + head.inputs = num_inputs; + head.outputs = num_outputs; + head.edges = vec![Vec::new(); num_outputs]; + + // Step 9: Tombstone all other group members, forwarding activations to representative. + for &node in group.iter() { + if node == representative { continue; } + let child = &mut children[node]; + child.name = format!("Tombstone({})", child.name); + child.operator = None; + child.shared_progress = Rc::new(RefCell::new(SharedProgress::new(0, 0))); + child.edges = Vec::new(); + child.inputs = 0; + child.outputs = 0; + child.internal_summary = Vec::new(); + child.forward_to = Some(representative); + } + + // Step 10: Rewrite edge_stash. + // Remove edges where both endpoints are in the group. + // Rewrite edges incoming to group members: target.node = representative, target.port = group_input_index. + // Rewrite edges outgoing from group members: source.node = representative, source.port = group_output_index. + let mut new_edge_stash: Vec<(Source, Target)> = Vec::new(); + + for (source, target) in edge_stash.iter() { + let src_in = group_set.contains(&source.node); + let tgt_in = group_set.contains(&target.node); + + if src_in && tgt_in { + // Internal edge: remove. + continue; + } else if !src_in && tgt_in { + // Incoming edge: rewrite target. + if let Some(&group_input) = input_port_to_group_input.get(&(target.node, target.port)) { + new_edge_stash.push(( + *source, + Target::new(representative, group_input), + )); + } + } else if src_in && !tgt_in { + // Outgoing edge: rewrite source. + let topo_pos = node_to_topo[&source.node]; + if let Some(group_outs) = output_port_to_group_output.get(&(source.node, source.port)) { + for &group_out in group_outs { + if output_map[group_out] == (topo_pos, source.port) { + new_edge_stash.push(( + Source::new(representative, group_out), + *target, + )); + } + } + } + } else { + // Neither endpoint in group: keep as-is. + new_edge_stash.push((*source, *target)); + } + } + + *edge_stash = new_edge_stash; +} diff --git a/timely/src/progress/graph_pass.rs b/timely/src/progress/graph_pass.rs new file mode 100644 index 000000000..f38a1b28f --- /dev/null +++ b/timely/src/progress/graph_pass.rs @@ -0,0 +1,35 @@ +//! Graph transformation pass API for subgraph optimization. +//! +//! A `GraphPass` receives the graph topology (children and edges) and can transform +//! it before the reachability tracker is built. This allows optimizations like +//! operator fusion to be implemented as pluggable passes, decoupled from the +//! progress tracking code in `subgraph.rs`. + +use crate::progress::{Source, Target, Timestamp}; +use super::subgraph::PerOperatorState; + +/// A graph transformation pass that runs during `SubgraphBuilder::build()`. +/// +/// Passes receive the children vector and edge list, and may transform them +/// in place. A pass that merges operators should tombstone the absorbed +/// operators (setting their inputs/outputs to 0, clearing their operator, +/// and optionally setting `forward_to` for activation forwarding). +/// +/// Passes run sequentially in registration order, each seeing the output +/// of the previous pass. +pub(crate) trait GraphPass { + /// Transform the graph topology in place. + /// + /// The `children` vector is indexed by operator index (child 0 is the + /// subgraph boundary). The `edges` vector contains all (source, target) + /// pairs representing dataflow connections. + /// + /// Implementations may: + /// * Modify operator state (replace operators, change port counts) + /// * Add or remove edges + /// * Tombstone operators by clearing their fields and setting `forward_to` + /// + /// Implementations must preserve the length of `children` (indices are + /// used by the reachability tracker). + fn apply(&self, children: &mut Vec>, edges: &mut Vec<(Source, Target)>); +} diff --git a/timely/src/progress/mod.rs b/timely/src/progress/mod.rs index 1ff95a977..09d2081c7 100644 --- a/timely/src/progress/mod.rs +++ b/timely/src/progress/mod.rs @@ -15,6 +15,8 @@ pub mod operate; pub mod broadcast; pub mod reachability; pub mod subgraph; +pub(crate) mod graph_pass; +pub(crate) mod fusion; /// A timely dataflow location. #[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Serialize, Deserialize, Columnar)] diff --git a/timely/src/progress/operate.rs b/timely/src/progress/operate.rs index 6c3f12955..92dbf8017 100644 --- a/timely/src/progress/operate.rs +++ b/timely/src/progress/operate.rs @@ -61,6 +61,14 @@ pub trait Operate { /// frontier changes on that input should cause the operator to be scheduled. The conservative /// default is `Always` for each input. fn notify_me(&self) -> &[FrontierInterest];// { &vec![FrontierInterest::Always; self.inputs()] } + + /// Indicates whether all inputs use thread-local (pipeline) channels. + /// + /// Operators with pipeline inputs receive data through thread-local channels, + /// meaning data pushed by an upstream operator on the same worker is immediately + /// available. This property is used by chain fusion to determine whether + /// consecutive operators can be scheduled as a single unit. + fn pipeline(&self) -> bool { true } } /// The ways in which an operator can express interest in activation when an input frontier changes. diff --git a/timely/src/progress/subgraph.rs b/timely/src/progress/subgraph.rs index e11011368..12336753b 100644 --- a/timely/src/progress/subgraph.rs +++ b/timely/src/progress/subgraph.rs @@ -26,6 +26,7 @@ use crate::progress::reachability; use crate::progress::timestamp::Refines; use crate::worker::ProgressMode; +use crate::progress::graph_pass::GraphPass; // IMPORTANT : by convention, a child identifier of zero is used to indicate inputs and outputs of // the Subgraph itself. An identifier greater than zero corresponds to an actual child, which can @@ -69,6 +70,9 @@ where logging: Option, /// Typed logging handle for operator summaries. summary_logging: Option>, + + /// Graph transformation passes to run during `build()`. + graph_passes: Vec>>, } impl SubgraphBuilder @@ -123,6 +127,7 @@ where output_capabilities: Vec::new(), logging, summary_logging, + graph_passes: Vec::new(), } } @@ -132,6 +137,15 @@ where self.child_count - 1 } + /// Registers a graph transformation pass to run during `build()`. + /// + /// Passes run sequentially in registration order, each seeing the output + /// of the previous pass. They execute after operators are initialized but + /// before the reachability tracker is built. + pub(crate) fn add_graph_pass(&mut self, pass: Box>) { + self.graph_passes.push(pass); + } + /// Adds a new child to the subgraph. pub fn add_child(&mut self, child: Box>, index: usize, identifier: usize) { let child = PerOperatorState::new(child, index, identifier, self.logging.clone(), &mut self.summary_logging); @@ -166,12 +180,19 @@ where // Create empty child zero representative. self.children[0] = PerOperatorState::empty(outputs, inputs); + // Run registered graph transformation passes (e.g., operator fusion). + for pass in self.graph_passes.iter() { + pass.apply(&mut self.children, &mut self.edge_stash); + } + let mut builder = reachability::Builder::new(); // Child 0 has `inputs` outputs and `outputs` inputs, not yet connected. let summary = (0..outputs).map(|_| PortConnectivity::default()).collect(); builder.add_node(0, outputs, inputs, summary); for (index, child) in self.children.iter().enumerate().skip(1) { + // Tombstoned children are added with (0, 0) inputs/outputs and empty summary + // to preserve index positions in the reachability tracker. builder.add_node(index, child.inputs, child.outputs, child.internal_summary.clone()); } @@ -192,7 +213,13 @@ where let mut incomplete = vec![true; self.children.len()]; incomplete[0] = false; - let incomplete_count = incomplete.len() - 1; + // Tombstoned children are not incomplete. + for (i, child) in self.children.iter().enumerate().skip(1) { + if child.inputs == 0 && child.outputs == 0 && child.operator.is_none() { + incomplete[i] = false; + } + } + let incomplete_count = incomplete.iter().filter(|&&b| b).count(); let activations = worker.activations(); @@ -321,13 +348,16 @@ where // // We should be able to schedule arbitrary subsets of children, as // long as we eventually schedule all children that need to do work. - let mut previous = 0; + let mut scheduled = std::collections::HashSet::new(); + scheduled.insert(0); // Child 0 is the subgraph boundary, never scheduled. while let Some(Reverse(index)) = self.temp_active.pop() { - // De-duplicate, and don't revisit. - if index > previous { + if !scheduled.insert(index) { continue; } + // Tombstoned group members forward activations to their representative. + if let Some(fwd) = self.children[index].forward_to { + self.temp_active.push(Reverse(fwd)); + } else { // TODO: This is a moment where a scheduling decision happens. self.activate_child(index); - previous = index; } } @@ -606,26 +636,43 @@ where fn notify_me(&self) -> &[FrontierInterest] { &self.notify_me } } -struct PerOperatorState { +/// Per-operator state within a subgraph. +/// +/// Each child operator in a subgraph has an associated `PerOperatorState` that +/// tracks its scheduling state, progress information, and graph connectivity. +/// +/// Graph passes may modify these fields to implement transformations like +/// operator fusion. In particular, a pass that merges operators should +/// tombstone absorbed operators by clearing their fields and setting +/// `forward_to` for activation forwarding. +pub(crate) struct PerOperatorState { - name: String, // name of the operator - index: usize, // index of the operator within its parent scope - id: usize, // worker-unique identifier + pub(crate) name: String, // name of the operator + pub(crate) index: usize, // index of the operator within its parent scope + pub(crate) id: usize, // worker-unique identifier - local: bool, // indicates whether the operator will exchange data or not - notify: Vec, - inputs: usize, // number of inputs to the operator - outputs: usize, // number of outputs from the operator + pub(crate) local: bool, // indicates whether progress information is pre-circulated or not + pub(crate) notify: Vec, + pub(crate) pipeline: bool, // indicates whether all inputs use thread-local (pipeline) channels + pub(crate) inputs: usize, // number of inputs to the operator + pub(crate) outputs: usize, // number of outputs from the operator - operator: Option>, + pub(crate) operator: Option>, - edges: Vec>, // edges from the outputs of the operator + pub(crate) edges: Vec>, // edges from the outputs of the operator - shared_progress: Rc>>, + pub(crate) shared_progress: Rc>>, - internal_summary: Connectivity, // cached result from initialize. + pub(crate) internal_summary: Connectivity, // cached result from initialize. - logging: Option, + pub(crate) logging: Option, + + /// For tombstoned operators: forward activations to this operator index instead. + /// + /// When a graph pass merges multiple operators into one, the absorbed operators + /// are tombstoned and their activations are forwarded to the representative + /// operator that replaced them. + pub(crate) forward_to: Option, } impl PerOperatorState { @@ -638,6 +685,7 @@ impl PerOperatorState { id: usize::MAX, local: false, notify: vec![FrontierInterest::IfCapability; inputs], + pipeline: false, inputs, outputs, @@ -647,6 +695,7 @@ impl PerOperatorState { shared_progress: Rc::new(RefCell::new(SharedProgress::new(inputs,outputs))), internal_summary: Vec::new(), + forward_to: None, } } @@ -662,6 +711,7 @@ impl PerOperatorState { let inputs = scope.inputs(); let outputs = scope.outputs(); let notify = scope.notify_me().to_vec(); + let pipeline = scope.pipeline(); let (internal_summary, shared_progress, operator) = scope.initialize(); @@ -691,6 +741,7 @@ impl PerOperatorState { id: identifier, local, notify, + pipeline, inputs, outputs, edges: vec![vec![]; outputs], @@ -699,6 +750,7 @@ impl PerOperatorState { shared_progress, internal_summary, + forward_to: None, } } @@ -823,3 +875,4 @@ impl Drop for PerOperatorState { self.shut_down(); } } + diff --git a/timely/src/worker.rs b/timely/src/worker.rs index 70051214a..e95886946 100644 --- a/timely/src/worker.rs +++ b/timely/src/worker.rs @@ -78,14 +78,29 @@ impl FromStr for ProgressMode { } /// Worker configuration. -#[derive(Debug, Default, Clone)] +#[derive(Debug, Clone)] pub struct Config { /// The progress mode to use. pub(crate) progress_mode: ProgressMode, + /// Minimum chain length for pipeline chain fusion (default: 2). + /// + /// Chains of pipeline-connected operators shorter than this threshold + /// will not be fused. Set to 0 to disable fusion entirely. + pub(crate) fuse_chain_length: usize, /// A map from parameter name to typed parameter values. registry: HashMap>, } +impl Default for Config { + fn default() -> Self { + Config { + progress_mode: ProgressMode::default(), + fuse_chain_length: 2, + registry: HashMap::new(), + } + } +} + impl Config { /// Installs options into a [getopts::Options] struct that correspond /// to the parameters in the configuration. @@ -99,6 +114,7 @@ impl Config { #[cfg(feature = "getopts")] pub fn install_options(opts: &mut getopts::Options) { opts.optopt("", "progress-mode", "progress tracking mode (eager or demand)", "MODE"); + opts.optopt("", "fuse-chain-length", "minimum chain length for pipeline fusion (0 disables, default: 2)", "N"); } /// Instantiates a configuration based upon the parsed options in `matches`. @@ -113,7 +129,10 @@ impl Config { pub fn from_matches(matches: &getopts::Matches) -> Result { let progress_mode = matches .opt_get_default("progress-mode", ProgressMode::Demand)?; - Ok(Config::default().progress_mode(progress_mode)) + let fuse_chain_length: usize = matches + .opt_get_default("fuse-chain-length", 2) + .map_err(|e: std::num::ParseIntError| e.to_string())?; + Ok(Config::default().progress_mode(progress_mode).fuse_chain_length(fuse_chain_length)) } /// Sets the progress mode to `progress_mode`. @@ -122,6 +141,15 @@ impl Config { self } + /// Sets the minimum chain length for pipeline chain fusion (default: 2). + /// + /// Chains of pipeline-connected operators shorter than this threshold + /// will not be fused. Set to 0 to disable fusion entirely. + pub fn fuse_chain_length(mut self, fuse_chain_length: usize) -> Self { + self.fuse_chain_length = fuse_chain_length; + self + } + /// Sets a typed configuration parameter for the given `key`. /// /// It is recommended to install a single configuration struct using a key @@ -668,7 +696,17 @@ impl Worker { func(&mut resources, &mut builder) }; - let operator = subscope.into_inner().build(self); + let mut subscope = subscope.into_inner(); + + // Register the fusion pass if enabled. + let fuse_chain_length = self.config().fuse_chain_length; + if fuse_chain_length >= 2 { + subscope.add_graph_pass(Box::new( + crate::progress::fusion::FusionPass::new(fuse_chain_length) + )); + } + + let operator = subscope.build(self); if let Some(l) = logging.as_mut() { l.log(crate::logging::OperatesEvent { diff --git a/timely/tests/chain_fusion.rs b/timely/tests/chain_fusion.rs new file mode 100644 index 000000000..a3ca309af --- /dev/null +++ b/timely/tests/chain_fusion.rs @@ -0,0 +1,428 @@ +//! Tests for pipeline group fusion. + +use std::sync::{Arc, Mutex}; +use timely::dataflow::operators::{ToStream, Concat, Inspect, Probe, Feedback, ConnectLoop}; +use timely::dataflow::operators::vec::{Map, Filter, Partition}; + +/// Verifies that a chain of map operators produces correct output. +#[test] +fn chain_fusion_correctness() { + let result = Arc::new(Mutex::new(Vec::new())); + let result2 = Arc::clone(&result); + + timely::execute_from_args(std::env::args(), move |worker| { + let result3 = Arc::clone(&result2); + worker.dataflow::(|scope| { + (0..10u64) + .to_stream(scope) + .map(|x| x + 1) + .map(|x| x * 2) + .map(|x| x + 10) + .inspect(move |x| { + result3.lock().unwrap().push(*x); + }); + }); + }).unwrap(); + + let mut got = result.lock().unwrap().clone(); + got.sort(); + let expected: Vec = (0..10).map(|x| (x + 1) * 2 + 10).collect(); + assert_eq!(got, expected); +} + +/// Verifies that a longer chain of maps (5 operators) produces correct output. +#[test] +fn chain_fusion_long_chain() { + let result = Arc::new(Mutex::new(Vec::new())); + let result2 = Arc::clone(&result); + + timely::execute_from_args(std::env::args(), move |worker| { + let result3 = Arc::clone(&result2); + worker.dataflow::(|scope| { + (0..5u64) + .to_stream(scope) + .map(|x| x + 1) + .map(|x| x * 2) + .map(|x| x + 3) + .map(|x| x * 4) + .map(|x| x + 5) + .inspect(move |x| { + result3.lock().unwrap().push(*x); + }); + }); + }).unwrap(); + + let mut got = result.lock().unwrap().clone(); + got.sort(); + let expected: Vec = (0..5).map(|x| ((x + 1) * 2 + 3) * 4 + 5).collect(); + assert_eq!(got, expected); +} + +/// Verifies that fusion works with probe (which tests that the dataflow completes). +#[test] +fn chain_fusion_with_probe() { + timely::execute_from_args(std::env::args(), move |worker| { + let probe = worker.dataflow::(|scope| { + (0..100u64) + .to_stream(scope) + .map(|x| x + 1) + .map(|x| x * 2) + .map(|x| x + 10) + .probe() + .0 + }); + + worker.step_while(|| probe.less_than(&usize::MAX)); + }).unwrap(); +} + +/// Verifies that fusion is disabled when fuse_chain_length is 0. +#[test] +fn chain_fusion_disabled() { + let result = Arc::new(Mutex::new(Vec::new())); + let result2 = Arc::clone(&result); + + let config = timely::Config { + communication: timely::CommunicationConfig::Thread, + worker: timely::WorkerConfig::default().fuse_chain_length(0), + }; + + timely::execute(config, move |worker| { + let result3 = Arc::clone(&result2); + worker.dataflow::(|scope| { + (0..10u64) + .to_stream(scope) + .map(|x| x + 1) + .map(|x| x * 2) + .inspect(move |x| { + result3.lock().unwrap().push(*x); + }); + }); + }).unwrap(); + + let mut got = result.lock().unwrap().clone(); + got.sort(); + let expected: Vec = (0..10).map(|x| (x + 1) * 2).collect(); + assert_eq!(got, expected); +} + +/// Verifies that fusion works with notify=true operators (inspect uses unary_frontier). +/// This test drives multiple rounds to exercise frontier propagation within the fused chain. +#[test] +fn chain_fusion_notify_operator() { + let result = Arc::new(Mutex::new(Vec::new())); + let result2 = Arc::clone(&result); + + timely::execute_from_args(std::env::args(), move |worker| { + let result3 = Arc::clone(&result2); + let (mut input, probe) = worker.dataflow::(|scope| { + use timely::dataflow::operators::Input; + let (input, stream) = scope.new_input(); + let probe = stream + .map(|x: u64| x + 1) + .map(|x| x * 2) + .inspect(move |x| { + result3.lock().unwrap().push(*x); + }) + .probe() + .0; + (input, probe) + }); + + for round in 0..5usize { + input.send(round as u64); + input.advance_to(round + 1); + worker.step_while(|| probe.less_than(&(round + 1))); + } + }).unwrap(); + + let mut got = result.lock().unwrap().clone(); + got.sort(); + let expected: Vec = (0..5).map(|x: u64| (x + 1) * 2).collect(); + assert_eq!(got, expected); +} + +/// Verifies that fusion works with a unary_notify operator that buffers data +/// and emits on frontier notification. This exercises frontier propagation within +/// the fused chain. +#[test] +fn chain_fusion_unary_notify() { + use timely::dataflow::channels::pact::Pipeline; + use timely::dataflow::operators::generic::operator::Operator; + + let result = Arc::new(Mutex::new(Vec::new())); + let result2 = Arc::clone(&result); + + timely::execute_from_args(std::env::args(), move |worker| { + let result3 = Arc::clone(&result2); + let (mut input, probe) = worker.dataflow::(|scope| { + use timely::dataflow::operators::Input; + let (input, stream) = scope.new_input(); + let probe = stream + .map(|x: u64| x + 1) + // A unary_notify operator that buffers data and emits on notification. + // This is a notify=true operator with 1 input and 1 output. + .unary_notify(Pipeline, "Buffer", vec![], { + let mut stash: std::collections::HashMap> = std::collections::HashMap::new(); + move |input, output, notificator| { + input.for_each(|time, data| { + stash.entry(time.time().clone()) + .or_insert_with(Vec::new) + .extend(data.drain(..)); + notificator.notify_at(time.retain(0)); + }); + notificator.for_each(|time, _count, _notify| { + if let Some(data) = stash.remove(time.time()) { + let mut session = output.session(&time); + for datum in data { + session.give(datum); + } + } + }); + } + }) + .map(|x: u64| x * 10) + .inspect(move |x| { + result3.lock().unwrap().push(*x); + }) + .probe() + .0; + (input, probe) + }); + + for round in 0..5usize { + input.send(round as u64); + input.advance_to(round + 1); + worker.step_while(|| probe.less_than(&(round + 1))); + } + }).unwrap(); + + let mut got = result.lock().unwrap().clone(); + got.sort(); + // Each value: (x + 1) buffered by unary_notify, then * 10 + let expected: Vec = (0..5).map(|x: u64| (x + 1) * 10).collect(); + assert_eq!(got, expected); +} + +/// Verifies that flat_map (single in/out, pipeline, notify=false) is also fused. +#[test] +fn chain_fusion_flat_map() { + let result = Arc::new(Mutex::new(Vec::new())); + let result2 = Arc::clone(&result); + + timely::execute_from_args(std::env::args(), move |worker| { + let result3 = Arc::clone(&result2); + worker.dataflow::(|scope| { + (0..5u64) + .to_stream(scope) + .map(|x| x + 1) + .flat_map(|x| vec![x, x * 10]) + .map(|x| x + 100) + .inspect(move |x| { + result3.lock().unwrap().push(*x); + }); + }); + }).unwrap(); + + let mut got = result.lock().unwrap().clone(); + got.sort(); + let mut expected: Vec = (0..5u64) + .map(|x| x + 1) + .flat_map(|x| vec![x, x * 10]) + .map(|x| x + 100) + .collect(); + expected.sort(); + assert_eq!(got, expected); +} + +/// Diamond pattern: stream -> map (left) + map (right) -> concat -> inspect. +/// All operators are fusible (!notify, pipeline, identity summary). +#[test] +fn group_fusion_diamond() { + let result = Arc::new(Mutex::new(Vec::new())); + let result2 = Arc::clone(&result); + + timely::execute_from_args(std::env::args(), move |worker| { + let result3 = Arc::clone(&result2); + worker.dataflow::(|scope| { + let stream = (0..10u64).to_stream(scope); + let left = stream.clone().map(|x| x + 1); + let right = stream.map(|x| x + 100); + left.concat(right) + .map(|x| x * 2) + .inspect(move |x| { + result3.lock().unwrap().push(*x); + }); + }); + }).unwrap(); + + let mut got = result.lock().unwrap().clone(); + got.sort(); + let mut expected: Vec = (0..10u64) + .flat_map(|x| vec![(x + 1) * 2, (x + 100) * 2]) + .collect(); + expected.sort(); + assert_eq!(got, expected); +} + +/// Diamond with probe: verifies dataflow completion with DAG fusion. +#[test] +fn group_fusion_diamond_with_probe() { + let result = Arc::new(Mutex::new(Vec::new())); + let result2 = Arc::clone(&result); + + timely::execute_from_args(std::env::args(), move |worker| { + let result3 = Arc::clone(&result2); + let (mut input, probe) = worker.dataflow::(|scope| { + use timely::dataflow::operators::Input; + let (input, stream) = scope.new_input(); + let left = stream.clone().map(|x: u64| x + 1); + let right = stream.map(|x: u64| x + 100); + let probe = left.concat(right) + .map(|x| x * 2) + .inspect(move |x| { + result3.lock().unwrap().push(*x); + }) + .probe() + .0; + (input, probe) + }); + + for round in 0..5usize { + input.send(round as u64); + input.advance_to(round + 1); + worker.step_while(|| probe.less_than(&(round + 1))); + } + }).unwrap(); + + let mut got = result.lock().unwrap().clone(); + got.sort(); + let mut expected: Vec = (0..5u64) + .flat_map(|x| vec![(x + 1) * 2, (x + 100) * 2]) + .collect(); + expected.sort(); + assert_eq!(got, expected); +} + +/// Multi-input merge: two independent input streams -> concat -> map. +#[test] +fn group_fusion_multi_input_merge() { + let result = Arc::new(Mutex::new(Vec::new())); + let result2 = Arc::clone(&result); + + timely::execute_from_args(std::env::args(), move |worker| { + let result3 = Arc::clone(&result2); + worker.dataflow::(|scope| { + let s1 = (0..5u64).to_stream(scope).map(|x| x + 1); + let s2 = (10..15u64).to_stream(scope).map(|x| x + 1); + s1.concat(s2) + .map(|x| x * 3) + .inspect(move |x| { + result3.lock().unwrap().push(*x); + }); + }); + }).unwrap(); + + let mut got = result.lock().unwrap().clone(); + got.sort(); + let mut expected: Vec = (0..5u64).map(|x| (x + 1) * 3) + .chain((10..15u64).map(|x| (x + 1) * 3)) + .collect(); + expected.sort(); + assert_eq!(got, expected); +} + +/// Branch without merge: map -> (map + map) with two separate outputs consumed by inspect. +/// The two branches are not merged back, testing fan-out group outputs. +#[test] +fn group_fusion_branch() { + let left_result = Arc::new(Mutex::new(Vec::new())); + let right_result = Arc::new(Mutex::new(Vec::new())); + let left_result2 = Arc::clone(&left_result); + let right_result2 = Arc::clone(&right_result); + + timely::execute_from_args(std::env::args(), move |worker| { + let left_result3 = Arc::clone(&left_result2); + let right_result3 = Arc::clone(&right_result2); + worker.dataflow::(|scope| { + let stream = (0..5u64).to_stream(scope).map(|x| x + 1); + // Two branches from the same map output. + stream.clone().map(|x| x * 2) + .inspect(move |x| { + left_result3.lock().unwrap().push(*x); + }); + stream.map(|x| x * 3) + .inspect(move |x| { + right_result3.lock().unwrap().push(*x); + }); + }); + }).unwrap(); + + let mut got_left = left_result.lock().unwrap().clone(); + got_left.sort(); + let expected_left: Vec = (0..5).map(|x| (x + 1) * 2).collect(); + assert_eq!(got_left, expected_left); + + let mut got_right = right_result.lock().unwrap().clone(); + got_right.sort(); + let expected_right: Vec = (0..5).map(|x| (x + 1) * 3).collect(); + assert_eq!(got_right, expected_right); +} + +/// Collatz mutual recursion with feedback loops. +/// This exercises DAG fusion with external feedback edges. +#[test] +fn group_fusion_collatz_mutual_recursion() { + let config = timely::Config { + communication: timely::CommunicationConfig::Thread, + worker: timely::WorkerConfig::default(), + }; + + timely::execute(config, |worker| { + worker.dataflow::(|scope| { + let (handle0, stream0) = scope.feedback(1); + let (handle1, stream1) = scope.feedback(1); + + let results0 = stream0.map(|x: u64| x / 2).filter(|x| *x != 1); + let results1 = stream1.map(|x: u64| 3 * x + 1); + + let mut parts = + (1u64..10) + .to_stream(scope) + .concat(results0) + .concat(results1) + .inspect(|_x| {}) + .partition(2, |x| (x % 2, x)); + + parts.pop().unwrap().connect_loop(handle1); + parts.pop().unwrap().connect_loop(handle0); + }); + }).unwrap(); +} + +/// Repeated diamond chain with probe: tests larger fused groups. +#[test] +fn group_fusion_repeated_diamonds_with_probe() { + use timely::dataflow::operators::Input; + + timely::execute_from_args(std::env::args(), move |worker| { + let (mut input, probe) = worker.dataflow(|scope| { + let (input, mut stream) = scope.new_input(); + for _diamond in 0..15 { + let left = stream.clone().map(|x: u64| x); + let right = stream.map(|x: u64| x); + stream = left.concat(right).container::>(); + } + let (probe, _stream) = stream.probe(); + (input, probe) + }); + + for round in 0..5usize { + input.send(0u64); + input.advance_to(round); + while probe.less_than(&round) { + worker.step(); + } + } + }).unwrap(); +}