Skip to content

Commit a904faf

Browse files
authored
Add missed AWS Sagemaker permission changes (#124)
1 parent 027cf34 commit a904faf

8 files changed

Lines changed: 139 additions & 24 deletions

File tree

cleancloud/doctor/aws.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -645,13 +645,26 @@ def run_aws_ai_doctor(profile: Optional[str], region: Optional[str] = None) -> N
645645
# present since ListEndpoints already confirmed SageMaker access.
646646
endpoints = sagemaker.list_endpoints(MaxResults=1, StatusEquals="InService")
647647
endpoint_list = endpoints.get("Endpoints", [])
648+
endpoint_config_name = None
648649
if endpoint_list:
649-
sagemaker.describe_endpoint(EndpointName=endpoint_list[0]["EndpointName"])
650+
ep = sagemaker.describe_endpoint(EndpointName=endpoint_list[0]["EndpointName"])
651+
endpoint_config_name = ep.get("EndpointConfigName")
650652
permissions_tested.append("sagemaker:DescribeEndpoint")
651653
success("sagemaker:DescribeEndpoint")
652654
except Exception as e:
653655
permissions_failed.append(("sagemaker:DescribeEndpoint", str(e)))
654656
warn(f"sagemaker:DescribeEndpoint - {e}")
657+
endpoint_config_name = None
658+
659+
try:
660+
# DescribeEndpointConfig — needed to resolve instance type (not in DescribeEndpoint response)
661+
if endpoint_config_name:
662+
sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
663+
permissions_tested.append("sagemaker:DescribeEndpointConfig")
664+
success("sagemaker:DescribeEndpointConfig")
665+
except Exception as e:
666+
permissions_failed.append(("sagemaker:DescribeEndpointConfig", str(e)))
667+
warn(f"sagemaker:DescribeEndpointConfig - {e}")
655668

656669
try:
657670
cloudwatch = session.client("cloudwatch", region_name=region)

cleancloud/providers/aws/rules/sagemaker_endpoint_idle.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def find_idle_sagemaker_endpoints(
199199
raise PermissionError(
200200
"Missing required IAM permissions: "
201201
"sagemaker:ListEndpoints, sagemaker:DescribeEndpoint, "
202-
"cloudwatch:GetMetricStatistics"
202+
"sagemaker:DescribeEndpointConfig, cloudwatch:GetMetricStatistics"
203203
) from e
204204
raise
205205

@@ -246,25 +246,42 @@ def _describe_endpoint(
246246
) -> Tuple[float, bool, int, int, Optional[str]]:
247247
"""Return (monthly_cost, is_gpu, variant_count, total_instances, primary_instance_type).
248248
249+
Instance type lives in the endpoint *config* (describe_endpoint_config), not in
250+
the endpoint summary (describe_endpoint ProductionVariantSummary has no InstanceType
251+
field — only DesiredInstanceCount). We make two calls: one for instance counts,
252+
one for instance types, then pair them by VariantName.
253+
249254
Cost is computed per-variant using DesiredInstanceCount × per-instance cost, summed
250255
across all variants. GPU flag is True if any variant uses an accelerator instance.
251256
252-
Returns (default_cost, False, 1, 1, None) on failure to ensure the endpoint is
253-
still flagged conservatively rather than silently dropped.
257+
Returns (0, False, 0, 0, None) on failure so the endpoint is skipped rather than
258+
flagged with assumed values.
254259
"""
255260
try:
256-
response = sagemaker.describe_endpoint(EndpointName=endpoint_name)
257-
variants = response.get("ProductionVariants", [])
261+
endpoint = sagemaker.describe_endpoint(EndpointName=endpoint_name)
262+
variants = endpoint.get("ProductionVariants", [])
258263
if not variants:
259264
return _DEFAULT_MONTHLY_COST, False, 0, 0, None
260265

266+
# Fetch instance types from the endpoint config
267+
config_name = endpoint.get("EndpointConfigName", "")
268+
instance_type_by_variant: dict = {}
269+
try:
270+
config = sagemaker.describe_endpoint_config(EndpointConfigName=config_name)
271+
for cv in config.get("ProductionVariants", []):
272+
itype = cv.get("InstanceType")
273+
if itype:
274+
instance_type_by_variant[cv["VariantName"]] = itype
275+
except ClientError:
276+
pass # config inaccessible — costs/GPU will use defaults
277+
261278
total_monthly_cost = 0.0
262279
is_gpu = False
263280
total_instances = 0
264281
primary_instance_type: Optional[str] = None
265282

266283
for i, v in enumerate(variants):
267-
itype = v.get("CurrentInstanceType")
284+
itype = instance_type_by_variant.get(v.get("VariantName", ""))
268285
count = v.get("DesiredInstanceCount") or 0
269286
total_instances += count
270287

deploy/cloudformation/cleancloud-role.yaml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,18 @@ Parameters:
2323
Optional. If set, the hub account must provide this value when assuming the role
2424
(confused deputy protection). Pass the same value via --external-id in CleanCloud.
2525
26+
EnableAIScan:
27+
Type: String
28+
Default: "false"
29+
AllowedValues: ["true", "false"]
30+
Description: >
31+
Set to true to attach the AI/ML policy (SageMaker idle endpoint detection).
32+
Required for: cleancloud scan --category ai
33+
See: security/aws/ai-readonly.json
34+
2635
Conditions:
2736
UseExternalId: !Not [!Equals [!Ref ExternalId, ""]]
37+
EnableAIScan: !Equals [!Ref EnableAIScan, "true"]
2838

2939
Resources:
3040
CleanCloudRole:
@@ -97,6 +107,24 @@ Resources:
97107
- Key: Purpose
98108
Value: CrossAccountReadOnlyScanning
99109

110+
CleanCloudAIPolicy:
111+
Type: AWS::IAM::Policy
112+
Condition: EnableAIScan
113+
Properties:
114+
PolicyName: CleanCloudAIReadOnly
115+
Roles:
116+
- !Ref CleanCloudRole
117+
PolicyDocument:
118+
Version: "2012-10-17"
119+
Statement:
120+
- Sid: SageMakerReadOnly
121+
Effect: Allow
122+
Action:
123+
- sagemaker:ListEndpoints
124+
- sagemaker:DescribeEndpoint
125+
- sagemaker:DescribeEndpointConfig
126+
Resource: "*"
127+
100128
Outputs:
101129
RoleArn:
102130
Description: ARN of the CleanCloud IAM role.

deploy/terraform/aws/main.tf

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,29 @@ resource "aws_iam_role" "cleancloud" {
3333
})
3434
}
3535

