diff --git a/gateway/pyproject.toml b/gateway/pyproject.toml index c40a37b7f5..6c4d406a6f 100644 --- a/gateway/pyproject.toml +++ b/gateway/pyproject.toml @@ -15,7 +15,7 @@ dependencies = [ ] [project.optional-dependencies] -sglang = ["sglang-router==0.2.1"] +sglang = ["sglang-router==0.3.2"] [tool.setuptools.package-data] "dstack.gateway" = [ diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 49513e3211..5939751132 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -39,7 +39,6 @@ SSHKey, ) from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData -from dstack._internal.core.models.routers import AnyRouterConfig from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run from dstack._internal.core.models.volumes import ( Volume, @@ -924,7 +923,7 @@ def get_run_shim_script( ] -def get_gateway_user_data(authorized_key: str, router: Optional[AnyRouterConfig] = None) -> str: +def get_gateway_user_data(authorized_key: str, router: Optional[str] = None) -> str: return get_cloud_config( package_update=True, packages=[ @@ -1036,7 +1035,7 @@ def get_latest_runner_build() -> Optional[str]: return None -def get_dstack_gateway_wheel(build: str, router: Optional[AnyRouterConfig] = None) -> str: +def get_dstack_gateway_wheel(build: str, router: Optional[str] = None) -> str: channel = "release" if settings.DSTACK_RELEASE else "stgn" base_url = f"https://dstack-gateway-downloads.s3.amazonaws.com/{channel}" if build == "latest": @@ -1045,11 +1044,11 @@ def get_dstack_gateway_wheel(build: str, router: Optional[AnyRouterConfig] = Non wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" # Build package spec with extras if router is specified if router: - return f"dstack-gateway[{router.type}] @ {wheel}" + return f"dstack-gateway[{router}] @ {wheel}" return f"dstack-gateway @ {wheel}" -def get_dstack_gateway_commands(router: Optional[AnyRouterConfig] = None) -> List[str]: +def get_dstack_gateway_commands(router: Optional[str] = None) -> List[str]: build = get_dstack_runner_version() or "latest" gateway_package = get_dstack_gateway_wheel(build, router) return [ diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index 51abddc70c..5223cdaa7c 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -66,7 +66,6 @@ ) from dstack._internal.core.models.placement import PlacementGroup from dstack._internal.core.models.resources import CPUSpec, GPUSpec -from dstack._internal.core.models.routers import AnyRouterConfig from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run from dstack._internal.core.models.volumes import Volume from dstack._internal.utils.common import get_or_error @@ -863,9 +862,7 @@ def _wait_for_load_balancer_address( time.sleep(1) -def _get_gateway_commands( - authorized_keys: List[str], router: Optional[AnyRouterConfig] = None -) -> List[str]: +def _get_gateway_commands(authorized_keys: List[str], router: Optional[str] = None) -> List[str]: authorized_keys_content = "\n".join(authorized_keys).strip() gateway_commands = " && ".join(get_dstack_gateway_commands(router=router)) quoted_gateway_commands = shlex.quote(gateway_commands) diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index db965a7697..f66053831a 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -28,6 +28,7 @@ parse_off_duration, ) from dstack._internal.core.models.resources import Range, ResourcesSpec +from dstack._internal.core.models.routers import AnyRouterConfig from dstack._internal.core.models.services import AnyModel, OpenAIChatModel from dstack._internal.core.models.unix import UnixUser from dstack._internal.core.models.volumes import MountPoint, VolumeConfiguration, parse_mount_point @@ -888,6 +889,14 @@ class ServiceConfigurationParams(CoreModel): ) ), ] = None + router: Annotated[ + Optional[AnyRouterConfig], + Field( + description=( + "Router configuration for the service. Requires a gateway with matching router enabled. " + ), + ), + ] = None @validator("port") def convert_port(cls, v) -> PortMapping: diff --git a/src/dstack/_internal/core/models/gateways.py b/src/dstack/_internal/core/models/gateways.py index b342c0a73b..e9581b83f0 100644 --- a/src/dstack/_internal/core/models/gateways.py +++ b/src/dstack/_internal/core/models/gateways.py @@ -8,7 +8,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import CoreModel -from dstack._internal.core.models.routers import AnyRouterConfig +from dstack._internal.core.models.routers import RouterType from dstack._internal.utils.tags import tags_validator @@ -63,8 +63,8 @@ class GatewayConfiguration(CoreModel): ), ] = None router: Annotated[ - Optional[AnyRouterConfig], - Field(description="The router configuration"), + Optional[RouterType], + Field(description="The router type enabled on this gateway. E.g. 'sglang'."), ] = None domain: Annotated[ Optional[str], Field(description="The gateway domain, e.g. `example.com`") @@ -134,7 +134,7 @@ class GatewayComputeConfiguration(CoreModel): ssh_key_pub: str certificate: Optional[AnyGatewayCertificate] = None tags: Optional[Dict[str, str]] = None - router: Optional[AnyRouterConfig] = None + router: Optional[RouterType] = None class GatewayProvisioningData(CoreModel): diff --git a/src/dstack/_internal/core/models/routers.py b/src/dstack/_internal/core/models/routers.py index e07631e12e..e42cd9976d 100644 --- a/src/dstack/_internal/core/models/routers.py +++ b/src/dstack/_internal/core/models/routers.py @@ -19,6 +19,10 @@ class SGLangRouterConfig(CoreModel): description="The routing policy. Options: `random`, `round_robin`, `cache_aware`, `power_of_two`" ), ] = "cache_aware" + pd_disaggregation: Annotated[ + bool, + Field(description="Enable PD disaggregation mode for the SGLang router"), + ] = False AnyRouterConfig = SGLangRouterConfig diff --git a/src/dstack/_internal/proxy/gateway/repo/state_v1.py b/src/dstack/_internal/proxy/gateway/repo/state_v1.py index b49550ef5f..2f4ec79418 100644 --- a/src/dstack/_internal/proxy/gateway/repo/state_v1.py +++ b/src/dstack/_internal/proxy/gateway/repo/state_v1.py @@ -98,6 +98,7 @@ def parse_replica(replica: dict) -> Replica: ssh_destination=replica["ssh_host"], ssh_port=replica["ssh_port"], ssh_proxy=ssh_proxy, + internal_ip=replica.get("internal_ip"), ) diff --git a/src/dstack/_internal/proxy/gateway/routers/registry.py b/src/dstack/_internal/proxy/gateway/routers/registry.py index dd4f63f325..c5f4cf8a1a 100644 --- a/src/dstack/_internal/proxy/gateway/routers/registry.py +++ b/src/dstack/_internal/proxy/gateway/routers/registry.py @@ -80,6 +80,7 @@ async def register_replica( ssh_proxy=body.ssh_proxy, ssh_head_proxy=body.ssh_head_proxy, ssh_head_proxy_private_key=body.ssh_head_proxy_private_key, + internal_ip=body.internal_ip, repo=repo, nginx=nginx, service_conn_pool=service_conn_pool, diff --git a/src/dstack/_internal/proxy/gateway/schemas/registry.py b/src/dstack/_internal/proxy/gateway/schemas/registry.py index 53a29f68ca..967e9d9960 100644 --- a/src/dstack/_internal/proxy/gateway/schemas/registry.py +++ b/src/dstack/_internal/proxy/gateway/schemas/registry.py @@ -56,6 +56,7 @@ class RegisterReplicaRequest(BaseModel): ssh_proxy: Optional[SSHConnectionParams] ssh_head_proxy: Optional[SSHConnectionParams] ssh_head_proxy_private_key: Optional[str] + internal_ip: Optional[str] = None class RegisterEntrypointRequest(BaseModel): diff --git a/src/dstack/_internal/proxy/gateway/services/model_routers/base.py b/src/dstack/_internal/proxy/gateway/services/model_routers/base.py index 867591ca13..a9b54347e4 100644 --- a/src/dstack/_internal/proxy/gateway/services/model_routers/base.py +++ b/src/dstack/_internal/proxy/gateway/services/model_routers/base.py @@ -89,3 +89,25 @@ def update_replicas(self, replica_urls: List[str]) -> None: Exception: If updating replicas fails. """ ... + + def add_worker_to_router( + self, + url: str, + worker_type: str = "regular", + bootstrap_port: Optional[int] = None, + ) -> bool: + """Add a worker to the router. + + Args: + url: Worker URL (e.g. http://10.0.5.134:8000). + worker_type: Type of worker ("regular", "prefill", or "decode"). + bootstrap_port: Bootstrap port for prefill workers (optional). + + Returns: + True if the worker was accepted, False otherwise. + """ + raise NotImplementedError + + def register_worker(self, url: str) -> bool: + """Register worker with one attempt (no polling). Returns True if ready and added.""" + raise NotImplementedError diff --git a/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py b/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py index c3a0dfaae9..5214bb8e93 100644 --- a/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py +++ b/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py @@ -2,7 +2,6 @@ import subprocess import sys import time -import urllib.parse from typing import List, Optional import httpx @@ -68,6 +67,8 @@ def start(self) -> None: "--policy", self.config.policy, ] + if self.config.pd_disaggregation: + cmd.append("--pd-disaggregation") subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) @@ -174,7 +175,7 @@ def update_replicas(self, replica_urls: List[str]) -> None: # Add workers for worker_url in sorted(workers_to_add): - success = self._add_worker_to_router(worker_url) + success = self.register_worker(worker_url) if not success: logger.warning("Failed to add worker %s, continuing with others", worker_url) @@ -197,9 +198,16 @@ def _get_router_workers(self) -> List[dict]: logger.exception("Error getting sglang router workers") return [] - def _add_worker_to_router(self, worker_url: str) -> bool: + def add_worker_to_router( + self, + url: str, + worker_type: str = "regular", + bootstrap_port: Optional[int] = None, + ) -> bool: try: - payload = {"url": worker_url, "worker_type": "regular"} + payload: dict = {"url": url, "worker_type": worker_type} + if bootstrap_port is not None: + payload["bootstrap_port"] = bootstrap_port with httpx.Client(timeout=5.0) as client: response = client.post( f"http://{self.context.host}:{self.context.port}/workers", @@ -209,8 +217,9 @@ def _add_worker_to_router(self, worker_url: str) -> bool: response_data = response.json() if response_data.get("status") == "accepted": logger.info( - "Worker %s accepted by sglang router on port %s", - worker_url, + "Worker %s (type=%s) accepted by sglang router on port %s", + url, + worker_type, self.context.port, ) return True @@ -224,21 +233,65 @@ def _add_worker_to_router(self, worker_url: str) -> bool: else: logger.error( "Failed to add worker %s: status %d, %s", - worker_url, + url, response.status_code, response.text, ) return False except Exception: - logger.exception("Error adding worker %s", worker_url) + logger.exception("Error adding worker %s", url) + return False + + def register_worker(self, url: str) -> bool: + server_info_url = f"{url}/server_info" + try: + with httpx.Client(timeout=10) as client: + resp = client.get(server_info_url) + if resp.status_code != 200: + return False + data = resp.json() + if data.get("status") != "ready": + return False + disaggregation_mode = data.get("disaggregation_mode", "") + if disaggregation_mode == "prefill": + worker_type = "prefill" + bootstrap_port = data.get("disaggregation_bootstrap_port") + elif disaggregation_mode == "decode": + worker_type = "decode" + bootstrap_port = None + else: + worker_type = "regular" + bootstrap_port = None + logger.info( + "Registering worker %s (type=%s)", + url, + worker_type, + ) + return self.add_worker_to_router( + url, + worker_type, + bootstrap_port, + ) + except Exception: + logger.exception("Error registering worker %s", url) return False def _remove_worker_from_router(self, worker_url: str) -> bool: try: - encoded_url = urllib.parse.quote(worker_url, safe="") + current_workers = self._get_router_workers() + worker_id = None + for worker in current_workers: + url = worker.get("url") + if url and isinstance(url, str) and url == worker_url: + worker_id = worker.get("id") + if worker_id and isinstance(worker_id, str): + break + if not worker_id: + logger.exception("No worker id found for url %s", worker_url) + return False with httpx.Client(timeout=5.0) as client: response = client.delete( - f"http://{self.context.host}:{self.context.port}/workers/{encoded_url}" + f"http://{self.context.host}:{self.context.port}/workers/{worker_id}" ) if response.status_code == 202: response_data = response.json() diff --git a/src/dstack/_internal/proxy/gateway/services/nginx.py b/src/dstack/_internal/proxy/gateway/services/nginx.py index bbda92d91b..d79b0c7932 100644 --- a/src/dstack/_internal/proxy/gateway/services/nginx.py +++ b/src/dstack/_internal/proxy/gateway/services/nginx.py @@ -5,6 +5,7 @@ from asyncio import Lock from pathlib import Path from typing import Dict, Optional +from urllib.parse import urlparse import jinja2 from pydantic import BaseModel @@ -18,6 +19,7 @@ RouterContext, get_router, ) +from dstack._internal.proxy.lib import models from dstack._internal.proxy.lib.errors import ProxyError, UnexpectedProxyError from dstack._internal.utils.common import run_async from dstack._internal.utils.logging import get_logger @@ -43,6 +45,8 @@ def render(self) -> str: class ReplicaConfig(BaseModel): id: str socket: Path + port: int + internal_ip: Optional[str] = None class LimitReqZoneConfig(BaseModel): @@ -95,7 +99,7 @@ def __init__(self, conf_dir: Path = Path("/etc/nginx/sites-enabled")) -> None: self._next_router_port: int = self._ROUTER_PORT_MIN # Tracking of worker ports to avoid conflicts across router instances self._allocated_worker_ports: set[int] = set() - self._domain_to_worker_ports: Dict[str, list[int]] = {} + self._domain_to_worker_urls: Dict[str, list[str]] = {} self._next_worker_port: int = self._WORKER_PORT_MIN async def register(self, conf: SiteConfig, acme: ACMESettings) -> None: @@ -144,33 +148,40 @@ async def register(self, conf: SiteConfig, acme: ACMESettings) -> None: del self._domain_to_router[conf.domain] raise - allocated_ports = self._allocate_worker_ports(len(conf.replicas)) - replica_urls = [ - f"http://{router.context.host}:{port}" for port in allocated_ports - ] - - # Write router workers config - try: + if conf.router.pd_disaggregation: + # PD path: replica_urls from internal_ip (router talks directly to workers) + if any(not r.internal_ip for r in conf.replicas): + raise ProxyError( + "PD disaggregation requires internal IP for all replicas." + ) + replica_urls = [ + f"http://{replica.internal_ip}:{replica.port}" + for replica in conf.replicas + ] + self._domain_to_worker_urls[conf.domain] = replica_urls + else: + # Non-PD path: allocate gateway-local ports, nginx proxies to replica sockets + allocated_ports = self._allocate_worker_ports(len(conf.replicas)) + replica_urls = [ + f"http://{router.context.host}:{port}" for port in allocated_ports + ] if conf.replicas: - await run_async(self.write_router_workers_conf, conf, allocated_ports) - # Discard old worker ports if domain already has allocated ports (required for scaling case) - if conf.domain in self._domain_to_worker_ports: - old_worker_ports = self._domain_to_worker_ports[conf.domain] - for port in old_worker_ports: - self._allocated_worker_ports.discard(port) - self._domain_to_worker_ports[conf.domain] = allocated_ports - except Exception as e: - logger.exception( - "write_router_workers_conf failed for domain=%s: %s", conf.domain, e - ) - raise + await run_async( + self.write_router_workers_conf, + conf, + allocated_ports, + ) + if conf.domain in self._domain_to_worker_urls: + self._discard_ports(self._domain_to_worker_urls[conf.domain]) + self._domain_to_worker_urls[conf.domain] = replica_urls - # Update replicas to router (actual HTTP API calls to add workers) try: await run_async(router.update_replicas, replica_urls) except Exception as e: logger.exception( - "Failed to add replicas to router for domain=%s: %s", conf.domain, e + "Failed to add replicas to router for domain=%s: %s", + conf.domain, + e, ) raise @@ -178,7 +189,7 @@ async def register(self, conf: SiteConfig, acme: ACMESettings) -> None: logger.info("Registered %s domain %s", conf.type, conf.domain) - async def unregister(self, domain: str) -> None: + async def unregister(self, domain: str, service: models.Service) -> None: logger.debug("Unregistering domain %s", domain) conf_path = self._conf_dir / self.get_config_name(domain) if not conf_path.exists(): @@ -189,12 +200,16 @@ async def unregister(self, domain: str) -> None: if domain in self._domain_to_router: router = self._domain_to_router[domain] # Remove all workers for this domain - if domain in self._domain_to_worker_ports: - worker_ports = self._domain_to_worker_ports[domain] - replica_urls = [ - f"http://{router.context.host}:{port}" for port in worker_ports - ] - await run_async(router.remove_replicas, replica_urls) + if domain in self._domain_to_worker_urls: + worker_urls = self._domain_to_worker_urls[domain] + await run_async(router.remove_replicas, worker_urls) + pd_disaggregation = ( + service.router.pd_disaggregation if service.router else False + ) + if not pd_disaggregation: + self._discard_ports(worker_urls) + del self._domain_to_worker_urls[domain] + logger.debug("Removed worker URLs for domain %s", domain) # Stop and kill the router await run_async(router.stop) # Remove from mappings @@ -203,14 +218,6 @@ async def unregister(self, domain: str) -> None: del self._router_port_to_domain[router_port] del self._domain_to_router[domain] - # Discard worker ports for this domain - if domain in self._domain_to_worker_ports: - worker_ports = self._domain_to_worker_ports[domain] - for port in worker_ports: - self._allocated_worker_ports.discard(port) - del self._domain_to_worker_ports[domain] - logger.debug("Freed worker ports %s for domain %s", worker_ports, domain) - # Remove workers config file workers_conf_path = self._conf_dir / f"router-workers.{domain}.conf" if workers_conf_path.exists(): @@ -403,6 +410,12 @@ def _allocate_worker_ports(self, num_ports: int) -> list[int]: return allocated + def _discard_ports(self, urls: list[str]) -> None: + for u in urls: + parsed = urlparse(u) + if parsed.port is not None and parsed.port in self._allocated_worker_ports: + self._allocated_worker_ports.discard(parsed.port) + def write_global_conf(self) -> None: conf = read_package_resource("00-log-format.conf") self.write_conf(conf, "00-log-format.conf") diff --git a/src/dstack/_internal/proxy/gateway/services/registry.py b/src/dstack/_internal/proxy/gateway/services/registry.py index 636d8c38ec..fd523e8d12 100644 --- a/src/dstack/_internal/proxy/gateway/services/registry.py +++ b/src/dstack/_internal/proxy/gateway/services/registry.py @@ -116,7 +116,7 @@ async def unregister_service( ids=(r.id for r in service.replicas), service_conn_pool=service_conn_pool, ) - await nginx.unregister(service.domain_safe) + await nginx.unregister(service.domain_safe, service) await repo.delete_models_by_run(project_name, run_name) await repo.delete_service(project_name, run_name) @@ -136,6 +136,7 @@ async def register_replica( repo: GatewayProxyRepo, nginx: Nginx, service_conn_pool: ServiceConnectionPool, + internal_ip: Optional[str] = None, ) -> None: replica = models.Replica( id=replica_id, @@ -145,6 +146,7 @@ async def register_replica( ssh_proxy=ssh_proxy, ssh_head_proxy=ssh_head_proxy, ssh_head_proxy_private_key=ssh_head_proxy_private_key, + internal_ip=internal_ip, ) async with lock: @@ -258,7 +260,12 @@ async def apply_service( service, repo, service_conn_pool ) replica_configs = [ - ReplicaConfig(id=replica.id, socket=conn.app_socket_path) + ReplicaConfig( + id=replica.id, + socket=conn.app_socket_path, + port=replica.app_port, + internal_ip=replica.internal_ip, + ) for replica, conn in replica_conns.items() ] service_config = await get_nginx_service_config(service, replica_configs) diff --git a/src/dstack/_internal/proxy/lib/models.py b/src/dstack/_internal/proxy/lib/models.py index bf37e0b5aa..9ae6b4d2ad 100644 --- a/src/dstack/_internal/proxy/lib/models.py +++ b/src/dstack/_internal/proxy/lib/models.py @@ -27,6 +27,7 @@ class Replica(ImmutableModel): # Optional outer proxy, a head node/bastion ssh_head_proxy: Optional[SSHConnectionParams] = None ssh_head_proxy_private_key: Optional[str] = None + internal_ip: Optional[str] = None class IPAddressPartitioningKey(ImmutableModel): diff --git a/src/dstack/_internal/server/services/gateways/client.py b/src/dstack/_internal/server/services/gateways/client.py index d4f1c831e8..e68b874728 100644 --- a/src/dstack/_internal/server/services/gateways/client.py +++ b/src/dstack/_internal/server/services/gateways/client.py @@ -99,6 +99,7 @@ async def register_replica( assert jpd is not None assert jpd.hostname is not None assert jpd.ssh_port is not None + payload["internal_ip"] = jpd.internal_ip if not jpd.dockerized: payload.update( { diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index 06aa5b0ef0..201351e50d 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -26,6 +26,7 @@ ) from dstack._internal.core.models.gateways import GatewayConfiguration, GatewayStatus from dstack._internal.core.models.instances import SSHConnectionParams +from dstack._internal.core.models.routers import RouterType from dstack._internal.core.models.runs import JobSpec, Run, RunSpec, ServiceModelSpec, ServiceSpec from dstack._internal.server import settings from dstack._internal.server.models import GatewayModel, JobModel, ProjectModel, RunModel @@ -91,8 +92,17 @@ async def _register_service_in_gateway( raise ServerClientError("Gateway status is not running") gateway_configuration = get_gateway_configuration(gateway) + if ( + run_spec.configuration.router is not None + and run_spec.configuration.router.type == RouterType.SGLANG + ): + if gateway_configuration.router != RouterType.SGLANG: + raise ServerClientError( + f"Service requires a SGLang gateway but gateway '{gateway.name}' " + "does not have the SGLang router configured." + ) service_https = _get_service_https(run_spec, gateway_configuration) - router = gateway_configuration.router + router = run_spec.configuration.router service_protocol = "https" if service_https else "http" if service_https and gateway_configuration.certificate is None: @@ -152,6 +162,14 @@ async def _register_service_in_gateway( def _register_service_in_server(run_model: RunModel, run_spec: RunSpec) -> ServiceSpec: assert run_spec.configuration.type == "service" + if ( + run_spec.configuration.router is not None + and run_spec.configuration.router.type == RouterType.SGLANG + ): + raise ServerClientError( + "Service with SGLang router configuration requires a gateway. " + "Please configure a gateway with the SGLang router enabled." + ) if run_spec.configuration.https != SERVICE_HTTPS_DEFAULT: # Note: if the user sets `https: `, it will be ignored silently # TODO: in 0.19, make `https` Optional to be able to tell if it was set or omitted