diff --git a/skills/dstack/SKILL.md b/skills/dstack/SKILL.md index 8bad2b5c79..1d362c8520 100644 --- a/skills/dstack/SKILL.md +++ b/skills/dstack/SKILL.md @@ -222,7 +222,7 @@ resources: - Without gateway: `/proxy/services/f//` - With gateway: `https://./` - Authentication: Unless `auth` is `false`, include `Authorization: Bearer ` on all service requests. -- OpenAI-compatible models: Use `service.url` from `dstack run get --json` and append `/v1` as the base URL; do **not** use deprecated `service.model.base_url` for requests. +- Model endpoint: If `model` is set, `service.model.base_url` from `dstack run get --json` provides the model endpoint. For OpenAI-compatible models (the default, unless format is set otherwise), this will be `service.url` + `/v1`. - Example (with gateway): ```bash curl -sS -X POST "https://./v1/chat/completions" \ diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 9c8b40b6ec..a6e5bffb40 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -57,8 +57,7 @@ DEFAULT_PROBE_UNTIL_READY = False MAX_PROBE_URL_LEN = 2048 DEFAULT_REPLICA_GROUP_NAME = "0" -DEFAULT_MODEL_PROBE_TIMEOUT = 30 -DEFAULT_MODEL_PROBE_URL = "/v1/chat/completions" +OPENAI_MODEL_PROBE_TIMEOUT = 30 class RunConfigurationType(str, Enum): diff --git a/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 b/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 index 31f987706a..521e6a23fb 100644 --- a/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 +++ b/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 @@ -24,6 +24,17 @@ server { {% for location in locations %} location {{ location.prefix }} { + {% if cors_enabled %} + # Handle CORS preflight before auth (rewrite phase runs before access phase) + if ($request_method = 'OPTIONS') { + add_header 'Access-Control-Allow-Origin' '*' always; + add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD' always; + add_header 'Access-Control-Allow-Headers' '*' always; + add_header 'Access-Control-Max-Age' '600' always; + return 204; + } + {% endif %} + {% if auth %} auth_request /_dstack_auth; {% endif %} @@ -46,6 +57,15 @@ server { location @websocket { set $dstack_replica_hit 1; {% if replicas %} + {% if cors_enabled %} + proxy_hide_header 'Access-Control-Allow-Origin'; + proxy_hide_header 'Access-Control-Allow-Methods'; + proxy_hide_header 'Access-Control-Allow-Headers'; + proxy_hide_header 'Access-Control-Allow-Credentials'; + add_header 'Access-Control-Allow-Origin' '*' always; + add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD' always; + add_header 'Access-Control-Allow-Headers' '*' always; + {% endif %} proxy_pass http://{{ domain }}.upstream; proxy_set_header X-Real-IP $remote_addr; proxy_set_header Host $host; @@ -60,6 +80,15 @@ server { location @ { set $dstack_replica_hit 1; {% if replicas %} + {% if cors_enabled %} + proxy_hide_header 'Access-Control-Allow-Origin'; + proxy_hide_header 'Access-Control-Allow-Methods'; + proxy_hide_header 'Access-Control-Allow-Headers'; + proxy_hide_header 'Access-Control-Allow-Credentials'; + add_header 'Access-Control-Allow-Origin' '*' always; + add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD' always; + add_header 'Access-Control-Allow-Headers' '*' always; + {% endif %} proxy_pass http://{{ domain }}.upstream; proxy_set_header X-Real-IP $remote_addr; proxy_set_header Host $host; diff --git a/src/dstack/_internal/proxy/gateway/services/nginx.py b/src/dstack/_internal/proxy/gateway/services/nginx.py index bbda92d91b..c971d4197a 100644 --- a/src/dstack/_internal/proxy/gateway/services/nginx.py +++ b/src/dstack/_internal/proxy/gateway/services/nginx.py @@ -72,6 +72,7 @@ class ServiceConfig(SiteConfig): replicas: list[ReplicaConfig] router: Optional[AnyRouterConfig] = None router_port: Optional[int] = None + cors_enabled: bool = False class ModelEntrypointConfig(SiteConfig): diff --git a/src/dstack/_internal/proxy/gateway/services/registry.py b/src/dstack/_internal/proxy/gateway/services/registry.py index 636d8c38ec..fbc45e7de1 100644 --- a/src/dstack/_internal/proxy/gateway/services/registry.py +++ b/src/dstack/_internal/proxy/gateway/services/registry.py @@ -47,6 +47,7 @@ async def register_service( service_conn_pool: ServiceConnectionPool, router: Optional[AnyRouterConfig] = None, ) -> None: + cors_enabled = model is not None and model.type == "chat" and model.format == "openai" service = models.Service( project_name=project_name, run_name=run_name, @@ -57,6 +58,7 @@ async def register_service( client_max_body_size=client_max_body_size, replicas=(), router=router, + cors_enabled=cors_enabled, ) async with lock: @@ -261,7 +263,9 @@ async def apply_service( ReplicaConfig(id=replica.id, socket=conn.app_socket_path) for replica, conn in replica_conns.items() ] - service_config = await get_nginx_service_config(service, replica_configs) + service_config = await get_nginx_service_config( + service, replica_configs, cors_enabled=service.cors_enabled + ) await nginx.register(service_config, (await repo.get_config()).acme_settings) return replica_failures @@ -305,7 +309,7 @@ async def stop_replica_connections( async def get_nginx_service_config( - service: models.Service, replicas: Iterable[ReplicaConfig] + service: models.Service, replicas: Iterable[ReplicaConfig], cors_enabled: bool = False ) -> ServiceConfig: limit_req_zones: list[LimitReqZoneConfig] = [] locations: list[LocationConfig] = [] @@ -374,6 +378,7 @@ async def get_nginx_service_config( locations=locations, replicas=sorted(replicas, key=lambda r: r.id), # sort for reproducible configs router=service.router, + cors_enabled=cors_enabled, ) @@ -389,9 +394,34 @@ async def apply_entrypoint( await nginx.register(config, acme) +async def _migrate_cors_enabled(repo: GatewayProxyRepo) -> None: + """Migrate services registered before the cors_enabled field was added. + + Old gateway versions didn't persist cors_enabled on services. This derives it + from the associated model's format so that CORS is enabled for openai-format + models on gateway restart without requiring service re-registration. + """ + services = await repo.list_services() + openai_run_names: set[tuple[str, str]] = set() + for service in services: + for model in await repo.list_models(service.project_name): + if model.run_name == service.run_name and isinstance( + model.format_spec, models.OpenAIChatModelFormat + ): + openai_run_names.add((service.project_name, service.run_name)) + for service in services: + if ( + not service.cors_enabled + and (service.project_name, service.run_name) in openai_run_names + ): + updated = models.Service(**{**service.dict(), "cors_enabled": True}) + await repo.set_service(updated) + + async def apply_all( repo: GatewayProxyRepo, nginx: Nginx, service_conn_pool: ServiceConnectionPool ) -> None: + await _migrate_cors_enabled(repo) service_tasks = [ apply_service( service=service, diff --git a/src/dstack/_internal/proxy/lib/models.py b/src/dstack/_internal/proxy/lib/models.py index bf37e0b5aa..f304bbc394 100644 --- a/src/dstack/_internal/proxy/lib/models.py +++ b/src/dstack/_internal/proxy/lib/models.py @@ -59,6 +59,7 @@ class Service(ImmutableModel): strip_prefix: bool = True # only used in-server replicas: tuple[Replica, ...] router: Optional[AnyRouterConfig] = None + cors_enabled: bool = False # only used on gateways; enabled for openai-format models @property def domain_safe(self) -> str: diff --git a/src/dstack/_internal/server/services/jobs/configurators/base.py b/src/dstack/_internal/server/services/jobs/configurators/base.py index 3b6038ccd9..a9496ad348 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/base.py +++ b/src/dstack/_internal/server/services/jobs/configurators/base.py @@ -12,8 +12,6 @@ from dstack._internal.core.errors import DockerRegistryError, ServerClientError from dstack._internal.core.models.common import RegistryAuth from dstack._internal.core.models.configurations import ( - DEFAULT_MODEL_PROBE_TIMEOUT, - DEFAULT_MODEL_PROBE_URL, DEFAULT_PROBE_INTERVAL, DEFAULT_PROBE_METHOD, DEFAULT_PROBE_READY_AFTER, @@ -22,6 +20,7 @@ DEFAULT_PROBE_URL, DEFAULT_REPLICA_GROUP_NAME, LEGACY_REPO_DIR, + OPENAI_MODEL_PROBE_TIMEOUT, HTTPHeaderSpec, PortMapping, ProbeConfig, @@ -406,7 +405,7 @@ def _probes(self) -> list[ProbeSpec]: # Generate default probe if model is set model = self.run_spec.configuration.model if isinstance(model, OpenAIChatModel): - return [_default_model_probe_spec(model.name)] + return [_openai_model_probe_spec(model.name, model.prefix)] return [] @@ -460,7 +459,7 @@ def _probe_config_to_spec(c: ProbeConfig) -> ProbeSpec: ) -def _default_model_probe_spec(model_name: str) -> ProbeSpec: +def _openai_model_probe_spec(model_name: str, prefix: str) -> ProbeSpec: body = orjson.dumps( { "model": model_name, @@ -471,12 +470,12 @@ def _default_model_probe_spec(model_name: str) -> ProbeSpec: return ProbeSpec( type="http", method="post", - url=DEFAULT_MODEL_PROBE_URL, + url=prefix.rstrip("/") + "/chat/completions", headers=[ HTTPHeaderSpec(name="Content-Type", value="application/json"), ], body=body, - timeout=DEFAULT_MODEL_PROBE_TIMEOUT, + timeout=OPENAI_MODEL_PROBE_TIMEOUT, interval=DEFAULT_PROBE_INTERVAL, ready_after=DEFAULT_PROBE_READY_AFTER, ) diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index 06aa5b0ef0..511cf7cc93 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -27,6 +27,7 @@ from dstack._internal.core.models.gateways import GatewayConfiguration, GatewayStatus from dstack._internal.core.models.instances import SSHConnectionParams from dstack._internal.core.models.runs import JobSpec, Run, RunSpec, ServiceModelSpec, ServiceSpec +from dstack._internal.core.models.services import OpenAIChatModel from dstack._internal.server import settings from dstack._internal.server.models import GatewayModel, JobModel, ProjectModel, RunModel from dstack._internal.server.services import events @@ -106,10 +107,15 @@ async def _register_service_in_gateway( wildcard_domain = gateway.wildcard_domain.lstrip("*.") if gateway.wildcard_domain else None if wildcard_domain is None: raise ServerClientError("Domain is required for gateway") + service_url = f"{service_protocol}://{run_model.run_name}.{wildcard_domain}" + if isinstance(run_spec.configuration.model, OpenAIChatModel): + model_url = service_url + run_spec.configuration.model.prefix + else: + model_url = f"{gateway_protocol}://gateway.{wildcard_domain}" service_spec = get_service_spec( configuration=run_spec.configuration, - service_url=f"{service_protocol}://{run_model.run_name}.{wildcard_domain}", - model_url=f"{gateway_protocol}://gateway.{wildcard_domain}", + service_url=service_url, + model_url=model_url, ) domain = service_spec.get_domain() @@ -173,10 +179,15 @@ def _register_service_in_server(run_model: RunModel, run_spec: RunSpec) -> Servi "Rate limits are not supported when running services without a gateway." " Please configure a gateway or remove `rate_limits` from the service configuration" ) + service_url = f"/proxy/services/{run_model.project.name}/{run_model.run_name}/" + if isinstance(run_spec.configuration.model, OpenAIChatModel): + model_url = service_url.rstrip("/") + run_spec.configuration.model.prefix + else: + model_url = f"/proxy/models/{run_model.project.name}/" return get_service_spec( configuration=run_spec.configuration, - service_url=f"/proxy/services/{run_model.project.name}/{run_model.run_name}/", - model_url=f"/proxy/models/{run_model.project.name}/", + service_url=service_url, + model_url=model_url, ) diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index ad8ad878d1..be78414a9e 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -588,6 +588,7 @@ def get_service_run_spec( repo_id: str, run_name: Optional[str] = None, gateway: Optional[Union[bool, str]] = None, + model: Union[str, dict] = "test-model", ) -> dict: return { "configuration": { @@ -595,7 +596,7 @@ def get_service_run_spec( "commands": ["python -m http.server"], "port": 8000, "gateway": gateway, - "model": "test-model", + "model": model, "repos": [ { "url": "https://github.com/dstackai/dstack", @@ -2303,48 +2304,69 @@ def mock_gateway_connections(self) -> Generator[None, None, None]: "expected_service_url", "expected_model_url", "is_gateway", + "model", ), [ pytest.param( [("default-gateway", True), ("non-default-gateway", False)], None, "https://test-service.default-gateway.example", - "https://gateway.default-gateway.example", + "https://test-service.default-gateway.example/v1", True, + "test-model", id="submits-to-default-gateway", ), pytest.param( [("default-gateway", True), ("non-default-gateway", False)], True, "https://test-service.default-gateway.example", - "https://gateway.default-gateway.example", + "https://test-service.default-gateway.example/v1", True, + "test-model", id="submits-to-default-gateway-when-gateway-true", ), pytest.param( [("default-gateway", True), ("non-default-gateway", False)], "non-default-gateway", "https://test-service.non-default-gateway.example", - "https://gateway.non-default-gateway.example", + "https://test-service.non-default-gateway.example/v1", True, + "test-model", id="submits-to-specified-gateway", ), pytest.param( [("non-default-gateway", False)], None, "/proxy/services/test-project/test-service/", - "/proxy/models/test-project/", + "/proxy/services/test-project/test-service/v1", False, + "test-model", id="submits-in-server-when-no-default-gateway", ), pytest.param( [("default-gateway", True)], False, "/proxy/services/test-project/test-service/", - "/proxy/models/test-project/", + "/proxy/services/test-project/test-service/v1", False, + "test-model", id="submits-in-server-when-specified", ), + pytest.param( + [("default-gateway", True)], + None, + "https://test-service.default-gateway.example", + "https://gateway.default-gateway.example", + True, + { + "type": "chat", + "name": "test-model", + "format": "tgi", + "chat_template": "test", + "eos_token": "", + }, + id="submits-tgi-model-to-gateway", + ), ], ) async def test_submit_to_correct_proxy( @@ -2357,6 +2379,7 @@ async def test_submit_to_correct_proxy( expected_service_url: str, expected_model_url: str, is_gateway: bool, + model: Union[str, dict], ) -> None: user = await create_user(session=session, global_role=GlobalRole.USER) project = await create_project(session=session, owner=user, name="test-project") @@ -2386,6 +2409,7 @@ async def test_submit_to_correct_proxy( repo_id=repo.name, run_name="test-service", gateway=specified_gateway_in_run_conf, + model=model, ) response = await client.post( f"/api/project/{project.name}/runs/submit", diff --git a/src/tests/_internal/server/services/jobs/configurators/test_service.py b/src/tests/_internal/server/services/jobs/configurators/test_service.py index b52ee297a5..cafab73d9c 100644 --- a/src/tests/_internal/server/services/jobs/configurators/test_service.py +++ b/src/tests/_internal/server/services/jobs/configurators/test_service.py @@ -1,8 +1,7 @@ import pytest from dstack._internal.core.models.configurations import ( - DEFAULT_MODEL_PROBE_TIMEOUT, - DEFAULT_MODEL_PROBE_URL, + OPENAI_MODEL_PROBE_TIMEOUT, ProbeConfig, ServiceConfiguration, ) @@ -35,8 +34,8 @@ async def test_default_probe_when_model_set(self): probe = probes[0] assert probe.type == "http" assert probe.method == "post" - assert probe.url == DEFAULT_MODEL_PROBE_URL - assert probe.timeout == DEFAULT_MODEL_PROBE_TIMEOUT + assert probe.url == "/v1/chat/completions" + assert probe.timeout == OPENAI_MODEL_PROBE_TIMEOUT assert len(probe.headers) == 1 assert probe.headers[0].name == "Content-Type" assert probe.headers[0].value == "application/json"