36+
resource "aws_iam_role_policy" "cleancloud_ai" {
37+
count = var.enable_ai ? 1 : 0
38+
39+
name = "CleanCloudAIReadOnly"
40+
role = aws_iam_role.cleancloud.id
41+
42+
policy = jsonencode({
43+
Version = "2012-10-17"
44+
Statement = [
45+
{
46+
Sid = "SageMakerReadOnly"
47+
Effect = "Allow"
48+
Action = [
49+
"sagemaker:ListEndpoints",
50+
"sagemaker:DescribeEndpoint",
51+
"sagemaker:DescribeEndpointConfig",
52+
]
53+
Resource = "*"
54+
},
55+
]
56+
})
57+
}
58+
3659
resource "aws_iam_role_policy" "cleancloud" {
3760
name = "CleanCloudReadOnly"
3861
role = aws_iam_role.cleancloud.id

deploy/terraform/aws/variables.tf

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ variable "external_id" {
2020
description = "Optional. If set, adds an ExternalId condition to the trust policy (confused deputy protection). Pass the same value via --external-id in CleanCloud."
2121
}
2222

23+
variable "enable_ai" {
24+
type = bool
25+
default = false
26+
description = "Attach the AI/ML policy (SageMaker idle endpoint detection). Required for: cleancloud scan --category ai. See: security/aws/ai-readonly.json"
27+
}
28+
2329
variable "tags" {
2430
type = map(string)
2531
default = {}

docs/rules.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,7 @@ Confidence thresholds and signal weighting are documented in [confidence.md](con
731731
**Required permissions:**
732732
- `sagemaker:ListEndpoints`
733733
- `sagemaker:DescribeEndpoint`
734+
- `sagemaker:DescribeEndpointConfig`
734735
- `cloudwatch:GetMetricStatistics`
735736

736737
> **Not run by default.** AI/ML rules are opt-in to avoid surprising users who don't use these services. Run with `cleancloud scan --provider aws --category ai` (or `--category all` to combine with hygiene rules). If the permissions above are not granted, the rule is gracefully skipped and reported in the skipped rules section — it will not fail the scan. Attach [`security/aws/ai-readonly.json`](../security/aws/ai-readonly.json) to your IAM role to enable this rule.

security/aws/ai-readonly.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
"Effect": "Allow",
77
"Action": [
88
"sagemaker:ListEndpoints",
9-
"sagemaker:DescribeEndpoint"
9+
"sagemaker:DescribeEndpoint",
10+
"sagemaker:DescribeEndpointConfig"
1011
],
1112
"Resource": "*"
1213
}

tests/cleancloud/providers/aws/test_aws_sagemaker_endpoint_idle.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,27 @@ def _make_endpoint(name="test-endpoint", age_days=30):
3434
def _make_describe_response(
3535
instance_type="ml.m5.xlarge", variant_count=1, desired_instance_count=1
3636
):
37-
"""Build a describe_endpoint response with DesiredInstanceCount on each variant."""
37+
"""Build a describe_endpoint response.
38+
39+
ProductionVariantSummary does NOT include InstanceType — that lives in the
40+
endpoint config. We include EndpointConfigName so the rule can fetch it.
41+
"""
3842
variants = [
3943
{
4044
"VariantName": f"variant-{i}",
41-
"CurrentInstanceType": instance_type,
4245
"CurrentInstanceCount": desired_instance_count,
4346
"DesiredInstanceCount": desired_instance_count,
4447
}
4548
for i in range(variant_count)
4649
]
50+
return {"ProductionVariants": variants, "EndpointConfigName": "test-config"}
51+
52+
53+
def _make_describe_config_response(instance_type="ml.m5.xlarge", variant_count=1):
54+
"""Build a describe_endpoint_config response with InstanceType per variant."""
55+
variants = [
56+
{"VariantName": f"variant-{i}", "InstanceType": instance_type} for i in range(variant_count)
57+
]
4758
return {"ProductionVariants": variants}
4859

4960

@@ -73,6 +84,7 @@ def test_idle_cpu_endpoint_detected():
7384
paginator = sagemaker.get_paginator.return_value
7485
paginator.paginate.return_value = [{"Endpoints": [_make_endpoint(age_days=30)]}]
7586
sagemaker.describe_endpoint.return_value = _make_describe_response("ml.m5.xlarge")
87+
sagemaker.describe_endpoint_config.return_value = _make_describe_config_response("ml.m5.xlarge")
7688
cloudwatch.get_metric_statistics.return_value = _no_invocations()
7789

7890
session = _make_session(sagemaker, cloudwatch)
@@ -99,6 +111,9 @@ def test_idle_gpu_endpoint_detected_high_risk():
99111
paginator = sagemaker.get_paginator.return_value
100112
paginator.paginate.return_value = [{"Endpoints": [_make_endpoint(age_days=30)]}]
101113
sagemaker.describe_endpoint.return_value = _make_describe_response("ml.p3.2xlarge")
114+
sagemaker.describe_endpoint_config.return_value = _make_describe_config_response(
115+
"ml.p3.2xlarge"
116+
)
102117
cloudwatch.get_metric_statistics.return_value = _no_invocations()
103118

104119
session = _make_session(sagemaker, cloudwatch)
@@ -156,6 +171,7 @@ def test_timezone_naive_creation_time_handled():
156171

157172
paginator.paginate.return_value = [{"Endpoints": [endpoint]}]
158173
sagemaker.describe_endpoint.return_value = _make_describe_response("ml.m5.xlarge")
174+
sagemaker.describe_endpoint_config.return_value = _make_describe_config_response("ml.m5.xlarge")
159175
cloudwatch.get_metric_statistics.return_value = _no_invocations()
160176

161177
session = _make_session(sagemaker, cloudwatch)
@@ -212,7 +228,8 @@ def test_missing_desired_instance_count_treated_as_zero():
212228
paginator.paginate.return_value = [{"Endpoints": [_make_endpoint(age_days=30)]}]
213229
# No DesiredInstanceCount key — AWS response omits it
214230
sagemaker.describe_endpoint.return_value = {
215-
"ProductionVariants": [{"VariantName": "v1", "CurrentInstanceType": "ml.m5.xlarge"}]
231+
"ProductionVariants": [{"VariantName": "v1"}],
232+
"EndpointConfigName": "test-config",
216233
}
217234
cloudwatch.get_metric_statistics.return_value = _no_invocations()
218235

@@ -233,9 +250,10 @@ def test_partial_scaled_to_zero_still_flagged():
233250
# 2 variants: one with 1 instance, one with 0
234251
sagemaker.describe_endpoint.return_value = {
235252
"ProductionVariants": [
236-
{"VariantName": "v1", "CurrentInstanceType": "ml.m5.xlarge", "DesiredInstanceCount": 1},
237-
{"VariantName": "v2", "CurrentInstanceType": "ml.m5.xlarge", "DesiredInstanceCount": 0},
238-
]
253+
{"VariantName": "v1", "DesiredInstanceCount": 1},
254+
{"VariantName": "v2", "DesiredInstanceCount": 0},
255+
],
256+
"EndpointConfigName": "test-config",
239257
}
240258
cloudwatch.get_metric_statistics.return_value = _no_invocations()
241259

@@ -386,6 +404,9 @@ def test_g4dn_instance_detected_as_gpu():
386404
paginator = sagemaker.get_paginator.return_value
387405
paginator.paginate.return_value = [{"Endpoints": [_make_endpoint(age_days=30)]}]
388406
sagemaker.describe_endpoint.return_value = _make_describe_response("ml.g4dn.xlarge")
407+
sagemaker.describe_endpoint_config.return_value = _make_describe_config_response(
408+
"ml.g4dn.xlarge"
409+
)
389410
cloudwatch.get_metric_statistics.return_value = _no_invocations()
390411

391412
session = _make_session(sagemaker, cloudwatch)
@@ -404,6 +425,9 @@ def test_inf1_instance_detected_as_gpu():
404425
paginator = sagemaker.get_paginator.return_value
405426
paginator.paginate.return_value = [{"Endpoints": [_make_endpoint(age_days=30)]}]
406427
sagemaker.describe_endpoint.return_value = _make_describe_response("ml.inf1.xlarge")
428+
sagemaker.describe_endpoint_config.return_value = _make_describe_config_response(
429+
"ml.inf1.xlarge"
430+
)
407431
cloudwatch.get_metric_statistics.return_value = _no_invocations()
408432

409433
session = _make_session(sagemaker, cloudwatch)
@@ -427,6 +451,9 @@ def test_multi_variant_cost_scaled():
427451
sagemaker.describe_endpoint.return_value = _make_describe_response(
428452
"ml.m5.xlarge", variant_count=3
429453
)
454+
sagemaker.describe_endpoint_config.return_value = _make_describe_config_response(
455+
"ml.m5.xlarge", variant_count=3
456+
)
430457
cloudwatch.get_metric_statistics.return_value = _no_invocations()
431458

432459
session = _make_session(sagemaker, cloudwatch)
@@ -448,16 +475,15 @@ def test_multi_variant_mixed_instance_types_cost():
448475
paginator.paginate.return_value = [{"Endpoints": [_make_endpoint(age_days=30)]}]
449476
sagemaker.describe_endpoint.return_value = {
450477
"ProductionVariants": [
451-
{
452-
"VariantName": "cpu",
453-
"CurrentInstanceType": "ml.m5.xlarge",
454-
"DesiredInstanceCount": 2,
455-
},
456-
{
457-
"VariantName": "gpu",
458-
"CurrentInstanceType": "ml.g4dn.xlarge",
459-
"DesiredInstanceCount": 1,
460-
},
478+
{"VariantName": "cpu", "DesiredInstanceCount": 2},
479+
{"VariantName": "gpu", "DesiredInstanceCount": 1},
480+
],
481+
"EndpointConfigName": "test-config",
482+
}
483+
sagemaker.describe_endpoint_config.return_value = {
484+
"ProductionVariants": [
485+
{"VariantName": "cpu", "InstanceType": "ml.m5.xlarge"},
486+
{"VariantName": "gpu", "InstanceType": "ml.g4dn.xlarge"},
461487
]
462488
}
463489
cloudwatch.get_metric_statistics.return_value = _no_invocations()

0 commit comments

Comments
 (0)