Skip to content

Commit d593488

Browse files
committed
feat: add disable_hybrid_access_mode + update tests
1 parent 914f0fd commit d593488

2 files changed

Lines changed: 211 additions & 15 deletions

File tree

sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_group.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,17 @@ class LakeFormationConfig(Base):
4242
show_s3_policy: If True, prints the S3 deny policy to the console after successful
4343
Lake Formation setup. This policy should be added to your S3 bucket to restrict
4444
access to only the allowed principals. Default is False.
45+
disable_hybrid_access_mode: If True, revokes IAMAllowedPrincipal permissions from
46+
the Glue table, moving it to Lake Formation-only access mode. If False, keeps
47+
hybrid access mode where both IAM and Lake Formation control access.
48+
Default is True (LF-only mode).
4549
"""
4650

4751
enabled: bool = False
4852
use_service_linked_role: bool = True
4953
registration_role_arn: Optional[str] = None
5054
show_s3_policy: bool = False
55+
disable_hybrid_access_mode: bool = True
5156

5257

5358
class FeatureGroup(CoreFeatureGroup):
@@ -395,6 +400,7 @@ def enable_lake_formation(
395400
registration_role_arn: Optional[str] = None,
396401
wait_for_active: bool = False,
397402
show_s3_policy: bool = False,
403+
disable_hybrid_access_mode: bool = True,
398404
) -> dict:
399405
"""
400406
Enable Lake Formation governance for this Feature Group's offline store.
@@ -404,7 +410,8 @@ def enable_lake_formation(
404410
2. Validates Feature Group status is 'Created'
405411
3. Registers the offline store S3 location as data lake location
406412
4. Grants the execution role permissions on the Glue table
407-
5. Revokes IAMAllowedPrincipal permissions from the Glue table
413+
5. Optionally revokes IAMAllowedPrincipal permissions from the Glue table
414+
(controlled by disable_hybrid_access_mode)
408415
409416
The role ARN is automatically extracted from the Feature Group's configuration.
410417
Each phase depends on the success of the previous phase - if any phase fails,
@@ -424,11 +431,15 @@ def enable_lake_formation(
424431
show_s3_policy: If True, prints the S3 deny policy to the console after successful
425432
Lake Formation setup. This policy should be added to your S3 bucket to restrict
426433
access to only the allowed principals. Default is False.
434+
disable_hybrid_access_mode: If True, revokes IAMAllowedPrincipal permissions from
435+
the Glue table, moving it to Lake Formation-only access mode. If False, keeps
436+
hybrid access mode where both IAM and Lake Formation control access.
437+
Default is True.
427438
428439
Returns:
429440
Dict with status of each Lake Formation operation:
430441
- s3_registration: bool
431-
- iam_principal_revoked: bool
442+
- iam_principal_revoked: bool or None (None when disable_hybrid_access_mode=False)
432443
- permissions_granted: bool
433444
434445
Raises:
@@ -548,19 +559,25 @@ def enable_lake_formation(
548559
f"Subsequent phases skipped. Results: {results}"
549560
)
550561

551-
# Phase 3: Revoke IAMAllowedPrincipal permissions
552-
try:
553-
results["iam_principal_revoked"] = self._revoke_iam_allowed_principal(
554-
database_name_str, table_name_str, session, region
555-
)
556-
except Exception as e:
557-
raise RuntimeError(
558-
f"Failed to revoke IAMAllowedPrincipal permissions. Results: {results}. Error: {e}"
559-
) from e
560-
561-
if not results["iam_principal_revoked"]:
562-
raise RuntimeError(
563-
f"Failed to revoke IAMAllowedPrincipal permissions. Results: {results}"
562+
# Phase 3: Revoke IAMAllowedPrincipal permissions (if disabling hybrid access mode)
563+
if disable_hybrid_access_mode:
564+
try:
565+
results["iam_principal_revoked"] = self._revoke_iam_allowed_principal(
566+
database_name_str, table_name_str, session, region
567+
)
568+
except Exception as e:
569+
raise RuntimeError(
570+
f"Failed to revoke IAMAllowedPrincipal permissions. Results: {results}. Error: {e}"
571+
) from e
572+
573+
if not results["iam_principal_revoked"]:
574+
raise RuntimeError(
575+
f"Failed to revoke IAMAllowedPrincipal permissions. Results: {results}"
576+
)
577+
else:
578+
results["iam_principal_revoked"] = None
579+
logger.info(
580+
"Skipping IAMAllowedPrincipal revocation - hybrid access mode preserved."
564581
)
565582

566583
logger.info(f"Lake Formation setup complete for {self.feature_group_name}: {results}")
@@ -724,5 +741,6 @@ def create(
724741
use_service_linked_role=lake_formation_config.use_service_linked_role,
725742
registration_role_arn=lake_formation_config.registration_role_arn,
726743
show_s3_policy=lake_formation_config.show_s3_policy,
744+
disable_hybrid_access_mode=lake_formation_config.disable_hybrid_access_mode,
727745
)
728746
return feature_group

sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_lakeformation.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,37 @@ def test_creates_client_with_provided_session(self):
6969
mock_session.client.assert_called_with("lakeformation", region_name="us-west-2")
7070
assert client == mock_client
7171

72+
def test_caches_client_for_same_session_and_region(self):
73+
"""Test that repeated calls with the same session and region reuse the cached client."""
74+
mock_session = MagicMock()
75+
mock_client = MagicMock()
76+
mock_session.client.return_value = mock_client
77+
78+
fg = MagicMock(spec=FeatureGroup)
79+
fg._get_lake_formation_client = FeatureGroup._get_lake_formation_client.__get__(fg)
80+
81+
client1 = fg._get_lake_formation_client(session=mock_session, region="us-west-2")
82+
client2 = fg._get_lake_formation_client(session=mock_session, region="us-west-2")
83+
84+
assert client1 is client2
85+
mock_session.client.assert_called_once()
86+
87+
def test_creates_new_client_for_different_region(self):
88+
"""Test that a different region produces a new client."""
89+
mock_session = MagicMock()
90+
mock_client_west = MagicMock()
91+
mock_client_east = MagicMock()
92+
mock_session.client.side_effect = [mock_client_west, mock_client_east]
93+
94+
fg = MagicMock(spec=FeatureGroup)
95+
fg._get_lake_formation_client = FeatureGroup._get_lake_formation_client.__get__(fg)
96+
97+
client1 = fg._get_lake_formation_client(session=mock_session, region="us-west-2")
98+
client2 = fg._get_lake_formation_client(session=mock_session, region="us-east-1")
99+
100+
assert client1 is not client2
101+
assert mock_session.client.call_count == 2
102+
72103

73104
class TestRegisterS3WithLakeFormation:
74105
"""Tests for _register_s3_with_lake_formation method."""
@@ -942,6 +973,7 @@ def test_enable_lake_formation_called_when_enabled(
942973
use_service_linked_role=True,
943974
registration_role_arn=None,
944975
show_s3_policy=False,
976+
disable_hybrid_access_mode=True,
945977
)
946978
# Verify the feature group was returned
947979
assert result == mock_fg
@@ -1150,11 +1182,157 @@ def test_use_service_linked_role_extraction_from_config(
11501182
use_service_linked_role=use_slr,
11511183
registration_role_arn=expected_registration_role,
11521184
show_s3_policy=False,
1185+
disable_hybrid_access_mode=True,
11531186
)
11541187
# Verify the feature group was returned
11551188
assert result == mock_fg
11561189

11571190

1191+
class TestDisableHybridAccessMode:
1192+
"""Tests for disable_hybrid_access_mode parameter in enable_lake_formation."""
1193+
1194+
def setup_method(self):
1195+
"""Set up test fixtures."""
1196+
from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig
1197+
1198+
self.fg = FeatureGroup(feature_group_name="test-fg")
1199+
self.fg.offline_store_config = OfflineStoreConfig(
1200+
s3_storage_config=S3StorageConfig(
1201+
s3_uri="s3://test-bucket/path",
1202+
resolved_output_s3_uri="s3://test-bucket/resolved-path",
1203+
),
1204+
data_catalog_config=DataCatalogConfig(
1205+
catalog="AwsDataCatalog", database="test_db", table_name="test_table"
1206+
),
1207+
)
1208+
self.fg.role_arn = "arn:aws:iam::123456789012:role/TestRole"
1209+
self.fg.feature_group_arn = "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test-fg"
1210+
self.fg.feature_group_status = "Created"
1211+
1212+
@patch.object(FeatureGroup, "refresh")
1213+
@patch.object(FeatureGroup, "_register_s3_with_lake_formation")
1214+
@patch.object(FeatureGroup, "_grant_lake_formation_permissions")
1215+
@patch.object(FeatureGroup, "_revoke_iam_allowed_principal")
1216+
def test_revoke_called_when_disable_hybrid_access_mode_true(
1217+
self, mock_revoke, mock_grant, mock_register, mock_refresh
1218+
):
1219+
"""Test that IAMAllowedPrincipal is revoked when disable_hybrid_access_mode=True (default)."""
1220+
mock_register.return_value = True
1221+
mock_grant.return_value = True
1222+
mock_revoke.return_value = True
1223+
1224+
result = self.fg.enable_lake_formation(disable_hybrid_access_mode=True)
1225+
1226+
mock_revoke.assert_called_once()
1227+
assert result["iam_principal_revoked"] is True
1228+
1229+
@patch.object(FeatureGroup, "refresh")
1230+
@patch.object(FeatureGroup, "_register_s3_with_lake_formation")
1231+
@patch.object(FeatureGroup, "_grant_lake_formation_permissions")
1232+
@patch.object(FeatureGroup, "_revoke_iam_allowed_principal")
1233+
def test_revoke_skipped_when_disable_hybrid_access_mode_false(
1234+
self, mock_revoke, mock_grant, mock_register, mock_refresh
1235+
):
1236+
"""Test that IAMAllowedPrincipal revocation is skipped when disable_hybrid_access_mode=False."""
1237+
mock_register.return_value = True
1238+
mock_grant.return_value = True
1239+
1240+
result = self.fg.enable_lake_formation(disable_hybrid_access_mode=False)
1241+
1242+
mock_revoke.assert_not_called()
1243+
assert result["iam_principal_revoked"] is None
1244+
assert result["s3_registration"] is True
1245+
assert result["permissions_granted"] is True
1246+
1247+
@patch.object(FeatureGroup, "refresh")
1248+
@patch.object(FeatureGroup, "_register_s3_with_lake_formation")
1249+
@patch.object(FeatureGroup, "_grant_lake_formation_permissions")
1250+
@patch.object(FeatureGroup, "_revoke_iam_allowed_principal")
1251+
def test_default_disable_hybrid_access_mode_is_true(
1252+
self, mock_revoke, mock_grant, mock_register, mock_refresh
1253+
):
1254+
"""Test that disable_hybrid_access_mode defaults to True (revoke is called by default)."""
1255+
mock_register.return_value = True
1256+
mock_grant.return_value = True
1257+
mock_revoke.return_value = True
1258+
1259+
# Call without specifying disable_hybrid_access_mode
1260+
result = self.fg.enable_lake_formation()
1261+
1262+
mock_revoke.assert_called_once()
1263+
assert result["iam_principal_revoked"] is True
1264+
1265+
1266+
class TestCreateWithLakeFormationDisableHybridAccessMode:
1267+
"""Tests for disable_hybrid_access_mode passed through create() via LakeFormationConfig."""
1268+
1269+
@patch("sagemaker.core.resources.Base.get_sagemaker_client")
1270+
@patch.object(FeatureGroup, "get")
1271+
@patch.object(FeatureGroup, "wait_for_status")
1272+
@patch.object(FeatureGroup, "enable_lake_formation")
1273+
def test_disable_hybrid_access_mode_false_passed_through_create(
1274+
self, mock_enable_lf, mock_wait, mock_get, mock_get_client
1275+
):
1276+
"""Test that disable_hybrid_access_mode=False is passed through create() to enable_lake_formation."""
1277+
from sagemaker.core.shapes import FeatureDefinition, OfflineStoreConfig, S3StorageConfig
1278+
1279+
mock_client = MagicMock()
1280+
mock_client.create_feature_group.return_value = {
1281+
"FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test"
1282+
}
1283+
mock_get_client.return_value = mock_client
1284+
1285+
mock_fg = MagicMock(spec=FeatureGroup)
1286+
mock_fg.wait_for_status = mock_wait
1287+
mock_fg.enable_lake_formation = mock_enable_lf
1288+
mock_get.return_value = mock_fg
1289+
1290+
feature_definitions = [
1291+
FeatureDefinition(feature_name="record_id", feature_type="String"),
1292+
FeatureDefinition(feature_name="event_time", feature_type="String"),
1293+
]
1294+
1295+
lf_config = LakeFormationConfig()
1296+
lf_config.enabled = True
1297+
lf_config.disable_hybrid_access_mode = False
1298+
1299+
FeatureGroup.create(
1300+
feature_group_name="test-fg",
1301+
record_identifier_feature_name="record_id",
1302+
event_time_feature_name="event_time",
1303+
feature_definitions=feature_definitions,
1304+
offline_store_config=OfflineStoreConfig(
1305+
s3_storage_config=S3StorageConfig(s3_uri="s3://bucket/path")
1306+
),
1307+
role_arn="arn:aws:iam::123456789012:role/TestRole",
1308+
lake_formation_config=lf_config,
1309+
)
1310+
1311+
mock_enable_lf.assert_called_once_with(
1312+
session=None,
1313+
region=None,
1314+
use_service_linked_role=True,
1315+
registration_role_arn=None,
1316+
show_s3_policy=False,
1317+
disable_hybrid_access_mode=False,
1318+
)
1319+
1320+
1321+
class TestLakeFormationConfigDefaults:
1322+
"""Tests for LakeFormationConfig default values."""
1323+
1324+
def test_disable_hybrid_access_mode_defaults_to_true(self):
1325+
"""Test that LakeFormationConfig.disable_hybrid_access_mode defaults to True."""
1326+
config = LakeFormationConfig()
1327+
assert config.disable_hybrid_access_mode is True
1328+
1329+
def test_disable_hybrid_access_mode_can_be_set_to_false(self):
1330+
"""Test that LakeFormationConfig.disable_hybrid_access_mode can be set to False."""
1331+
config = LakeFormationConfig()
1332+
config.disable_hybrid_access_mode = False
1333+
assert config.disable_hybrid_access_mode is False
1334+
1335+
11581336
class TestExtractAccountIdFromArn:
11591337
"""Tests for _extract_account_id_from_arn static method."""
11601338

0 commit comments

Comments
 (0)