diff --git a/dask_cloudprovider/aws/ecs.py b/dask_cloudprovider/aws/ecs.py index 7e1ce1ce..492029d0 100644 --- a/dask_cloudprovider/aws/ecs.py +++ b/dask_cloudprovider/aws/ecs.py @@ -1,6 +1,5 @@ import asyncio import logging -import uuid import warnings import weakref from typing import List, Optional @@ -224,9 +223,9 @@ async def start(self): "awsvpcConfiguration": { "subnets": self._vpc_subnets, "securityGroups": self._security_groups, - "assignPublicIp": "ENABLED" - if self._use_public_ip - else "DISABLED", + "assignPublicIp": ( + "ENABLED" if self._use_public_ip else "DISABLED" + ), } }, } @@ -461,7 +460,9 @@ class ECSCluster(SpecCluster, ConfigMixin): This creates a dask scheduler and workers on an existing ECS cluster. All the other required resources such as roles, task definitions, tasks, etc - will be created automatically like in :class:`FargateCluster`. + will be created automatically like in :class:`FargateCluster`. Resource names will + include the value of `self.name` to uniquely associate them with this cluster, and + they will also be tagged with `dask_cluster_name` using the same value. Parameters ---------- @@ -579,9 +580,11 @@ class ECSCluster(SpecCluster, ConfigMixin): Defaults to ``None`` which results in a new cluster being created for you. cluster_name_template: str (optional) A template to use for the cluster name if ``cluster_arn`` is set to - ``None``. + ``None``. Valid substitution variables are: - Defaults to ``'dask-{uuid}'`` + ``name`` <= self.name (usually a UUID) + + Defaults to ``'dask-{name}'`` execution_role_arn: str (optional) The ARN of an existing IAM role to use for ECS execution. @@ -626,9 +629,12 @@ class ECSCluster(SpecCluster, ConfigMixin): Default ``None`` (one will be created called ``dask-ecs``) cloudwatch_logs_stream_prefix: str (optional) - Prefix for log streams. + Prefix for log streams. Valid substitution variables are: + + ``name`` <= self.name (usually a UUID) + ``cluster_name`` <= self.cluster_name (ECS cluster name) - Defaults to the cluster name. + Defaults to ``{cluster_name}/{name}``. cloudwatch_logs_default_retention: int (optional) Retention for logs in days. For use when log group is auto created. @@ -921,7 +927,10 @@ async def _start( if self._cloudwatch_logs_stream_prefix is None: self._cloudwatch_logs_stream_prefix = self.config.get( "cloudwatch_logs_stream_prefix" - ).format(cluster_name=self.cluster_name) + ).format( + cluster_name=self.cluster_name, + name=self.name, + ) if self.cloudwatch_logs_group is None: self.cloudwatch_logs_group = ( @@ -1025,7 +1034,12 @@ def _new_worker_name(self, worker_number): @property def tags(self): - return {**self._tags, **DEFAULT_TAGS, "cluster": self.cluster_name} + return { + **self._tags, + **DEFAULT_TAGS, + "cluster": self.cluster_name, + "dask_cluster_name": self.name, + } async def _create_cluster(self): if not self._fargate_scheduler or not self._fargate_workers: @@ -1038,7 +1052,10 @@ async def _create_cluster(self): self.cluster_name = dask.config.expand_environment_variables( self._cluster_name_template ) - self.cluster_name = self.cluster_name.format(uuid=str(uuid.uuid4())[:10]) + self.cluster_name = self.cluster_name.format( + name=self.name, + uuid=self.name, # backwards-compatible + ) async with self._client("ecs") as ecs: response = await ecs.create_cluster( clusterName=self.cluster_name, @@ -1059,7 +1076,7 @@ async def _delete_cluster(self): @property def _execution_role_name(self): - return "{}-{}".format(self.cluster_name, "execution-role") + return "dask-{}-execution-role".format(self.name) async def _create_execution_role(self): async with self._client("iam") as iam: @@ -1099,7 +1116,7 @@ async def _create_execution_role(self): @property def _task_role_name(self): - return "{}-{}".format(self.cluster_name, "task-role") + return "dask-{}-task-role".format(self.name) async def _create_task_role(self): async with self._client("iam") as iam: @@ -1141,6 +1158,8 @@ async def _delete_role(self, role): await iam.delete_role(RoleName=role) async def _create_cloudwatch_logs_group(self): + # The log group does not include `name` because it is shared by all Dask ECS clusters. But, + # log streams do because they are specific to each Dask cluster. log_group_name = "dask-ecs" async with self._client("logs") as logs: groups = await logs.describe_log_groups() @@ -1160,23 +1179,29 @@ async def _create_cloudwatch_logs_group(self): # Note: Not cleaning up the logs here as they may be useful after the cluster is destroyed return log_group_name + @property + def _security_group_name(self): + return "dask-{}-security-group".format(self.name) + async def _create_security_groups(self): async with self._client("ec2") as client: group = await create_default_security_group( - client, self.cluster_name, self._vpc, self.tags + client, self._security_group_name, self._vpc, self.tags ) weakref.finalize(self, self.sync, self._delete_security_groups) return [group] async def _delete_security_groups(self): timeout = Timeout( - 30, "Unable to delete AWS security group " + self.cluster_name, warn=True + 30, + "Unable to delete AWS security group {}".format(self._security_group_name), + warn=True, ) async with self._client("ec2") as ec2: while timeout.run(): try: await ec2.delete_security_group( - GroupName=self.cluster_name, DryRun=False + GroupName=self._security_group_name, DryRun=False ) except Exception: await asyncio.sleep(2) @@ -1185,7 +1210,7 @@ async def _delete_security_groups(self): async def _create_scheduler_task_definition_arn(self): async with self._client("ecs") as ecs: response = await ecs.register_task_definition( - family="{}-{}".format(self.cluster_name, "scheduler"), + family="dask-{}-scheduler".format(self.name), taskRoleArn=self._task_role_arn, executionRoleArn=self._execution_role_arn, networkMode="awsvpc", @@ -1223,14 +1248,18 @@ async def _create_scheduler_task_definition_arn(self): "awslogs-create-group": "true", }, }, - "mountPoints": self._mount_points - if self._mount_points and self._mount_volumes_on_scheduler - else [], + "mountPoints": ( + self._mount_points + if self._mount_points and self._mount_volumes_on_scheduler + else [] + ), } ], - volumes=self._volumes - if self._volumes and self._mount_volumes_on_scheduler - else [], + volumes=( + self._volumes + if self._volumes and self._mount_volumes_on_scheduler + else [] + ), requiresCompatibilities=["FARGATE"] if self._fargate_scheduler else [], runtimePlatform={"cpuArchitecture": self._cpu_architecture}, cpu=str(self._scheduler_cpu), @@ -1255,7 +1284,7 @@ async def _create_worker_task_definition_arn(self): ) async with self._client("ecs") as ecs: response = await ecs.register_task_definition( - family="{}-{}".format(self.cluster_name, "worker"), + family="dask-{}-worker".format(self.name), taskRoleArn=self._task_role_arn, executionRoleArn=self._execution_role_arn, networkMode="awsvpc", diff --git a/dask_cloudprovider/aws/tests/test_ecs.py b/dask_cloudprovider/aws/tests/test_ecs.py index 7f633c05..cbe58af0 100644 --- a/dask_cloudprovider/aws/tests/test_ecs.py +++ b/dask_cloudprovider/aws/tests/test_ecs.py @@ -1,3 +1,6 @@ +from unittest import mock +from unittest.mock import AsyncMock + import pytest aiobotocore = pytest.importorskip("aiobotocore") @@ -6,3 +9,127 @@ def test_import(): from dask_cloudprovider.aws import ECSCluster # noqa from dask_cloudprovider.aws import FargateCluster # noqa + + +def test_reuse_ecs_cluster(): + from dask_cloudprovider.aws import ECSCluster # noqa + + fc1_name = "Spooky" + fc2_name = "Weevil" + vpc_name = "MyNetwork" + vpc_subnets = ["MySubnet1", "MySubnet2"] + cluster_arn = "CompletelyMadeUp" + cluster_name = "Crunchy" + log_group_name = "dask-ecs" + + expected_execution_role_name1 = f"dask-{fc1_name}-execution-role" + expected_task_role_name1 = f"dask-{fc1_name}-task-role" + expected_log_stream_prefix1 = f"{cluster_name}/{fc1_name}" + expected_security_group_name1 = f"dask-{fc1_name}-security-group" + expected_scheduler_task_name1 = f"dask-{fc1_name}-scheduler" + expected_worker_task_name1 = f"dask-{fc1_name}-worker" + + expected_execution_role_name2 = f"dask-{fc2_name}-execution-role" + expected_task_role_name2 = f"dask-{fc2_name}-task-role" + expected_log_stream_prefix2 = f"{cluster_name}/{fc2_name}" + expected_security_group_name2 = f"dask-{fc2_name}-security-group" + expected_scheduler_task_name2 = f"dask-{fc2_name}-scheduler" + expected_worker_task_name2 = f"dask-{fc2_name}-worker" + + mock_client = AsyncMock() + mock_client.describe_clusters.return_value = { + "clusters": [{"clusterName": cluster_name}] + } + mock_client.list_account_settings.return_value = {"settings": {"value": "enabled"}} + mock_client.create_role.return_value = {"Role": {"Arn": "Random"}} + mock_client.describe_log_groups.return_value = {"logGroups": []} + + class MockSession: + class MockClient: + async def __aenter__(self, *args, **kwargs): + return mock_client + + async def __aexit__(self, *args, **kwargs): + return + + def create_client(self, *args, **kwargs): + return MockSession.MockClient() + + with ( + mock.patch( + "dask_cloudprovider.aws.ecs.get_session", return_value=MockSession() + ), + mock.patch("distributed.deploy.spec.SpecCluster._start"), + mock.patch("weakref.finalize"), + ): + # Make ourselves a test cluster + fc1 = ECSCluster( + name=fc1_name, + cluster_arn=cluster_arn, + vpc=vpc_name, + subnets=vpc_subnets, + skip_cleanup=True, + ) + # Are we re-using the existing ECS cluster? + assert fc1.cluster_name == cluster_name + # Have we made completely unique AWS resources to run on that cluster? + assert fc1._execution_role_name == expected_execution_role_name1 + assert fc1._task_role_name == expected_task_role_name1 + assert fc1._cloudwatch_logs_stream_prefix == expected_log_stream_prefix1 + assert ( + fc1.scheduler_spec["options"]["log_stream_prefix"] + == expected_log_stream_prefix1 + ) + assert ( + fc1.new_spec["options"]["log_stream_prefix"] == expected_log_stream_prefix1 + ) + assert fc1.cloudwatch_logs_group == log_group_name + assert fc1.scheduler_spec["options"]["log_group"] == log_group_name + assert fc1.new_spec["options"]["log_group"] == log_group_name + sg_calls = mock_client.create_security_group.call_args_list + assert len(sg_calls) == 1 + assert sg_calls[0].kwargs["GroupName"] == expected_security_group_name1 + td_calls = mock_client.register_task_definition.call_args_list + assert len(td_calls) == 2 + assert td_calls[0].kwargs["family"] == expected_scheduler_task_name1 + assert td_calls[1].kwargs["family"] == expected_worker_task_name1 + + # Reset mocks ready for second cluster + mock_client.create_security_group.reset_mock() + mock_client.register_task_definition.reset_mock() + + # Make ourselves a second test cluster on the same ECS cluster + fc2 = ECSCluster( + name=fc2_name, + cluster_arn=cluster_arn, + vpc=vpc_name, + subnets=vpc_subnets, + skip_cleanup=True, + ) + # Are we re-using the existing ECS cluster? + assert fc2.cluster_name == cluster_name + # Have we made completely unique AWS resources to run on that cluster? + assert fc2._execution_role_name == expected_execution_role_name2 + assert fc2._task_role_name == expected_task_role_name2 + assert fc2._cloudwatch_logs_stream_prefix == expected_log_stream_prefix2 + assert ( + fc2.scheduler_spec["options"]["log_stream_prefix"] + == expected_log_stream_prefix2 + ) + assert ( + fc2.new_spec["options"]["log_stream_prefix"] == expected_log_stream_prefix2 + ) + assert fc2.cloudwatch_logs_group == log_group_name + assert fc2.scheduler_spec["options"]["log_group"] == log_group_name + assert fc2.new_spec["options"]["log_group"] == log_group_name + sg_calls = mock_client.create_security_group.call_args_list + assert len(sg_calls) == 1 + assert sg_calls[0].kwargs["GroupName"] == expected_security_group_name2 + td_calls = mock_client.register_task_definition.call_args_list + assert len(td_calls) == 2 + assert td_calls[0].kwargs["family"] == expected_scheduler_task_name2 + assert td_calls[1].kwargs["family"] == expected_worker_task_name2 + + # Finish up + fc1.close() + fc2.close() diff --git a/dask_cloudprovider/cloudprovider.yaml b/dask_cloudprovider/cloudprovider.yaml index ce5b8df2..2a20106d 100755 --- a/dask_cloudprovider/cloudprovider.yaml +++ b/dask_cloudprovider/cloudprovider.yaml @@ -17,7 +17,7 @@ cloudprovider: image: "daskdev/dask:latest" # Docker image to use for non GPU tasks cpu_architecture: "X86_64" # Runtime platform CPU architecture gpu_image: "rapidsai/rapidsai:latest" # Docker image to use for GPU tasks - cluster_name_template: "dask-{uuid}" # Template to use when creating a cluster + cluster_name_template: "dask-{name}" # Template to use when creating a cluster cluster_arn: "" # ARN of existing ECS cluster to use (if not set one will be created) execution_role_arn: "" # Arn of existing execution role to use (if not set one will be created) task_role_arn: "" # Arn of existing task role to use (if not set one will be created) @@ -25,7 +25,7 @@ cloudprovider: # platform_version: "LATEST" # Fargate platformVersion string like "1.4.0" or "LATEST" cloudwatch_logs_group: "" # Name of existing cloudwatch logs group to use (if not set one will be created) - cloudwatch_logs_stream_prefix: "{cluster_name}" # Stream prefix template + cloudwatch_logs_stream_prefix: "{cluster_name}/{name}" # Stream prefix template cloudwatch_logs_default_retention: 30 # Number of days to retain logs (only applied if not using existing group) vpc: "default" # VPC to use for tasks