diff --git a/src/Grpc/orchestrator_service.proto b/src/Grpc/orchestrator_service.proto index 0c34d986..cbcd648d 100644 --- a/src/Grpc/orchestrator_service.proto +++ b/src/Grpc/orchestrator_service.proto @@ -182,6 +182,7 @@ message EntityOperationSignaledEvent { google.protobuf.Timestamp scheduledTime = 3; google.protobuf.StringValue input = 4; google.protobuf.StringValue targetInstanceId = 5; // used only within histories, null in messages + TraceContext parentTraceContext = 6; } message EntityOperationCalledEvent { @@ -192,6 +193,7 @@ message EntityOperationCalledEvent { google.protobuf.StringValue parentInstanceId = 5; // used only within messages, null in histories google.protobuf.StringValue parentExecutionId = 6; // used only within messages, null in histories google.protobuf.StringValue targetInstanceId = 7; // used only within histories, null in messages + TraceContext parentTraceContext = 8; } message EntityLockRequestedEvent { @@ -318,6 +320,8 @@ message SendEntityMessageAction { EntityLockRequestedEvent entityLockRequested = 3; EntityUnlockSentEvent entityUnlockSent = 4; } + + TraceContext parentTraceContext = 5; } message OrchestratorAction { diff --git a/src/Shared/Grpc/ProtoUtils.cs b/src/Shared/Grpc/ProtoUtils.cs index 2412fac3..6bac27c5 100644 --- a/src/Shared/Grpc/ProtoUtils.cs +++ b/src/Shared/Grpc/ProtoUtils.cs @@ -415,6 +415,7 @@ internal static P.OrchestratorResponse ConstructOrchestratorResponse( out string requestId); entityConversionState.EntityRequestIds.Add(requestId); + sendAction.ParentTraceContext = CreateTraceContext(); switch (sendAction.EntityMessageTypeCase) { @@ -636,6 +637,9 @@ internal static void ToEntityBatchRequest( Id = Guid.Parse(op.EntityOperationSignaled.RequestId), Operation = op.EntityOperationSignaled.Operation, Input = op.EntityOperationSignaled.Input, + TraceContext = op.EntityOperationSignaled.ParentTraceContext is { } signalTc + ? new DistributedTraceContext(signalTc.TraceParent, signalTc.TraceState) + : null, }); operationInfos.Add(new P.OperationInfo { @@ -650,6 +654,9 @@ internal static void ToEntityBatchRequest( Id = Guid.Parse(op.EntityOperationCalled.RequestId), Operation = op.EntityOperationCalled.Operation, Input = op.EntityOperationCalled.Input, + TraceContext = op.EntityOperationCalled.ParentTraceContext is { } calledTc + ? new DistributedTraceContext(calledTc.TraceParent, calledTc.TraceState) + : null, }); operationInfos.Add(new P.OperationInfo { diff --git a/test/Worker/Grpc.Tests/ProtoUtilsTraceContextTests.cs b/test/Worker/Grpc.Tests/ProtoUtilsTraceContextTests.cs new file mode 100644 index 00000000..d130c54a --- /dev/null +++ b/test/Worker/Grpc.Tests/ProtoUtilsTraceContextTests.cs @@ -0,0 +1,378 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics; +using DurableTask.Core; +using DurableTask.Core.Command; +using DurableTask.Core.Entities.OperationFormat; +using Newtonsoft.Json; +using P = Microsoft.DurableTask.Protobuf; + +namespace Microsoft.DurableTask.Worker.Grpc.Tests; + +public class ProtoUtilsTraceContextTests +{ + static readonly ActivitySource TestSource = new(nameof(ProtoUtilsTraceContextTests)); + + [Fact] + public void SendEntityMessage_SignalEntity_SetsParentTraceContext() + { + // Arrange + using ActivityListener listener = CreateListener(); + using Activity? orchestrationActivity = TestSource.StartActivity("TestOrchestration"); + orchestrationActivity.Should().NotBeNull(); + + string requestId = Guid.NewGuid().ToString(); + string entityInstanceId = "@counter@myKey"; + string eventData = JsonConvert.SerializeObject(new + { + op = "increment", + signal = true, + id = requestId, + }); + + SendEventOrchestratorAction sendEventAction = new() + { + Id = 1, + Instance = new OrchestrationInstance { InstanceId = entityInstanceId }, + EventName = "op", + EventData = eventData, + }; + + ProtoUtils.EntityConversionState entityConversionState = new(insertMissingEntityUnlocks: false); + + // Act + P.OrchestratorResponse response = ProtoUtils.ConstructOrchestratorResponse( + instanceId: "test-orchestration", + executionId: "exec-1", + customStatus: null, + actions: [sendEventAction], + completionToken: "token", + entityConversionState: entityConversionState, + orchestrationActivity: orchestrationActivity); + + // Assert + response.Actions.Should().ContainSingle(); + P.OrchestratorAction action = response.Actions[0]; + action.SendEntityMessage.Should().NotBeNull(); + action.SendEntityMessage.EntityOperationSignaled.Should().NotBeNull(); + action.SendEntityMessage.ParentTraceContext.Should().NotBeNull(); + action.SendEntityMessage.ParentTraceContext.TraceParent.Should().NotBeNullOrEmpty(); + action.SendEntityMessage.ParentTraceContext.TraceParent.Should().Contain( + orchestrationActivity!.TraceId.ToString()); + } + + [Fact] + public void SendEntityMessage_CallEntity_SetsParentTraceContext() + { + // Arrange + using ActivityListener listener = CreateListener(); + using Activity? orchestrationActivity = TestSource.StartActivity("TestOrchestration"); + orchestrationActivity.Should().NotBeNull(); + + string requestId = Guid.NewGuid().ToString(); + string entityInstanceId = "@counter@myKey"; + string eventData = JsonConvert.SerializeObject(new + { + op = "get", + signal = false, + id = requestId, + parent = "parent-instance", + }); + + SendEventOrchestratorAction sendEventAction = new() + { + Id = 1, + Instance = new OrchestrationInstance { InstanceId = entityInstanceId }, + EventName = "op", + EventData = eventData, + }; + + ProtoUtils.EntityConversionState entityConversionState = new(insertMissingEntityUnlocks: false); + + // Act + P.OrchestratorResponse response = ProtoUtils.ConstructOrchestratorResponse( + instanceId: "test-orchestration", + executionId: "exec-1", + customStatus: null, + actions: [sendEventAction], + completionToken: "token", + entityConversionState: entityConversionState, + orchestrationActivity: orchestrationActivity); + + // Assert + response.Actions.Should().ContainSingle(); + P.OrchestratorAction action = response.Actions[0]; + action.SendEntityMessage.Should().NotBeNull(); + action.SendEntityMessage.EntityOperationCalled.Should().NotBeNull(); + action.SendEntityMessage.ParentTraceContext.Should().NotBeNull(); + action.SendEntityMessage.ParentTraceContext.TraceParent.Should().NotBeNullOrEmpty(); + action.SendEntityMessage.ParentTraceContext.TraceParent.Should().Contain( + orchestrationActivity!.TraceId.ToString()); + } + + [Fact] + public void SendEntityMessage_NoOrchestrationActivity_DoesNotSetParentTraceContext() + { + // Arrange + string requestId = Guid.NewGuid().ToString(); + string entityInstanceId = "@counter@myKey"; + string eventData = JsonConvert.SerializeObject(new + { + op = "increment", + signal = true, + id = requestId, + }); + + SendEventOrchestratorAction sendEventAction = new() + { + Id = 1, + Instance = new OrchestrationInstance { InstanceId = entityInstanceId }, + EventName = "op", + EventData = eventData, + }; + + ProtoUtils.EntityConversionState entityConversionState = new(insertMissingEntityUnlocks: false); + + // Act + P.OrchestratorResponse response = ProtoUtils.ConstructOrchestratorResponse( + instanceId: "test-orchestration", + executionId: "exec-1", + customStatus: null, + actions: [sendEventAction], + completionToken: "token", + entityConversionState: entityConversionState, + orchestrationActivity: null); + + // Assert + response.Actions.Should().ContainSingle(); + P.OrchestratorAction action = response.Actions[0]; + action.SendEntityMessage.Should().NotBeNull(); + action.SendEntityMessage.ParentTraceContext.Should().BeNull(); + } + + [Fact] + public void SendEntityMessage_NoEntityConversionState_SendsAsSendEvent() + { + // Arrange + using ActivityListener listener = CreateListener(); + using Activity? orchestrationActivity = TestSource.StartActivity("TestOrchestration"); + + string requestId = Guid.NewGuid().ToString(); + string entityInstanceId = "@counter@myKey"; + string eventData = JsonConvert.SerializeObject(new + { + op = "increment", + signal = true, + id = requestId, + }); + + SendEventOrchestratorAction sendEventAction = new() + { + Id = 1, + Instance = new OrchestrationInstance { InstanceId = entityInstanceId }, + EventName = "op", + EventData = eventData, + }; + + // Act - no entityConversionState means entity events are NOT converted + P.OrchestratorResponse response = ProtoUtils.ConstructOrchestratorResponse( + instanceId: "test-orchestration", + executionId: "exec-1", + customStatus: null, + actions: [sendEventAction], + completionToken: "token", + entityConversionState: null, + orchestrationActivity: orchestrationActivity); + + // Assert - should be a SendEvent, not SendEntityMessage + response.Actions.Should().ContainSingle(); + P.OrchestratorAction action = response.Actions[0]; + action.SendEvent.Should().NotBeNull(); + action.SendEntityMessage.Should().BeNull(); + } + + [Fact] + public void SendEntityMessage_TraceContextHasUniqueSpanId() + { + // Arrange + using ActivityListener listener = CreateListener(); + using Activity? orchestrationActivity = TestSource.StartActivity("TestOrchestration"); + orchestrationActivity.Should().NotBeNull(); + + string entityInstanceId = "@counter@myKey"; + string eventData1 = JsonConvert.SerializeObject(new + { + op = "increment", + signal = true, + id = Guid.NewGuid().ToString(), + }); + + string eventData2 = JsonConvert.SerializeObject(new + { + op = "increment", + signal = true, + id = Guid.NewGuid().ToString(), + }); + + SendEventOrchestratorAction action1 = new() + { + Id = 1, + Instance = new OrchestrationInstance { InstanceId = entityInstanceId }, + EventName = "op", + EventData = eventData1, + }; + + SendEventOrchestratorAction action2 = new() + { + Id = 2, + Instance = new OrchestrationInstance { InstanceId = entityInstanceId }, + EventName = "op", + EventData = eventData2, + }; + + ProtoUtils.EntityConversionState entityConversionState = new(insertMissingEntityUnlocks: false); + + // Act + P.OrchestratorResponse response = ProtoUtils.ConstructOrchestratorResponse( + instanceId: "test-orchestration", + executionId: "exec-1", + customStatus: null, + actions: [action1, action2], + completionToken: "token", + entityConversionState: entityConversionState, + orchestrationActivity: orchestrationActivity); + + // Assert - each entity message should get a unique span ID + response.Actions.Should().HaveCount(2); + string traceParent1 = response.Actions[0].SendEntityMessage.ParentTraceContext.TraceParent; + string traceParent2 = response.Actions[1].SendEntityMessage.ParentTraceContext.TraceParent; + traceParent1.Should().NotBeNullOrEmpty(); + traceParent2.Should().NotBeNullOrEmpty(); + + // Same trace ID (from orchestration activity) + traceParent1.Should().Contain(orchestrationActivity!.TraceId.ToString()); + traceParent2.Should().Contain(orchestrationActivity.TraceId.ToString()); + + // Different span IDs + traceParent1.Should().NotBe(traceParent2); + } + + static ActivityListener CreateListener() + { + ActivityListener listener = new() + { + ShouldListenTo = _ => true, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllDataAndRecorded, + }; + + ActivitySource.AddActivityListener(listener); + return listener; + } + + [Fact] + public void ToEntityBatchRequest_SignalEntity_ExtractsTraceContext() + { + // Arrange + string traceParent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01"; + string traceState = "vendor=value"; + + P.EntityRequest entityRequest = new() + { + InstanceId = "@counter@myKey", + OperationRequests = + { + new P.HistoryEvent + { + EntityOperationSignaled = new P.EntityOperationSignaledEvent + { + RequestId = Guid.NewGuid().ToString(), + Operation = "increment", + ParentTraceContext = new P.TraceContext + { + TraceParent = traceParent, + TraceState = traceState, + }, + }, + }, + }, + }; + + // Act + entityRequest.ToEntityBatchRequest(out EntityBatchRequest batchRequest, out _); + + // Assert + batchRequest.Operations.Should().ContainSingle(); + batchRequest.Operations[0].TraceContext.Should().NotBeNull(); + batchRequest.Operations[0].TraceContext!.TraceParent.Should().Be(traceParent); + batchRequest.Operations[0].TraceContext!.TraceState.Should().Be(traceState); + } + + [Fact] + public void ToEntityBatchRequest_CallEntity_ExtractsTraceContext() + { + // Arrange + string traceParent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01"; + string traceState = "vendor=value"; + + P.EntityRequest entityRequest = new() + { + InstanceId = "@counter@myKey", + OperationRequests = + { + new P.HistoryEvent + { + EntityOperationCalled = new P.EntityOperationCalledEvent + { + RequestId = Guid.NewGuid().ToString(), + Operation = "get", + ParentInstanceId = "parent-instance", + ParentExecutionId = "parent-exec", + ParentTraceContext = new P.TraceContext + { + TraceParent = traceParent, + TraceState = traceState, + }, + }, + }, + }, + }; + + // Act + entityRequest.ToEntityBatchRequest(out EntityBatchRequest batchRequest, out _); + + // Assert + batchRequest.Operations.Should().ContainSingle(); + batchRequest.Operations[0].TraceContext.Should().NotBeNull(); + batchRequest.Operations[0].TraceContext!.TraceParent.Should().Be(traceParent); + batchRequest.Operations[0].TraceContext!.TraceState.Should().Be(traceState); + } + + [Fact] + public void ToEntityBatchRequest_NoTraceContext_LeavesTraceContextNull() + { + // Arrange + P.EntityRequest entityRequest = new() + { + InstanceId = "@counter@myKey", + OperationRequests = + { + new P.HistoryEvent + { + EntityOperationSignaled = new P.EntityOperationSignaledEvent + { + RequestId = Guid.NewGuid().ToString(), + Operation = "increment", + }, + }, + }, + }; + + // Act + entityRequest.ToEntityBatchRequest(out EntityBatchRequest batchRequest, out _); + + // Assert + batchRequest.Operations.Should().ContainSingle(); + batchRequest.Operations[0].TraceContext.Should().BeNull(); + } +}