diff --git a/python/packages/azurefunctions/tests/test_app.py b/python/packages/azurefunctions/tests/test_app.py index cc68176c93..2729d73e7e 100644 --- a/python/packages/azurefunctions/tests/test_app.py +++ b/python/packages/azurefunctions/tests/test_app.py @@ -1164,5 +1164,158 @@ def decorator(func: FuncT) -> FuncT: assert body["agents"][0]["mcp_tool_enabled"] is True +class TestAgentFunctionAppErrorPaths: + """Test suite for error handling paths.""" + + def test_init_with_invalid_max_poll_retries(self) -> None: + """Test initialization handles invalid max_poll_retries by falling back to default.""" + mock_agent = Mock() + mock_agent.name = "TestAgent" + + # Test with invalid type + app = AgentFunctionApp(agents=[mock_agent], max_poll_retries="invalid") + assert app.max_poll_retries >= 1 # Should use default + + # Test with None + app2 = AgentFunctionApp(agents=[mock_agent], max_poll_retries=None) + assert app2.max_poll_retries >= 1 # Should use default + + def test_init_with_invalid_poll_interval_seconds(self) -> None: + """Test initialization handles invalid poll_interval_seconds by falling back to default.""" + mock_agent = Mock() + mock_agent.name = "TestAgent" + + # Test with invalid type + app = AgentFunctionApp(agents=[mock_agent], poll_interval_seconds="invalid") + assert app.poll_interval_seconds > 0 # Should use default + + # Test with None + app2 = AgentFunctionApp(agents=[mock_agent], poll_interval_seconds=None) + assert app2.poll_interval_seconds > 0 # Should use default + + def test_get_agent_raises_for_unregistered_agent(self) -> None: + """Test get_agent raises ValueError for unregistered agent.""" + mock_agent = Mock() + mock_agent.name = "RegisteredAgent" + + app = AgentFunctionApp(agents=[mock_agent], enable_http_endpoints=False) + + # Create mock orchestration context + mock_context = Mock() + + # Should raise ValueError for unregistered agent + with pytest.raises(ValueError, match="Agent 'UnknownAgent' is not registered"): + app.get_agent(mock_context, "UnknownAgent") + + def test_convert_payload_to_text_with_response_key(self) -> None: + """Test _convert_payload_to_text returns response key value.""" + app = AgentFunctionApp(enable_http_endpoints=False, enable_health_check=False) + + # Test with response key + payload = {"response": "Test response"} + result = app._convert_payload_to_text(payload) + assert result == "Test response" + + # Test with error key + payload = {"error": "Error message"} + result = app._convert_payload_to_text(payload) + assert result == "Error message" + + # Test with message key + payload = {"message": "Message text"} + result = app._convert_payload_to_text(payload) + assert result == "Message text" + + # Test with no matching keys - should return JSON string + payload = {"other": "value"} + result = app._convert_payload_to_text(payload) + assert "other" in result + assert "value" in result + + def test_create_session_id_with_thread_id(self) -> None: + """Test _create_session_id with provided thread_id.""" + app = AgentFunctionApp(enable_http_endpoints=False, enable_health_check=False) + + # With thread_id provided + session_id = app._create_session_id("TestAgent", "my-thread-123") + assert session_id.key == "my-thread-123" + + # Without thread_id (None) - should generate random + session_id = app._create_session_id("TestAgent", None) + assert session_id.key is not None + assert len(session_id.key) > 0 + + def test_resolve_thread_id_from_body(self) -> None: + """Test _resolve_thread_id extracts from body.""" + app = AgentFunctionApp(enable_http_endpoints=False, enable_health_check=False) + + mock_req = Mock() + mock_req.params = {} + + # Thread ID in body - field name is "thread_id" + req_body = {"thread_id": "body-thread-123"} + result = app._resolve_thread_id(mock_req, req_body) + assert result == "body-thread-123" + + def test_select_body_parser_json_content_type(self) -> None: + """Test _select_body_parser for JSON content type.""" + app = AgentFunctionApp(enable_http_endpoints=False, enable_health_check=False) + + # Test with application/json + parser, format_str = app._select_body_parser("application/json") + assert parser == app._parse_json_body + assert format_str == "json" + + # Test with +json suffix + parser, format_str = app._select_body_parser("application/vnd.api+json") + assert parser == app._parse_json_body + assert format_str == "json" + + def test_accepts_json_response_with_accept_header(self) -> None: + """Test _accepts_json_response checks accept header.""" + app = AgentFunctionApp(enable_http_endpoints=False, enable_health_check=False) + + # With application/json in accept header + headers = {"accept": "application/json"} + result = app._accepts_json_response(headers) + assert result is True + + # Without accept header + headers = {} + result = app._accepts_json_response(headers) + assert result is False + + def test_parse_json_body_invalid_type(self) -> None: + """Test _parse_json_body raises error for invalid JSON.""" + from agent_framework_azurefunctions._errors import IncomingRequestError + + app = AgentFunctionApp(enable_http_endpoints=False, enable_health_check=False) + + # Mock request with non-dict JSON + mock_req = Mock() + mock_req.get_json.return_value = ["not", "a", "dict"] + + with pytest.raises(IncomingRequestError, match="Invalid JSON payload"): + app._parse_json_body(mock_req) + + def test_coerce_to_bool_with_none(self) -> None: + """Test _coerce_to_bool handles None and various value types.""" + app = AgentFunctionApp(enable_http_endpoints=False, enable_health_check=False) + + # None returns False + assert app._coerce_to_bool(None) is False + + # Integer + assert app._coerce_to_bool(1) is True + assert app._coerce_to_bool(0) is False + + # String + assert app._coerce_to_bool("true") is True + assert app._coerce_to_bool("false") is False + + # Other type returns False + assert app._coerce_to_bool([]) is False + + if __name__ == "__main__": pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/azurefunctions/tests/test_entities.py b/python/packages/azurefunctions/tests/test_entities.py index 2fdbc3463e..a9db60c9dc 100644 --- a/python/packages/azurefunctions/tests/test_entities.py +++ b/python/packages/azurefunctions/tests/test_entities.py @@ -198,6 +198,114 @@ def test_entity_function_restores_existing_state(self) -> None: persisted_state = mock_context.set_state.call_args[0][0] assert persisted_state["data"]["conversationHistory"] == [] + def test_entity_function_handles_string_input(self) -> None: + """Test that the entity function handles non-dict input by converting to string.""" + mock_agent = Mock() + mock_agent.run = AsyncMock(return_value=_agent_response("String response")) + + entity_function = create_agent_entity(mock_agent) + + # Mock context with non-dict input (like a number) + mock_context = Mock() + mock_context.operation_name = "run" + mock_context.entity_key = "conv-456" + # Use a number to test the str() conversion path + mock_context.get_input.return_value = 12345 + mock_context.get_state.return_value = None + + # Execute - entity will convert non-dict input to string + entity_function(mock_context) + + # Verify the result was set + assert mock_context.set_result.called + + def test_entity_function_handles_none_input(self) -> None: + """Test that the entity function handles None input by converting to empty string.""" + mock_agent = Mock() + mock_agent.run = AsyncMock(return_value=_agent_response("Empty response")) + + entity_function = create_agent_entity(mock_agent) + + # Mock context with None input + mock_context = Mock() + mock_context.operation_name = "run" + mock_context.entity_key = "conv-789" + mock_context.get_input.return_value = None + mock_context.get_state.return_value = None + + # Execute - should hit error path since entity expects dict or valid JSON string + entity_function(mock_context) + + # Verify the result was set (likely error result) + assert mock_context.set_result.called + + def test_entity_function_handles_event_loop_runtime_error(self) -> None: + """Test that the entity function handles RuntimeError from get_event_loop by creating a new loop.""" + from unittest.mock import patch + + mock_agent = Mock() + mock_agent.run = AsyncMock(return_value=_agent_response("Response")) + + entity_function = create_agent_entity(mock_agent) + + mock_context = Mock() + mock_context.operation_name = "run" + mock_context.entity_key = "conv-loop-test" + mock_context.get_input.return_value = {"message": "Test"} + mock_context.get_state.return_value = None + + # Simulate RuntimeError when getting event loop + with ( + patch("asyncio.get_event_loop", side_effect=RuntimeError("No event loop")), + patch("asyncio.new_event_loop") as mock_new_loop, + patch("asyncio.set_event_loop") as mock_set_loop, + ): + mock_loop = Mock() + mock_loop.is_running.return_value = False + mock_loop.run_until_complete = Mock() + mock_new_loop.return_value = mock_loop + + # Execute + entity_function(mock_context) + + # Verify new event loop was created + mock_new_loop.assert_called_once() + mock_set_loop.assert_called_once_with(mock_loop) + + def test_entity_function_handles_running_event_loop(self) -> None: + """Test that the entity function handles a running event loop by creating a temporary loop.""" + from unittest.mock import patch + + mock_agent = Mock() + mock_agent.run = AsyncMock(return_value=_agent_response("Response")) + + entity_function = create_agent_entity(mock_agent) + + mock_context = Mock() + mock_context.operation_name = "run" + mock_context.entity_key = "conv-running-loop" + mock_context.get_input.return_value = {"message": "Test"} + mock_context.get_state.return_value = None + + # Simulate a running event loop + mock_existing_loop = Mock() + mock_existing_loop.is_running.return_value = True + + mock_temp_loop = Mock() + mock_temp_loop.run_until_complete = Mock() + mock_temp_loop.close = Mock() + + with ( + patch("asyncio.get_event_loop", return_value=mock_existing_loop), + patch("asyncio.new_event_loop", return_value=mock_temp_loop), + ): + # Execute + entity_function(mock_context) + + # Verify temporary loop was created and closed + mock_temp_loop.run_until_complete.assert_called_once() + mock_temp_loop.close.assert_called_once() + if __name__ == "__main__": pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/azurefunctions/tests/test_errors.py b/python/packages/azurefunctions/tests/test_errors.py new file mode 100644 index 0000000000..09bf8797c6 --- /dev/null +++ b/python/packages/azurefunctions/tests/test_errors.py @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for custom exception types.""" + +import pytest + +from agent_framework_azurefunctions._errors import IncomingRequestError + + +class TestIncomingRequestError: + """Test suite for IncomingRequestError exception.""" + + def test_incoming_request_error_default_status_code(self) -> None: + """Test that IncomingRequestError has a default status code of 400.""" + error = IncomingRequestError("Invalid request") + + assert str(error) == "Invalid request" + assert error.status_code == 400 + + def test_incoming_request_error_custom_status_code(self) -> None: + """Test that IncomingRequestError can have a custom status code.""" + error = IncomingRequestError("Unauthorized", status_code=401) + + assert str(error) == "Unauthorized" + assert error.status_code == 401 + + def test_incoming_request_error_is_value_error(self) -> None: + """Test that IncomingRequestError inherits from ValueError.""" + error = IncomingRequestError("Test error") + + assert isinstance(error, ValueError) + + def test_incoming_request_error_can_be_raised_and_caught(self) -> None: + """Test that IncomingRequestError can be raised and caught.""" + with pytest.raises(IncomingRequestError) as exc_info: + raise IncomingRequestError("Bad request", status_code=400) + + assert exc_info.value.status_code == 400 diff --git a/python/packages/azurefunctions/tests/test_orchestration.py b/python/packages/azurefunctions/tests/test_orchestration.py index d1be7d9a77..45e57f0dcf 100644 --- a/python/packages/azurefunctions/tests/test_orchestration.py +++ b/python/packages/azurefunctions/tests/test_orchestration.py @@ -129,6 +129,25 @@ def executor_with_context(mock_context_with_uuid: tuple[Mock, str]) -> tuple[Any class TestAgentResponseHelpers: """Tests for response handling through public AgentTask API.""" + def test_try_set_value_exception_handling(self) -> None: + """Test try_set_value handles exceptions raised when converting a successful task result to AgentResponse.""" + entity_task = _create_entity_task() + task = AgentTask(entity_task, None, "correlation-id") + + # Simulate successful entity task with invalid result that causes exception + entity_task.state = TaskState.SUCCEEDED + entity_task.result = {"invalid": "format"} # Missing required fields for AgentResponse + + # Clear pending_tasks to simulate that parent has processed the child + task.pending_tasks.clear() + + # Call try_set_value - should catch exception and set error + task.try_set_value(entity_task) + + # Verify task failed due to conversion exception + assert task.state == TaskState.FAILED + assert isinstance(task.result, Exception) + def test_try_set_value_success(self) -> None: """Test try_set_value correctly processes successful task completion.""" entity_task = _create_entity_task() @@ -279,6 +298,27 @@ def test_blocking_mode_still_works(self, executor_with_uuid: tuple[Any, Mock, st assert isinstance(result, AgentTask) +class TestAzureFunctionsAgentExecutor: + """Tests for AzureFunctionsAgentExecutor.""" + + def test_generate_unique_id(self, mock_context_with_uuid: tuple[Mock, str]) -> None: + """Test generate_unique_id method returns UUID from orchestration context.""" + from agent_framework_azurefunctions._orchestration import AzureFunctionsAgentExecutor + + context, _ = mock_context_with_uuid + executor = AzureFunctionsAgentExecutor(context) + + # Call generate_unique_id + unique_id = executor.generate_unique_id() + + # Verify it returns the UUID from context (as string with dashes) + # The UUID is returned in standard format with dashes + context.new_uuid.assert_called_once() + # Just verify it's a string representation of UUID + assert isinstance(unique_id, str) + assert len(unique_id) > 0 + + class TestOrchestrationIntegration: """Integration tests for orchestration scenarios."""