Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 89 additions & 36 deletions src/uipath/_cli/cli_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,25 @@
"eval": eval,
}


class _ServerState:
"""Mutable server state, initialized lazily at server startup."""

def __init__(self) -> None:
self.lock: asyncio.Lock | None = None
self.baseline_env: dict[str, str] | None = None

def init(self) -> None:
"""Must be called inside a running event loop at server startup."""
if self.lock is not None:
return
self.lock = asyncio.Lock()
self.baseline_env = os.environ.copy()


_state = _ServerState()


DEFAULT_PRELOAD_MODULES = [
# Network/async - slowest to load
"pysignalr.client",
Expand Down Expand Up @@ -168,48 +187,78 @@ async def handle_start(request: web.Request) -> web.Response:
status=400,
)

# Save original state
original_cwd = os.getcwd()
original_env = os.environ.copy()

console.info(f"Original cwd: {original_cwd}")
console.info(f"Original cwd: {os.getcwd()}")
console.info(f"Requested working_dir: {working_dir}")

try:
if isinstance(env_vars, dict):
os.environ.update(env_vars)

if working_dir and isinstance(working_dir, str):
os.chdir(working_dir)

result = await asyncio.to_thread(cmd.main, args, standalone_mode=False)
if _state.lock is None or _state.baseline_env is None:
raise RuntimeError("Server state not initialized")

# Validate environmentVariables type early
if env_vars and not isinstance(env_vars, dict):
return web.json_response(
{
"success": True,
"job_key": job_key,
"result": result,
}
)
except SystemExit as e:
exit_code = e.code if isinstance(e.code, int) else 1
return web.json_response(
{
"success": exit_code == 0,
"job_key": job_key,
"error": None if exit_code == 0 else f"Exit code: {exit_code}",
}
)
except Exception as e:
return web.json_response(
{"success": False, "job_key": job_key, "error": str(e)},
status=500,
"success": False,
"error": "Invalid field: 'environmentVariables' must be a dict",
},
status=400,
)
finally:
# Restore original state
os.chdir(original_cwd)
os.environ.clear()
os.environ.update(original_env)

# Serialize command execution to prevent concurrent os.environ mutation
async with _state.lock:
original_cwd = os.getcwd()

try:
# Start from server baseline + request env vars only.
# This ensures no env vars from previous requests leak through.
os.environ.clear()
os.environ.update(_state.baseline_env)
if isinstance(env_vars, dict):
os.environ.update(env_vars)

if working_dir and isinstance(working_dir, str):
try:
os.chdir(working_dir)
except (FileNotFoundError, NotADirectoryError, PermissionError) as e:
return web.json_response(
{
"success": False,
"job_key": job_key,
"error": f"Cannot change to working directory: {e}",
},
status=400,
)

result = await asyncio.to_thread(cmd.main, args, standalone_mode=False)

return web.json_response(
{
"success": True,
"job_key": job_key,
"result": result,
}
)
except SystemExit as e:
exit_code = e.code if isinstance(e.code, int) else 1
return web.json_response(
{
"success": exit_code == 0,
"job_key": job_key,
"error": None if exit_code == 0 else f"Exit code: {exit_code}",
}
)
except Exception as e:
return web.json_response(
{"success": False, "job_key": job_key, "error": str(e)},
status=500,
)
finally:
# Restore to server baseline
try:
os.chdir(original_cwd)
except OSError:
pass
os.environ.clear()
os.environ.update(_state.baseline_env)


ALLOWED_HOSTS = {"127.0.0.1", "localhost", "[::1]"}
Expand Down Expand Up @@ -253,6 +302,8 @@ async def start_unix_server(
ack_socket_path: str, server_socket_path: str | None = None
) -> None:
"""Start Unix domain socket HTTP server."""
_state.init()

server_socket_path = server_socket_path or generate_socket_path()

