Skip to content

Commit bdcbcf7

Browse files
committed
Removed excess LakeFormation tests and readded Athena Query tests
1 parent f0c84e3 commit bdcbcf7

2 files changed

Lines changed: 114 additions & 2307 deletions

File tree

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# Licensed under the Apache License, Version 2.0
3+
"""Unit tests for athena_query.py"""
4+
import os
5+
import pytest
6+
from unittest.mock import Mock, patch, MagicMock
7+
import pandas as pd
8+
9+
from sagemaker.mlops.feature_store.athena_query import AthenaQuery
10+
11+
12+
class TestAthenaQuery:
13+
@pytest.fixture
14+
def mock_session(self):
15+
session = Mock()
16+
session.boto_session.client.return_value = Mock()
17+
session.boto_region_name = "us-west-2"
18+
session.sagemaker_config = {}
19+
return session
20+
21+
@pytest.fixture
22+
def athena_query(self, mock_session):
23+
return AthenaQuery(
24+
catalog="AwsDataCatalog",
25+
database="sagemaker_featurestore",
26+
table_name="my_feature_group",
27+
sagemaker_session=mock_session,
28+
)
29+
30+
def test_initialization(self, athena_query):
31+
assert athena_query.catalog == "AwsDataCatalog"
32+
assert athena_query.database == "sagemaker_featurestore"
33+
assert athena_query.table_name == "my_feature_group"
34+
assert athena_query._current_query_execution_id is None
35+
36+
@patch("sagemaker.mlops.feature_store.athena_query.start_query_execution")
37+
def test_run_starts_query(self, mock_start, athena_query):
38+
mock_start.return_value = {"QueryExecutionId": "query-123"}
39+
40+
result = athena_query.run(
41+
query_string="SELECT * FROM table",
42+
output_location="s3://bucket/output",
43+
)
44+
45+
assert result == "query-123"
46+
assert athena_query._current_query_execution_id == "query-123"
47+
assert athena_query._result_bucket == "bucket"
48+
assert athena_query._result_file_prefix == "output"
49+
50+
@patch("sagemaker.mlops.feature_store.athena_query.start_query_execution")
51+
def test_run_with_kms_key(self, mock_start, athena_query):
52+
mock_start.return_value = {"QueryExecutionId": "query-123"}
53+
54+
athena_query.run(
55+
query_string="SELECT * FROM table",
56+
output_location="s3://bucket/output",
57+
kms_key="arn:aws:kms:us-west-2:123:key/abc",
58+
)
59+
60+
mock_start.assert_called_once()
61+
call_kwargs = mock_start.call_args[1]
62+
assert call_kwargs["kms_key"] == "arn:aws:kms:us-west-2:123:key/abc"
63+
64+
@patch("sagemaker.mlops.feature_store.athena_query.wait_for_athena_query")
65+
def test_wait_calls_helper(self, mock_wait, athena_query):
66+
athena_query._current_query_execution_id = "query-123"
67+
68+
athena_query.wait()
69+
70+
mock_wait.assert_called_once_with(athena_query.sagemaker_session, "query-123")
71+
72+
@patch("sagemaker.mlops.feature_store.athena_query.get_query_execution")
73+
def test_get_query_execution(self, mock_get, athena_query):
74+
athena_query._current_query_execution_id = "query-123"
75+
mock_get.return_value = {"QueryExecution": {"Status": {"State": "SUCCEEDED"}}}
76+
77+
result = athena_query.get_query_execution()
78+
79+
assert result["QueryExecution"]["Status"]["State"] == "SUCCEEDED"
80+
81+
@patch("sagemaker.mlops.feature_store.athena_query.get_query_execution")
82+
@patch("sagemaker.mlops.feature_store.athena_query.download_athena_query_result")
83+
@patch("pandas.read_csv")
84+
@patch("os.path.join")
85+
def test_as_dataframe_success(self, mock_join, mock_read_csv, mock_download, mock_get, athena_query):
86+
athena_query._current_query_execution_id = "query-123"
87+
athena_query._result_bucket = "bucket"
88+
athena_query._result_file_prefix = "prefix"
89+
90+
mock_get.return_value = {"QueryExecution": {"Status": {"State": "SUCCEEDED"}}}
91+
mock_join.return_value = "/tmp/query-123.csv"
92+
mock_read_csv.return_value = pd.DataFrame({"col": [1, 2, 3]})
93+
94+
with patch("tempfile.gettempdir", return_value="/tmp"):
95+
with patch("os.remove"):
96+
df = athena_query.as_dataframe()
97+
98+
assert len(df) == 3
99+
100+
@patch("sagemaker.mlops.feature_store.athena_query.get_query_execution")
101+
def test_as_dataframe_raises_when_running(self, mock_get, athena_query):
102+
athena_query._current_query_execution_id = "query-123"
103+
mock_get.return_value = {"QueryExecution": {"Status": {"State": "RUNNING"}}}
104+
105+
with pytest.raises(RuntimeError, match="still executing"):
106+
athena_query.as_dataframe()
107+
108+
@patch("sagemaker.mlops.feature_store.athena_query.get_query_execution")
109+
def test_as_dataframe_raises_when_failed(self, mock_get, athena_query):
110+
athena_query._current_query_execution_id = "query-123"
111+
mock_get.return_value = {"QueryExecution": {"Status": {"State": "FAILED"}}}
112+
113+
with pytest.raises(RuntimeError, match="failed"):
114+
athena_query.as_dataframe()

0 commit comments

Comments
 (0)