Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ include = [
"src/dstack/plugins",
"src/dstack/_internal/server",
"src/dstack/_internal/core/services",
"src/dstack/_internal/core/backends/aws",
"src/dstack/_internal/core/backends/kubernetes",
"src/dstack/_internal/core/backends/runpod",
"src/dstack/_internal/cli/services/configurators",
Expand Down
139 changes: 93 additions & 46 deletions src/dstack/_internal/core/backends/aws/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
NoCapacityError,
PlacementGroupInUseError,
PlacementGroupNotSupportedError,
ProvisioningError,
)
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import CoreModel
Expand Down Expand Up @@ -291,35 +292,35 @@ def create_instance(
}
if reservation.get("ReservationType") == "capacity-block":
is_capacity_block = True

except botocore.exceptions.ClientError as e:
logger.warning("Got botocore.exceptions.ClientError: %s", e)
raise NoCapacityError()

tried_zones = set()
for subnet_id, az in subnet_id_to_az_map.items():
if az in tried_zones:
continue
tried_zones.add(az)
logger.debug("Trying provisioning %s in %s", instance_offer.instance.name, az)
image_id, username = self._get_image_id_and_username(
ec2_client=ec2_client,
region=instance_offer.region,
gpu_name=(
instance_offer.instance.resources.gpus[0].name
if len(instance_offer.instance.resources.gpus) > 0
else None
),
instance_type=instance_offer.instance.name,
image_config=self.config.os_images,
)
security_group_id = self._create_security_group(
ec2_client=ec2_client,
region=instance_offer.region,
project_id=project_name,
vpc_id=vpc_id,
)
try:
logger.debug("Trying provisioning %s in %s", instance_offer.instance.name, az)
image_id, username = self._get_image_id_and_username(
ec2_client=ec2_client,
region=instance_offer.region,
gpu_name=(
instance_offer.instance.resources.gpus[0].name
if len(instance_offer.instance.resources.gpus) > 0
else None
),
instance_type=instance_offer.instance.name,
image_config=self.config.os_images,
)
security_group_id = self._create_security_group(
ec2_client=ec2_client,
region=instance_offer.region,
project_id=project_name,
vpc_id=vpc_id,
)
response = ec2_resource.create_instances(
response = ec2_resource.create_instances( # pyright: ignore[reportAttributeAccessIssue]
**aws_resources.create_instances_struct(
disk_size=disk_size,
image_id=image_id,
Expand All @@ -343,39 +344,85 @@ def create_instance(
is_capacity_block=is_capacity_block,
)
)
instance = response[0]
instance.wait_until_running()
instance.reload() # populate instance.public_ip_address
if instance_offer.instance.resources.spot: # it will not terminate the instance
ec2_client.cancel_spot_instance_requests(
SpotInstanceRequestIds=[instance.spot_instance_request_id]
)
hostname = _get_instance_ip(instance, allocate_public_ip)
return JobProvisioningData(
backend=instance_offer.backend,
instance_type=instance_offer.instance,
instance_id=instance.instance_id,
public_ip_enabled=allocate_public_ip,
hostname=hostname,
internal_ip=instance.private_ip_address,
region=instance_offer.region,
availability_zone=az,
reservation=instance.capacity_reservation_id,
price=instance_offer.price,
username=username,
ssh_port=22,
dockerized=True, # because `dstack-shim` is used
ssh_proxy=None,
backend_data=None,
)
except botocore.exceptions.ClientError as e:
logger.warning("Got botocore.exceptions.ClientError: %s", e)
if e.response["Error"]["Code"] == "InvalidParameterValue":
msg = e.response["Error"].get("Message", "")
raise ComputeError(f"Invalid AWS request: {msg}")
continue
instance = response[0]
if instance_offer.instance.resources.spot:
# it will not terminate the instance
try:
ec2_client.cancel_spot_instance_requests(
SpotInstanceRequestIds=[instance.spot_instance_request_id]
)
except Exception:
logger.exception(
"Failed to cancel spot instance request. The instance will be terminated."
)
self.terminate_instance(
instance_id=instance.instance_id, region=instance_offer.region
)
raise NoCapacityError()
return JobProvisioningData(
backend=instance_offer.backend,
instance_type=instance_offer.instance,
instance_id=instance.instance_id,
public_ip_enabled=allocate_public_ip,
hostname=None,
internal_ip=None,
region=instance_offer.region,
availability_zone=az,
reservation=instance.capacity_reservation_id,
price=instance_offer.price,
username=username,
ssh_port=None,
dockerized=True, # because `dstack-shim` is used
ssh_proxy=None,
backend_data=None,
)
raise NoCapacityError()

def update_provisioning_data(
self,
provisioning_data: JobProvisioningData,
project_ssh_public_key: str,
project_ssh_private_key: str,
):
ec2_resource = self.session.resource("ec2", region_name=provisioning_data.region)
instance = ec2_resource.Instance(provisioning_data.instance_id) # pyright: ignore[reportAttributeAccessIssue]
try:
instance.load()
except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == "InvalidInstanceID.NotFound":
logger.debug(
"Instance %s not found. Waiting for the instance to appear"
" or to timeout if the instance is manually deleted.",
provisioning_data.instance_id,
)
# Instance may be created but not yet visible to due AWS eventual consistency,
# so we wait instead of failing immediately.
return
raise e

state = instance.state.get("Name")
if state == "pending":
return
if state in [None, "shutting-down", "terminated", "stopping", "stopped"]:
raise ProvisioningError(
f"Failed to get instance IP address. Instance state is {state}."
)
if state != "running":
raise ProvisioningError(
f"Failed to get instance IP address. Unknown instance state {state}."
)

hostname = _get_instance_ip(instance, self.config.allocate_public_ips)
provisioning_data.hostname = hostname
provisioning_data.internal_ip = instance.private_ip_address
provisioning_data.ssh_port = 22

def create_placement_group(
self,
placement_group: PlacementGroup,
Expand Down Expand Up @@ -478,7 +525,7 @@ def create_gateway(
allocate_public_ip=configuration.public_ip,
)
try:
response = ec2_resource.create_instances(**instance_struct)
response = ec2_resource.create_instances(**instance_struct) # pyright: ignore[reportAttributeAccessIssue]
except botocore.exceptions.ClientError as e:
msg = f"AWS Error: {e.response['Error']['Code']}"
if e.response["Error"].get("Message"):
Expand Down
4 changes: 2 additions & 2 deletions src/dstack/_internal/core/backends/base/configurator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, ClassVar, Generic, List, Optional, TypeVar
from typing import Any, ClassVar, Generic, List, NoReturn, Optional, TypeVar
from uuid import UUID

from dstack._internal.core.backends.base.backend import Backend
Expand Down Expand Up @@ -110,7 +110,7 @@ def get_backend(self, record: StoredBackendRecord) -> Backend:

def raise_invalid_credentials_error(
fields: Optional[List[List[str]]] = None, details: Optional[Any] = None
):
) -> NoReturn:
msg = BackendInvalidCredentialsError.msg
if details:
msg += f": {details}"
Expand Down
2 changes: 1 addition & 1 deletion src/dstack/_internal/server/services/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def get_filtered_offers_with_backends(
if not exclude_not_available or offer.availability.is_available():
yield (backend, offer)

logger.info("Requesting instance offers from backends: %s", [b.TYPE.value for b in backends])
logger.debug("Requesting instance offers from backends: %s", [b.TYPE.value for b in backends])
tasks = [run_async(get_offers_tracked, backend, requirements) for backend in backends]
offers_by_backend = []
for backend, result in zip(backends, await asyncio.gather(*tasks, return_exceptions=True)):
Expand Down