if os.path.exists(server_socket_path):
Expand Down Expand Up @@ -280,6 +331,8 @@ async def start_unix_server(

async def start_tcp_server(host: str, port: int) -> None:
"""Start TCP HTTP server (Windows fallback)."""
_state.init()

app = create_app()
runner = web.AppRunner(app)
await runner.setup()
Expand Down
163 changes: 163 additions & 0 deletions tests/cli/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,26 @@ async def start_job(
return await response.json()


async def start_job_with_env(
port: int,
job_key: str,
command: str,
args: list[str],
env_vars: dict[str, str],
) -> dict[str, Any]:
"""Start a job on the server with environment variables."""
async with aiohttp.ClientSession() as session:
async with session.post(
f"http://127.0.0.1:{port}/jobs/{job_key}/start",
json={
"command": command,
"args": args,
"environmentVariables": env_vars,
},
) as response:
return await response.json()


class TestServer:
@pytest.fixture
def server_port(self):
Expand Down Expand Up @@ -206,3 +226,146 @@ async def send_with_host(host: str):
status, body = asyncio.run(send_with_host(host))
assert status == 200, f"Host '{host}' should be allowed but got {status}"
assert body == "OK"


class TestServerEnvIsolation:
"""Test that environment variables are isolated between sequential server requests."""

@pytest.fixture
def server_port(self):
import socket

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]

@pytest.fixture
def env_snapshots(self):
return []

@pytest.fixture
def server_with_spy(self, server_port, env_snapshots):
"""Start server with a spy command that captures os.environ."""
import click

from uipath._cli import cli_server

@click.command()
def spy_cmd():
env_snapshots.append(dict(os.environ))

original_commands = cli_server.COMMANDS.copy()
cli_server.COMMANDS["spy"] = spy_cmd

def run_server():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(start_tcp_server("127.0.0.1", server_port))
except asyncio.CancelledError:
pass
finally:
loop.close()

thread = threading.Thread(target=run_server, daemon=True)
thread.start()
time.sleep(0.5)

yield server_port

cli_server.COMMANDS.clear()
cli_server.COMMANDS.update(original_commands)

def test_env_vars_do_not_leak_between_requests(
self, server_with_spy, env_snapshots
):
"""Env vars from request 1 must not be visible in request 2."""
port = server_with_spy

# Request 1: set TEST_VAR_A
asyncio.run(
start_job_with_env(
port,
"job-1",
"spy",
[],
{"TEST_VAR_A": "value_a"},
)
)

# Request 2: set TEST_VAR_B (but NOT TEST_VAR_A)
asyncio.run(
start_job_with_env(
port,
"job-2",
"spy",
[],
{"TEST_VAR_B": "value_b"},
)
)

assert len(env_snapshots) == 2

env_run1 = env_snapshots[0]
env_run2 = env_snapshots[1]

# Run 1 should have TEST_VAR_A
assert env_run1["TEST_VAR_A"] == "value_a"
assert "TEST_VAR_B" not in env_run1

# Run 2 should have TEST_VAR_B but NOT TEST_VAR_A
assert env_run2["TEST_VAR_B"] == "value_b"
assert "TEST_VAR_A" not in env_run2

def test_server_baseline_env_preserved(self, server_with_spy, env_snapshots):
"""Server baseline env vars (like PATH) should be available during command execution."""
from uipath._cli import cli_server

port = server_with_spy

asyncio.run(
start_job_with_env(
port,
"job-1",
"spy",
[],
{"CUSTOM_VAR": "custom_value"},
)
)

assert len(env_snapshots) == 1
env_run = env_snapshots[0]

# Baseline is captured at server start, not import time
baseline = cli_server._state.baseline_env
assert baseline is not None

# Baseline env vars should be present
assert env_run.get("PATH") == baseline.get("PATH")

# Request env var should override/add
assert env_run["CUSTOM_VAR"] == "custom_value"

def test_env_restored_after_request(self, server_with_spy):
"""os.environ should be restored to baseline after each request."""
from uipath._cli import cli_server

port = server_with_spy

asyncio.run(
start_job_with_env(
port,
"job-1",
"spy",
[],
{"SHOULD_NOT_PERSIST": "temporary"},
)
)

baseline = cli_server._state.baseline_env
assert baseline is not None

# After the request, os.environ should match baseline
assert "SHOULD_NOT_PERSIST" not in os.environ
for key in baseline:
assert os.environ.get(key) == baseline[key]