From 44ecd1b7a5b64916dc436f5640995dc363d91b2e Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Fri, 30 Jan 2026 14:22:54 +0545 Subject: [PATCH 1/6] Initial PD disaggregation implementation Test2 Internal IP Test Add worker with internal_ip Check status and register Add Status Ready Log Add Prefill-Decode Add PD to dstack Test register worker without poll Add router config in service config Update remove worker Clean Up router code Clean Up Further Cleanup --- gateway/pyproject.toml | 4 +- .../_internal/core/backends/base/compute.py | 3 +- .../core/backends/kubernetes/compute.py | 5 ++ .../_internal/core/models/configurations.py | 9 +++ src/dstack/_internal/core/models/routers.py | 21 ++++- .../_internal/proxy/gateway/repo/state_v1.py | 1 + .../proxy/gateway/routers/registry.py | 1 + .../proxy/gateway/schemas/registry.py | 1 + .../gateway/services/model_routers/base.py | 24 +++++- .../gateway/services/model_routers/sglang.py | 79 +++++++++++++++--- .../_internal/proxy/gateway/services/nginx.py | 80 ++++++++++--------- .../proxy/gateway/services/registry.py | 9 ++- src/dstack/_internal/proxy/lib/models.py | 1 + .../server/services/gateways/client.py | 1 + .../server/services/services/__init__.py | 6 +- 15 files changed, 189 insertions(+), 56 deletions(-) diff --git a/gateway/pyproject.toml b/gateway/pyproject.toml index c40a37b7f5..243f37b939 100644 --- a/gateway/pyproject.toml +++ b/gateway/pyproject.toml @@ -11,11 +11,11 @@ requires-python = ">=3.10" dynamic = ["version"] dependencies = [ # release builds of dstack-gateway depend on a PyPI version of dstack instead - "dstack[gateway] @ https://github.com/dstackai/dstack/archive/refs/heads/master.tar.gz", + "dstack[gateway] @ https://github.com/Bihan/dstack/archive/refs/heads/pd_design_test.tar.gz", ] [project.optional-dependencies] -sglang = ["sglang-router==0.2.1"] +sglang = ["sglang-router==0.3.2"] [tool.setuptools.package-data] "dstack.gateway" = [ diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 49513e3211..b07c510e9c 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -1042,7 +1042,8 @@ def get_dstack_gateway_wheel(build: str, router: Optional[AnyRouterConfig] = Non if build == "latest": build = _fetch_version(f"{base_url}/latest-version") or "latest" logger.debug("Found the latest gateway build: %s", build) - wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" + # wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" + wheel = "https://bihan-test-bucket.s3.eu-west-1.amazonaws.com/dstack_gateway-0.0.1-py3-none-any.whl" # Build package spec with extras if router is specified if router: return f"dstack-gateway[{router.type}] @ {wheel}" diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index 7f8ef9123f..f25b87d9cc 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -330,6 +330,11 @@ def update_provisioning_data( if not pod_ip: return provisioning_data.internal_ip = pod_ip + logger.debug( + "Replica pod %s internal_ip=%s (cluster_ip will be set from Service)", + provisioning_data.instance_id, + pod_ip, + ) service = self.api.read_namespaced_service( name=_get_pod_service_name(provisioning_data.instance_id), namespace=self.config.namespace, diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 9c8b40b6ec..7055b14073 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -28,6 +28,7 @@ parse_off_duration, ) from dstack._internal.core.models.resources import Range, ResourcesSpec +from dstack._internal.core.models.routers import AnyRouterConfig from dstack._internal.core.models.services import AnyModel, OpenAIChatModel from dstack._internal.core.models.unix import UnixUser from dstack._internal.core.models.volumes import MountPoint, VolumeConfiguration, parse_mount_point @@ -885,6 +886,14 @@ class ServiceConfigurationParams(CoreModel): ) ), ] = None + router_config: Annotated[ + Optional[AnyRouterConfig], + Field( + description=( + "Router configuration for the service (e.g. routing policy and pd_disaggregation). " + ), + ), + ] = None @validator("port") def convert_port(cls, v) -> PortMapping: diff --git a/src/dstack/_internal/core/models/routers.py b/src/dstack/_internal/core/models/routers.py index e07631e12e..d80ad35873 100644 --- a/src/dstack/_internal/core/models/routers.py +++ b/src/dstack/_internal/core/models/routers.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Literal +from typing import Literal, Optional from pydantic import Field from typing_extensions import Annotated @@ -19,6 +19,25 @@ class SGLangRouterConfig(CoreModel): description="The routing policy. Options: `random`, `round_robin`, `cache_aware`, `power_of_two`" ), ] = "cache_aware" + pd_disaggregation: Annotated[ + bool, + Field(description="Enable PD disaggregation mode for the SGLang router"), + ] = False AnyRouterConfig = SGLangRouterConfig + + +def merge_router_config_for_service( + gateway_router: Optional[AnyRouterConfig], + service_router_config: Optional[AnyRouterConfig], +) -> Optional[AnyRouterConfig]: + if gateway_router is None: + return None + if service_router_config is None: + return gateway_router + return SGLangRouterConfig( + type=gateway_router.type, + policy=service_router_config.policy, + pd_disaggregation=service_router_config.pd_disaggregation, + ) diff --git a/src/dstack/_internal/proxy/gateway/repo/state_v1.py b/src/dstack/_internal/proxy/gateway/repo/state_v1.py index b49550ef5f..2f4ec79418 100644 --- a/src/dstack/_internal/proxy/gateway/repo/state_v1.py +++ b/src/dstack/_internal/proxy/gateway/repo/state_v1.py @@ -98,6 +98,7 @@ def parse_replica(replica: dict) -> Replica: ssh_destination=replica["ssh_host"], ssh_port=replica["ssh_port"], ssh_proxy=ssh_proxy, + internal_ip=replica.get("internal_ip"), ) diff --git a/src/dstack/_internal/proxy/gateway/routers/registry.py b/src/dstack/_internal/proxy/gateway/routers/registry.py index dd4f63f325..c5f4cf8a1a 100644 --- a/src/dstack/_internal/proxy/gateway/routers/registry.py +++ b/src/dstack/_internal/proxy/gateway/routers/registry.py @@ -80,6 +80,7 @@ async def register_replica( ssh_proxy=body.ssh_proxy, ssh_head_proxy=body.ssh_head_proxy, ssh_head_proxy_private_key=body.ssh_head_proxy_private_key, + internal_ip=body.internal_ip, repo=repo, nginx=nginx, service_conn_pool=service_conn_pool, diff --git a/src/dstack/_internal/proxy/gateway/schemas/registry.py b/src/dstack/_internal/proxy/gateway/schemas/registry.py index 53a29f68ca..967e9d9960 100644 --- a/src/dstack/_internal/proxy/gateway/schemas/registry.py +++ b/src/dstack/_internal/proxy/gateway/schemas/registry.py @@ -56,6 +56,7 @@ class RegisterReplicaRequest(BaseModel): ssh_proxy: Optional[SSHConnectionParams] ssh_head_proxy: Optional[SSHConnectionParams] ssh_head_proxy_private_key: Optional[str] + internal_ip: Optional[str] = None class RegisterEntrypointRequest(BaseModel): diff --git a/src/dstack/_internal/proxy/gateway/services/model_routers/base.py b/src/dstack/_internal/proxy/gateway/services/model_routers/base.py index 867591ca13..129562e497 100644 --- a/src/dstack/_internal/proxy/gateway/services/model_routers/base.py +++ b/src/dstack/_internal/proxy/gateway/services/model_routers/base.py @@ -79,7 +79,7 @@ def remove_replicas(self, replica_urls: List[str]) -> None: ... @abstractmethod - def update_replicas(self, replica_urls: List[str]) -> None: + async def update_replicas(self, replica_urls: List[str]) -> None: """Update replicas for service, replacing the current set. Args: @@ -89,3 +89,25 @@ def update_replicas(self, replica_urls: List[str]) -> None: Exception: If updating replicas fails. """ ... + + def add_worker_to_router( + self, + url: str, + worker_type: str = "regular", + bootstrap_port: Optional[int] = None, + ) -> bool: + """Add a worker to the router. + + Args: + url: Worker URL (e.g. http://10.0.5.134:8000). + worker_type: Type of worker ("regular", "prefill", or "decode"). + bootstrap_port: Bootstrap port for prefill workers (optional). + + Returns: + True if the worker was accepted, False otherwise. + """ + raise NotImplementedError + + async def register_worker(self, url: str) -> bool: + """Register worker with one attempt (no polling). Returns True if ready and added.""" + raise NotImplementedError diff --git a/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py b/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py index c3a0dfaae9..7d976b0059 100644 --- a/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py +++ b/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py @@ -2,7 +2,6 @@ import subprocess import sys import time -import urllib.parse from typing import List, Optional import httpx @@ -10,6 +9,7 @@ from dstack._internal.core.models.routers import RouterType, SGLangRouterConfig from dstack._internal.proxy.lib.errors import UnexpectedProxyError +from dstack._internal.utils.common import run_async from dstack._internal.utils.logging import get_logger from .base import Router, RouterContext @@ -68,6 +68,8 @@ def start(self) -> None: "--policy", self.config.policy, ] + if self.config.pd_disaggregation: + cmd.append("--pd-disaggregation") subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) @@ -140,7 +142,7 @@ def remove_replicas(self, replica_urls: List[str]) -> None: for replica_url in replica_urls: self._remove_worker_from_router(replica_url) - def update_replicas(self, replica_urls: List[str]) -> None: + async def update_replicas(self, replica_urls: List[str]) -> None: """Update replicas for service, replacing the current set.""" # Query router to get current worker URLs current_workers = self._get_router_workers() @@ -172,9 +174,9 @@ def update_replicas(self, replica_urls: List[str]) -> None: self.context.port, ) - # Add workers + # Add workers: poll /server_info and register with discovered type for worker_url in sorted(workers_to_add): - success = self._add_worker_to_router(worker_url) + success = await self.register_worker(worker_url) if not success: logger.warning("Failed to add worker %s, continuing with others", worker_url) @@ -197,9 +199,16 @@ def _get_router_workers(self) -> List[dict]: logger.exception("Error getting sglang router workers") return [] - def _add_worker_to_router(self, worker_url: str) -> bool: + def add_worker_to_router( + self, + url: str, + worker_type: str = "regular", + bootstrap_port: Optional[int] = None, + ) -> bool: try: - payload = {"url": worker_url, "worker_type": "regular"} + payload: dict = {"url": url, "worker_type": worker_type} + if bootstrap_port is not None: + payload["bootstrap_port"] = bootstrap_port with httpx.Client(timeout=5.0) as client: response = client.post( f"http://{self.context.host}:{self.context.port}/workers", @@ -209,8 +218,9 @@ def _add_worker_to_router(self, worker_url: str) -> bool: response_data = response.json() if response_data.get("status") == "accepted": logger.info( - "Worker %s accepted by sglang router on port %s", - worker_url, + "Worker %s (type=%s) accepted by sglang router on port %s", + url, + worker_type, self.context.port, ) return True @@ -224,21 +234,66 @@ def _add_worker_to_router(self, worker_url: str) -> bool: else: logger.error( "Failed to add worker %s: status %d, %s", - worker_url, + url, response.status_code, response.text, ) return False except Exception: - logger.exception("Error adding worker %s", worker_url) + logger.exception("Error adding worker %s", url) + return False + + async def register_worker(self, url: str) -> bool: + server_info_url = f"{url}/server_info" + try: + async with httpx.AsyncClient(timeout=10) as client: + resp = await client.get(server_info_url) + if resp.status_code != 200: + return False + data = resp.json() + if data.get("status") != "ready": + return False + disaggregation_mode = data.get("disaggregation_mode", "") + if disaggregation_mode == "prefill": + worker_type = "prefill" + bootstrap_port = data.get("disaggregation_bootstrap_port") + elif disaggregation_mode == "decode": + worker_type = "decode" + bootstrap_port = None + else: + worker_type = "regular" + bootstrap_port = None + logger.info( + "Registering worker %s (type=%s)", + url, + worker_type, + ) + return await run_async( + self.add_worker_to_router, + url, + worker_type, + bootstrap_port, + ) + except Exception: + logger.exception("Error registering worker %s", url) return False def _remove_worker_from_router(self, worker_url: str) -> bool: try: - encoded_url = urllib.parse.quote(worker_url, safe="") + current_workers = self._get_router_workers() + worker_id = None + for worker in current_workers: + url = worker.get("url") + if url and isinstance(url, str) and url == worker_url: + worker_id = worker.get("id") + if worker_id and isinstance(worker_id, str): + break + if not worker_id: + logger.exception("No worker id found for url %s", worker_url) + return False with httpx.Client(timeout=5.0) as client: response = client.delete( - f"http://{self.context.host}:{self.context.port}/workers/{encoded_url}" + f"http://{self.context.host}:{self.context.port}/workers/{worker_id}" ) if response.status_code == 202: response_data = response.json() diff --git a/src/dstack/_internal/proxy/gateway/services/nginx.py b/src/dstack/_internal/proxy/gateway/services/nginx.py index bbda92d91b..85473923f7 100644 --- a/src/dstack/_internal/proxy/gateway/services/nginx.py +++ b/src/dstack/_internal/proxy/gateway/services/nginx.py @@ -5,6 +5,7 @@ from asyncio import Lock from pathlib import Path from typing import Dict, Optional +from urllib.parse import urlparse import jinja2 from pydantic import BaseModel @@ -43,6 +44,8 @@ def render(self) -> str: class ReplicaConfig(BaseModel): id: str socket: Path + port: int + internal_ip: Optional[str] = None class LimitReqZoneConfig(BaseModel): @@ -95,7 +98,8 @@ def __init__(self, conf_dir: Path = Path("/etc/nginx/sites-enabled")) -> None: self._next_router_port: int = self._ROUTER_PORT_MIN # Tracking of worker ports to avoid conflicts across router instances self._allocated_worker_ports: set[int] = set() - self._domain_to_worker_ports: Dict[str, list[int]] = {} + # Domain -> list of worker URLs (used for remove_replicas; non-PD URLs are gateway-local) + self._domain_to_worker_urls: Dict[str, list[str]] = {} self._next_worker_port: int = self._WORKER_PORT_MIN async def register(self, conf: SiteConfig, acme: ACMESettings) -> None: @@ -144,33 +148,37 @@ async def register(self, conf: SiteConfig, acme: ACMESettings) -> None: del self._domain_to_router[conf.domain] raise - allocated_ports = self._allocate_worker_ports(len(conf.replicas)) - replica_urls = [ - f"http://{router.context.host}:{port}" for port in allocated_ports - ] - - # Write router workers config - try: + if conf.router.pd_disaggregation: + # PD path: replica_urls from internal_ip (router talks directly to workers) + replica_urls = [ + f"http://{replica.internal_ip}:{replica.port}" + for replica in conf.replicas + if replica.internal_ip + ] + self._domain_to_worker_urls[conf.domain] = replica_urls + else: + # Non-PD path: allocate gateway-local ports, nginx proxies to replica sockets + allocated_ports = self._allocate_worker_ports(len(conf.replicas)) + replica_urls = [ + f"http://{router.context.host}:{port}" for port in allocated_ports + ] if conf.replicas: - await run_async(self.write_router_workers_conf, conf, allocated_ports) - # Discard old worker ports if domain already has allocated ports (required for scaling case) - if conf.domain in self._domain_to_worker_ports: - old_worker_ports = self._domain_to_worker_ports[conf.domain] - for port in old_worker_ports: - self._allocated_worker_ports.discard(port) - self._domain_to_worker_ports[conf.domain] = allocated_ports - except Exception as e: - logger.exception( - "write_router_workers_conf failed for domain=%s: %s", conf.domain, e - ) - raise + await run_async( + self.write_router_workers_conf, + conf, + allocated_ports, + ) + if conf.domain in self._domain_to_worker_urls: + self._discard_ports(self._domain_to_worker_urls[conf.domain]) + self._domain_to_worker_urls[conf.domain] = replica_urls - # Update replicas to router (actual HTTP API calls to add workers) try: - await run_async(router.update_replicas, replica_urls) + await router.update_replicas(replica_urls) except Exception as e: logger.exception( - "Failed to add replicas to router for domain=%s: %s", conf.domain, e + "Failed to add replicas to router for domain=%s: %s", + conf.domain, + e, ) raise @@ -189,12 +197,12 @@ async def unregister(self, domain: str) -> None: if domain in self._domain_to_router: router = self._domain_to_router[domain] # Remove all workers for this domain - if domain in self._domain_to_worker_ports: - worker_ports = self._domain_to_worker_ports[domain] - replica_urls = [ - f"http://{router.context.host}:{port}" for port in worker_ports - ] - await run_async(router.remove_replicas, replica_urls) + if domain in self._domain_to_worker_urls: + worker_urls = self._domain_to_worker_urls[domain] + await run_async(router.remove_replicas, worker_urls) + self._discard_ports(worker_urls) + del self._domain_to_worker_urls[domain] + logger.debug("Removed worker URLs for domain %s", domain) # Stop and kill the router await run_async(router.stop) # Remove from mappings @@ -203,14 +211,6 @@ async def unregister(self, domain: str) -> None: del self._router_port_to_domain[router_port] del self._domain_to_router[domain] - # Discard worker ports for this domain - if domain in self._domain_to_worker_ports: - worker_ports = self._domain_to_worker_ports[domain] - for port in worker_ports: - self._allocated_worker_ports.discard(port) - del self._domain_to_worker_ports[domain] - logger.debug("Freed worker ports %s for domain %s", worker_ports, domain) - # Remove workers config file workers_conf_path = self._conf_dir / f"router-workers.{domain}.conf" if workers_conf_path.exists(): @@ -403,6 +403,12 @@ def _allocate_worker_ports(self, num_ports: int) -> list[int]: return allocated + def _discard_ports(self, urls: list[str]) -> None: + for u in urls: + parsed = urlparse(u) + if parsed.port is not None and parsed.port in self._allocated_worker_ports: + self._allocated_worker_ports.discard(parsed.port) + def write_global_conf(self) -> None: conf = read_package_resource("00-log-format.conf") self.write_conf(conf, "00-log-format.conf") diff --git a/src/dstack/_internal/proxy/gateway/services/registry.py b/src/dstack/_internal/proxy/gateway/services/registry.py index 636d8c38ec..061ee8bec8 100644 --- a/src/dstack/_internal/proxy/gateway/services/registry.py +++ b/src/dstack/_internal/proxy/gateway/services/registry.py @@ -136,6 +136,7 @@ async def register_replica( repo: GatewayProxyRepo, nginx: Nginx, service_conn_pool: ServiceConnectionPool, + internal_ip: Optional[str] = None, ) -> None: replica = models.Replica( id=replica_id, @@ -145,6 +146,7 @@ async def register_replica( ssh_proxy=ssh_proxy, ssh_head_proxy=ssh_head_proxy, ssh_head_proxy_private_key=ssh_head_proxy_private_key, + internal_ip=internal_ip, ) async with lock: @@ -258,7 +260,12 @@ async def apply_service( service, repo, service_conn_pool ) replica_configs = [ - ReplicaConfig(id=replica.id, socket=conn.app_socket_path) + ReplicaConfig( + id=replica.id, + socket=conn.app_socket_path, + port=replica.app_port, + internal_ip=replica.internal_ip, + ) for replica, conn in replica_conns.items() ] service_config = await get_nginx_service_config(service, replica_configs) diff --git a/src/dstack/_internal/proxy/lib/models.py b/src/dstack/_internal/proxy/lib/models.py index bf37e0b5aa..9ae6b4d2ad 100644 --- a/src/dstack/_internal/proxy/lib/models.py +++ b/src/dstack/_internal/proxy/lib/models.py @@ -27,6 +27,7 @@ class Replica(ImmutableModel): # Optional outer proxy, a head node/bastion ssh_head_proxy: Optional[SSHConnectionParams] = None ssh_head_proxy_private_key: Optional[str] = None + internal_ip: Optional[str] = None class IPAddressPartitioningKey(ImmutableModel): diff --git a/src/dstack/_internal/server/services/gateways/client.py b/src/dstack/_internal/server/services/gateways/client.py index d4f1c831e8..e68b874728 100644 --- a/src/dstack/_internal/server/services/gateways/client.py +++ b/src/dstack/_internal/server/services/gateways/client.py @@ -99,6 +99,7 @@ async def register_replica( assert jpd is not None assert jpd.hostname is not None assert jpd.ssh_port is not None + payload["internal_ip"] = jpd.internal_ip if not jpd.dockerized: payload.update( { diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index 06aa5b0ef0..272f66d6b3 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -26,6 +26,7 @@ ) from dstack._internal.core.models.gateways import GatewayConfiguration, GatewayStatus from dstack._internal.core.models.instances import SSHConnectionParams +from dstack._internal.core.models.routers import merge_router_config_for_service from dstack._internal.core.models.runs import JobSpec, Run, RunSpec, ServiceModelSpec, ServiceSpec from dstack._internal.server import settings from dstack._internal.server.models import GatewayModel, JobModel, ProjectModel, RunModel @@ -92,7 +93,10 @@ async def _register_service_in_gateway( gateway_configuration = get_gateway_configuration(gateway) service_https = _get_service_https(run_spec, gateway_configuration) - router = gateway_configuration.router + router = merge_router_config_for_service( + gateway_configuration.router, + run_spec.configuration.router_config, + ) service_protocol = "https" if service_https else "http" if service_https and gateway_configuration.certificate is None: From a90173b7612b5b695c9e380874c10cfe8e7d0f48 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Tue, 10 Feb 2026 13:14:36 +0545 Subject: [PATCH 2/6] Add pd disaggregation service --- gateway/pyproject.toml | 2 +- src/dstack/_internal/core/backends/base/compute.py | 3 +-- .../_internal/core/backends/kubernetes/compute.py | 5 ----- src/dstack/_internal/core/models/configurations.py | 2 +- src/dstack/_internal/core/models/routers.py | 5 +++++ .../proxy/gateway/services/model_routers/base.py | 2 +- .../proxy/gateway/services/model_routers/sglang.py | 12 +++++------- src/dstack/_internal/proxy/gateway/services/nginx.py | 1 - 8 files changed, 14 insertions(+), 18 deletions(-) diff --git a/gateway/pyproject.toml b/gateway/pyproject.toml index 243f37b939..6c4d406a6f 100644 --- a/gateway/pyproject.toml +++ b/gateway/pyproject.toml @@ -11,7 +11,7 @@ requires-python = ">=3.10" dynamic = ["version"] dependencies = [ # release builds of dstack-gateway depend on a PyPI version of dstack instead - "dstack[gateway] @ https://github.com/Bihan/dstack/archive/refs/heads/pd_design_test.tar.gz", + "dstack[gateway] @ https://github.com/dstackai/dstack/archive/refs/heads/master.tar.gz", ] [project.optional-dependencies] diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index b07c510e9c..49513e3211 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -1042,8 +1042,7 @@ def get_dstack_gateway_wheel(build: str, router: Optional[AnyRouterConfig] = Non if build == "latest": build = _fetch_version(f"{base_url}/latest-version") or "latest" logger.debug("Found the latest gateway build: %s", build) - # wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" - wheel = "https://bihan-test-bucket.s3.eu-west-1.amazonaws.com/dstack_gateway-0.0.1-py3-none-any.whl" + wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" # Build package spec with extras if router is specified if router: return f"dstack-gateway[{router.type}] @ {wheel}" diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index f25b87d9cc..7f8ef9123f 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -330,11 +330,6 @@ def update_provisioning_data( if not pod_ip: return provisioning_data.internal_ip = pod_ip - logger.debug( - "Replica pod %s internal_ip=%s (cluster_ip will be set from Service)", - provisioning_data.instance_id, - pod_ip, - ) service = self.api.read_namespaced_service( name=_get_pod_service_name(provisioning_data.instance_id), namespace=self.config.namespace, diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 7055b14073..955f90657c 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -890,7 +890,7 @@ class ServiceConfigurationParams(CoreModel): Optional[AnyRouterConfig], Field( description=( - "Router configuration for the service (e.g. routing policy and pd_disaggregation). " + "Router configuration for the service. Currently supports routing policy and pd_disaggregation. " ), ), ] = None diff --git a/src/dstack/_internal/core/models/routers.py b/src/dstack/_internal/core/models/routers.py index d80ad35873..aeabadc31a 100644 --- a/src/dstack/_internal/core/models/routers.py +++ b/src/dstack/_internal/core/models/routers.py @@ -32,6 +32,11 @@ def merge_router_config_for_service( gateway_router: Optional[AnyRouterConfig], service_router_config: Optional[AnyRouterConfig], ) -> Optional[AnyRouterConfig]: + """Merge gateway and service router config. + + Gateway router config supplies the router type; service router config supplies + policy and pd_disaggregation. The result combines both. + """ if gateway_router is None: return None if service_router_config is None: diff --git a/src/dstack/_internal/proxy/gateway/services/model_routers/base.py b/src/dstack/_internal/proxy/gateway/services/model_routers/base.py index 129562e497..c8704e0cee 100644 --- a/src/dstack/_internal/proxy/gateway/services/model_routers/base.py +++ b/src/dstack/_internal/proxy/gateway/services/model_routers/base.py @@ -90,7 +90,7 @@ async def update_replicas(self, replica_urls: List[str]) -> None: """ ... - def add_worker_to_router( + async def add_worker_to_router( self, url: str, worker_type: str = "regular", diff --git a/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py b/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py index 7d976b0059..b187a0699f 100644 --- a/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py +++ b/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py @@ -9,7 +9,6 @@ from dstack._internal.core.models.routers import RouterType, SGLangRouterConfig from dstack._internal.proxy.lib.errors import UnexpectedProxyError -from dstack._internal.utils.common import run_async from dstack._internal.utils.logging import get_logger from .base import Router, RouterContext @@ -174,7 +173,7 @@ async def update_replicas(self, replica_urls: List[str]) -> None: self.context.port, ) - # Add workers: poll /server_info and register with discovered type + # Add workers for worker_url in sorted(workers_to_add): success = await self.register_worker(worker_url) if not success: @@ -199,7 +198,7 @@ def _get_router_workers(self) -> List[dict]: logger.exception("Error getting sglang router workers") return [] - def add_worker_to_router( + async def add_worker_to_router( self, url: str, worker_type: str = "regular", @@ -209,8 +208,8 @@ def add_worker_to_router( payload: dict = {"url": url, "worker_type": worker_type} if bootstrap_port is not None: payload["bootstrap_port"] = bootstrap_port - with httpx.Client(timeout=5.0) as client: - response = client.post( + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( f"http://{self.context.host}:{self.context.port}/workers", json=payload, ) @@ -268,8 +267,7 @@ async def register_worker(self, url: str) -> bool: url, worker_type, ) - return await run_async( - self.add_worker_to_router, + return await self.add_worker_to_router( url, worker_type, bootstrap_port, diff --git a/src/dstack/_internal/proxy/gateway/services/nginx.py b/src/dstack/_internal/proxy/gateway/services/nginx.py index 85473923f7..fdd6249adb 100644 --- a/src/dstack/_internal/proxy/gateway/services/nginx.py +++ b/src/dstack/_internal/proxy/gateway/services/nginx.py @@ -98,7 +98,6 @@ def __init__(self, conf_dir: Path = Path("/etc/nginx/sites-enabled")) -> None: self._next_router_port: int = self._ROUTER_PORT_MIN # Tracking of worker ports to avoid conflicts across router instances self._allocated_worker_ports: set[int] = set() - # Domain -> list of worker URLs (used for remove_replicas; non-PD URLs are gateway-local) self._domain_to_worker_urls: Dict[str, list[str]] = {} self._next_worker_port: int = self._WORKER_PORT_MIN From 5ec1f97a35ea7616c2fc1e9429f03616dd710514 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Fri, 13 Feb 2026 08:12:45 +0545 Subject: [PATCH 3/6] Move router configuration to service --- .../_internal/core/backends/base/compute.py | 9 ++++---- .../core/backends/kubernetes/compute.py | 5 +---- .../_internal/core/models/configurations.py | 4 ++-- src/dstack/_internal/core/models/gateways.py | 8 +++---- src/dstack/_internal/core/models/routers.py | 22 +------------------ .../server/services/services/__init__.py | 6 +---- 6 files changed, 13 insertions(+), 41 deletions(-) diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 49513e3211..5939751132 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -39,7 +39,6 @@ SSHKey, ) from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData -from dstack._internal.core.models.routers import AnyRouterConfig from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run from dstack._internal.core.models.volumes import ( Volume, @@ -924,7 +923,7 @@ def get_run_shim_script( ] -def get_gateway_user_data(authorized_key: str, router: Optional[AnyRouterConfig] = None) -> str: +def get_gateway_user_data(authorized_key: str, router: Optional[str] = None) -> str: return get_cloud_config( package_update=True, packages=[ @@ -1036,7 +1035,7 @@ def get_latest_runner_build() -> Optional[str]: return None -def get_dstack_gateway_wheel(build: str, router: Optional[AnyRouterConfig] = None) -> str: +def get_dstack_gateway_wheel(build: str, router: Optional[str] = None) -> str: channel = "release" if settings.DSTACK_RELEASE else "stgn" base_url = f"https://dstack-gateway-downloads.s3.amazonaws.com/{channel}" if build == "latest": @@ -1045,11 +1044,11 @@ def get_dstack_gateway_wheel(build: str, router: Optional[AnyRouterConfig] = Non wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" # Build package spec with extras if router is specified if router: - return f"dstack-gateway[{router.type}] @ {wheel}" + return f"dstack-gateway[{router}] @ {wheel}" return f"dstack-gateway @ {wheel}" -def get_dstack_gateway_commands(router: Optional[AnyRouterConfig] = None) -> List[str]: +def get_dstack_gateway_commands(router: Optional[str] = None) -> List[str]: build = get_dstack_runner_version() or "latest" gateway_package = get_dstack_gateway_wheel(build, router) return [ diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index 7f8ef9123f..e9692344d9 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -50,7 +50,6 @@ ) from dstack._internal.core.models.placement import PlacementGroup from dstack._internal.core.models.resources import CPUSpec, GPUSpec, Memory -from dstack._internal.core.models.routers import AnyRouterConfig from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run from dstack._internal.core.models.volumes import Volume from dstack._internal.utils.common import get_or_error, parse_memory @@ -1002,9 +1001,7 @@ def _add_authorized_key_to_jump_pod( ) -def _get_gateway_commands( - authorized_keys: List[str], router: Optional[AnyRouterConfig] = None -) -> List[str]: +def _get_gateway_commands(authorized_keys: List[str], router: Optional[str] = None) -> List[str]: authorized_keys_content = "\n".join(authorized_keys).strip() gateway_commands = " && ".join(get_dstack_gateway_commands(router=router)) quoted_gateway_commands = shlex.quote(gateway_commands) diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 955f90657c..801f87daae 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -886,11 +886,11 @@ class ServiceConfigurationParams(CoreModel): ) ), ] = None - router_config: Annotated[ + router: Annotated[ Optional[AnyRouterConfig], Field( description=( - "Router configuration for the service. Currently supports routing policy and pd_disaggregation. " + "Router configuration for the service. Requires a gateway with matching router enabled. " ), ), ] = None diff --git a/src/dstack/_internal/core/models/gateways.py b/src/dstack/_internal/core/models/gateways.py index b342c0a73b..e9581b83f0 100644 --- a/src/dstack/_internal/core/models/gateways.py +++ b/src/dstack/_internal/core/models/gateways.py @@ -8,7 +8,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import CoreModel -from dstack._internal.core.models.routers import AnyRouterConfig +from dstack._internal.core.models.routers import RouterType from dstack._internal.utils.tags import tags_validator @@ -63,8 +63,8 @@ class GatewayConfiguration(CoreModel): ), ] = None router: Annotated[ - Optional[AnyRouterConfig], - Field(description="The router configuration"), + Optional[RouterType], + Field(description="The router type enabled on this gateway. E.g. 'sglang'."), ] = None domain: Annotated[ Optional[str], Field(description="The gateway domain, e.g. `example.com`") @@ -134,7 +134,7 @@ class GatewayComputeConfiguration(CoreModel): ssh_key_pub: str certificate: Optional[AnyGatewayCertificate] = None tags: Optional[Dict[str, str]] = None - router: Optional[AnyRouterConfig] = None + router: Optional[RouterType] = None class GatewayProvisioningData(CoreModel): diff --git a/src/dstack/_internal/core/models/routers.py b/src/dstack/_internal/core/models/routers.py index aeabadc31a..e42cd9976d 100644 --- a/src/dstack/_internal/core/models/routers.py +++ b/src/dstack/_internal/core/models/routers.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Literal, Optional +from typing import Literal from pydantic import Field from typing_extensions import Annotated @@ -26,23 +26,3 @@ class SGLangRouterConfig(CoreModel): AnyRouterConfig = SGLangRouterConfig - - -def merge_router_config_for_service( - gateway_router: Optional[AnyRouterConfig], - service_router_config: Optional[AnyRouterConfig], -) -> Optional[AnyRouterConfig]: - """Merge gateway and service router config. - - Gateway router config supplies the router type; service router config supplies - policy and pd_disaggregation. The result combines both. - """ - if gateway_router is None: - return None - if service_router_config is None: - return gateway_router - return SGLangRouterConfig( - type=gateway_router.type, - policy=service_router_config.policy, - pd_disaggregation=service_router_config.pd_disaggregation, - ) diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index 272f66d6b3..8f3d10c9dc 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -26,7 +26,6 @@ ) from dstack._internal.core.models.gateways import GatewayConfiguration, GatewayStatus from dstack._internal.core.models.instances import SSHConnectionParams -from dstack._internal.core.models.routers import merge_router_config_for_service from dstack._internal.core.models.runs import JobSpec, Run, RunSpec, ServiceModelSpec, ServiceSpec from dstack._internal.server import settings from dstack._internal.server.models import GatewayModel, JobModel, ProjectModel, RunModel @@ -93,10 +92,7 @@ async def _register_service_in_gateway( gateway_configuration = get_gateway_configuration(gateway) service_https = _get_service_https(run_spec, gateway_configuration) - router = merge_router_config_for_service( - gateway_configuration.router, - run_spec.configuration.router_config, - ) + router = run_spec.configuration.router service_protocol = "https" if service_https else "http" if service_https and gateway_configuration.certificate is None: From 6e7dbe7abbe9f8cdc640d148885ccc3a4fdaf751 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Fri, 13 Feb 2026 13:02:07 +0545 Subject: [PATCH 4/6] Resolve major comments --- .../gateway/services/model_routers/base.py | 6 +++--- .../gateway/services/model_routers/sglang.py | 18 +++++++++--------- .../_internal/proxy/gateway/services/nginx.py | 16 ++++++++++++---- .../proxy/gateway/services/registry.py | 2 +- .../server/services/services/__init__.py | 18 ++++++++++++++++++ 5 files changed, 43 insertions(+), 17 deletions(-) diff --git a/src/dstack/_internal/proxy/gateway/services/model_routers/base.py b/src/dstack/_internal/proxy/gateway/services/model_routers/base.py index c8704e0cee..a9b54347e4 100644 --- a/src/dstack/_internal/proxy/gateway/services/model_routers/base.py +++ b/src/dstack/_internal/proxy/gateway/services/model_routers/base.py @@ -79,7 +79,7 @@ def remove_replicas(self, replica_urls: List[str]) -> None: ... @abstractmethod - async def update_replicas(self, replica_urls: List[str]) -> None: + def update_replicas(self, replica_urls: List[str]) -> None: """Update replicas for service, replacing the current set. Args: @@ -90,7 +90,7 @@ async def update_replicas(self, replica_urls: List[str]) -> None: """ ... - async def add_worker_to_router( + def add_worker_to_router( self, url: str, worker_type: str = "regular", @@ -108,6 +108,6 @@ async def add_worker_to_router( """ raise NotImplementedError - async def register_worker(self, url: str) -> bool: + def register_worker(self, url: str) -> bool: """Register worker with one attempt (no polling). Returns True if ready and added.""" raise NotImplementedError diff --git a/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py b/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py index b187a0699f..5214bb8e93 100644 --- a/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py +++ b/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py @@ -141,7 +141,7 @@ def remove_replicas(self, replica_urls: List[str]) -> None: for replica_url in replica_urls: self._remove_worker_from_router(replica_url) - async def update_replicas(self, replica_urls: List[str]) -> None: + def update_replicas(self, replica_urls: List[str]) -> None: """Update replicas for service, replacing the current set.""" # Query router to get current worker URLs current_workers = self._get_router_workers() @@ -175,7 +175,7 @@ async def update_replicas(self, replica_urls: List[str]) -> None: # Add workers for worker_url in sorted(workers_to_add): - success = await self.register_worker(worker_url) + success = self.register_worker(worker_url) if not success: logger.warning("Failed to add worker %s, continuing with others", worker_url) @@ -198,7 +198,7 @@ def _get_router_workers(self) -> List[dict]: logger.exception("Error getting sglang router workers") return [] - async def add_worker_to_router( + def add_worker_to_router( self, url: str, worker_type: str = "regular", @@ -208,8 +208,8 @@ async def add_worker_to_router( payload: dict = {"url": url, "worker_type": worker_type} if bootstrap_port is not None: payload["bootstrap_port"] = bootstrap_port - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( + with httpx.Client(timeout=5.0) as client: + response = client.post( f"http://{self.context.host}:{self.context.port}/workers", json=payload, ) @@ -242,11 +242,11 @@ async def add_worker_to_router( logger.exception("Error adding worker %s", url) return False - async def register_worker(self, url: str) -> bool: + def register_worker(self, url: str) -> bool: server_info_url = f"{url}/server_info" try: - async with httpx.AsyncClient(timeout=10) as client: - resp = await client.get(server_info_url) + with httpx.Client(timeout=10) as client: + resp = client.get(server_info_url) if resp.status_code != 200: return False data = resp.json() @@ -267,7 +267,7 @@ async def register_worker(self, url: str) -> bool: url, worker_type, ) - return await self.add_worker_to_router( + return self.add_worker_to_router( url, worker_type, bootstrap_port, diff --git a/src/dstack/_internal/proxy/gateway/services/nginx.py b/src/dstack/_internal/proxy/gateway/services/nginx.py index fdd6249adb..d79b0c7932 100644 --- a/src/dstack/_internal/proxy/gateway/services/nginx.py +++ b/src/dstack/_internal/proxy/gateway/services/nginx.py @@ -19,6 +19,7 @@ RouterContext, get_router, ) +from dstack._internal.proxy.lib import models from dstack._internal.proxy.lib.errors import ProxyError, UnexpectedProxyError from dstack._internal.utils.common import run_async from dstack._internal.utils.logging import get_logger @@ -149,10 +150,13 @@ async def register(self, conf: SiteConfig, acme: ACMESettings) -> None: if conf.router.pd_disaggregation: # PD path: replica_urls from internal_ip (router talks directly to workers) + if any(not r.internal_ip for r in conf.replicas): + raise ProxyError( + "PD disaggregation requires internal IP for all replicas." + ) replica_urls = [ f"http://{replica.internal_ip}:{replica.port}" for replica in conf.replicas - if replica.internal_ip ] self._domain_to_worker_urls[conf.domain] = replica_urls else: @@ -172,7 +176,7 @@ async def register(self, conf: SiteConfig, acme: ACMESettings) -> None: self._domain_to_worker_urls[conf.domain] = replica_urls try: - await router.update_replicas(replica_urls) + await run_async(router.update_replicas, replica_urls) except Exception as e: logger.exception( "Failed to add replicas to router for domain=%s: %s", @@ -185,7 +189,7 @@ async def register(self, conf: SiteConfig, acme: ACMESettings) -> None: logger.info("Registered %s domain %s", conf.type, conf.domain) - async def unregister(self, domain: str) -> None: + async def unregister(self, domain: str, service: models.Service) -> None: logger.debug("Unregistering domain %s", domain) conf_path = self._conf_dir / self.get_config_name(domain) if not conf_path.exists(): @@ -199,7 +203,11 @@ async def unregister(self, domain: str) -> None: if domain in self._domain_to_worker_urls: worker_urls = self._domain_to_worker_urls[domain] await run_async(router.remove_replicas, worker_urls) - self._discard_ports(worker_urls) + pd_disaggregation = ( + service.router.pd_disaggregation if service.router else False + ) + if not pd_disaggregation: + self._discard_ports(worker_urls) del self._domain_to_worker_urls[domain] logger.debug("Removed worker URLs for domain %s", domain) # Stop and kill the router diff --git a/src/dstack/_internal/proxy/gateway/services/registry.py b/src/dstack/_internal/proxy/gateway/services/registry.py index 061ee8bec8..fd523e8d12 100644 --- a/src/dstack/_internal/proxy/gateway/services/registry.py +++ b/src/dstack/_internal/proxy/gateway/services/registry.py @@ -116,7 +116,7 @@ async def unregister_service( ids=(r.id for r in service.replicas), service_conn_pool=service_conn_pool, ) - await nginx.unregister(service.domain_safe) + await nginx.unregister(service.domain_safe, service) await repo.delete_models_by_run(project_name, run_name) await repo.delete_service(project_name, run_name) diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index 8f3d10c9dc..201351e50d 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -26,6 +26,7 @@ ) from dstack._internal.core.models.gateways import GatewayConfiguration, GatewayStatus from dstack._internal.core.models.instances import SSHConnectionParams +from dstack._internal.core.models.routers import RouterType from dstack._internal.core.models.runs import JobSpec, Run, RunSpec, ServiceModelSpec, ServiceSpec from dstack._internal.server import settings from dstack._internal.server.models import GatewayModel, JobModel, ProjectModel, RunModel @@ -91,6 +92,15 @@ async def _register_service_in_gateway( raise ServerClientError("Gateway status is not running") gateway_configuration = get_gateway_configuration(gateway) + if ( + run_spec.configuration.router is not None + and run_spec.configuration.router.type == RouterType.SGLANG + ): + if gateway_configuration.router != RouterType.SGLANG: + raise ServerClientError( + f"Service requires a SGLang gateway but gateway '{gateway.name}' " + "does not have the SGLang router configured." + ) service_https = _get_service_https(run_spec, gateway_configuration) router = run_spec.configuration.router service_protocol = "https" if service_https else "http" @@ -152,6 +162,14 @@ async def _register_service_in_gateway( def _register_service_in_server(run_model: RunModel, run_spec: RunSpec) -> ServiceSpec: assert run_spec.configuration.type == "service" + if ( + run_spec.configuration.router is not None + and run_spec.configuration.router.type == RouterType.SGLANG + ): + raise ServerClientError( + "Service with SGLang router configuration requires a gateway. " + "Please configure a gateway with the SGLang router enabled." + ) if run_spec.configuration.https != SERVICE_HTTPS_DEFAULT: # Note: if the user sets `https: `, it will be ignored silently # TODO: in 0.19, make `https` Optional to be able to tell if it was set or omitted From 860ea230ff53cfd1b80925f353de2cfb15a0f84c Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Fri, 13 Feb 2026 13:49:35 +0545 Subject: [PATCH 5/6] Resolve Lint Error --- src/dstack/_internal/core/backends/kubernetes/compute.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index e3a430b35c..5223cdaa7c 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -862,9 +862,7 @@ def _wait_for_load_balancer_address( time.sleep(1) -def _get_gateway_commands( - authorized_keys: List[str], router: Optional[str] = None -) -> List[str]: +def _get_gateway_commands(authorized_keys: List[str], router: Optional[str] = None) -> List[str]: authorized_keys_content = "\n".join(authorized_keys).strip() gateway_commands = " && ".join(get_dstack_gateway_commands(router=router)) quoted_gateway_commands = shlex.quote(gateway_commands) From 38eee94052c9d44eb8d6a3ada91899761a46857d Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Fri, 13 Feb 2026 20:35:02 +0545 Subject: [PATCH 6/6] Minor Update --- src/dstack/_internal/core/backends/base/compute.py | 9 +++++---- src/dstack/_internal/core/backends/kubernetes/compute.py | 5 ++++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 5939751132..ade1b3daeb 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -39,6 +39,7 @@ SSHKey, ) from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData +from dstack._internal.core.models.routers import RouterType from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run from dstack._internal.core.models.volumes import ( Volume, @@ -923,7 +924,7 @@ def get_run_shim_script( ] -def get_gateway_user_data(authorized_key: str, router: Optional[str] = None) -> str: +def get_gateway_user_data(authorized_key: str, router: Optional[RouterType] = None) -> str: return get_cloud_config( package_update=True, packages=[ @@ -1035,7 +1036,7 @@ def get_latest_runner_build() -> Optional[str]: return None -def get_dstack_gateway_wheel(build: str, router: Optional[str] = None) -> str: +def get_dstack_gateway_wheel(build: str, router: Optional[RouterType] = None) -> str: channel = "release" if settings.DSTACK_RELEASE else "stgn" base_url = f"https://dstack-gateway-downloads.s3.amazonaws.com/{channel}" if build == "latest": @@ -1044,11 +1045,11 @@ def get_dstack_gateway_wheel(build: str, router: Optional[str] = None) -> str: wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" # Build package spec with extras if router is specified if router: - return f"dstack-gateway[{router}] @ {wheel}" + return f"dstack-gateway[{router.value}] @ {wheel}" return f"dstack-gateway @ {wheel}" -def get_dstack_gateway_commands(router: Optional[str] = None) -> List[str]: +def get_dstack_gateway_commands(router: Optional[RouterType] = None) -> List[str]: build = get_dstack_runner_version() or "latest" gateway_package = get_dstack_gateway_wheel(build, router) return [ diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index 5223cdaa7c..d98dfc94a8 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -66,6 +66,7 @@ ) from dstack._internal.core.models.placement import PlacementGroup from dstack._internal.core.models.resources import CPUSpec, GPUSpec +from dstack._internal.core.models.routers import RouterType from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run from dstack._internal.core.models.volumes import Volume from dstack._internal.utils.common import get_or_error @@ -862,7 +863,9 @@ def _wait_for_load_balancer_address( time.sleep(1) -def _get_gateway_commands(authorized_keys: List[str], router: Optional[str] = None) -> List[str]: +def _get_gateway_commands( + authorized_keys: List[str], router: Optional[RouterType] = None +) -> List[str]: authorized_keys_content = "\n".join(authorized_keys).strip() gateway_commands = " && ".join(get_dstack_gateway_commands(router=router)) quoted_gateway_commands = shlex.quote(gateway_commands)