diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 96c105493..45a8512eb 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -2087,6 +2087,7 @@ def start_activity( activity_id=activity_id, versioning_intent=versioning_intent, summary=summary, + priority=priority, ) diff --git a/tests/worker/test_interceptor.py b/tests/worker/test_interceptor.py index 1392cd350..9dddcefaf 100644 --- a/tests/worker/test_interceptor.py +++ b/tests/worker/test_interceptor.py @@ -7,6 +7,7 @@ from temporalio import activity, workflow from temporalio.client import Client, WorkflowUpdateFailedError +from temporalio.common import Priority from temporalio.exceptions import ApplicationError from temporalio.testing import WorkflowEnvironment from temporalio.worker import ( @@ -324,3 +325,66 @@ async def test_workflow_instance_access_from_interceptor(client: Client): task_queue=task_queue, ) assert difference == 0 + + +@activity.defn +async def priority_activity() -> str: + return "done" + + +captured_start_activity_inputs: List[StartActivityInput] = [] + + +class PriorityCapturingInterceptor(Interceptor): + def workflow_interceptor_class( + self, input: WorkflowInterceptorClassInput + ) -> Optional[Type[WorkflowInboundInterceptor]]: + return PriorityCapturingInboundInterceptor + + +class PriorityCapturingInboundInterceptor(WorkflowInboundInterceptor): + def init(self, outbound: WorkflowOutboundInterceptor) -> None: + super().init(PriorityCapturingOutboundInterceptor(outbound)) + + async def execute_workflow(self, input: ExecuteWorkflowInput) -> Any: + return await super().execute_workflow(input) + + +class PriorityCapturingOutboundInterceptor(WorkflowOutboundInterceptor): + def start_activity(self, input: StartActivityInput) -> workflow.ActivityHandle: + captured_start_activity_inputs.append(input) + return super().start_activity(input) + + +@workflow.defn +class StartActivityPriorityWorkflow: + @workflow.run + async def run(self) -> str: + # Use start_activity (not execute_activity) to test that path + handle = workflow.start_activity( + priority_activity, + start_to_close_timeout=timedelta(seconds=5), + priority=Priority(priority_key=3), + ) + return await handle + + +async def test_start_activity_forwards_priority(client: Client): + captured_start_activity_inputs.clear() + task_queue = f"task_queue_{uuid.uuid4()}" + async with Worker( + client, + task_queue=task_queue, + workflows=[StartActivityPriorityWorkflow], + activities=[priority_activity], + interceptors=[PriorityCapturingInterceptor()], + ): + result = await client.execute_workflow( + StartActivityPriorityWorkflow.run, + id=f"workflow_{uuid.uuid4()}", + task_queue=task_queue, + ) + assert result == "done" + + assert len(captured_start_activity_inputs) == 1 + assert captured_start_activity_inputs[0].priority == Priority(priority_key=3)