diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index b66939f..6f4908d 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -1,3 +1,7 @@ +### 4.10.0 + +* Added `AsyncSeq.withCancellation` — returns a new `AsyncSeq` that passes the given `CancellationToken` to `GetAsyncEnumerator`, overriding whatever token would otherwise be supplied. Mirrors `TaskSeq.withCancellation` and is useful when consuming sequences from libraries (e.g. Entity Framework) that accept a cancellation token through `GetAsyncEnumerator`. Part of ongoing design-parity work with FSharp.Control.TaskSeq (see #277). + ### 4.9.0 * Performance: `filterAsync` — replaced `asyncSeq`-builder implementation with a direct optimised enumerator, reducing allocation and generator overhead. diff --git a/src/FSharp.Control.AsyncSeq/AsyncSeq.fs b/src/FSharp.Control.AsyncSeq/AsyncSeq.fs index 90fdd40..642bab5 100644 --- a/src/FSharp.Control.AsyncSeq/AsyncSeq.fs +++ b/src/FSharp.Control.AsyncSeq/AsyncSeq.fs @@ -2522,6 +2522,13 @@ module AsyncSeq = (emptyAsync fillChannelTask) } + /// Returns a new AsyncSeq that passes the given CancellationToken to GetAsyncEnumerator, + /// overriding whatever token would otherwise be used. Useful when consuming sequences from + /// libraries (such as Entity Framework) that accept a CancellationToken through GetAsyncEnumerator. + let withCancellation (cancellationToken: CancellationToken) (source: AsyncSeq<'T>) : AsyncSeq<'T> = + { new IAsyncEnumerable<'T> with + member _.GetAsyncEnumerator(_ct) = source.GetAsyncEnumerator(cancellationToken) } + #endif diff --git a/src/FSharp.Control.AsyncSeq/AsyncSeq.fsi b/src/FSharp.Control.AsyncSeq/AsyncSeq.fsi index 5113dd6..b17f019 100644 --- a/src/FSharp.Control.AsyncSeq/AsyncSeq.fsi +++ b/src/FSharp.Control.AsyncSeq/AsyncSeq.fsi @@ -824,6 +824,14 @@ module AsyncSeq = /// Transforms an async seq to a new one that fetches values ahead of time to improve throughput. val prefetch<'T> : numberToPrefetch: int -> source: AsyncSeq<'T> -> AsyncSeq<'T> + /// + /// Returns a new AsyncSeq that passes the given CancellationToken to + /// GetAsyncEnumerator, overriding whatever token would otherwise be used when iterating. + /// This is useful when consuming sequences from libraries such as Entity Framework that + /// accept a CancellationToken through GetAsyncEnumerator. + /// + val withCancellation<'T> : cancellationToken: System.Threading.CancellationToken -> source: AsyncSeq<'T> -> AsyncSeq<'T> + #endif diff --git a/tests/FSharp.Control.AsyncSeq.Tests/AsyncSeqTests.fs b/tests/FSharp.Control.AsyncSeq.Tests/AsyncSeqTests.fs index 17277e6..6e39e77 100644 --- a/tests/FSharp.Control.AsyncSeq.Tests/AsyncSeqTests.fs +++ b/tests/FSharp.Control.AsyncSeq.Tests/AsyncSeqTests.fs @@ -3662,3 +3662,60 @@ let ``AsyncSeq.insertAt raises ArgumentException when index exceeds length`` () |> AsyncSeq.toArrayAsync |> Async.RunSynchronously |> ignore) |> ignore + +// ===== withCancellation ===== + +[] +let ``AsyncSeq.withCancellation passes token to enumerator`` () = + use cts = new System.Threading.CancellationTokenSource() + let receivedToken = ref System.Threading.CancellationToken.None + let source = + { new System.Collections.Generic.IAsyncEnumerable with + member _.GetAsyncEnumerator(ct) = + receivedToken.Value <- ct + (AsyncSeq.ofSeq [1; 2; 3]).GetAsyncEnumerator(ct) } + source + |> AsyncSeq.withCancellation cts.Token + |> AsyncSeq.toArrayAsync + |> Async.RunSynchronously + |> ignore + Assert.AreEqual(cts.Token, receivedToken.Value) + +[] +let ``AsyncSeq.withCancellation overrides incoming token`` () = + use cts1 = new System.Threading.CancellationTokenSource() + use cts2 = new System.Threading.CancellationTokenSource() + let receivedToken = ref System.Threading.CancellationToken.None + let source : System.Collections.Generic.IAsyncEnumerable = + { new System.Collections.Generic.IAsyncEnumerable with + member _.GetAsyncEnumerator(ct) = + receivedToken.Value <- ct + (AsyncSeq.ofSeq [1; 2; 3]).GetAsyncEnumerator(ct) } + let wrapped = source |> AsyncSeq.withCancellation cts1.Token + // Enumerate with cts2's token - withCancellation should still pass cts1's token + let e = wrapped.GetAsyncEnumerator(cts2.Token) + e.MoveNextAsync().AsTask() |> Async.AwaitTask |> Async.RunSynchronously |> ignore + e.DisposeAsync() |> ignore + Assert.AreEqual(cts1.Token, receivedToken.Value) + +[] +let ``AsyncSeq.withCancellation preserves sequence values`` () = + use cts = new System.Threading.CancellationTokenSource() + let result = + AsyncSeq.ofSeq [1; 2; 3; 4; 5] + |> AsyncSeq.withCancellation cts.Token + |> AsyncSeq.toArrayAsync + |> Async.RunSynchronously + Assert.AreEqual([| 1; 2; 3; 4; 5 |], result) + +[] +let ``AsyncSeq.withCancellation with cancelled token raises OperationCanceledException`` () = + use cts = new System.Threading.CancellationTokenSource() + cts.Cancel() + Assert.Catch(fun () -> + AsyncSeq.ofSeq [1; 2; 3] + |> AsyncSeq.withCancellation cts.Token + |> AsyncSeq.toArrayAsync + |> Async.RunSynchronously + |> ignore) + |> ignore diff --git a/version.props b/version.props index 026dfba..989fd4a 100644 --- a/version.props +++ b/version.props @@ -1,5 +1,5 @@ - 4.8.0 + 4.10.0