diff --git a/src/databricks/sql/common/feature_flag.py b/src/databricks/sql/common/feature_flag.py index 032701f63..36e4b8a02 100644 --- a/src/databricks/sql/common/feature_flag.py +++ b/src/databricks/sql/common/feature_flag.py @@ -6,6 +6,7 @@ from typing import Dict, Optional, List, Any, TYPE_CHECKING from databricks.sql.common.http import HttpMethod +from databricks.sql.common.url_utils import normalize_host_with_protocol if TYPE_CHECKING: from databricks.sql.client import Connection @@ -67,7 +68,8 @@ def __init__( endpoint_suffix = FEATURE_FLAGS_ENDPOINT_SUFFIX_FORMAT.format(__version__) self._feature_flag_endpoint = ( - f"https://{self._connection.session.host}{endpoint_suffix}" + normalize_host_with_protocol(self._connection.session.host) + + endpoint_suffix ) # Use the provided HTTP client diff --git a/src/databricks/sql/common/url_utils.py b/src/databricks/sql/common/url_utils.py new file mode 100644 index 000000000..0a6f89274 --- /dev/null +++ b/src/databricks/sql/common/url_utils.py @@ -0,0 +1,44 @@ +""" +URL utility functions for the Databricks SQL connector. +""" + + +def normalize_host_with_protocol(host: str) -> str: + """ + Normalize a connection hostname by ensuring it has a protocol and removing trailing slashes. + + This is useful for handling cases where users may provide hostnames with or without protocols + (common with dbt-databricks users copying URLs from their browser). + + Args: + host: Connection hostname which may or may not include a protocol prefix (https:// or http://) + and may or may not have a trailing slash + + Returns: + Normalized hostname with protocol prefix and no trailing slash + + Examples: + normalize_host_with_protocol("myserver.com") -> "https://myserver.com" + normalize_host_with_protocol("https://myserver.com") -> "https://myserver.com" + normalize_host_with_protocol("HTTPS://myserver.com") -> "https://myserver.com" + + Raises: + ValueError: If host is None or empty string + """ + # Handle None or empty host + if not host or not host.strip(): + raise ValueError("Host cannot be None or empty") + + # Remove trailing slash + host = host.rstrip("/") + + # Add protocol if not present (case-insensitive check) + host_lower = host.lower() + if not host_lower.startswith("https://") and not host_lower.startswith("http://"): + host = f"https://{host}" + elif host_lower.startswith("https://") or host_lower.startswith("http://"): + # Normalize protocol to lowercase + protocol_end = host.index("://") + 3 + host = host[:protocol_end].lower() + host[protocol_end:] + + return host diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 77d1a2f9c..bc1626a3d 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -47,6 +47,7 @@ TelemetryPushClient, CircuitBreakerTelemetryPushClient, ) +from databricks.sql.common.url_utils import normalize_host_with_protocol if TYPE_CHECKING: from databricks.sql.client import Connection @@ -278,7 +279,7 @@ def _send_telemetry(self, events): if self._auth_provider else self.TELEMETRY_UNAUTHENTICATED_PATH ) - url = f"https://{self._host_url}{path}" + url = normalize_host_with_protocol(self._host_url) + path headers = {"Accept": "application/json", "Content-Type": "application/json"} diff --git a/tests/e2e/test_circuit_breaker.py b/tests/e2e/test_circuit_breaker.py index 45c494d19..cd4d2b129 100644 --- a/tests/e2e/test_circuit_breaker.py +++ b/tests/e2e/test_circuit_breaker.py @@ -23,6 +23,46 @@ from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager +def wait_for_circuit_state(circuit_breaker, expected_state, timeout=5): + """ + Wait for circuit breaker to reach expected state with polling. + + Args: + circuit_breaker: The circuit breaker instance to monitor + expected_state: The expected state (STATE_OPEN, STATE_CLOSED, STATE_HALF_OPEN) + timeout: Maximum time to wait in seconds + + Returns: + True if state reached, False if timeout + """ + start = time.time() + while time.time() - start < timeout: + if circuit_breaker.current_state == expected_state: + return True + time.sleep(0.1) # Poll every 100ms + return False + + +def wait_for_circuit_state_multiple(circuit_breaker, expected_states, timeout=5): + """ + Wait for circuit breaker to reach one of multiple expected states. + + Args: + circuit_breaker: The circuit breaker instance to monitor + expected_states: List of acceptable states + timeout: Maximum time to wait in seconds + + Returns: + True if any state reached, False if timeout + """ + start = time.time() + while time.time() - start < timeout: + if circuit_breaker.current_state in expected_states: + return True + time.sleep(0.1) + return False + + @pytest.fixture(autouse=True) def aggressive_circuit_breaker_config(): """ @@ -107,9 +147,13 @@ def mock_request(*args, **kwargs): time.sleep(0.5) if should_trigger: - # Circuit should be OPEN after 2 rate-limit failures + # Wait for circuit to open (async telemetry may take time) + assert wait_for_circuit_state(circuit_breaker, STATE_OPEN, timeout=5), \ + f"Circuit didn't open within 5s, state: {circuit_breaker.current_state}" + + # Circuit should be OPEN after rate-limit failures assert circuit_breaker.current_state == STATE_OPEN - assert circuit_breaker.fail_counter == 2 + assert circuit_breaker.fail_counter >= 2 # At least 2 failures # Track requests before another query requests_before = request_count["count"] @@ -197,7 +241,9 @@ def mock_conditional_request(*args, **kwargs): cursor.fetchone() time.sleep(2) - assert circuit_breaker.current_state == STATE_OPEN + # Wait for circuit to open + assert wait_for_circuit_state(circuit_breaker, STATE_OPEN, timeout=5), \ + f"Circuit didn't open, state: {circuit_breaker.current_state}" # Wait for reset timeout (5 seconds in test) time.sleep(6) @@ -208,24 +254,20 @@ def mock_conditional_request(*args, **kwargs): # Execute query to trigger HALF_OPEN state cursor.execute("SELECT 3") cursor.fetchone() - time.sleep(1) - # Circuit should be recovering - assert circuit_breaker.current_state in [ - STATE_HALF_OPEN, - STATE_CLOSED, - ], f"Circuit should be recovering, but is {circuit_breaker.current_state}" + # Wait for circuit to start recovering + assert wait_for_circuit_state_multiple( + circuit_breaker, [STATE_HALF_OPEN, STATE_CLOSED], timeout=5 + ), f"Circuit didn't recover, state: {circuit_breaker.current_state}" # Execute more queries to fully recover cursor.execute("SELECT 4") cursor.fetchone() - time.sleep(1) - current_state = circuit_breaker.current_state - assert current_state in [ - STATE_CLOSED, - STATE_HALF_OPEN, - ], f"Circuit should recover to CLOSED or HALF_OPEN, got {current_state}" + # Wait for full recovery + assert wait_for_circuit_state_multiple( + circuit_breaker, [STATE_CLOSED, STATE_HALF_OPEN], timeout=5 + ), f"Circuit didn't fully recover, state: {circuit_breaker.current_state}" if __name__ == "__main__": diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 8f8a97eae..5b6991931 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -646,13 +646,31 @@ class TransactionTestSuite(unittest.TestCase): "access_token": "tok", } + def _setup_mock_session_with_http_client(self, mock_session): + """ + Helper to configure a mock session with HTTP client mocks. + This prevents feature flag network requests during Connection initialization. + """ + mock_session.host = "foo" + + # Mock HTTP client to prevent feature flag network requests + mock_http_client = Mock() + mock_session.http_client = mock_http_client + + # Mock feature flag response to prevent blocking HTTP calls + mock_ff_response = Mock() + mock_ff_response.status = 200 + mock_ff_response.data = b'{"flags": [], "ttl_seconds": 900}' + mock_http_client.request.return_value = mock_ff_response + def _create_mock_connection(self, mock_session_class): """Helper to create a mocked connection for transaction tests.""" - # Mock session mock_session = Mock() mock_session.is_open = True mock_session.guid_hex = "test-session-id" mock_session.get_autocommit.return_value = True + + self._setup_mock_session_with_http_client(mock_session) mock_session_class.return_value = mock_session # Create connection with ignore_transactions=False to test actual transaction functionality @@ -736,9 +754,7 @@ def test_autocommit_setter_preserves_exception_chain(self, mock_session_class): conn = self._create_mock_connection(mock_session_class) mock_cursor = Mock() - original_error = DatabaseError( - "Original error", host_url="test-host" - ) + original_error = DatabaseError("Original error", host_url="test-host") mock_cursor.execute.side_effect = original_error with patch.object(conn, "cursor", return_value=mock_cursor): @@ -927,6 +943,8 @@ def test_fetch_autocommit_from_server_queries_server(self, mock_session_class): mock_session = Mock() mock_session.is_open = True mock_session.guid_hex = "test-session-id" + + self._setup_mock_session_with_http_client(mock_session) mock_session_class.return_value = mock_session conn = client.Connection( @@ -959,6 +977,8 @@ def test_fetch_autocommit_from_server_handles_false_value(self, mock_session_cla mock_session = Mock() mock_session.is_open = True mock_session.guid_hex = "test-session-id" + + self._setup_mock_session_with_http_client(mock_session) mock_session_class.return_value = mock_session conn = client.Connection( @@ -986,6 +1006,8 @@ def test_fetch_autocommit_from_server_raises_on_no_result(self, mock_session_cla mock_session = Mock() mock_session.is_open = True mock_session.guid_hex = "test-session-id" + + self._setup_mock_session_with_http_client(mock_session) mock_session_class.return_value = mock_session conn = client.Connection( @@ -1015,6 +1037,8 @@ def test_commit_is_noop_when_ignore_transactions_true(self, mock_session_class): mock_session = Mock() mock_session.is_open = True mock_session.guid_hex = "test-session-id" + + self._setup_mock_session_with_http_client(mock_session) mock_session_class.return_value = mock_session # Create connection with ignore_transactions=True (default) @@ -1043,6 +1067,8 @@ def test_rollback_raises_not_supported_when_ignore_transactions_true( mock_session = Mock() mock_session.is_open = True mock_session.guid_hex = "test-session-id" + + self._setup_mock_session_with_http_client(mock_session) mock_session_class.return_value = mock_session # Create connection with ignore_transactions=True (default) @@ -1068,6 +1094,8 @@ def test_autocommit_setter_is_noop_when_ignore_transactions_true( mock_session = Mock() mock_session.is_open = True mock_session.guid_hex = "test-session-id" + + self._setup_mock_session_with_http_client(mock_session) mock_session_class.return_value = mock_session # Create connection with ignore_transactions=True (default) diff --git a/tests/unit/test_url_utils.py b/tests/unit/test_url_utils.py new file mode 100644 index 000000000..d45b68263 --- /dev/null +++ b/tests/unit/test_url_utils.py @@ -0,0 +1,70 @@ +"""Tests for URL utility functions.""" +import pytest +from databricks.sql.common.url_utils import normalize_host_with_protocol + + +class TestNormalizeHostWithProtocol: + """Tests for normalize_host_with_protocol function.""" + + @pytest.mark.parametrize("input_host,expected_output", [ + # Hostname without protocol - should add https:// + ("myserver.com", "https://myserver.com"), + ("workspace.databricks.com", "https://workspace.databricks.com"), + + # Hostname with https:// - should not duplicate + ("https://myserver.com", "https://myserver.com"), + ("https://workspace.databricks.com", "https://workspace.databricks.com"), + + # Hostname with http:// - should preserve + ("http://localhost", "http://localhost"), + ("http://myserver.com:8080", "http://myserver.com:8080"), + + # Hostname with port numbers + ("myserver.com:443", "https://myserver.com:443"), + ("https://myserver.com:443", "https://myserver.com:443"), + ("http://localhost:8080", "http://localhost:8080"), + + # Trailing slash - should be removed + ("myserver.com/", "https://myserver.com"), + ("https://myserver.com/", "https://myserver.com"), + ("http://localhost/", "http://localhost"), + + # Case-insensitive protocol handling - should normalize to lowercase + ("HTTPS://myserver.com", "https://myserver.com"), + ("HTTP://myserver.com", "http://myserver.com"), + ("HttPs://workspace.databricks.com", "https://workspace.databricks.com"), + ("HtTp://localhost:8080", "http://localhost:8080"), + ("HTTPS://MYSERVER.COM", "https://MYSERVER.COM"), # Only protocol lowercased + + # Case-insensitive with trailing slashes + ("HTTPS://myserver.com/", "https://myserver.com"), + ("HTTP://localhost:8080/", "http://localhost:8080"), + ("HttPs://workspace.databricks.com//", "https://workspace.databricks.com"), + + # Mixed case protocols with ports + ("HTTPS://myserver.com:443", "https://myserver.com:443"), + ("HtTp://myserver.com:8080", "http://myserver.com:8080"), + + # Case preservation - only protocol lowercased, hostname case preserved + ("HTTPS://MyServer.DataBricks.COM", "https://MyServer.DataBricks.COM"), + ("HttPs://CamelCase.Server.com", "https://CamelCase.Server.com"), + ("HTTP://UPPERCASE.COM:8080", "http://UPPERCASE.COM:8080"), + ]) + def test_normalize_host_with_protocol(self, input_host, expected_output): + """Test host normalization with various input formats.""" + result = normalize_host_with_protocol(input_host) + assert result == expected_output + + # Additional assertion: verify protocol is always lowercase + assert result.startswith("https://") or result.startswith("http://") + + @pytest.mark.parametrize("invalid_host", [ + None, + "", + " ", # Whitespace only + ]) + def test_normalize_host_with_protocol_raises_on_invalid_input(self, invalid_host): + """Test that function raises ValueError for None or empty host.""" + with pytest.raises(ValueError, match="Host cannot be None or empty"): + normalize_host_with_protocol(invalid_host) +