Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gateway/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ dependencies = [
]

[project.optional-dependencies]
sglang = ["sglang-router==0.2.1"]
sglang = ["sglang-router==0.3.2"]

[tool.setuptools.package-data]
"dstack.gateway" = [
Expand Down
9 changes: 9 additions & 0 deletions src/dstack/_internal/core/models/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
parse_off_duration,
)
from dstack._internal.core.models.resources import Range, ResourcesSpec
from dstack._internal.core.models.routers import AnyRouterConfig
from dstack._internal.core.models.services import AnyModel, OpenAIChatModel
from dstack._internal.core.models.unix import UnixUser
from dstack._internal.core.models.volumes import MountPoint, VolumeConfiguration, parse_mount_point
Expand Down Expand Up @@ -885,6 +886,14 @@ class ServiceConfigurationParams(CoreModel):
)
),
] = None
router_config: Annotated[
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) router_config -> router for brevity and consistency with gateway configurations?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any new properties should be excluded from client requests for compatibility with older servers.

See get_run_spec_excludes.

Optional[AnyRouterConfig],
Field(
description=(
"Router configuration for the service. Currently supports routing policy and pd_disaggregation. "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) The supported properties (routing policy and pd_disaggregation) should already be visible in the AnyRouterConfig reference, duplicating them here may lead to inconsistencies when adding or removing properties in the future

),
),
] = None

@validator("port")
def convert_port(cls, v) -> PortMapping:
Expand Down
26 changes: 25 additions & 1 deletion src/dstack/_internal/core/models/routers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Literal
from typing import Literal, Optional

from pydantic import Field
from typing_extensions import Annotated
Expand All @@ -19,6 +19,30 @@ class SGLangRouterConfig(CoreModel):
description="The routing policy. Options: `random`, `round_robin`, `cache_aware`, `power_of_two`"
),
] = "cache_aware"
pd_disaggregation: Annotated[
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any new properties should be excluded from client requests for compatibility with older servers.

See get_run_spec_excludes and _get_gateway_configuration_excludes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This adds pd_disaggregation to both service configurations and gateway configurations. I'd advocate for adding it to service configurations only.

  • Whether or not the service is configured to use PD disaggregation is clearly a service property, because it depends on the replica groups configuration. I don't think many users would want to configure that property at the gateway level, making assumptions about what services are going to run on that gateway in the future.

  • Having two places for the same property complicates the interface — you'd have to explain in the docs how these places are related to each other, when and how one setting overrides the other, etc.

  • Having the property at the gateway level can potentially complicate further development — that way, you can only tell whether a service is using PD disaggregation if the service is already registered, and you need to fetch the GatewayModel object from the database to do so. For example, this would complicate adding the default probe for services with PD disaggregation.

bool,
Field(description="Enable PD disaggregation mode for the SGLang router"),
] = False


AnyRouterConfig = SGLangRouterConfig


def merge_router_config_for_service(
gateway_router: Optional[AnyRouterConfig],
service_router_config: Optional[AnyRouterConfig],
) -> Optional[AnyRouterConfig]:
"""Merge gateway and service router config.

Gateway router config supplies the router type; service router config supplies
policy and pd_disaggregation. The result combines both.
"""
if gateway_router is None:
return None
if service_router_config is None:
return gateway_router
return SGLangRouterConfig(
type=gateway_router.type,
policy=service_router_config.policy,
pd_disaggregation=service_router_config.pd_disaggregation,
)
1 change: 1 addition & 0 deletions src/dstack/_internal/proxy/gateway/repo/state_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def parse_replica(replica: dict) -> Replica:
ssh_destination=replica["ssh_host"],
ssh_port=replica["ssh_port"],
ssh_proxy=ssh_proxy,
internal_ip=replica.get("internal_ip"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) Redundant, old state.json files never had the internal_ip property

Suggested change
internal_ip=replica.get("internal_ip"),

)


Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/proxy/gateway/routers/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ async def register_replica(
ssh_proxy=body.ssh_proxy,
ssh_head_proxy=body.ssh_head_proxy,
ssh_head_proxy_private_key=body.ssh_head_proxy_private_key,
internal_ip=body.internal_ip,
repo=repo,
nginx=nginx,
service_conn_pool=service_conn_pool,
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/proxy/gateway/schemas/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class RegisterReplicaRequest(BaseModel):
ssh_proxy: Optional[SSHConnectionParams]
ssh_head_proxy: Optional[SSHConnectionParams]
ssh_head_proxy_private_key: Optional[str]
internal_ip: Optional[str] = None


class RegisterEntrypointRequest(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def remove_replicas(self, replica_urls: List[str]) -> None:
...

@abstractmethod
def update_replicas(self, replica_urls: List[str]) -> None:
async def update_replicas(self, replica_urls: List[str]) -> None:
"""Update replicas for service, replacing the current set.

Args:
Expand All @@ -89,3 +89,25 @@ def update_replicas(self, replica_urls: List[str]) -> None:
Exception: If updating replicas fails.
"""
...

async def add_worker_to_router(
self,
url: str,
worker_type: str = "regular",
bootstrap_port: Optional[int] = None,
) -> bool:
"""Add a worker to the router.

Args:
url: Worker URL (e.g. http://10.0.5.134:8000).
worker_type: Type of worker ("regular", "prefill", or "decode").
bootstrap_port: Bootstrap port for prefill workers (optional).

Returns:
True if the worker was accepted, False otherwise.
"""
raise NotImplementedError

async def register_worker(self, url: str) -> bool:
"""Register worker with one attempt (no polling). Returns True if ready and added."""
raise NotImplementedError
Comment on lines +92 to +113
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) These methods don't need to be in the Router base class, they are only ever called by SglangRouter.update_replicas and never by external callers. Consider removing them from Router and making private in SglangRouter (prefixed with _)

79 changes: 66 additions & 13 deletions src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import subprocess
import sys
import time
import urllib.parse
from typing import List, Optional

import httpx
Expand Down Expand Up @@ -68,6 +67,8 @@ def start(self) -> None:
"--policy",
self.config.policy,
]
if self.config.pd_disaggregation:
cmd.append("--pd-disaggregation")

subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

Expand Down Expand Up @@ -140,7 +141,7 @@ def remove_replicas(self, replica_urls: List[str]) -> None:
for replica_url in replica_urls:
self._remove_worker_from_router(replica_url)

def update_replicas(self, replica_urls: List[str]) -> None:
async def update_replicas(self, replica_urls: List[str]) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method is now async, but it still calls _get_router_workers and _remove_worker_from_router, which are synchronous and perform blocking I/O using httpx.Client. This can block the gateway's event loop for a long time, during which the gateway will be inoperable.

I'd suggest to revert the method back to synchronous and call it outside of the event loop using run_async, as was done before.

Or make all SglangRouter methods async and ensure there are no blocking calls (no httpx.Client, subprocess.Popen, etc). But that'd be too many unrelated changes, so I'd suggest to stick to the synchronous interface in this PR.

"""Update replicas for service, replacing the current set."""
# Query router to get current worker URLs
current_workers = self._get_router_workers()
Expand Down Expand Up @@ -174,7 +175,7 @@ def update_replicas(self, replica_urls: List[str]) -> None:

# Add workers
for worker_url in sorted(workers_to_add):
success = self._add_worker_to_router(worker_url)
success = await self.register_worker(worker_url)
if not success:
logger.warning("Failed to add worker %s, continuing with others", worker_url)

Expand All @@ -197,20 +198,28 @@ def _get_router_workers(self) -> List[dict]:
logger.exception("Error getting sglang router workers")
return []

def _add_worker_to_router(self, worker_url: str) -> bool:
async def add_worker_to_router(
self,
url: str,
worker_type: str = "regular",
bootstrap_port: Optional[int] = None,
) -> bool:
try:
payload = {"url": worker_url, "worker_type": "regular"}
with httpx.Client(timeout=5.0) as client:
response = client.post(
payload: dict = {"url": url, "worker_type": worker_type}
if bootstrap_port is not None:
payload["bootstrap_port"] = bootstrap_port
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.post(
f"http://{self.context.host}:{self.context.port}/workers",
json=payload,
)
if response.status_code == 202:
response_data = response.json()
if response_data.get("status") == "accepted":
logger.info(
"Worker %s accepted by sglang router on port %s",
worker_url,
"Worker %s (type=%s) accepted by sglang router on port %s",
url,
worker_type,
self.context.port,
)
return True
Expand All @@ -224,21 +233,65 @@ def _add_worker_to_router(self, worker_url: str) -> bool:
else:
logger.error(
"Failed to add worker %s: status %d, %s",
worker_url,
url,
response.status_code,
response.text,
)
return False
except Exception:
logger.exception("Error adding worker %s", worker_url)
logger.exception("Error adding worker %s", url)
return False

async def register_worker(self, url: str) -> bool:
server_info_url = f"{url}/server_info"
try:
async with httpx.AsyncClient(timeout=10) as client:
resp = await client.get(server_info_url)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) Avoid this request and assume worker_type = "regular" if pd_disaggregation is disabled?

if resp.status_code != 200:
return False
data = resp.json()
if data.get("status") != "ready":
return False
disaggregation_mode = data.get("disaggregation_mode", "")
if disaggregation_mode == "prefill":
worker_type = "prefill"
bootstrap_port = data.get("disaggregation_bootstrap_port")
elif disaggregation_mode == "decode":
worker_type = "decode"
bootstrap_port = None
else:
worker_type = "regular"
bootstrap_port = None
logger.info(
"Registering worker %s (type=%s)",
url,
worker_type,
)
return await self.add_worker_to_router(
url,
worker_type,
bootstrap_port,
)
except Exception:
logger.exception("Error registering worker %s", url)
return False

def _remove_worker_from_router(self, worker_url: str) -> bool:
try:
encoded_url = urllib.parse.quote(worker_url, safe="")
current_workers = self._get_router_workers()
worker_id = None
for worker in current_workers:
url = worker.get("url")
if url and isinstance(url, str) and url == worker_url:
worker_id = worker.get("id")
if worker_id and isinstance(worker_id, str):
break
if not worker_id:
logger.exception("No worker id found for url %s", worker_url)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) logger.exception should only be called from an exception handler. You can use logger.error

return False
with httpx.Client(timeout=5.0) as client:
response = client.delete(
f"http://{self.context.host}:{self.context.port}/workers/{encoded_url}"
f"http://{self.context.host}:{self.context.port}/workers/{worker_id}"
Comment on lines +281 to +294
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this change for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Latest sglang_router (0.3.2) does not support /{worker_url} endpoint for deletion.

)
if response.status_code == 202:
response_data = response.json()
Expand Down
Loading