@@ -69,6 +69,37 @@ def test_creates_client_with_provided_session(self):
6969 mock_session .client .assert_called_with ("lakeformation" , region_name = "us-west-2" )
7070 assert client == mock_client
7171
72+ def test_caches_client_for_same_session_and_region (self ):
73+ """Test that repeated calls with the same session and region reuse the cached client."""
74+ mock_session = MagicMock ()
75+ mock_client = MagicMock ()
76+ mock_session .client .return_value = mock_client
77+
78+ fg = MagicMock (spec = FeatureGroup )
79+ fg ._get_lake_formation_client = FeatureGroup ._get_lake_formation_client .__get__ (fg )
80+
81+ client1 = fg ._get_lake_formation_client (session = mock_session , region = "us-west-2" )
82+ client2 = fg ._get_lake_formation_client (session = mock_session , region = "us-west-2" )
83+
84+ assert client1 is client2
85+ mock_session .client .assert_called_once ()
86+
87+ def test_creates_new_client_for_different_region (self ):
88+ """Test that a different region produces a new client."""
89+ mock_session = MagicMock ()
90+ mock_client_west = MagicMock ()
91+ mock_client_east = MagicMock ()
92+ mock_session .client .side_effect = [mock_client_west , mock_client_east ]
93+
94+ fg = MagicMock (spec = FeatureGroup )
95+ fg ._get_lake_formation_client = FeatureGroup ._get_lake_formation_client .__get__ (fg )
96+
97+ client1 = fg ._get_lake_formation_client (session = mock_session , region = "us-west-2" )
98+ client2 = fg ._get_lake_formation_client (session = mock_session , region = "us-east-1" )
99+
100+ assert client1 is not client2
101+ assert mock_session .client .call_count == 2
102+
72103
73104class TestRegisterS3WithLakeFormation :
74105 """Tests for _register_s3_with_lake_formation method."""
@@ -942,6 +973,7 @@ def test_enable_lake_formation_called_when_enabled(
942973 use_service_linked_role = True ,
943974 registration_role_arn = None ,
944975 show_s3_policy = False ,
976+ disable_hybrid_access_mode = True ,
945977 )
946978 # Verify the feature group was returned
947979 assert result == mock_fg
@@ -1150,11 +1182,157 @@ def test_use_service_linked_role_extraction_from_config(
11501182 use_service_linked_role = use_slr ,
11511183 registration_role_arn = expected_registration_role ,
11521184 show_s3_policy = False ,
1185+ disable_hybrid_access_mode = True ,
11531186 )
11541187 # Verify the feature group was returned
11551188 assert result == mock_fg
11561189
11571190
1191+ class TestDisableHybridAccessMode :
1192+ """Tests for disable_hybrid_access_mode parameter in enable_lake_formation."""
1193+
1194+ def setup_method (self ):
1195+ """Set up test fixtures."""
1196+ from sagemaker .core .shapes import OfflineStoreConfig , S3StorageConfig , DataCatalogConfig
1197+
1198+ self .fg = FeatureGroup (feature_group_name = "test-fg" )
1199+ self .fg .offline_store_config = OfflineStoreConfig (
1200+ s3_storage_config = S3StorageConfig (
1201+ s3_uri = "s3://test-bucket/path" ,
1202+ resolved_output_s3_uri = "s3://test-bucket/resolved-path" ,
1203+ ),
1204+ data_catalog_config = DataCatalogConfig (
1205+ catalog = "AwsDataCatalog" , database = "test_db" , table_name = "test_table"
1206+ ),
1207+ )
1208+ self .fg .role_arn = "arn:aws:iam::123456789012:role/TestRole"
1209+ self .fg .feature_group_arn = "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test-fg"
1210+ self .fg .feature_group_status = "Created"
1211+
1212+ @patch .object (FeatureGroup , "refresh" )
1213+ @patch .object (FeatureGroup , "_register_s3_with_lake_formation" )
1214+ @patch .object (FeatureGroup , "_grant_lake_formation_permissions" )
1215+ @patch .object (FeatureGroup , "_revoke_iam_allowed_principal" )
1216+ def test_revoke_called_when_disable_hybrid_access_mode_true (
1217+ self , mock_revoke , mock_grant , mock_register , mock_refresh
1218+ ):
1219+ """Test that IAMAllowedPrincipal is revoked when disable_hybrid_access_mode=True (default)."""
1220+ mock_register .return_value = True
1221+ mock_grant .return_value = True
1222+ mock_revoke .return_value = True
1223+
1224+ result = self .fg .enable_lake_formation (disable_hybrid_access_mode = True )
1225+
1226+ mock_revoke .assert_called_once ()
1227+ assert result ["iam_principal_revoked" ] is True
1228+
1229+ @patch .object (FeatureGroup , "refresh" )
1230+ @patch .object (FeatureGroup , "_register_s3_with_lake_formation" )
1231+ @patch .object (FeatureGroup , "_grant_lake_formation_permissions" )
1232+ @patch .object (FeatureGroup , "_revoke_iam_allowed_principal" )
1233+ def test_revoke_skipped_when_disable_hybrid_access_mode_false (
1234+ self , mock_revoke , mock_grant , mock_register , mock_refresh
1235+ ):
1236+ """Test that IAMAllowedPrincipal revocation is skipped when disable_hybrid_access_mode=False."""
1237+ mock_register .return_value = True
1238+ mock_grant .return_value = True
1239+
1240+ result = self .fg .enable_lake_formation (disable_hybrid_access_mode = False )
1241+
1242+ mock_revoke .assert_not_called ()
1243+ assert result ["iam_principal_revoked" ] is None
1244+ assert result ["s3_registration" ] is True
1245+ assert result ["permissions_granted" ] is True
1246+
1247+ @patch .object (FeatureGroup , "refresh" )
1248+ @patch .object (FeatureGroup , "_register_s3_with_lake_formation" )
1249+ @patch .object (FeatureGroup , "_grant_lake_formation_permissions" )
1250+ @patch .object (FeatureGroup , "_revoke_iam_allowed_principal" )
1251+ def test_default_disable_hybrid_access_mode_is_true (
1252+ self , mock_revoke , mock_grant , mock_register , mock_refresh
1253+ ):
1254+ """Test that disable_hybrid_access_mode defaults to True (revoke is called by default)."""
1255+ mock_register .return_value = True
1256+ mock_grant .return_value = True
1257+ mock_revoke .return_value = True
1258+
1259+ # Call without specifying disable_hybrid_access_mode
1260+ result = self .fg .enable_lake_formation ()
1261+
1262+ mock_revoke .assert_called_once ()
1263+ assert result ["iam_principal_revoked" ] is True
1264+
1265+
1266+ class TestCreateWithLakeFormationDisableHybridAccessMode :
1267+ """Tests for disable_hybrid_access_mode passed through create() via LakeFormationConfig."""
1268+
1269+ @patch ("sagemaker.core.resources.Base.get_sagemaker_client" )
1270+ @patch .object (FeatureGroup , "get" )
1271+ @patch .object (FeatureGroup , "wait_for_status" )
1272+ @patch .object (FeatureGroup , "enable_lake_formation" )
1273+ def test_disable_hybrid_access_mode_false_passed_through_create (
1274+ self , mock_enable_lf , mock_wait , mock_get , mock_get_client
1275+ ):
1276+ """Test that disable_hybrid_access_mode=False is passed through create() to enable_lake_formation."""
1277+ from sagemaker .core .shapes import FeatureDefinition , OfflineStoreConfig , S3StorageConfig
1278+
1279+ mock_client = MagicMock ()
1280+ mock_client .create_feature_group .return_value = {
1281+ "FeatureGroupArn" : "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test"
1282+ }
1283+ mock_get_client .return_value = mock_client
1284+
1285+ mock_fg = MagicMock (spec = FeatureGroup )
1286+ mock_fg .wait_for_status = mock_wait
1287+ mock_fg .enable_lake_formation = mock_enable_lf
1288+ mock_get .return_value = mock_fg
1289+
1290+ feature_definitions = [
1291+ FeatureDefinition (feature_name = "record_id" , feature_type = "String" ),
1292+ FeatureDefinition (feature_name = "event_time" , feature_type = "String" ),
1293+ ]
1294+
1295+ lf_config = LakeFormationConfig ()
1296+ lf_config .enabled = True
1297+ lf_config .disable_hybrid_access_mode = False
1298+
1299+ FeatureGroup .create (
1300+ feature_group_name = "test-fg" ,
1301+ record_identifier_feature_name = "record_id" ,
1302+ event_time_feature_name = "event_time" ,
1303+ feature_definitions = feature_definitions ,
1304+ offline_store_config = OfflineStoreConfig (
1305+ s3_storage_config = S3StorageConfig (s3_uri = "s3://bucket/path" )
1306+ ),
1307+ role_arn = "arn:aws:iam::123456789012:role/TestRole" ,
1308+ lake_formation_config = lf_config ,
1309+ )
1310+
1311+ mock_enable_lf .assert_called_once_with (
1312+ session = None ,
1313+ region = None ,
1314+ use_service_linked_role = True ,
1315+ registration_role_arn = None ,
1316+ show_s3_policy = False ,
1317+ disable_hybrid_access_mode = False ,
1318+ )
1319+
1320+
1321+ class TestLakeFormationConfigDefaults :
1322+ """Tests for LakeFormationConfig default values."""
1323+
1324+ def test_disable_hybrid_access_mode_defaults_to_true (self ):
1325+ """Test that LakeFormationConfig.disable_hybrid_access_mode defaults to True."""
1326+ config = LakeFormationConfig ()
1327+ assert config .disable_hybrid_access_mode is True
1328+
1329+ def test_disable_hybrid_access_mode_can_be_set_to_false (self ):
1330+ """Test that LakeFormationConfig.disable_hybrid_access_mode can be set to False."""
1331+ config = LakeFormationConfig ()
1332+ config .disable_hybrid_access_mode = False
1333+ assert config .disable_hybrid_access_mode is False
1334+
1335+
11581336class TestExtractAccountIdFromArn :
11591337 """Tests for _extract_account_id_from_arn static method."""
11601338
0 commit comments