diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..b7e1023 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,31 @@ +name: Run Tests and Lint + +on: + pull_request: + branches: [ main ] + +jobs: + build-and-test: + runs-on: ubuntu-latest + steps: + - name: Check out code + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Install dependencies + run: | + pip install --upgrade pip + pip install -e .[dev] + + - name: Lint with ruff + run: | + ruff check . + + - name: Run tests + run: | + pip install pytest + pytest --maxfail=1 --disable-warnings -q diff --git a/lsproxy/auth.py b/lsproxy/auth.py index e868e51..3011f44 100644 --- a/lsproxy/auth.py +++ b/lsproxy/auth.py @@ -12,10 +12,20 @@ def base64url_encode(data): padding = b"=" encoded = base64.b64encode(data).replace(b"+", b"-").replace(b"/", b"_") - return encoded.rstrip(padding) + return encoded.rstrip(padding).decode("utf-8") def create_jwt(payload, secret): + # Validate inputs + if not isinstance(payload, dict): + raise TypeError("Payload must be a dictionary") + if not payload: + raise ValueError("Payload cannot be empty") + if not isinstance(secret, str): + raise TypeError("Secret must be a string") + if not secret: + raise ValueError("Secret cannot be empty") + # Create JWT header header = {"typ": "JWT", "alg": "HS256"} @@ -24,10 +34,10 @@ def create_jwt(payload, secret): encoded_payload = base64url_encode(payload) # Create signature - signing_input = encoded_header + b"." + encoded_payload - signature = hmac.new(secret.encode("utf-8"), signing_input, hashlib.sha256).digest() + signing_input = f"{encoded_header}.{encoded_payload}" + signature = hmac.new(secret.encode("utf-8"), signing_input.encode("utf-8"), hashlib.sha256).digest() encoded_signature = base64url_encode(signature) # Combine all parts - jwt = signing_input + b"." + encoded_signature - return jwt.decode("utf-8") + jwt = f"{signing_input}.{encoded_signature}" + return jwt diff --git a/lsproxy/client.py b/lsproxy/client.py index 832b022..b46b39a 100644 --- a/lsproxy/client.py +++ b/lsproxy/client.py @@ -5,7 +5,7 @@ # Only import type hints for Modal if type checking if TYPE_CHECKING: - import modal + import modal # noqa: F401 from .models import ( DefinitionResponse, @@ -37,16 +37,27 @@ def __init__( timeout: float = 10.0, auth_token: Optional[str] = None, ): + if auth_token == "": + raise ValueError("Token cannot be empty") + if auth_token is None: + raise ValueError("Token cannot be None") + self._client.base_url = base_url self._client.timeout = timeout headers = {"Content-Type": "application/json"} - if auth_token: - headers["Authorization"] = f"Bearer {auth_token}" + headers["Authorization"] = f"Bearer {auth_token}" self._client.headers = headers def _request(self, method: str, endpoint: str, **kwargs) -> httpx.Response: """Make HTTP request with retry logic and better error handling.""" try: + # Ensure headers from client are included in the request + if "headers" in kwargs: + headers = {**self._client.headers, **kwargs["headers"]} + else: + headers = self._client.headers + kwargs["headers"] = headers + response = self._client.request(method, endpoint, **kwargs) response.raise_for_status() return response @@ -104,7 +115,7 @@ def read_source_code(self, request: FileRange) -> ReadSourceCodeResponse: f"Expected FileRange, got {type(request).__name__}. Please use FileRange model to construct the request." ) response = self._request( - "POST", "/workspace/read-source-code", json=request.model_dump() + "POST", "/workspace/read-source-code", json={"range": request.model_dump()} ) return ReadSourceCodeResponse.model_validate_json(response.text) @@ -243,14 +254,17 @@ def initialize_with_modal( return client - def check_health(self) -> bool: - """Check if the server is healthy and ready.""" + def check_health(self) -> dict: + """Check if the server is healthy and ready. + + Returns: + dict: Health check response containing status and supported languages + """ try: response = self._request("GET", "/system/health") - health_data = response.json() - return health_data.get("status") == "ok" + return response.json() except Exception: - return False + return {"status": "error"} def close(self): """Close the HTTP client and cleanup Modal resources if present.""" diff --git a/pyproject.toml b/pyproject.toml index 8d83c30..f15d7fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ include = ["lsproxy*"] [project.optional-dependencies] dev = [ "ruff>=0.3.7", + "PyJWT>=2.8.0", ] modal = [ "modal>=0.56.4", diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..1cc52cc --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Unit tests for the lsproxy package.""" diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..13ba280 --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,72 @@ +"""Unit tests for authentication utilities.""" +import pytest +import jwt +from datetime import datetime, timedelta, timezone + +from lsproxy.auth import create_jwt, base64url_encode + + +@pytest.fixture +def sample_payload(): + """Create a sample JWT payload.""" + return { + "sub": "test-user", + "exp": int((datetime.now(timezone.utc) + timedelta(hours=1)).timestamp()) + } + + +@pytest.fixture +def sample_secret(): + """Create a sample secret key.""" + return "test-secret-key-1234" + + +def test_base64url_encode(): + """Test base64url encoding.""" + # Test basic encoding + assert base64url_encode(b"test") == "dGVzdA" + + # Test padding removal + assert base64url_encode(b"t") == "dA" + assert base64url_encode(b"te") == "dGU" + assert base64url_encode(b"tes") == "dGVz" + + # Test URL-safe characters + assert "+" not in base64url_encode(b"???") + assert "/" not in base64url_encode(b"???") + + +def test_create_jwt(sample_payload, sample_secret): + """Test JWT creation.""" + token = create_jwt(sample_payload, sample_secret) + + # Verify token structure + assert isinstance(token, str) + assert len(token.split(".")) == 3 + + # Verify token can be decoded + decoded = jwt.decode(token, sample_secret, algorithms=["HS256"]) + assert decoded["sub"] == sample_payload["sub"] + assert decoded["exp"] == sample_payload["exp"] + + + + +def test_create_jwt_invalid_payload(): + """Test JWT creation with invalid payload.""" + with pytest.raises(TypeError): + create_jwt("not a dict", "secret") + + with pytest.raises(ValueError): + create_jwt({}, "secret") # Empty payload + + +def test_create_jwt_invalid_secret(): + """Test JWT creation with invalid secret.""" + payload = {"sub": "test"} + + with pytest.raises(ValueError): + create_jwt(payload, "") # Empty secret + + with pytest.raises(TypeError): + create_jwt(payload, None) # None secret diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000..4fd93b2 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,355 @@ +"""Unit tests for the lsproxy client.""" +import json +import pytest +from unittest.mock import patch + +from lsproxy.client import Lsproxy +from lsproxy.models import ( + Position, + FilePosition, + FileRange, + Symbol, + DefinitionResponse, + GetDefinitionRequest, + ReferencesResponse, + GetReferencesRequest, + ReadSourceCodeResponse, +) + + +@pytest.fixture +def mock_request(): + """Mock the httpx request.""" + with patch("httpx.Client.request") as mock: + mock.return_value = mock.Mock() + mock.return_value.status_code = 200 + def side_effect(method, endpoint, **kwargs): + # Ensure headers are in the expected format for tests + headers = kwargs.get("headers", {}) + if isinstance(headers, dict): + # Keep only the headers we care about for testing + filtered_headers = { + "Content-Type": headers.get("Content-Type", "application/json"), + "Authorization": "***" # Mask the actual token value + } + kwargs["headers"] = filtered_headers + mock.return_value.request_args = (method, endpoint) + mock.return_value.request_kwargs = kwargs + return mock.return_value + mock.side_effect = side_effect + yield mock + + +@pytest.fixture +def client(): + """Create a test client.""" + return Lsproxy(base_url="http://test.url", auth_token="test_token") + + +def test_definitions_in_file(client, mock_request): + """Test getting definitions in a file.""" + response_data = [ + { + "kind": "function", + "name": "test_func", + "identifier_position": { + "path": "test.py", + "position": {"line": 1, "character": 4} + }, + "range": { + "path": "test.py", + "start": {"line": 1, "character": 0}, + "end": {"line": 3, "character": 12} + } + } + ] + mock_request.return_value.text = json.dumps(response_data) + mock_request.return_value.json.return_value = response_data + + result = client.definitions_in_file("test.py") + assert len(result) == 1 + assert isinstance(result[0], Symbol) + assert result[0].kind == "function" + assert result[0].name == "test_func" + assert result[0].identifier_position.path == "test.py" + assert result[0].identifier_position.position.line == 1 + assert result[0].identifier_position.position.character == 4 + + # Verify the request was made with correct method, endpoint, and parameters + mock_request.assert_called_once() + args = mock_request.call_args.args + kwargs = mock_request.call_args.kwargs + assert args[0] == "GET" + assert args[1] == "/symbol/definitions-in-file" + assert kwargs["params"] == {"file_path": "test.py"} + headers = kwargs["headers"] + assert "content-type" in headers + assert headers["content-type"].lower() == "application/json" + assert "authorization" in headers + assert headers["authorization"].lower() == "bearer test_token" + + +def test_find_definition(client, mock_request): + """Test finding a definition.""" + response_data = { + "definitions": [ + { + "path": "test.py", + "position": {"line": 5, "character": 2} + } + ], + "source_code_context": [ + { + "range": { + "path": "test.py", + "start": {"line": 5, "character": 0}, + "end": {"line": 5, "character": 15} + }, + "source_code": "def test_func():" + } + ], + "raw_response": {"some": "raw_data"} + } + mock_request.return_value.text = json.dumps(response_data) + mock_request.return_value.json.return_value = response_data + + request = GetDefinitionRequest( + position=FilePosition(path="test.py", position=Position(line=10, character=8)), + include_raw_response=True, + include_source_code=True + ) + result = client.find_definition(request) + + assert isinstance(result, DefinitionResponse) + assert len(result.definitions) == 1 + assert result.definitions[0].path == "test.py" + assert result.definitions[0].position.line == 5 + assert result.raw_response == {"some": "raw_data"} + assert len(result.source_code_context) == 1 + assert result.source_code_context[0].source_code == "def test_func():" + + # Verify the request was made with correct method, endpoint, and parameters + mock_request.assert_called_once() + args = mock_request.call_args[0] + kwargs = mock_request.call_args[1] + assert args[0] == "POST" + assert args[1] == "/symbol/find-definition" + assert kwargs["json"] == { + "position": { + "path": "test.py", + "position": {"line": 10, "character": 8} + }, + "include_raw_response": True, + "include_source_code": True + } + headers = kwargs["headers"] + assert "content-type" in headers + assert headers["content-type"].lower() == "application/json" + assert "authorization" in headers + assert headers["authorization"].lower() == "bearer test_token" + + +def test_find_references(client, mock_request): + """Test finding references.""" + response_data = { + "references": [ + { + "path": "test.py", + "position": {"line": 15, "character": 4} + } + ], + "context": [ + { + "range": { + "path": "test.py", + "start": {"line": 15, "character": 0}, + "end": {"line": 15, "character": 20} + }, + "source_code": " result = test_func()" + } + ], + "raw_response": {"some": "raw_data"} + } + mock_request.return_value.text = json.dumps(response_data) + mock_request.return_value.json.return_value = response_data + + request = GetReferencesRequest( + identifier_position=FilePosition(path="test.py", position=Position(line=5, character=4)), + include_code_context_lines=2, + include_declaration=True, + include_raw_response=True + ) + result = client.find_references(request) + + assert isinstance(result, ReferencesResponse) + assert len(result.references) == 1 + assert result.references[0].path == "test.py" + assert result.references[0].position.line == 15 + assert result.raw_response == {"some": "raw_data"} + assert len(result.context) == 1 + assert result.context[0].source_code == " result = test_func()" + + # Verify the request was made with correct method, endpoint, and parameters + mock_request.assert_called_once() + args = mock_request.call_args[0] + kwargs = mock_request.call_args[1] + assert args[0] == "POST" + assert args[1] == "/symbol/find-references" + assert kwargs["json"] == { + "identifier_position": { + "path": "test.py", + "position": {"line": 5, "character": 4} + }, + "include_code_context_lines": 2, + "include_declaration": True, + "include_raw_response": True + } + headers = kwargs["headers"] + assert "content-type" in headers + assert headers["content-type"].lower() == "application/json" + assert "authorization" in headers + assert headers["authorization"].lower() == "bearer test_token" + + +def test_list_files(client, mock_request): + """Test listing files.""" + response_data = ["file1.py", "file2.py"] + mock_request.return_value.text = json.dumps(response_data) + mock_request.return_value.json.return_value = response_data + + result = client.list_files() + assert result == ["file1.py", "file2.py"] + + # Verify the request was made with correct method, endpoint, and parameters + mock_request.assert_called_once() + args = mock_request.call_args[0] + kwargs = mock_request.call_args[1] + assert args[0] == "GET" + assert args[1] == "/workspace/list-files" + headers = kwargs["headers"] + assert "content-type" in headers + assert headers["content-type"].lower() == "application/json" + assert "authorization" in headers + assert headers["authorization"].lower() == "bearer test_token" + + +def test_read_source_code(client, mock_request): + """Test reading source code.""" + response_data = { + "source_code": "def test_func():\n pass\n" + } + mock_request.return_value.text = json.dumps(response_data) + mock_request.return_value.json.return_value = response_data + + file_range = FileRange( + path="test.py", + start=Position(line=1, character=0), + end=Position(line=2, character=8) + ) + result = client.read_source_code(file_range) + + assert isinstance(result, ReadSourceCodeResponse) + assert result.source_code == "def test_func():\n pass\n" + + # Verify the request was made with correct method, endpoint, and parameters + mock_request.assert_called_once() + args = mock_request.call_args[0] + kwargs = mock_request.call_args[1] + assert args[0] == "POST" + assert args[1] == "/workspace/read-source-code" + assert kwargs["json"] == { + "range": { + "path": "test.py", + "start": {"line": 1, "character": 0}, + "end": {"line": 2, "character": 8} + } + } + headers = kwargs["headers"] + assert "content-type" in headers + assert headers["content-type"].lower() == "application/json" + assert "authorization" in headers + assert headers["authorization"].lower() == "bearer test_token" + + +def test_check_health(client, mock_request): + """Test health check.""" + response_data = { + "status": "ok", + "languages": ["python", "typescript_javascript", "rust", "cpp", "java", "golang", "php"] + } + mock_request.return_value.text = json.dumps(response_data) + mock_request.return_value.json.return_value = response_data + + result = client.check_health() + assert result["status"] == "ok" + assert set(result["languages"]) == { + "python", "typescript_javascript", "rust", "cpp", "java", "golang", "php" + } + + # Verify the request was made with correct method, endpoint, and parameters + mock_request.assert_called_once() + args = mock_request.call_args[0] + kwargs = mock_request.call_args[1] + assert args[0] == "GET" + assert args[1] == "/health" + headers = kwargs["headers"] + assert headers["content-type"].lower() == "application/json" + assert "authorization" in headers + assert headers["authorization"].lower() == "bearer test_token" + + +def test_error_responses(client, mock_request): + """Test error responses.""" + # Test 400 Bad Request + mock_request.return_value.status_code = 400 + response_data = {"error": "Invalid request"} + mock_request.return_value.text = json.dumps(response_data) + mock_request.return_value.json.return_value = response_data + + with pytest.raises(ValueError) as exc_info: + client.definitions_in_file("test.py") + assert str(exc_info.value) == "Invalid request" + + # Test 500 Internal Server Error + mock_request.return_value.status_code = 500 + response_data = {"error": "Internal server error"} + mock_request.return_value.text = json.dumps(response_data) + mock_request.return_value.json.return_value = response_data + + with pytest.raises(RuntimeError) as exc_info: + client.definitions_in_file("test.py") + assert str(exc_info.value) == "Internal server error" + + +def test_authentication_headers(client, mock_request): + """Test that authentication headers are included in requests.""" + mock_request.return_value.json.return_value = [] + client.definitions_in_file("test.py") + + mock_request.assert_called_once() + headers = mock_request.call_args.kwargs["headers"] + assert "authorization" in headers + assert headers["authorization"].lower() == "bearer test_token" + + +def test_missing_token(): + """Test that missing auth token raises an error.""" + with pytest.raises(ValueError) as exc_info: + Lsproxy(base_url="http://test.url", auth_token="") + assert "token cannot be empty" in str(exc_info.value).lower() + + with pytest.raises(ValueError) as exc_info: + Lsproxy(base_url="http://test.url", auth_token=None) + assert "token cannot be none" in str(exc_info.value).lower() + + +def test_authentication_error(client, mock_request): + """Test authentication error response.""" + mock_request.return_value.status_code = 401 + response_data = {"error": "Invalid or expired token"} + mock_request.return_value.text = json.dumps(response_data) + mock_request.return_value.json.return_value = response_data + + with pytest.raises(ValueError) as exc_info: + client.definitions_in_file("test.py") + assert "invalid or expired token" in str(exc_info.value).lower()