-
Notifications
You must be signed in to change notification settings - Fork 213
Add pd disaggregated inference #3558
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,6 +28,7 @@ | |
| parse_off_duration, | ||
| ) | ||
| from dstack._internal.core.models.resources import Range, ResourcesSpec | ||
| from dstack._internal.core.models.routers import AnyRouterConfig | ||
| from dstack._internal.core.models.services import AnyModel, OpenAIChatModel | ||
| from dstack._internal.core.models.unix import UnixUser | ||
| from dstack._internal.core.models.volumes import MountPoint, VolumeConfiguration, parse_mount_point | ||
|
|
@@ -885,6 +886,14 @@ class ServiceConfigurationParams(CoreModel): | |
| ) | ||
| ), | ||
| ] = None | ||
| router_config: Annotated[ | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any new properties should be excluded from client requests for compatibility with older servers. See |
||
| Optional[AnyRouterConfig], | ||
| Field( | ||
| description=( | ||
| "Router configuration for the service. Currently supports routing policy and pd_disaggregation. " | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nit) The supported properties (routing policy and pd_disaggregation) should already be visible in the |
||
| ), | ||
| ), | ||
| ] = None | ||
|
|
||
| @validator("port") | ||
| def convert_port(cls, v) -> PortMapping: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,5 @@ | ||
| from enum import Enum | ||
| from typing import Literal | ||
| from typing import Literal, Optional | ||
|
|
||
| from pydantic import Field | ||
| from typing_extensions import Annotated | ||
|
|
@@ -19,6 +19,30 @@ class SGLangRouterConfig(CoreModel): | |
| description="The routing policy. Options: `random`, `round_robin`, `cache_aware`, `power_of_two`" | ||
| ), | ||
| ] = "cache_aware" | ||
| pd_disaggregation: Annotated[ | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any new properties should be excluded from client requests for compatibility with older servers. See
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This adds pd_disaggregation to both service configurations and gateway configurations. I'd advocate for adding it to service configurations only.
|
||
| bool, | ||
| Field(description="Enable PD disaggregation mode for the SGLang router"), | ||
| ] = False | ||
|
|
||
|
|
||
| AnyRouterConfig = SGLangRouterConfig | ||
|
|
||
|
|
||
| def merge_router_config_for_service( | ||
| gateway_router: Optional[AnyRouterConfig], | ||
| service_router_config: Optional[AnyRouterConfig], | ||
| ) -> Optional[AnyRouterConfig]: | ||
| """Merge gateway and service router config. | ||
|
|
||
| Gateway router config supplies the router type; service router config supplies | ||
| policy and pd_disaggregation. The result combines both. | ||
| """ | ||
| if gateway_router is None: | ||
| return None | ||
| if service_router_config is None: | ||
| return gateway_router | ||
| return SGLangRouterConfig( | ||
| type=gateway_router.type, | ||
| policy=service_router_config.policy, | ||
| pd_disaggregation=service_router_config.pd_disaggregation, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -98,6 +98,7 @@ def parse_replica(replica: dict) -> Replica: | |||
| ssh_destination=replica["ssh_host"], | ||||
| ssh_port=replica["ssh_port"], | ||||
| ssh_proxy=ssh_proxy, | ||||
| internal_ip=replica.get("internal_ip"), | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nit) Redundant, old
Suggested change
|
||||
| ) | ||||
|
|
||||
|
|
||||
|
|
||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -79,7 +79,7 @@ def remove_replicas(self, replica_urls: List[str]) -> None: | |
| ... | ||
|
|
||
| @abstractmethod | ||
| def update_replicas(self, replica_urls: List[str]) -> None: | ||
| async def update_replicas(self, replica_urls: List[str]) -> None: | ||
| """Update replicas for service, replacing the current set. | ||
|
|
||
| Args: | ||
|
|
@@ -89,3 +89,25 @@ def update_replicas(self, replica_urls: List[str]) -> None: | |
| Exception: If updating replicas fails. | ||
| """ | ||
| ... | ||
|
|
||
| async def add_worker_to_router( | ||
| self, | ||
| url: str, | ||
| worker_type: str = "regular", | ||
| bootstrap_port: Optional[int] = None, | ||
| ) -> bool: | ||
| """Add a worker to the router. | ||
|
|
||
| Args: | ||
| url: Worker URL (e.g. http://10.0.5.134:8000). | ||
| worker_type: Type of worker ("regular", "prefill", or "decode"). | ||
| bootstrap_port: Bootstrap port for prefill workers (optional). | ||
|
|
||
| Returns: | ||
| True if the worker was accepted, False otherwise. | ||
| """ | ||
| raise NotImplementedError | ||
|
|
||
| async def register_worker(self, url: str) -> bool: | ||
| """Register worker with one attempt (no polling). Returns True if ready and added.""" | ||
| raise NotImplementedError | ||
|
Comment on lines
+92
to
+113
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nit) These methods don't need to be in the |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,7 +2,6 @@ | |
| import subprocess | ||
| import sys | ||
| import time | ||
| import urllib.parse | ||
| from typing import List, Optional | ||
|
|
||
| import httpx | ||
|
|
@@ -68,6 +67,8 @@ def start(self) -> None: | |
| "--policy", | ||
| self.config.policy, | ||
| ] | ||
| if self.config.pd_disaggregation: | ||
| cmd.append("--pd-disaggregation") | ||
|
|
||
| subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) | ||
|
|
||
|
|
@@ -140,7 +141,7 @@ def remove_replicas(self, replica_urls: List[str]) -> None: | |
| for replica_url in replica_urls: | ||
| self._remove_worker_from_router(replica_url) | ||
|
|
||
| def update_replicas(self, replica_urls: List[str]) -> None: | ||
| async def update_replicas(self, replica_urls: List[str]) -> None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The method is now async, but it still calls I'd suggest to revert the method back to synchronous and call it outside of the event loop using Or make all |
||
| """Update replicas for service, replacing the current set.""" | ||
| # Query router to get current worker URLs | ||
| current_workers = self._get_router_workers() | ||
|
|
@@ -174,7 +175,7 @@ def update_replicas(self, replica_urls: List[str]) -> None: | |
|
|
||
| # Add workers | ||
| for worker_url in sorted(workers_to_add): | ||
| success = self._add_worker_to_router(worker_url) | ||
| success = await self.register_worker(worker_url) | ||
| if not success: | ||
| logger.warning("Failed to add worker %s, continuing with others", worker_url) | ||
|
|
||
|
|
@@ -197,20 +198,28 @@ def _get_router_workers(self) -> List[dict]: | |
| logger.exception("Error getting sglang router workers") | ||
| return [] | ||
|
|
||
| def _add_worker_to_router(self, worker_url: str) -> bool: | ||
| async def add_worker_to_router( | ||
| self, | ||
| url: str, | ||
| worker_type: str = "regular", | ||
| bootstrap_port: Optional[int] = None, | ||
| ) -> bool: | ||
| try: | ||
| payload = {"url": worker_url, "worker_type": "regular"} | ||
| with httpx.Client(timeout=5.0) as client: | ||
| response = client.post( | ||
| payload: dict = {"url": url, "worker_type": worker_type} | ||
| if bootstrap_port is not None: | ||
| payload["bootstrap_port"] = bootstrap_port | ||
| async with httpx.AsyncClient(timeout=5.0) as client: | ||
| response = await client.post( | ||
| f"http://{self.context.host}:{self.context.port}/workers", | ||
| json=payload, | ||
| ) | ||
| if response.status_code == 202: | ||
| response_data = response.json() | ||
| if response_data.get("status") == "accepted": | ||
| logger.info( | ||
| "Worker %s accepted by sglang router on port %s", | ||
| worker_url, | ||
| "Worker %s (type=%s) accepted by sglang router on port %s", | ||
| url, | ||
| worker_type, | ||
| self.context.port, | ||
| ) | ||
| return True | ||
|
|
@@ -224,21 +233,65 @@ def _add_worker_to_router(self, worker_url: str) -> bool: | |
| else: | ||
| logger.error( | ||
| "Failed to add worker %s: status %d, %s", | ||
| worker_url, | ||
| url, | ||
| response.status_code, | ||
| response.text, | ||
| ) | ||
| return False | ||
| except Exception: | ||
| logger.exception("Error adding worker %s", worker_url) | ||
| logger.exception("Error adding worker %s", url) | ||
| return False | ||
|
|
||
| async def register_worker(self, url: str) -> bool: | ||
| server_info_url = f"{url}/server_info" | ||
| try: | ||
| async with httpx.AsyncClient(timeout=10) as client: | ||
| resp = await client.get(server_info_url) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nit) Avoid this request and assume |
||
| if resp.status_code != 200: | ||
| return False | ||
| data = resp.json() | ||
| if data.get("status") != "ready": | ||
| return False | ||
| disaggregation_mode = data.get("disaggregation_mode", "") | ||
| if disaggregation_mode == "prefill": | ||
| worker_type = "prefill" | ||
| bootstrap_port = data.get("disaggregation_bootstrap_port") | ||
| elif disaggregation_mode == "decode": | ||
| worker_type = "decode" | ||
| bootstrap_port = None | ||
| else: | ||
| worker_type = "regular" | ||
| bootstrap_port = None | ||
| logger.info( | ||
| "Registering worker %s (type=%s)", | ||
| url, | ||
| worker_type, | ||
| ) | ||
| return await self.add_worker_to_router( | ||
| url, | ||
| worker_type, | ||
| bootstrap_port, | ||
| ) | ||
| except Exception: | ||
| logger.exception("Error registering worker %s", url) | ||
| return False | ||
|
|
||
| def _remove_worker_from_router(self, worker_url: str) -> bool: | ||
| try: | ||
| encoded_url = urllib.parse.quote(worker_url, safe="") | ||
| current_workers = self._get_router_workers() | ||
| worker_id = None | ||
| for worker in current_workers: | ||
| url = worker.get("url") | ||
| if url and isinstance(url, str) and url == worker_url: | ||
| worker_id = worker.get("id") | ||
| if worker_id and isinstance(worker_id, str): | ||
| break | ||
| if not worker_id: | ||
| logger.exception("No worker id found for url %s", worker_url) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nit) |
||
| return False | ||
| with httpx.Client(timeout=5.0) as client: | ||
| response = client.delete( | ||
| f"http://{self.context.host}:{self.context.port}/workers/{encoded_url}" | ||
| f"http://{self.context.host}:{self.context.port}/workers/{worker_id}" | ||
|
Comment on lines
+281
to
+294
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is this change for?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Latest |
||
| ) | ||
| if response.status_code == 202: | ||
| response_data = response.json() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit)
router_config->routerfor brevity and consistency with gateway configurations?