Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion skills/dstack/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ resources:
- Without gateway: `<dstack server URL>/proxy/services/f/<run name>/`
- With gateway: `https://<run name>.<gateway domain>/`
- Authentication: Unless `auth` is `false`, include `Authorization: Bearer <DSTACK_TOKEN>` on all service requests.
- OpenAI-compatible models: Use `service.url` from `dstack run get <run name> --json` and append `/v1` as the base URL; do **not** use deprecated `service.model.base_url` for requests.
- Model endpoint: If `model` is set, `service.model.base_url` from `dstack run get <run name> --json` provides the model endpoint. For OpenAI-compatible models (the default, unless format is set otherwise), this will be `service.url` + `/v1`.
- Example (with gateway):
```bash
curl -sS -X POST "https://<run name>.<gateway domain>/v1/chat/completions" \
Expand Down
3 changes: 1 addition & 2 deletions src/dstack/_internal/core/models/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@
DEFAULT_PROBE_UNTIL_READY = False
MAX_PROBE_URL_LEN = 2048
DEFAULT_REPLICA_GROUP_NAME = "0"
DEFAULT_MODEL_PROBE_TIMEOUT = 30
DEFAULT_MODEL_PROBE_URL = "/v1/chat/completions"
OPENAI_MODEL_PROBE_TIMEOUT = 30


class RunConfigurationType(str, Enum):
Expand Down
29 changes: 29 additions & 0 deletions src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ server {

{% for location in locations %}
location {{ location.prefix }} {
{% if cors_enabled %}
# Handle CORS preflight before auth (rewrite phase runs before access phase)
if ($request_method = 'OPTIONS') {
add_header 'Access-Control-Allow-Origin' '*' always;
add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD' always;
add_header 'Access-Control-Allow-Headers' '*' always;
add_header 'Access-Control-Max-Age' '600' always;
return 204;
}
{% endif %}

{% if auth %}
auth_request /_dstack_auth;
{% endif %}
Expand All @@ -46,6 +57,15 @@ server {
location @websocket {
set $dstack_replica_hit 1;
{% if replicas %}
{% if cors_enabled %}
proxy_hide_header 'Access-Control-Allow-Origin';
proxy_hide_header 'Access-Control-Allow-Methods';
proxy_hide_header 'Access-Control-Allow-Headers';
proxy_hide_header 'Access-Control-Allow-Credentials';
add_header 'Access-Control-Allow-Origin' '*' always;
add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD' always;
add_header 'Access-Control-Allow-Headers' '*' always;
{% endif %}
proxy_pass http://{{ domain }}.upstream;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header Host $host;
Expand All @@ -60,6 +80,15 @@ server {
location @ {
set $dstack_replica_hit 1;
{% if replicas %}
{% if cors_enabled %}
proxy_hide_header 'Access-Control-Allow-Origin';
proxy_hide_header 'Access-Control-Allow-Methods';
proxy_hide_header 'Access-Control-Allow-Headers';
proxy_hide_header 'Access-Control-Allow-Credentials';
add_header 'Access-Control-Allow-Origin' '*' always;
add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD' always;
add_header 'Access-Control-Allow-Headers' '*' always;
{% endif %}
proxy_pass http://{{ domain }}.upstream;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header Host $host;
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/proxy/gateway/services/nginx.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class ServiceConfig(SiteConfig):
replicas: list[ReplicaConfig]
router: Optional[AnyRouterConfig] = None
router_port: Optional[int] = None
cors_enabled: bool = False


class ModelEntrypointConfig(SiteConfig):
Expand Down
34 changes: 32 additions & 2 deletions src/dstack/_internal/proxy/gateway/services/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ async def register_service(
service_conn_pool: ServiceConnectionPool,
router: Optional[AnyRouterConfig] = None,
) -> None:
cors_enabled = model is not None and model.type == "chat" and model.format == "openai"
service = models.Service(
project_name=project_name,
run_name=run_name,
Expand All @@ -57,6 +58,7 @@ async def register_service(
client_max_body_size=client_max_body_size,
replicas=(),
router=router,
cors_enabled=cors_enabled,
)

async with lock:
Expand Down Expand Up @@ -261,7 +263,9 @@ async def apply_service(
ReplicaConfig(id=replica.id, socket=conn.app_socket_path)
for replica, conn in replica_conns.items()
]
service_config = await get_nginx_service_config(service, replica_configs)
service_config = await get_nginx_service_config(
service, replica_configs, cors_enabled=service.cors_enabled
)
await nginx.register(service_config, (await repo.get_config()).acme_settings)
return replica_failures

Expand Down Expand Up @@ -305,7 +309,7 @@ async def stop_replica_connections(


async def get_nginx_service_config(
service: models.Service, replicas: Iterable[ReplicaConfig]
service: models.Service, replicas: Iterable[ReplicaConfig], cors_enabled: bool = False
) -> ServiceConfig:
limit_req_zones: list[LimitReqZoneConfig] = []
locations: list[LocationConfig] = []
Expand Down Expand Up @@ -374,6 +378,7 @@ async def get_nginx_service_config(
locations=locations,
replicas=sorted(replicas, key=lambda r: r.id), # sort for reproducible configs
router=service.router,
cors_enabled=cors_enabled,
)


Expand All @@ -389,9 +394,34 @@ async def apply_entrypoint(
await nginx.register(config, acme)


async def _migrate_cors_enabled(repo: GatewayProxyRepo) -> None:
"""Migrate services registered before the cors_enabled field was added.

Old gateway versions didn't persist cors_enabled on services. This derives it
from the associated model's format so that CORS is enabled for openai-format
models on gateway restart without requiring service re-registration.
"""
services = await repo.list_services()
openai_run_names: set[tuple[str, str]] = set()
for service in services:
for model in await repo.list_models(service.project_name):
if model.run_name == service.run_name and isinstance(
model.format_spec, models.OpenAIChatModelFormat
):
openai_run_names.add((service.project_name, service.run_name))
for service in services:
if (
not service.cors_enabled
and (service.project_name, service.run_name) in openai_run_names
):
updated = models.Service(**{**service.dict(), "cors_enabled": True})
await repo.set_service(updated)


async def apply_all(
repo: GatewayProxyRepo, nginx: Nginx, service_conn_pool: ServiceConnectionPool
) -> None:
await _migrate_cors_enabled(repo)
service_tasks = [
apply_service(
service=service,
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/proxy/lib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class Service(ImmutableModel):
strip_prefix: bool = True # only used in-server
replicas: tuple[Replica, ...]
router: Optional[AnyRouterConfig] = None
cors_enabled: bool = False # only used on gateways; enabled for openai-format models

@property
def domain_safe(self) -> str:
Expand Down
11 changes: 5 additions & 6 deletions src/dstack/_internal/server/services/jobs/configurators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
from dstack._internal.core.errors import DockerRegistryError, ServerClientError
from dstack._internal.core.models.common import RegistryAuth
from dstack._internal.core.models.configurations import (
DEFAULT_MODEL_PROBE_TIMEOUT,
DEFAULT_MODEL_PROBE_URL,
DEFAULT_PROBE_INTERVAL,
DEFAULT_PROBE_METHOD,
DEFAULT_PROBE_READY_AFTER,
Expand All @@ -22,6 +20,7 @@
DEFAULT_PROBE_URL,
DEFAULT_REPLICA_GROUP_NAME,
LEGACY_REPO_DIR,
OPENAI_MODEL_PROBE_TIMEOUT,
HTTPHeaderSpec,
PortMapping,
ProbeConfig,
Expand Down Expand Up @@ -406,7 +405,7 @@ def _probes(self) -> list[ProbeSpec]:
# Generate default probe if model is set
model = self.run_spec.configuration.model
if isinstance(model, OpenAIChatModel):
return [_default_model_probe_spec(model.name)]
return [_openai_model_probe_spec(model.name, model.prefix)]
return []


Expand Down Expand Up @@ -460,7 +459,7 @@ def _probe_config_to_spec(c: ProbeConfig) -> ProbeSpec:
)


def _default_model_probe_spec(model_name: str) -> ProbeSpec:
def _openai_model_probe_spec(model_name: str, prefix: str) -> ProbeSpec:
body = orjson.dumps(
{
"model": model_name,
Expand All @@ -471,12 +470,12 @@ def _default_model_probe_spec(model_name: str) -> ProbeSpec:
return ProbeSpec(
type="http",
method="post",
url=DEFAULT_MODEL_PROBE_URL,
url=prefix.rstrip("/") + "/chat/completions",
headers=[
HTTPHeaderSpec(name="Content-Type", value="application/json"),
],
body=body,
timeout=DEFAULT_MODEL_PROBE_TIMEOUT,
timeout=OPENAI_MODEL_PROBE_TIMEOUT,
interval=DEFAULT_PROBE_INTERVAL,
ready_after=DEFAULT_PROBE_READY_AFTER,
)
Expand Down
19 changes: 15 additions & 4 deletions src/dstack/_internal/server/services/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from dstack._internal.core.models.gateways import GatewayConfiguration, GatewayStatus
from dstack._internal.core.models.instances import SSHConnectionParams
from dstack._internal.core.models.runs import JobSpec, Run, RunSpec, ServiceModelSpec, ServiceSpec
from dstack._internal.core.models.services import OpenAIChatModel
from dstack._internal.server import settings
from dstack._internal.server.models import GatewayModel, JobModel, ProjectModel, RunModel
from dstack._internal.server.services import events
Expand Down Expand Up @@ -106,10 +107,15 @@ async def _register_service_in_gateway(
wildcard_domain = gateway.wildcard_domain.lstrip("*.") if gateway.wildcard_domain else None
if wildcard_domain is None:
raise ServerClientError("Domain is required for gateway")
service_url = f"{service_protocol}://{run_model.run_name}.{wildcard_domain}"
if isinstance(run_spec.configuration.model, OpenAIChatModel):
model_url = service_url + run_spec.configuration.model.prefix
else:
model_url = f"{gateway_protocol}://gateway.{wildcard_domain}"
service_spec = get_service_spec(
configuration=run_spec.configuration,
service_url=f"{service_protocol}://{run_model.run_name}.{wildcard_domain}",
model_url=f"{gateway_protocol}://gateway.{wildcard_domain}",
service_url=service_url,
model_url=model_url,
)

domain = service_spec.get_domain()
Expand Down Expand Up @@ -173,10 +179,15 @@ def _register_service_in_server(run_model: RunModel, run_spec: RunSpec) -> Servi
"Rate limits are not supported when running services without a gateway."
" Please configure a gateway or remove `rate_limits` from the service configuration"
)
service_url = f"/proxy/services/{run_model.project.name}/{run_model.run_name}/"
if isinstance(run_spec.configuration.model, OpenAIChatModel):
model_url = service_url.rstrip("/") + run_spec.configuration.model.prefix
else:
model_url = f"/proxy/models/{run_model.project.name}/"
return get_service_spec(
configuration=run_spec.configuration,
service_url=f"/proxy/services/{run_model.project.name}/{run_model.run_name}/",
model_url=f"/proxy/models/{run_model.project.name}/",
service_url=service_url,
model_url=model_url,
)


Expand Down
36 changes: 30 additions & 6 deletions src/tests/_internal/server/routers/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,14 +588,15 @@ def get_service_run_spec(
repo_id: str,
run_name: Optional[str] = None,
gateway: Optional[Union[bool, str]] = None,
model: Union[str, dict] = "test-model",
) -> dict:
return {
"configuration": {
"type": "service",
"commands": ["python -m http.server"],
"port": 8000,
"gateway": gateway,
"model": "test-model",
"model": model,
"repos": [
{
"url": "https://github.com/dstackai/dstack",
Expand Down Expand Up @@ -2303,48 +2304,69 @@ def mock_gateway_connections(self) -> Generator[None, None, None]:
"expected_service_url",
"expected_model_url",
"is_gateway",
"model",
),
[
pytest.param(
[("default-gateway", True), ("non-default-gateway", False)],
None,
"https://test-service.default-gateway.example",
"https://gateway.default-gateway.example",
"https://test-service.default-gateway.example/v1",
True,
"test-model",
id="submits-to-default-gateway",
),
pytest.param(
[("default-gateway", True), ("non-default-gateway", False)],
True,
"https://test-service.default-gateway.example",
"https://gateway.default-gateway.example",
"https://test-service.default-gateway.example/v1",
True,
"test-model",
id="submits-to-default-gateway-when-gateway-true",
),
pytest.param(
[("default-gateway", True), ("non-default-gateway", False)],
"non-default-gateway",
"https://test-service.non-default-gateway.example",
"https://gateway.non-default-gateway.example",
"https://test-service.non-default-gateway.example/v1",
True,
"test-model",
id="submits-to-specified-gateway",
),
pytest.param(
[("non-default-gateway", False)],
None,
"/proxy/services/test-project/test-service/",
"/proxy/models/test-project/",
"/proxy/services/test-project/test-service/v1",
False,
"test-model",
id="submits-in-server-when-no-default-gateway",
),
pytest.param(
[("default-gateway", True)],
False,
"/proxy/services/test-project/test-service/",
"/proxy/models/test-project/",
"/proxy/services/test-project/test-service/v1",
False,
"test-model",
id="submits-in-server-when-specified",
),
pytest.param(
[("default-gateway", True)],
None,
"https://test-service.default-gateway.example",
"https://gateway.default-gateway.example",
True,
{
"type": "chat",
"name": "test-model",
"format": "tgi",
"chat_template": "test",
"eos_token": "</s>",
},
id="submits-tgi-model-to-gateway",
),
],
)
async def test_submit_to_correct_proxy(
Expand All @@ -2357,6 +2379,7 @@ async def test_submit_to_correct_proxy(
expected_service_url: str,
expected_model_url: str,
is_gateway: bool,
model: Union[str, dict],
) -> None:
user = await create_user(session=session, global_role=GlobalRole.USER)
project = await create_project(session=session, owner=user, name="test-project")
Expand Down Expand Up @@ -2386,6 +2409,7 @@ async def test_submit_to_correct_proxy(
repo_id=repo.name,
run_name="test-service",
gateway=specified_gateway_in_run_conf,
model=model,
)
response = await client.post(
f"/api/project/{project.name}/runs/submit",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import pytest

from dstack._internal.core.models.configurations import (
DEFAULT_MODEL_PROBE_TIMEOUT,
DEFAULT_MODEL_PROBE_URL,
OPENAI_MODEL_PROBE_TIMEOUT,
ProbeConfig,
ServiceConfiguration,
)
Expand Down Expand Up @@ -35,8 +34,8 @@ async def test_default_probe_when_model_set(self):
probe = probes[0]
assert probe.type == "http"
assert probe.method == "post"
assert probe.url == DEFAULT_MODEL_PROBE_URL
assert probe.timeout == DEFAULT_MODEL_PROBE_TIMEOUT
assert probe.url == "/v1/chat/completions"
assert probe.timeout == OPENAI_MODEL_PROBE_TIMEOUT
assert len(probe.headers) == 1
assert probe.headers[0].name == "Content-Type"
assert probe.headers[0].value == "application/json"
Expand Down