Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/databricks/sql/common/feature_flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Dict, Optional, List, Any, TYPE_CHECKING

from databricks.sql.common.http import HttpMethod
from databricks.sql.common.url_utils import normalize_host_with_protocol

if TYPE_CHECKING:
from databricks.sql.client import Connection
Expand Down Expand Up @@ -67,7 +68,8 @@ def __init__(

endpoint_suffix = FEATURE_FLAGS_ENDPOINT_SUFFIX_FORMAT.format(__version__)
self._feature_flag_endpoint = (
f"https://{self._connection.session.host}{endpoint_suffix}"
normalize_host_with_protocol(self._connection.session.host)
+ endpoint_suffix
)

# Use the provided HTTP client
Expand Down
31 changes: 31 additions & 0 deletions src/databricks/sql/common/url_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
URL utility functions for the Databricks SQL connector.
"""


def normalize_host_with_protocol(host: str) -> str:
"""
Normalize a connection hostname by ensuring it has a protocol and removing trailing slashes.

This is useful for handling cases where users may provide hostnames with or without protocols
(common with dbt-databricks users copying URLs from their browser).

Args:
host: Connection hostname which may or may not include a protocol prefix (https:// or http://)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit : case insensitivity should be considered

and may or may not have a trailing slash

Returns:
Normalized hostname with protocol prefix and no trailing slash

Examples:
normalize_host_with_protocol("myserver.com") -> "https://myserver.com"
normalize_host_with_protocol("https://myserver.com") -> "https://myserver.com"
"""
# Remove trailing slash
host = host.rstrip("/")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit : check for NPE


# Add protocol if not present
if not host.startswith("https://") and not host.startswith("http://"):
host = f"https://{host}"

return host
3 changes: 2 additions & 1 deletion src/databricks/sql/telemetry/telemetry_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
TelemetryPushClient,
CircuitBreakerTelemetryPushClient,
)
from databricks.sql.common.url_utils import normalize_host_with_protocol

if TYPE_CHECKING:
from databricks.sql.client import Connection
Expand Down Expand Up @@ -278,7 +279,7 @@ def _send_telemetry(self, events):
if self._auth_provider
else self.TELEMETRY_UNAUTHENTICATED_PATH
)
url = f"https://{self._host_url}{path}"
url = normalize_host_with_protocol(self._host_url) + path

headers = {"Accept": "application/json", "Content-Type": "application/json"}

Expand Down
36 changes: 32 additions & 4 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,13 +646,31 @@ class TransactionTestSuite(unittest.TestCase):
"access_token": "tok",
}

def _setup_mock_session_with_http_client(self, mock_session):
"""
Helper to configure a mock session with HTTP client mocks.
This prevents feature flag network requests during Connection initialization.
"""
mock_session.host = "foo"

# Mock HTTP client to prevent feature flag network requests
mock_http_client = Mock()
mock_session.http_client = mock_http_client

# Mock feature flag response to prevent blocking HTTP calls
mock_ff_response = Mock()
mock_ff_response.status = 200
mock_ff_response.data = b'{"flags": [], "ttl_seconds": 900}'
mock_http_client.request.return_value = mock_ff_response

def _create_mock_connection(self, mock_session_class):
"""Helper to create a mocked connection for transaction tests."""
# Mock session
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"
mock_session.get_autocommit.return_value = True

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

# Create connection with ignore_transactions=False to test actual transaction functionality
Expand Down Expand Up @@ -736,9 +754,7 @@ def test_autocommit_setter_preserves_exception_chain(self, mock_session_class):
conn = self._create_mock_connection(mock_session_class)

mock_cursor = Mock()
original_error = DatabaseError(
"Original error", host_url="test-host"
)
original_error = DatabaseError("Original error", host_url="test-host")
mock_cursor.execute.side_effect = original_error

with patch.object(conn, "cursor", return_value=mock_cursor):
Expand Down Expand Up @@ -927,6 +943,8 @@ def test_fetch_autocommit_from_server_queries_server(self, mock_session_class):
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

conn = client.Connection(
Expand Down Expand Up @@ -959,6 +977,8 @@ def test_fetch_autocommit_from_server_handles_false_value(self, mock_session_cla
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

conn = client.Connection(
Expand Down Expand Up @@ -986,6 +1006,8 @@ def test_fetch_autocommit_from_server_raises_on_no_result(self, mock_session_cla
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

conn = client.Connection(
Expand Down Expand Up @@ -1015,6 +1037,8 @@ def test_commit_is_noop_when_ignore_transactions_true(self, mock_session_class):
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

# Create connection with ignore_transactions=True (default)
Expand Down Expand Up @@ -1043,6 +1067,8 @@ def test_rollback_raises_not_supported_when_ignore_transactions_true(
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

# Create connection with ignore_transactions=True (default)
Expand All @@ -1068,6 +1094,8 @@ def test_autocommit_setter_is_noop_when_ignore_transactions_true(
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

# Create connection with ignore_transactions=True (default)
Expand Down
36 changes: 36 additions & 0 deletions tests/unit/test_url_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Tests for URL utility functions."""
import pytest
from databricks.sql.common.url_utils import normalize_host_with_protocol


class TestNormalizeHostWithProtocol:
"""Tests for normalize_host_with_protocol function."""

@pytest.mark.parametrize("input_host,expected_output", [
# Hostname without protocol - should add https://
("myserver.com", "https://myserver.com"),
("workspace.databricks.com", "https://workspace.databricks.com"),

# Hostname with https:// - should not duplicate
("https://myserver.com", "https://myserver.com"),
("https://workspace.databricks.com", "https://workspace.databricks.com"),

# Hostname with http:// - should preserve
("http://localhost", "http://localhost"),
("http://myserver.com:8080", "http://myserver.com:8080"),

# Hostname with port numbers
("myserver.com:443", "https://myserver.com:443"),
("https://myserver.com:443", "https://myserver.com:443"),
("http://localhost:8080", "http://localhost:8080"),

# Trailing slash - should be removed
("myserver.com/", "https://myserver.com"),
("https://myserver.com/", "https://myserver.com"),
("http://localhost/", "http://localhost"),

])
def test_normalize_host_with_protocol(self, input_host, expected_output):
"""Test host normalization with various input formats."""
assert normalize_host_with_protocol(input_host) == expected_output

Loading