Skip to content

Commit 7b44252

Browse files
authored
Implement custom per-resource tags (#2533)
* Implement custom per-resource tags * Handle client backward compatibility
1 parent 2d7eb7d commit 7b44252

File tree

24 files changed

+322
-33
lines changed

24 files changed

+322
-33
lines changed

src/dstack/_internal/core/backends/aws/compute.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -169,14 +169,19 @@ def create_instance(
169169
raise NoCapacityError("No eligible availability zones")
170170

171171
instance_name = generate_unique_instance_name(instance_config)
172-
tags = {
172+
base_tags = {
173173
"Name": instance_name,
174174
"owner": "dstack",
175175
"dstack_project": project_name,
176176
"dstack_name": instance_config.instance_name,
177177
"dstack_user": instance_config.user,
178178
}
179-
tags = merge_tags(tags=tags, backend_tags=self.config.tags)
179+
tags = merge_tags(
180+
base_tags=base_tags,
181+
backend_tags=self.config.tags,
182+
resource_tags=instance_config.tags,
183+
)
184+
tags = aws_resources.filter_invalid_tags(tags)
180185

181186
disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024)
182187
max_efa_interfaces = _get_maximum_efa_interfaces(
@@ -326,15 +331,20 @@ def create_gateway(
326331
ec2_client = self.session.client("ec2", region_name=configuration.region)
327332

328333
instance_name = generate_unique_gateway_instance_name(configuration)
329-
tags = {
334+
base_tags = {
330335
"Name": instance_name,
331336
"owner": "dstack",
332337
"dstack_project": configuration.project_name,
333338
"dstack_name": configuration.instance_name,
334339
}
335340
if settings.DSTACK_VERSION is not None:
336-
tags["dstack_version"] = settings.DSTACK_VERSION
337-
tags = merge_tags(tags=tags, backend_tags=self.config.tags)
341+
base_tags["dstack_version"] = settings.DSTACK_VERSION
342+
tags = merge_tags(
343+
base_tags=base_tags,
344+
backend_tags=self.config.tags,
345+
resource_tags=configuration.tags,
346+
)
347+
tags = aws_resources.filter_invalid_tags(tags)
338348
tags = aws_resources.make_tags(tags)
339349

340350
vpc_id, subnets_ids = get_vpc_id_subnet_id_or_error(
@@ -522,14 +532,19 @@ def create_volume(self, volume: Volume) -> VolumeProvisioningData:
522532
ec2_client = self.session.client("ec2", region_name=volume.configuration.region)
523533

524534
volume_name = generate_unique_volume_name(volume)
525-
tags = {
535+
base_tags = {
526536
"Name": volume_name,
527537
"owner": "dstack",
528538
"dstack_project": volume.project_name,
529539
"dstack_name": volume.name,
530540
"dstack_user": volume.user,
531541
}
532-
tags = merge_tags(tags=tags, backend_tags=self.config.tags)
542+
tags = merge_tags(
543+
base_tags=base_tags,
544+
backend_tags=self.config.tags,
545+
resource_tags=volume.configuration.tags,
546+
)
547+
tags = aws_resources.filter_invalid_tags(tags)
533548

534549
zones = aws_resources.get_availability_zones(
535550
ec2_client=ec2_client, region=volume.configuration.region

src/dstack/_internal/core/backends/aws/resources.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,16 @@ def make_tags(tags: Dict[str, str]) -> List[Dict[str, str]]:
448448
return tags_list
449449

450450

451+
def filter_invalid_tags(tags: Dict[str, str]) -> Dict[str, str]:
452+
filtered_tags = {}
453+
for k, v in tags.items():
454+
if not _is_valid_tag(k, v):
455+
logger.warning("Skipping invalid tag '%s: %s'", k, v)
456+
continue
457+
filtered_tags[k] = v
458+
return filtered_tags
459+
460+
451461
def validate_tags(tags: Dict[str, str]):
452462
for k, v in tags.items():
453463
if not _is_valid_tag(k, v):

src/dstack/_internal/core/backends/azure/compute.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,18 @@ def create_instance(
136136
location=location,
137137
)
138138

139-
tags = {
139+
base_tags = {
140140
"owner": "dstack",
141141
"dstack_project": instance_config.project_name,
142142
"dstack_name": instance_config.instance_name,
143143
"dstack_user": instance_config.user,
144144
}
145-
tags = merge_tags(tags=tags, backend_tags=self.config.tags)
145+
tags = merge_tags(
146+
base_tags=base_tags,
147+
backend_tags=self.config.tags,
148+
resource_tags=instance_config.tags,
149+
)
150+
tags = azure_resources.filter_invalid_tags(tags)
146151

147152
# TODO: Support custom availability_zones.
148153
# Currently, VMs are regional, which means they don't have zone info.
@@ -228,14 +233,19 @@ def create_gateway(
228233
location=configuration.region,
229234
)
230235

231-
tags = {
236+
base_tags = {
232237
"owner": "dstack",
233238
"dstack_project": configuration.project_name,
234239
"dstack_name": configuration.instance_name,
235240
}
236241
if settings.DSTACK_VERSION is not None:
237-
tags["dstack_version"] = settings.DSTACK_VERSION
238-
tags = merge_tags(tags=tags, backend_tags=self.config.tags)
242+
base_tags["dstack_version"] = settings.DSTACK_VERSION
243+
tags = merge_tags(
244+
base_tags=base_tags,
245+
backend_tags=self.config.tags,
246+
resource_tags=configuration.tags,
247+
)
248+
tags = azure_resources.filter_invalid_tags(tags)
239249

240250
vm = _launch_instance(
241251
compute_client=self._compute_client,

src/dstack/_internal/core/backends/azure/resources.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
from azure.mgmt.network.models import Subnet
66

77
from dstack._internal.core.errors import BackendError
8+
from dstack._internal.utils.logging import get_logger
9+
10+
logger = get_logger(__name__)
11+
812

913
MAX_RESOURCE_NAME_LEN = 64
1014

@@ -77,6 +81,16 @@ def _is_eligible_private_subnet(
7781
return False
7882

7983

84+
def filter_invalid_tags(tags: Dict[str, str]) -> Dict[str, str]:
85+
filtered_tags = {}
86+
for k, v in tags.items():
87+
if not _is_valid_tag(k, v):
88+
logger.warning("Skipping invalid tag '%s: %s'", k, v)
89+
continue
90+
filtered_tags[k] = v
91+
return filtered_tags
92+
93+
8094
def validate_tags(tags: Dict[str, str]):
8195
for k, v in tags.items():
8296
if not _is_valid_tag(k, v):

src/dstack/_internal/core/backends/base/compute.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def run_job(
173173
ssh_keys=[SSHKey(public=project_ssh_public_key.strip())],
174174
volumes=volumes,
175175
reservation=run.run_spec.configuration.reservation,
176+
tags=run.run_spec.merged_profile.tags,
176177
)
177178
instance_offer = instance_offer.copy()
178179
self._restrict_instance_offer_az_to_volumes_az(instance_offer, volumes)
@@ -692,9 +693,18 @@ def get_dstack_gateway_commands() -> List[str]:
692693
]
693694

694695

695-
def merge_tags(tags: Dict[str, str], backend_tags: Optional[Dict[str, str]]) -> Dict[str, str]:
696-
res = tags.copy()
696+
def merge_tags(
697+
base_tags: Dict[str, str],
698+
backend_tags: Optional[Dict[str, str]] = None,
699+
resource_tags: Optional[Dict[str, str]] = None,
700+
) -> Dict[str, str]:
701+
res = base_tags.copy()
702+
# backend_tags have priority over resource_tags
703+
# so that regular users do not override the tags set by admins
697704
if backend_tags is not None:
698705
for k, v in backend_tags.items():
699706
res.setdefault(k, v)
707+
if resource_tags is not None:
708+
for k, v in resource_tags.items():
709+
res.setdefault(k, v)
700710
return res

src/dstack/_internal/core/backends/gcp/compute.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,12 @@ def create_instance(
211211
"dstack_name": instance_config.instance_name,
212212
"dstack_user": instance_config.user.lower(),
213213
}
214-
labels = {k: v for k, v in labels.items() if gcp_resources.is_valid_label_value(v)}
215-
labels = merge_tags(tags=labels, backend_tags=self.config.tags)
214+
labels = merge_tags(
215+
base_tags=labels,
216+
backend_tags=self.config.tags,
217+
resource_tags=instance_config.tags,
218+
)
219+
labels = gcp_resources.filter_invalid_labels(labels)
216220
is_tpu = (
217221
_is_tpu(instance_offer.instance.resources.gpus[0].name)
218222
if instance_offer.instance.resources.gpus
@@ -471,8 +475,12 @@ def create_gateway(
471475
"dstack_project": configuration.project_name.lower(),
472476
"dstack_name": configuration.instance_name,
473477
}
474-
labels = {k: v for k, v in labels.items() if gcp_resources.is_valid_label_value(v)}
475-
labels = merge_tags(tags=labels, backend_tags=self.config.tags)
478+
labels = merge_tags(
479+
base_tags=labels,
480+
backend_tags=self.config.tags,
481+
resource_tags=configuration.tags,
482+
)
483+
labels = gcp_resources.filter_invalid_labels(labels)
476484

477485
request = compute_v1.InsertInstanceRequest()
478486
request.zone = zone
@@ -573,8 +581,12 @@ def create_volume(self, volume: Volume) -> VolumeProvisioningData:
573581
"dstack_name": volume.name,
574582
"dstack_user": volume.user,
575583
}
576-
labels = {k: v for k, v in labels.items() if gcp_resources.is_valid_label_value(v)}
577-
labels = merge_tags(tags=labels, backend_tags=self.config.tags)
584+
labels = merge_tags(
585+
base_tags=labels,
586+
backend_tags=self.config.tags,
587+
resource_tags=volume.configuration.tags,
588+
)
589+
labels = gcp_resources.filter_invalid_labels(labels)
578590

579591
disk = compute_v1.Disk()
580592
disk.name = disk_name

src/dstack/_internal/core/backends/gcp/resources.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,16 @@ def get_accelerators(
332332
return [accelerator_config]
333333

334334

335+
def filter_invalid_labels(labels: Dict[str, str]) -> Dict[str, str]:
336+
filtered_labels = {}
337+
for k, v in labels.items():
338+
if not _is_valid_label(k, v):
339+
logger.warning("Skipping invalid label '%s: %s'", k, v)
340+
continue
341+
filtered_labels[k] = v
342+
return filtered_labels
343+
344+
335345
def validate_labels(labels: Dict[str, str]):
336346
for k, v in labels.items():
337347
if not _is_valid_label(k, v):

src/dstack/_internal/core/models/fleets.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from dstack._internal.core.models.resources import Range, ResourcesSpec
2323
from dstack._internal.utils.json_schema import add_extra_schema_types
24+
from dstack._internal.utils.tags import tags_validator
2425

2526

2627
class FleetStatus(str, Enum):
@@ -249,7 +250,18 @@ class FleetProps(CoreModel):
249250

250251

251252
class FleetConfiguration(InstanceGroupParams, FleetProps):
252-
pass
253+
tags: Annotated[
254+
Optional[Dict[str, str]],
255+
Field(
256+
description=(
257+
"The custom tags to associate with the resource."
258+
" The tags are also propagated to the underlying backend resources."
259+
" If there is a conflict with backend-level tags, does not override them"
260+
)
261+
),
262+
] = None
263+
264+
_validate_tags = validator("tags", pre=True, allow_reuse=True)(tags_validator)
253265

254266

255267
class FleetSpec(CoreModel):

src/dstack/_internal/core/models/gateways.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import datetime
22
from enum import Enum
3-
from typing import Optional, Union
3+
from typing import Dict, Optional, Union
44

5-
from pydantic import Field
5+
from pydantic import Field, validator
66
from typing_extensions import Annotated, Literal
77

88
from dstack._internal.core.models.backends.base import BackendType
99
from dstack._internal.core.models.common import CoreModel
10+
from dstack._internal.utils.tags import tags_validator
1011

1112

1213
class GatewayStatus(str, Enum):
@@ -57,6 +58,18 @@ class GatewayConfiguration(CoreModel):
5758
Optional[AnyGatewayCertificate],
5859
Field(description="The SSL certificate configuration. Defaults to `type: lets-encrypt`"),
5960
] = LetsEncryptGatewayCertificate()
61+
tags: Annotated[
62+
Optional[Dict[str, str]],
63+
Field(
64+
description=(
65+
"The custom tags to associate with the gateway."
66+
" The tags are also propagated to the underlying backend resources."
67+
" If there is a conflict with backend-level tags, does not override them"
68+
)
69+
),
70+
] = None
71+
72+
_validate_tags = validator("tags", pre=True, allow_reuse=True)(tags_validator)
6073

6174

6275
class GatewaySpec(CoreModel):
@@ -88,7 +101,7 @@ class GatewayPlan(CoreModel):
88101
project_name: str
89102
user: str
90103
spec: GatewaySpec
91-
current_resource: Optional[Gateway]
104+
current_resource: Optional[Gateway] = None
92105

93106

94107
class GatewayComputeConfiguration(CoreModel):
@@ -98,7 +111,8 @@ class GatewayComputeConfiguration(CoreModel):
98111
region: str
99112
public_ip: bool
100113
ssh_key_pub: str
101-
certificate: Optional[AnyGatewayCertificate]
114+
certificate: Optional[AnyGatewayCertificate] = None
115+
tags: Optional[Dict[str, str]] = None
102116

103117

104118
class GatewayProvisioningData(CoreModel):

src/dstack/_internal/core/models/instances.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import datetime
22
from enum import Enum
3-
from typing import List, Optional
3+
from typing import Dict, List, Optional
44
from uuid import UUID
55

66
import gpuhunt
@@ -108,6 +108,7 @@ class InstanceConfiguration(CoreModel):
108108
placement_group_name: Optional[str] = None
109109
reservation: Optional[str] = None
110110
volumes: Optional[List[Volume]] = None
111+
tags: Optional[Dict[str, str]] = None
111112

112113
def get_public_keys(self) -> List[str]:
113114
return [ssh_key.public.strip() for ssh_key in self.ssh_keys]

0 commit comments

Comments
 (0)