Skip to content

Commit dbe7a03

Browse files
committed
reusing clients + bug fixes
1 parent 1077d11 commit dbe7a03

1 file changed

Lines changed: 28 additions & 11 deletions

File tree

sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_group.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Licensed under the Apache License, Version 2.0
33
"""FeatureGroup with Lake Formation support."""
44

5+
import json
56
import logging
67
from typing import List, Optional
78

@@ -132,6 +133,7 @@ def _generate_s3_deny_policy(
132133
s3_prefix: str,
133134
lake_formation_role_arn: str,
134135
feature_store_role_arn: str,
136+
region: Optional[str] = None,
135137
) -> dict:
136138
"""
137139
Generate an S3 deny policy for Lake Formation governance.
@@ -144,6 +146,8 @@ def _generate_s3_deny_policy(
144146
s3_prefix: S3 prefix path (without bucket name).
145147
lake_formation_role_arn: Lake Formation registration role ARN.
146148
feature_store_role_arn: Feature Store execution role ARN.
149+
region: AWS region name (e.g., 'us-west-2'). Used to determine the correct
150+
partition for S3 ARNs. If not provided, defaults to 'aws' partition.
147151
148152
Returns:
149153
S3 bucket policy as a dict with valid JSON structure containing:
@@ -152,6 +156,8 @@ def _generate_s3_deny_policy(
152156
1. Deny GetObject, PutObject, DeleteObject on data prefix except allowed principals
153157
2. Deny ListBucket on bucket with prefix condition except allowed principals
154158
"""
159+
partition = aws_partition(region) if region else "aws"
160+
155161
policy = {
156162
"Version": "2012-10-17",
157163
"Statement": [
@@ -160,7 +166,7 @@ def _generate_s3_deny_policy(
160166
"Effect": "Deny",
161167
"Principal": "*",
162168
"Action": ["s3:GetObject", "s3:PutObject", "s3:DeleteObject"],
163-
"Resource": f"arn:aws:s3:::{bucket_name}/{s3_prefix}/*",
169+
"Resource": f"arn:{partition}:s3:::{bucket_name}/{s3_prefix}/*",
164170
"Condition": {
165171
"StringNotEquals": {
166172
"aws:PrincipalArn": [
@@ -175,7 +181,7 @@ def _generate_s3_deny_policy(
175181
"Effect": "Deny",
176182
"Principal": "*",
177183
"Action": "s3:ListBucket",
178-
"Resource": f"arn:aws:s3:::{bucket_name}",
184+
"Resource": f"arn:{partition}:s3:::{bucket_name}",
179185
"Condition": {
180186
"StringLike": {"s3:prefix": f"{s3_prefix}/*"},
181187
"StringNotEquals": {
@@ -196,7 +202,11 @@ def _get_lake_formation_client(
196202
region: Optional[str] = None,
197203
):
198204
"""
199-
Get a Lake Formation client.
205+
Get a Lake Formation client, reusing a cached client when possible.
206+
207+
The client is cached on the instance keyed by (session, region). Subsequent
208+
calls with the same arguments return the existing client instead of creating
209+
a new one.
200210
201211
Args:
202212
session: Boto3 session. If not provided, a new session will be created.
@@ -205,9 +215,17 @@ def _get_lake_formation_client(
205215
Returns:
206216
A boto3 Lake Formation client.
207217
"""
208-
# TODO: don't create w new client for each call
209-
boto_session = session or Session()
210-
return boto_session.client("lakeformation", region_name=region)
218+
cache_key = (id(session), region)
219+
if not hasattr(self, "_lf_client_cache"):
220+
self._lf_client_cache: dict = {}
221+
222+
if cache_key not in self._lf_client_cache:
223+
boto_session = session or Session()
224+
self._lf_client_cache[cache_key] = boto_session.client(
225+
"lakeformation", region_name=region
226+
)
227+
228+
return self._lf_client_cache[cache_key]
211229

212230
def _register_s3_with_lake_formation(
213231
self,
@@ -242,7 +260,7 @@ def _register_s3_with_lake_formation(
242260

243261
# Get region from session if not provided
244262
if region is None and session is not None:
245-
region = session.region_name()
263+
region = session.region_name
246264

247265
client = self._get_lake_formation_client(session, region)
248266
resource_arn = self._s3_uri_to_arn(s3_location, region)
@@ -288,7 +306,7 @@ def _revoke_iam_allowed_principal(
288306
"""
289307
# Get region from session if not provided
290308
if region is None and session is not None:
291-
region = session.region_name()
309+
region = session.region_name
292310

293311
client = self._get_lake_formation_client(session, region)
294312

@@ -341,7 +359,7 @@ def _grant_lake_formation_permissions(
341359
"""
342360
# Get region from session if not provided
343361
if region is None and session is not None:
344-
region = session.region_name()
362+
region = session.region_name
345363

346364
client = self._get_lake_formation_client(session, region)
347365
permissions = ["SELECT", "INSERT", "DELETE", "DESCRIBE", "ALTER"]
@@ -569,11 +587,10 @@ def enable_lake_formation(
569587
s3_prefix=s3_prefix,
570588
lake_formation_role_arn=lf_role_arn,
571589
feature_store_role_arn=role_arn_str,
590+
region=region,
572591
)
573592

574593
# Print policy with clear instructions
575-
import json
576-
577594
print("\n" + "=" * 80)
578595
print("S3 Bucket Policy Update recommended")
579596
print("=" * 80)

0 commit comments

Comments
 (0)