diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md
index b66939f..6f4908d 100644
--- a/RELEASE_NOTES.md
+++ b/RELEASE_NOTES.md
@@ -1,3 +1,7 @@
+### 4.10.0
+
+* Added `AsyncSeq.withCancellation` — returns a new `AsyncSeq` that passes the given `CancellationToken` to `GetAsyncEnumerator`, overriding whatever token would otherwise be supplied. Mirrors `TaskSeq.withCancellation` and is useful when consuming sequences from libraries (e.g. Entity Framework) that accept a cancellation token through `GetAsyncEnumerator`. Part of ongoing design-parity work with FSharp.Control.TaskSeq (see #277).
+
### 4.9.0
* Performance: `filterAsync` — replaced `asyncSeq`-builder implementation with a direct optimised enumerator, reducing allocation and generator overhead.
diff --git a/src/FSharp.Control.AsyncSeq/AsyncSeq.fs b/src/FSharp.Control.AsyncSeq/AsyncSeq.fs
index 90fdd40..642bab5 100644
--- a/src/FSharp.Control.AsyncSeq/AsyncSeq.fs
+++ b/src/FSharp.Control.AsyncSeq/AsyncSeq.fs
@@ -2522,6 +2522,13 @@ module AsyncSeq =
(emptyAsync fillChannelTask)
}
+ /// Returns a new AsyncSeq that passes the given CancellationToken to GetAsyncEnumerator,
+ /// overriding whatever token would otherwise be used. Useful when consuming sequences from
+ /// libraries (such as Entity Framework) that accept a CancellationToken through GetAsyncEnumerator.
+ let withCancellation (cancellationToken: CancellationToken) (source: AsyncSeq<'T>) : AsyncSeq<'T> =
+ { new IAsyncEnumerable<'T> with
+ member _.GetAsyncEnumerator(_ct) = source.GetAsyncEnumerator(cancellationToken) }
+
#endif
diff --git a/src/FSharp.Control.AsyncSeq/AsyncSeq.fsi b/src/FSharp.Control.AsyncSeq/AsyncSeq.fsi
index 5113dd6..b17f019 100644
--- a/src/FSharp.Control.AsyncSeq/AsyncSeq.fsi
+++ b/src/FSharp.Control.AsyncSeq/AsyncSeq.fsi
@@ -824,6 +824,14 @@ module AsyncSeq =
/// Transforms an async seq to a new one that fetches values ahead of time to improve throughput.
val prefetch<'T> : numberToPrefetch: int -> source: AsyncSeq<'T> -> AsyncSeq<'T>
+ ///
+ /// Returns a new AsyncSeq that passes the given CancellationToken to
+ /// GetAsyncEnumerator, overriding whatever token would otherwise be used when iterating.
+ /// This is useful when consuming sequences from libraries such as Entity Framework that
+ /// accept a CancellationToken through GetAsyncEnumerator.
+ ///
+ val withCancellation<'T> : cancellationToken: System.Threading.CancellationToken -> source: AsyncSeq<'T> -> AsyncSeq<'T>
+
#endif
diff --git a/tests/FSharp.Control.AsyncSeq.Tests/AsyncSeqTests.fs b/tests/FSharp.Control.AsyncSeq.Tests/AsyncSeqTests.fs
index 17277e6..6e39e77 100644
--- a/tests/FSharp.Control.AsyncSeq.Tests/AsyncSeqTests.fs
+++ b/tests/FSharp.Control.AsyncSeq.Tests/AsyncSeqTests.fs
@@ -3662,3 +3662,60 @@ let ``AsyncSeq.insertAt raises ArgumentException when index exceeds length`` ()
|> AsyncSeq.toArrayAsync
|> Async.RunSynchronously |> ignore)
|> ignore
+
+// ===== withCancellation =====
+
+[]
+let ``AsyncSeq.withCancellation passes token to enumerator`` () =
+ use cts = new System.Threading.CancellationTokenSource()
+ let receivedToken = ref System.Threading.CancellationToken.None
+ let source =
+ { new System.Collections.Generic.IAsyncEnumerable with
+ member _.GetAsyncEnumerator(ct) =
+ receivedToken.Value <- ct
+ (AsyncSeq.ofSeq [1; 2; 3]).GetAsyncEnumerator(ct) }
+ source
+ |> AsyncSeq.withCancellation cts.Token
+ |> AsyncSeq.toArrayAsync
+ |> Async.RunSynchronously
+ |> ignore
+ Assert.AreEqual(cts.Token, receivedToken.Value)
+
+[]
+let ``AsyncSeq.withCancellation overrides incoming token`` () =
+ use cts1 = new System.Threading.CancellationTokenSource()
+ use cts2 = new System.Threading.CancellationTokenSource()
+ let receivedToken = ref System.Threading.CancellationToken.None
+ let source : System.Collections.Generic.IAsyncEnumerable =
+ { new System.Collections.Generic.IAsyncEnumerable with
+ member _.GetAsyncEnumerator(ct) =
+ receivedToken.Value <- ct
+ (AsyncSeq.ofSeq [1; 2; 3]).GetAsyncEnumerator(ct) }
+ let wrapped = source |> AsyncSeq.withCancellation cts1.Token
+ // Enumerate with cts2's token - withCancellation should still pass cts1's token
+ let e = wrapped.GetAsyncEnumerator(cts2.Token)
+ e.MoveNextAsync().AsTask() |> Async.AwaitTask |> Async.RunSynchronously |> ignore
+ e.DisposeAsync() |> ignore
+ Assert.AreEqual(cts1.Token, receivedToken.Value)
+
+[]
+let ``AsyncSeq.withCancellation preserves sequence values`` () =
+ use cts = new System.Threading.CancellationTokenSource()
+ let result =
+ AsyncSeq.ofSeq [1; 2; 3; 4; 5]
+ |> AsyncSeq.withCancellation cts.Token
+ |> AsyncSeq.toArrayAsync
+ |> Async.RunSynchronously
+ Assert.AreEqual([| 1; 2; 3; 4; 5 |], result)
+
+[]
+let ``AsyncSeq.withCancellation with cancelled token raises OperationCanceledException`` () =
+ use cts = new System.Threading.CancellationTokenSource()
+ cts.Cancel()
+ Assert.Catch(fun () ->
+ AsyncSeq.ofSeq [1; 2; 3]
+ |> AsyncSeq.withCancellation cts.Token
+ |> AsyncSeq.toArrayAsync
+ |> Async.RunSynchronously
+ |> ignore)
+ |> ignore
diff --git a/version.props b/version.props
index 026dfba..989fd4a 100644
--- a/version.props
+++ b/version.props
@@ -1,5 +1,5 @@
- 4.8.0
+ 4.10.0