diff --git a/azure-quantum/tests/unit/local/mock_client.py b/azure-quantum/tests/unit/local/mock_client.py index ba6ad1fb..50b55fec 100644 --- a/azure-quantum/tests/unit/local/mock_client.py +++ b/azure-quantum/tests/unit/local/mock_client.py @@ -16,6 +16,14 @@ from types import SimpleNamespace from azure.quantum._client import ServicesClient from azure.quantum._client.models import JobDetails, SessionDetails, ItemDetails +from azure.quantum._workspace_connection_params import WorkspaceConnectionParams +from common import ( + SUBSCRIPTION_ID, + RESOURCE_GROUP, + LOCATION, + ENDPOINT_URI, + WORKSPACE, +) def _paged(items: List, page_size: int = 100) -> ItemPaged: @@ -349,6 +357,35 @@ def list( return _paged(items[skip : skip + top], page_size=top) +class MockWorkspaceMgmtClient: + """Mock management client that avoids network calls to ARM/ARG.""" + + def __init__(self, credential: Optional[object] = None, base_url: Optional[str] = None, user_agent: Optional[str] = None) -> None: + self._credential = credential + self._base_url = base_url + self._user_agent = user_agent + + def close(self) -> None: + """No-op close for mock.""" + pass + + def __enter__(self) -> 'MockWorkspaceMgmtClient': + return self + + def __exit__(self, *exc_details) -> None: + pass + + def load_workspace_from_arg(self, connection_params: WorkspaceConnectionParams) -> None: + connection_params.subscription_id = SUBSCRIPTION_ID + connection_params.resource_group = RESOURCE_GROUP + connection_params.location = LOCATION + connection_params.quantum_endpoint = ENDPOINT_URI + + def load_workspace_from_arm(self, connection_params: WorkspaceConnectionParams) -> None: + connection_params.location = LOCATION + connection_params.quantum_endpoint = ENDPOINT_URI + + class MockServicesClient(ServicesClient): def __init__(self, authentication_policy: Optional[object] = None) -> None: # in-memory stores @@ -363,8 +400,20 @@ def __init__(self, authentication_policy: Optional[object] = None) -> None: # Mimic ServicesClient config shape for tests that inspect policy self._config = SimpleNamespace(authentication_policy=authentication_policy) + def __enter__(self) -> 'MockServicesClient': + return self + + def __exit__(self, *exc_details) -> None: + pass + class WorkspaceMock(Workspace): + def __init__(self, **kwargs) -> None: + # Create and pass mock management client to prevent network calls + if '_mgmt_client' not in kwargs: + kwargs['_mgmt_client'] = MockWorkspaceMgmtClient() + super().__init__(**kwargs) + def _create_client(self) -> ServicesClient: # type: ignore[override] # Pass through the Workspace's auth policy to the mock client auth_policy = self._connection_params.get_auth_policy() @@ -466,7 +515,9 @@ def seed_sessions(ws: WorkspaceMock) -> None: def create_default_workspace() -> WorkspaceMock: ws = WorkspaceMock( - subscription_id="sub", resource_group="rg", name="ws", location="westus" + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + name=WORKSPACE ) seed_jobs(ws) seed_sessions(ws) diff --git a/azure-quantum/tests/unit/local/test_job_results.py b/azure-quantum/tests/unit/local/test_job_results.py index 7747a333..ccbde0e9 100644 --- a/azure-quantum/tests/unit/local/test_job_results.py +++ b/azure-quantum/tests/unit/local/test_job_results.py @@ -3,12 +3,12 @@ # Licensed under the MIT License. ## +import pytest from unittest.mock import Mock - from azure.quantum import Job, JobDetails -def _mock_job(output_data_format: str, results_as_json_str: str) -> Job: +def _mock_job(output_data_format: str, results_as_json_str: str, status: str = "Succeeded") -> Job: job_details = JobDetails( id="", name="", @@ -18,7 +18,7 @@ def _mock_job(output_data_format: str, results_as_json_str: str) -> Job: input_data_format="", output_data_format=output_data_format, ) - job_details.status = "Succeeded" + job_details.status = status job = Job(workspace=None, job_details=job_details) job.has_completed = Mock(return_value=True) @@ -37,8 +37,8 @@ def decode(): return job -def _get_job_results(output_data_format: str, results_as_json_str: str): - job = _mock_job(output_data_format, results_as_json_str) +def _get_job_results(output_data_format: str, results_as_json_str: str, status: str = "Succeeded"): + job = _mock_job(output_data_format, results_as_json_str, status) return job.get_results() @@ -70,6 +70,35 @@ def test_job_for_microsoft_quantum_results_v1_success(): assert job_results["[1]"] == 0.50 +def test_job_get_results_with_completed_status(): + job_results = _get_job_results( + "microsoft.quantum-results.v1", + '{"Histogram": ["[0]", 0.50, "[1]", 0.50]}', + "Completed", + ) + assert len(job_results.keys()) == 2 + assert job_results["[0]"] == 0.50 + assert job_results["[1]"] == 0.50 + + +def test_job_get_results_with_failed_status_raises_runtime_error(): + with pytest.raises(RuntimeError, match="Cannot retrieve results as job execution failed"): + _get_job_results( + "microsoft.quantum-results.v1", + '{"Histogram": ["[0]", 0.50, "[1]", 0.50]}', + "Failed", + ) + + +def test_job_get_results_with_cancelled_status_raises_runtime_error(): + with pytest.raises(RuntimeError, match="Cannot retrieve results as job execution failed"): + _get_job_results( + "microsoft.quantum-results.v1", + '{"Histogram": ["[0]", 0.50, "[1]", 0.50]}', + "Cancelled", + ) + + def test_job_for_microsoft_quantum_results_v1_no_histogram_returns_raw_result(): job_result_raw = '{"NotHistogramProperty": ["[0]", 0.50, "[1]", 0.50]}' job_result = _get_job_results("microsoft.quantum-results.v1", job_result_raw) diff --git a/azure-quantum/tests/unit/local/test_mgmt_client.py b/azure-quantum/tests/unit/local/test_mgmt_client.py new file mode 100644 index 00000000..cad4e4c8 --- /dev/null +++ b/azure-quantum/tests/unit/local/test_mgmt_client.py @@ -0,0 +1,659 @@ +## +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +## + +import pytest +from unittest.mock import MagicMock, patch +from http import HTTPStatus +from azure.core.exceptions import HttpResponseError +from azure.quantum._mgmt_client import WorkspaceMgmtClient +from azure.quantum._workspace_connection_params import WorkspaceConnectionParams +from azure.quantum._constants import ConnectionConstants +from common import ( + SUBSCRIPTION_ID, + RESOURCE_GROUP, + WORKSPACE, + LOCATION, + ENDPOINT_URI, +) + + +def test_init_creates_client(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + + client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + assert client._credential == mock_credential + assert client._base_url == base_url + assert client._client is not None + assert len(client._policies) == 5 + + +def test_init_without_user_agent(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + + client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url + ) + + assert client._credential == mock_credential + assert client._base_url == base_url + assert client._client is not None + + +def test_context_manager_enter(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + with patch.object(mgmt_client._client, '__enter__', return_value=mgmt_client._client): + result = mgmt_client.__enter__() + assert result == mgmt_client + + +def test_context_manager_exit(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + with patch.object(mgmt_client._client, '__exit__') as mock_exit: + mgmt_client.__exit__(None, None, None) + mock_exit.assert_called_once() + + +def test_close(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + with patch.object(mgmt_client._client, 'close') as mock_close: + mgmt_client.close() + mock_close.assert_called_once() + + +def test_load_workspace_from_arg_success(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + 'data': [{ + 'name': WORKSPACE, + 'subscriptionId': SUBSCRIPTION_ID, + 'resourceGroup': RESOURCE_GROUP, + 'location': LOCATION, + 'endpointUri': ENDPOINT_URI + }] + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response): + connection_params.subscription_id = None + connection_params.location = None + connection_params.quantum_endpoint = None + + mgmt_client.load_workspace_from_arg(connection_params) + + assert connection_params.subscription_id == SUBSCRIPTION_ID + assert connection_params.resource_group == RESOURCE_GROUP + assert connection_params.workspace_name == WORKSPACE + assert connection_params.location == LOCATION + assert connection_params.quantum_endpoint == ENDPOINT_URI + + +def test_load_workspace_from_arg_with_resource_group_filter(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + workspace_name=WORKSPACE, + resource_group=RESOURCE_GROUP + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + 'data': [{ + 'name': WORKSPACE, + 'subscriptionId': SUBSCRIPTION_ID, + 'resourceGroup': RESOURCE_GROUP, + 'location': LOCATION, + 'endpointUri': ENDPOINT_URI + }] + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response) as mock_send: + mgmt_client.load_workspace_from_arg(connection_params) + + call_args = mock_send.call_args + request = call_args[0][0] + assert RESOURCE_GROUP in str(request.content) + + +def test_load_workspace_from_arg_with_location_filter(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + workspace_name=WORKSPACE, + location=LOCATION + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + 'data': [{ + 'name': WORKSPACE, + 'subscriptionId': SUBSCRIPTION_ID, + 'resourceGroup': RESOURCE_GROUP, + 'location': LOCATION, + 'endpointUri': ENDPOINT_URI + }] + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response) as mock_send: + mgmt_client.load_workspace_from_arg(connection_params) + + call_args = mock_send.call_args + request = call_args[0][0] + assert LOCATION in str(request.content) + + +def test_load_workspace_from_arg_with_subscription_filter(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + workspace_name=WORKSPACE, + subscription_id=SUBSCRIPTION_ID + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + 'data': [{ + 'name': WORKSPACE, + 'subscriptionId': SUBSCRIPTION_ID, + 'resourceGroup': RESOURCE_GROUP, + 'location': LOCATION, + 'endpointUri': ENDPOINT_URI + }] + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response) as mock_send: + mgmt_client.load_workspace_from_arg(connection_params) + + call_args = mock_send.call_args + request = call_args[0][0] + request_body = request.content + assert 'subscriptions' in request_body + + +def test_load_workspace_from_arg_no_workspace_name(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams() + + with pytest.raises(ValueError, match="Workspace name must be specified"): + mgmt_client.load_workspace_from_arg(connection_params) + + +def test_load_workspace_from_arg_no_matching_workspace(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + ) + + mock_response = MagicMock() + mock_response.json.return_value = {'data': []} + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response): + with pytest.raises(ValueError, match="No matching workspace found"): + mgmt_client.load_workspace_from_arg(connection_params) + + +def test_load_workspace_from_arg_multiple_workspaces(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + 'data': [ + { + 'name': WORKSPACE, + 'subscriptionId': SUBSCRIPTION_ID, + 'resourceGroup': RESOURCE_GROUP, + 'location': LOCATION, + 'endpointUri': ENDPOINT_URI + }, + { + 'name': WORKSPACE, + 'subscriptionId': 'another-sub-id', + 'resourceGroup': 'another-rg', + 'location': 'westus', + 'endpointUri': 'https://another.endpoint.com/' + } + ] + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response): + with pytest.raises(ValueError, match="Multiple Azure Quantum workspaces found"): + mgmt_client.load_workspace_from_arg(connection_params) + + +def test_load_workspace_from_arg_incomplete_workspace_data(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + 'data': [{ + 'name': WORKSPACE, + 'subscriptionId': SUBSCRIPTION_ID, + 'resourceGroup': RESOURCE_GROUP, + 'location': LOCATION, + }] + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response): + with pytest.raises(ValueError, match="Failed to retrieve complete workspace details"): + mgmt_client.load_workspace_from_arg(connection_params) + + +def test_load_workspace_from_arg_request_exception(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + ) + + with patch.object(mgmt_client._client, 'send_request', side_effect=Exception("Network error")): + with pytest.raises(RuntimeError, match="Could not load workspace details from Azure Resource Graph"): + mgmt_client.load_workspace_from_arg(connection_params) + + +def test_load_workspace_from_arm_success(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + 'location': LOCATION, + 'properties': { + 'endpointUri': ENDPOINT_URI + } + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response): + connection_params.location = None + connection_params.quantum_endpoint = None + + mgmt_client.load_workspace_from_arm(connection_params) + + assert connection_params.location == LOCATION + assert connection_params.quantum_endpoint == ENDPOINT_URI + + +def test_load_workspace_from_arm_missing_required_params(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + workspace_name=WORKSPACE + ) + + with pytest.raises(ValueError, match="Missing required connection parameters"): + mgmt_client.load_workspace_from_arm(connection_params) + + +def test_load_workspace_from_arm_workspace_not_found(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + ) + + mock_error = HttpResponseError() + mock_error.status_code = HTTPStatus.NOT_FOUND + + with patch.object(mgmt_client._client, 'send_request', side_effect=mock_error): + with pytest.raises(ValueError, match="not found in resource group"): + mgmt_client.load_workspace_from_arm(connection_params) + + +def test_load_workspace_from_arm_http_error(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + ) + + mock_error = HttpResponseError() + mock_error.status_code = HTTPStatus.FORBIDDEN + + with patch.object(mgmt_client._client, 'send_request', side_effect=mock_error): + with pytest.raises(HttpResponseError): + mgmt_client.load_workspace_from_arm(connection_params) + + +def test_load_workspace_from_arm_missing_location(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + 'properties': { + 'endpointUri': ENDPOINT_URI + } + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response): + with pytest.raises(ValueError, match="Failed to retrieve location"): + mgmt_client.load_workspace_from_arm(connection_params) + + +def test_load_workspace_from_arm_missing_endpoint(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + 'location': LOCATION, + 'properties': {} + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response): + with pytest.raises(ValueError, match="Failed to retrieve endpoint uri"): + mgmt_client.load_workspace_from_arm(connection_params) + + +def test_load_workspace_from_arm_request_exception(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + ) + + with patch.object(mgmt_client._client, 'send_request', side_effect=Exception("Network error")): + with pytest.raises(RuntimeError, match="Could not load workspace details from ARM"): + mgmt_client.load_workspace_from_arm(connection_params) + + +def test_load_workspace_from_arm_uses_custom_api_version(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + api_version="2024-01-01" + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + 'location': LOCATION, + 'properties': { + 'endpointUri': ENDPOINT_URI + } + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response) as mock_send: + mgmt_client.load_workspace_from_arm(connection_params) + + call_args = mock_send.call_args + request = call_args[0][0] + assert "2024-01-01" in request.url + + +def test_load_workspace_from_arm_uses_default_api_version(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + 'location': LOCATION, + 'properties': { + 'endpointUri': ENDPOINT_URI + } + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response) as mock_send: + mgmt_client.load_workspace_from_arm(connection_params) + + call_args = mock_send.call_args + request = call_args[0][0] + assert ConnectionConstants.DEFAULT_WORKSPACE_API_VERSION in request.url + + +def test_load_workspace_from_arg_constructs_correct_url(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + 'data': [{ + 'name': WORKSPACE, + 'subscriptionId': SUBSCRIPTION_ID, + 'resourceGroup': RESOURCE_GROUP, + 'location': LOCATION, + 'endpointUri': ENDPOINT_URI + }] + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response) as mock_send: + mgmt_client.load_workspace_from_arg(connection_params) + + call_args = mock_send.call_args + request = call_args[0][0] + assert "/providers/Microsoft.ResourceGraph/resources" in request.url + assert ConnectionConstants.DEFAULT_ARG_API_VERSION in request.url + + +def test_load_workspace_from_arm_constructs_correct_url(): + mock_credential = MagicMock() + base_url = ConnectionConstants.ARM_PRODUCTION_ENDPOINT + mgmt_client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + 'location': LOCATION, + 'properties': { + 'endpointUri': ENDPOINT_URI + } + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response) as mock_send: + mgmt_client.load_workspace_from_arm(connection_params) + + call_args = mock_send.call_args + request = call_args[0][0] + assert f"/subscriptions/{SUBSCRIPTION_ID}" in request.url + assert f"/resourceGroups/{RESOURCE_GROUP}" in request.url + assert f"/providers/Microsoft.Quantum/workspaces/{WORKSPACE}" in request.url diff --git a/azure-quantum/tests/unit/local/test_workspace.py b/azure-quantum/tests/unit/local/test_workspace.py index 574e86bf..26bf41f4 100644 --- a/azure-quantum/tests/unit/local/test_workspace.py +++ b/azure-quantum/tests/unit/local/test_workspace.py @@ -10,7 +10,7 @@ from azure.core.pipeline.policies import AzureKeyCredentialPolicy from azure.identity import EnvironmentCredential -from mock_client import WorkspaceMock +from mock_client import WorkspaceMock, MockWorkspaceMgmtClient from common import ( SUBSCRIPTION_ID, RESOURCE_GROUP, @@ -18,6 +18,7 @@ LOCATION, STORAGE, API_KEY, + ENDPOINT_URI, ) SIMPLE_RESOURCE_ID = ConnectionConstants.VALID_RESOURCE_ID( @@ -34,47 +35,110 @@ quantum_endpoint=ConnectionConstants.GET_QUANTUM_PRODUCTION_ENDPOINT(LOCATION), ) +SIMPLE_CONNECTION_STRING_V2 = ConnectionConstants.VALID_CONNECTION_STRING( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + api_key=API_KEY, + quantum_endpoint=ConnectionConstants.GET_QUANTUM_PRODUCTION_ENDPOINT_v2(LOCATION) +) + def test_create_workspace_instance_valid(): + def assert_all_required_params(ws: WorkspaceMock): + assert ws.subscription_id == SUBSCRIPTION_ID + assert ws.resource_group == RESOURCE_GROUP + assert ws.name == WORKSPACE + assert ws.location == LOCATION + assert ws._connection_params.quantum_endpoint == ENDPOINT_URI + ws = WorkspaceMock( subscription_id=SUBSCRIPTION_ID, resource_group=RESOURCE_GROUP, name=WORKSPACE, - location=LOCATION, ) - assert ws.subscription_id == SUBSCRIPTION_ID - assert ws.resource_group == RESOURCE_GROUP - assert ws.name == WORKSPACE - assert ws.location == LOCATION + assert_all_required_params(ws) ws = WorkspaceMock( subscription_id=SUBSCRIPTION_ID, resource_group=RESOURCE_GROUP, name=WORKSPACE, - location=LOCATION, storage=STORAGE, ) + assert_all_required_params(ws) assert ws.storage == STORAGE - ws = WorkspaceMock(resource_id=SIMPLE_RESOURCE_ID, location=LOCATION) - assert ws.subscription_id == SUBSCRIPTION_ID - assert ws.resource_group == RESOURCE_GROUP - assert ws.name == WORKSPACE - assert ws.location == LOCATION + ws = WorkspaceMock( + resource_id=SIMPLE_RESOURCE_ID, + ) + assert_all_required_params(ws) ws = WorkspaceMock( - resource_id=SIMPLE_RESOURCE_ID, storage=STORAGE, location=LOCATION + resource_id=SIMPLE_RESOURCE_ID, + storage=STORAGE, ) + assert_all_required_params(ws) assert ws.storage == STORAGE + ws = WorkspaceMock( + name=WORKSPACE, + ) + assert_all_required_params(ws) -def test_create_workspace_locations(): - # User-provided location name should be normalized ws = WorkspaceMock( + name=WORKSPACE, + storage=STORAGE, + ) + assert_all_required_params(ws) + assert ws.storage == STORAGE + + ws = WorkspaceMock( + name=WORKSPACE, + location=LOCATION, + ) + assert_all_required_params(ws) + + ws = WorkspaceMock( + name=WORKSPACE, subscription_id=SUBSCRIPTION_ID, + ) + assert_all_required_params(ws) + + ws = WorkspaceMock( + name=WORKSPACE, + subscription_id=SUBSCRIPTION_ID, + location=LOCATION, + ) + assert_all_required_params(ws) + + ws = WorkspaceMock( + name=WORKSPACE, resource_group=RESOURCE_GROUP, + ) + assert_all_required_params(ws) + + ws = WorkspaceMock( name=WORKSPACE, + resource_group=RESOURCE_GROUP, + location=LOCATION, + ) + assert_all_required_params(ws) + + +def test_create_workspace_locations(): + # Location name should be normalized + _mgmt_client = MockWorkspaceMgmtClient() + def mock_load_workspace_from_arm(connection_params): + connection_params.location = "East US" + connection_params.quantum_endpoint = ENDPOINT_URI + _mgmt_client.load_workspace_from_arm = mock_load_workspace_from_arm + + ws = WorkspaceMock( + name=WORKSPACE, + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, location="East US", + _mgmt_client=_mgmt_client, ) assert ws.location == "eastus" @@ -123,7 +187,7 @@ def test_workspace_from_connection_string(): wrong_subscription_id = "00000000-2BAD-2BAD-2BAD-000000000000" wrong_resource_group = "wrongrg" wrong_workspace = "wrong-workspace" - wrong_location = "wrong-location" + wrong_location = "westus" wrong_connection_string = ConnectionConstants.VALID_CONNECTION_STRING( subscription_id=wrong_subscription_id, @@ -188,6 +252,67 @@ def test_workspace_from_connection_string(): assert workspace.resource_group == RESOURCE_GROUP assert workspace.name == WORKSPACE +def test_workspace_from_connection_string_v2(): + """Test that v2 QuantumEndpoint format is correctly parsed.""" + with mock.patch.dict( + os.environ, + clear=True + ): + workspace = WorkspaceMock.from_connection_string(SIMPLE_CONNECTION_STRING_V2) + assert workspace.location == LOCATION + assert workspace.subscription_id == SUBSCRIPTION_ID + assert workspace.resource_group == RESOURCE_GROUP + assert workspace.name == WORKSPACE + assert isinstance(workspace.credential, AzureKeyCredential) + assert workspace.credential.key == API_KEY + # pylint: disable=protected-access + assert isinstance( + workspace._client._config.authentication_policy, + AzureKeyCredentialPolicy) + auth_policy = workspace._client._config.authentication_policy + assert auth_policy._name == ConnectionConstants.QUANTUM_API_KEY_HEADER + assert id(auth_policy._credential) == id(workspace.credential) + +def test_workspace_from_connection_string_v2_dogfood(): + """Test v2 QuantumEndpoint with dogfood environment.""" + canary_location = "eastus2euap" + dogfood_connection_string_v2 = ConnectionConstants.VALID_CONNECTION_STRING( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + api_key=API_KEY, + quantum_endpoint=ConnectionConstants.GET_QUANTUM_DOGFOOD_ENDPOINT_v2(canary_location) + ) + + with mock.patch.dict(os.environ, clear=True): + workspace = WorkspaceMock.from_connection_string(dogfood_connection_string_v2) + assert workspace.location == canary_location + assert workspace.subscription_id == SUBSCRIPTION_ID + assert workspace.resource_group == RESOURCE_GROUP + assert workspace.name == WORKSPACE + assert isinstance(workspace.credential, AzureKeyCredential) + assert workspace.credential.key == API_KEY + +def test_env_connection_string_v2(): + """Test v2 QuantumEndpoint from environment variable.""" + with mock.patch.dict(os.environ): + os.environ.clear() + os.environ[EnvironmentVariables.CONNECTION_STRING] = SIMPLE_CONNECTION_STRING_V2 + + workspace = WorkspaceMock() + assert workspace.location == LOCATION + assert workspace.subscription_id == SUBSCRIPTION_ID + assert workspace.name == WORKSPACE + assert workspace.resource_group == RESOURCE_GROUP + assert isinstance(workspace.credential, AzureKeyCredential) + assert workspace.credential.key == API_KEY + # pylint: disable=protected-access + assert isinstance( + workspace._client._config.authentication_policy, + AzureKeyCredentialPolicy) + auth_policy = workspace._client._config.authentication_policy + assert auth_policy._name == ConnectionConstants.QUANTUM_API_KEY_HEADER + assert id(auth_policy._credential) == id(workspace.credential) def test_create_workspace_instance_invalid(): def assert_value_error(exception: Exception): @@ -196,56 +321,22 @@ def assert_value_error(exception: Exception): with mock.patch.dict(os.environ): os.environ.clear() - # missing location + # missing workspace name try: WorkspaceMock( - location=None, # type: ignore[arg-type] subscription_id=SUBSCRIPTION_ID, resource_group=RESOURCE_GROUP, - name=WORKSPACE, + name=None ) assert False, "Expected ValueError" except ValueError as e: assert_value_error(e) - # missing location with resource id - try: - WorkspaceMock(resource_id=SIMPLE_RESOURCE_ID) - assert False, "Expected ValueError" - except ValueError as e: - assert_value_error(e) - - # missing subscription id + # provide only subscription id and resource group try: WorkspaceMock( - location=LOCATION, - subscription_id=None, # type: ignore[arg-type] - resource_group=RESOURCE_GROUP, - name=WORKSPACE, - ) - assert False, "Expected ValueError" - except ValueError as e: - assert_value_error(e) - - # missing resource group - try: - WorkspaceMock( - location=LOCATION, - subscription_id=SUBSCRIPTION_ID, - resource_group=None, # type: ignore[arg-type] - name=WORKSPACE, - ) - assert False, "Expected ValueError" - except ValueError as e: - assert_value_error(e) - - # missing workspace name - try: - WorkspaceMock( - location=LOCATION, subscription_id=SUBSCRIPTION_ID, resource_group=RESOURCE_GROUP, - name=None, # type: ignore[arg-type] ) assert False, "Expected ValueError" except ValueError as e: @@ -277,7 +368,6 @@ def test_workspace_user_agent_appid(): subscription_id=SUBSCRIPTION_ID, resource_group=RESOURCE_GROUP, name=WORKSPACE, - location=LOCATION, ) assert ws.user_agent is None @@ -286,7 +376,6 @@ def test_workspace_user_agent_appid(): subscription_id=SUBSCRIPTION_ID, resource_group=RESOURCE_GROUP, name=WORKSPACE, - location=LOCATION, user_agent=user_agent, ) assert ws.user_agent == user_agent @@ -296,7 +385,6 @@ def test_workspace_user_agent_appid(): subscription_id=SUBSCRIPTION_ID, resource_group=RESOURCE_GROUP, name=WORKSPACE, - location=LOCATION, ) ws.append_user_agent("featurex") assert ws.user_agent == "featurex" @@ -309,7 +397,6 @@ def test_workspace_user_agent_appid(): subscription_id=SUBSCRIPTION_ID, resource_group=RESOURCE_GROUP, name=WORKSPACE, - location=LOCATION, ) assert ws.user_agent == app_id @@ -318,7 +405,6 @@ def test_workspace_user_agent_appid(): subscription_id=SUBSCRIPTION_ID, resource_group=RESOURCE_GROUP, name=WORKSPACE, - location=LOCATION, user_agent=user_agent, ) assert ws.user_agent == f"{app_id} {user_agent}" @@ -328,7 +414,6 @@ def test_workspace_user_agent_appid(): subscription_id=SUBSCRIPTION_ID, resource_group=RESOURCE_GROUP, name=WORKSPACE, - location=LOCATION, user_agent=user_agent, ) ws.append_user_agent("featurex") @@ -336,3 +421,47 @@ def test_workspace_user_agent_appid(): ws.append_user_agent(None) assert ws.user_agent == app_id + +def test_workspace_context_manager(): + """Test that Workspace can be used as a context manager""" + with WorkspaceMock( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + name=WORKSPACE, + ) as ws: + # Verify workspace is properly initialized + assert ws.subscription_id == SUBSCRIPTION_ID + assert ws.resource_group == RESOURCE_GROUP + assert ws.name == WORKSPACE + assert ws.location == LOCATION + + # Verify internal clients are accessible + assert ws._client is not None + assert ws._mgmt_client is not None + +def test_workspace_context_manager_calls_enter_exit(): + """Test that __enter__ and __exit__ are called on internal clients""" + ws = WorkspaceMock( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + name=WORKSPACE, + ) + + # Mock the internal clients' __enter__ and __exit__ methods + ws._client.__enter__ = mock.MagicMock(return_value=ws._client) + ws._client.__exit__ = mock.MagicMock(return_value=None) + ws._mgmt_client.__enter__ = mock.MagicMock(return_value=ws._mgmt_client) + ws._mgmt_client.__exit__ = mock.MagicMock(return_value=None) + + # Use workspace as context manager + with ws as context_ws: + # Verify __enter__ was called on both clients + ws._client.__enter__.assert_called_once() + ws._mgmt_client.__enter__.assert_called_once() + + # Verify context manager returns the workspace instance + assert context_ws is ws + + # Verify __exit__ was called on both clients after exiting context + ws._client.__exit__.assert_called_once() + ws._mgmt_client.__exit__.assert_called_once() diff --git a/azure-quantum/tests/unit/local/test_workspace_connection_params_validation.py b/azure-quantum/tests/unit/local/test_workspace_connection_params_validation.py new file mode 100644 index 00000000..fa411473 --- /dev/null +++ b/azure-quantum/tests/unit/local/test_workspace_connection_params_validation.py @@ -0,0 +1,329 @@ +## +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +## + +import pytest +from azure.quantum._workspace_connection_params import WorkspaceConnectionParams + + +def test_valid_subscription_ids(): + """Test that valid subscription_ids are accepted.""" + valid_ids = [ + "12345678-1234-1234-1234-123456789abc", + "ABCDEF01-2345-6789-ABCD-EF0123456789", + "abcdef01-2345-6789-abcd-ef0123456789", + ] + for subscription_id in valid_ids: + params = WorkspaceConnectionParams(subscription_id=subscription_id) + assert params.subscription_id == subscription_id + + +def test_invalid_subscription_ids(): + """Test that invalid subscription_ids raise ValueError.""" + invalid_ids = [ + ("not-a-guid", "Subscription ID must be a valid GUID."), + (12345, "Subscription ID must be a string."), + ] + for subscription_id, expected_message in invalid_ids: + with pytest.raises(ValueError) as exc_info: + WorkspaceConnectionParams(subscription_id=subscription_id) + assert expected_message in str(exc_info.value) + + +def test_valid_resource_groups(): + """Test that valid resource_groups are accepted.""" + valid_groups = [ + "my-resource-group", + "MyResourceGroup", + "resource_group_123", + "rg123", + "a" * 90, # Max length (90 chars) + "a", # Min length (1 char) + "Resource_Group-1", + "my.resource.group", # Periods allowed (except at end) + "group(test)", # Parentheses allowed + "group(test)name", + "(parentheses)", + "test-group_name", + "GROUP-123", + "123-group", + "Test.Group.Name", + "my-group.v2", + "rg_test(prod)-v1.2", + "café", # Unicode letters (Lo) + "日本語", # Unicode letters (Lo) + "Казан", # Unicode letters (Lu, Ll) + "αβγ", # Greek letters (Ll) + "test-café-123", # Mixed ASCII and Unicode + "group_名前", # Mixed ASCII and Unicode + "test.group(1)-name_v2", # Multiple special chars + ] + for resource_group in valid_groups: + params = WorkspaceConnectionParams(resource_group=resource_group) + assert params.resource_group == resource_group + + +def test_invalid_resource_groups(): + """Test that invalid resource_groups raise ValueError.""" + rg_invalid_chars_msg = "Resource group name can only include alphanumeric, underscore, parentheses, hyphen, period (except at end), and Unicode characters that match the allowed characters." + invalid_groups = [ + ("my/resource/group", rg_invalid_chars_msg), + ("my\\resource\\group", rg_invalid_chars_msg), + ("my resource group", rg_invalid_chars_msg), + (12345, "Resource group name must be a string."), + ("group.", rg_invalid_chars_msg), # Period at end + ("my-group.", rg_invalid_chars_msg), # Period at end + ("test.group.", rg_invalid_chars_msg), # Period at end + ("a" * 91, "Resource group name must be between 1 and 90 characters long."), # Too long + ("group@test", rg_invalid_chars_msg), # @ symbol + ("group#test", rg_invalid_chars_msg), # # symbol + ("group$test", rg_invalid_chars_msg), # $ symbol + ("group%test", rg_invalid_chars_msg), # % symbol + ("group^test", rg_invalid_chars_msg), # ^ symbol + ("group&test", rg_invalid_chars_msg), # & symbol + ("group*test", rg_invalid_chars_msg), # * symbol + ("group+test", rg_invalid_chars_msg), # + symbol + ("group=test", rg_invalid_chars_msg), # = symbol + ("group[test]", rg_invalid_chars_msg), # Square brackets + ("group{test}", rg_invalid_chars_msg), # Curly brackets + ("group|test", rg_invalid_chars_msg), # Pipe + ("group:test", rg_invalid_chars_msg), # Colon + ("group;test", rg_invalid_chars_msg), # Semicolon + ("group\"test", rg_invalid_chars_msg), # Quote + ("group'test", rg_invalid_chars_msg), # Single quote + ("group", rg_invalid_chars_msg), # Angle brackets + ("group,test", rg_invalid_chars_msg), # Comma + ("group?test", rg_invalid_chars_msg), # Question mark + ("group!test", rg_invalid_chars_msg), # Exclamation mark + ("group`test", rg_invalid_chars_msg), # Backtick + ("group~test", rg_invalid_chars_msg), # Tilde + ("test\ngroup", rg_invalid_chars_msg), # Newline + ("test\tgroup", rg_invalid_chars_msg), # Tab + ] + for resource_group, expected_message in invalid_groups: + with pytest.raises(ValueError) as exc_info: + WorkspaceConnectionParams(resource_group=resource_group) + assert expected_message in str(exc_info.value) + + +def test_empty_resource_group(): + """Test that empty resource_group is treated as None (not set).""" + # Empty strings are treated as falsy in the merge logic and not set + params = WorkspaceConnectionParams(resource_group="") + assert params.resource_group is None + + +def test_valid_workspace_names(): + """Test that valid workspace names are accepted.""" + valid_names = [ + "12", + "a1", + "1a", + "ab", + "myworkspace", + "WORKSPACE", + "MyWorkspace", + "myWorkSpace", + "myworkspacE", + "1234567890", + "123workspace", + "workspace123", + "w0rksp4c3", + "123abc456def", + "abc123", + # with hyphens + "my-workspace", + "my-work-space", + "workspace-with-a-long-name-that-is-still-valid", + "a-b-c-d-e", + "my-workspace-2", + "workspace-1-2-3", + "1-a", + "b-2", + "1-2", + "a-b", + "1-b-2", + "a-1-b", + "workspace" + "-" * 10 + "test", + "a" * 54, # Max length (54 chars) + "1" * 54, # Max length with numbers + ] + for workspace_name in valid_names: + params = WorkspaceConnectionParams(workspace_name=workspace_name) + assert params.workspace_name == workspace_name + + +def test_invalid_workspace_names(): + """Test that invalid workspace names raise ValueError.""" + not_valid_names = [ + ("a", "Workspace name must be between 2 and 54 characters long."), + ("1", "Workspace name must be between 2 and 54 characters long."), + ("a" * 55, "Workspace name must be between 2 and 54 characters long."), + ("1" * 55, "Workspace name must be between 2 and 54 characters long."), + ("my_workspace", "Workspace name can only include alphanumerics (a-zA-Z0-9) and hyphens, and cannot start or end with hyphen."), + ("my/workspace", "Workspace name can only include alphanumerics (a-zA-Z0-9) and hyphens, and cannot start or end with hyphen."), + ("my workspace", "Workspace name can only include alphanumerics (a-zA-Z0-9) and hyphens, and cannot start or end with hyphen."), + ("-myworkspace", "Workspace name can only include alphanumerics (a-zA-Z0-9) and hyphens, and cannot start or end with hyphen."), + ("myworkspace-", "Workspace name can only include alphanumerics (a-zA-Z0-9) and hyphens, and cannot start or end with hyphen."), + (12345, "Workspace name must be a string."), + ] + for workspace_name, expected_message in not_valid_names: + with pytest.raises(ValueError) as exc_info: + WorkspaceConnectionParams(workspace_name=workspace_name) + assert expected_message in str(exc_info.value) + + +def test_empty_workspace_name(): + """Test that empty workspace_name is treated as None (not set).""" + # Empty strings are treated as falsy in the merge logic and not set + params = WorkspaceConnectionParams(workspace_name="") + assert params.workspace_name is None + + +def test_valid_locations(): + """Test that valid locations are accepted and normalized.""" + valid_locations = [ + ("East US", "eastus"), + ("West Europe", "westeurope"), + ("eastus", "eastus"), + ("westus2", "westus2"), + ("EASTUS", "eastus"), + ("WestUs2", "westus2"), + ("South Central US", "southcentralus"), + ("North Europe", "northeurope"), + ("Southeast Asia", "southeastasia"), + ("Japan East", "japaneast"), + ("UK South", "uksouth"), + ("Australia East", "australiaeast"), + ("Central India", "centralindia"), + ("France Central", "francecentral"), + ("Germany West Central", "germanywestcentral"), + ("Switzerland North", "switzerlandnorth"), + ("UAE North", "uaenorth"), + ("Brazil South", "brazilsouth"), + ("Korea Central", "koreacentral"), + ("South Africa North", "southafricanorth"), + ("Norway East", "norwayeast"), + ("Sweden Central", "swedencentral"), + ("Qatar Central", "qatarcentral"), + ("Poland Central", "polandcentral"), + ("Italy North", "italynorth"), + ("Israel Central", "israelcentral"), + ("Spain Central", "spaincentral"), + ("Austria East", "austriaeast"), + ("Belgium Central", "belgiumcentral"), + ("Chile Central", "chilecentral"), + ("Indonesia Central", "indonesiacentral"), + ("Malaysia West", "malaysiawest"), + ("Mexico Central", "mexicocentral"), + ("New Zealand North", "newzealandnorth"), + ("westus3", "westus3"), + ("canadacentral", "canadacentral"), + ("westcentralus", "westcentralus"), + ] + for location, expected in valid_locations: + params = WorkspaceConnectionParams(location=location) + assert params.location == expected + + +def test_invalid_locations(): + """Test that invalid locations raise ValueError.""" + location_invalid_region_msg = "Location must be one of the Azure regions listed in https://learn.microsoft.com/en-us/azure/reliability/regions-list." + invalid_locations = [ + (" ", location_invalid_region_msg), + ("invalid-region", location_invalid_region_msg), + ("us-east", location_invalid_region_msg), + ("east-us", location_invalid_region_msg), + ("westus4", location_invalid_region_msg), + ("southus", location_invalid_region_msg), + ("centraleurope", location_invalid_region_msg), + ("asiaeast", location_invalid_region_msg), + ("chinaeast", location_invalid_region_msg), + ("usgovtexas", location_invalid_region_msg), + ("East US 3", location_invalid_region_msg), + ("not a region", location_invalid_region_msg), + (12345, "Location must be a string."), + (3.14, "Location must be a string."), + (True, "Location must be a string."), + ] + for location, expected_message in invalid_locations: + with pytest.raises(ValueError) as exc_info: + WorkspaceConnectionParams(location=location) + assert expected_message in str(exc_info.value) + + +def test_empty_location(): + """Test that empty location is treated as None (not set).""" + # Empty strings are treated as falsy in the merge logic and not set + params = WorkspaceConnectionParams(location="") + assert params.location is None + + # None is also allowed and treated as not set + params = WorkspaceConnectionParams(location=None) + assert params.location is None + + +def test_none_values_are_allowed(): + """Test that None values for optional fields are allowed.""" + # This should not raise any exceptions + params = WorkspaceConnectionParams( + subscription_id=None, + resource_group=None, + workspace_name=None, + location=None, + user_agent=None, + ) + assert params.subscription_id is None + assert params.resource_group is None + assert params.workspace_name is None + assert params.location is None + assert params.user_agent is None + + +def test_multiple_valid_parameters(): + """Test that multiple valid parameters work together.""" + params = WorkspaceConnectionParams( + subscription_id="12345678-1234-1234-1234-123456789abc", + resource_group="my-resource-group", + workspace_name="my-workspace", + location="East US", + user_agent="my-app/1.0", + ) + assert params.subscription_id == "12345678-1234-1234-1234-123456789abc" + assert params.resource_group == "my-resource-group" + assert params.workspace_name == "my-workspace" + assert params.location == "eastus" + assert params.user_agent == "my-app/1.0" + + +def test_validation_on_resource_id(): + """Test that validation works when using resource_id.""" + # Valid resource_id should work + resource_id = ( + "/subscriptions/12345678-1234-1234-1234-123456789abc" + "/resourceGroups/my-rg" + "/providers/Microsoft.Quantum" + "/Workspaces/my-ws" + ) + params = WorkspaceConnectionParams(resource_id=resource_id) + assert params.subscription_id == "12345678-1234-1234-1234-123456789abc" + assert params.resource_group == "my-rg" + assert params.workspace_name == "my-ws" + + +def test_validation_on_connection_string(): + """Test that validation works when using connection_string.""" + # Valid connection string should work + connection_string = ( + "SubscriptionId=12345678-1234-1234-1234-123456789abc;" + "ResourceGroupName=my-rg;" + "WorkspaceName=my-ws;" + "ApiKey=test-key;" + "QuantumEndpoint=https://eastus.quantum.azure.com/;" + ) + params = WorkspaceConnectionParams(connection_string=connection_string) + assert params.subscription_id == "12345678-1234-1234-1234-123456789abc" + assert params.resource_group == "my-rg" + assert params.workspace_name == "my-ws" + assert params.location == "eastus"