Skip to content

Commit 4da984f

Browse files
authored
Merge branch 'master' into oidc-rbac-ssl-logging
2 parents ee69865 + 9feca77 commit 4da984f

2 files changed

Lines changed: 373 additions & 0 deletions

File tree

sdk/python/feast/infra/compute_engines/spark/utils.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import logging
2+
import os
13
from typing import Dict, Iterable, Literal, Optional
24

35
import pandas as pd
@@ -9,6 +11,102 @@
911
from feast.infra.common.serde import SerializedArtifacts
1012
from feast.utils import _convert_arrow_to_proto, _run_pyarrow_field_mapping
1113

14+
try:
15+
import boto3
16+
from botocore.client import Config as BotoConfig
17+
except ImportError:
18+
boto3 = None # type: ignore[assignment]
19+
BotoConfig = None # type: ignore[assignment,misc]
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
def _ensure_s3a_event_log_dir(spark_config: Dict[str, str]) -> None:
25+
"""Pre-create the S3A event log prefix before SparkContext initialisation.
26+
27+
Spark's EventLogFileWriter.requireLogBaseDirAsDirectory() is called inside
28+
SparkContext.__init__ and crashes if the S3A path doesn't exist yet (S3 has no
29+
real directories, so an empty prefix returns a 404). This function writes a
30+
zero-byte placeholder so the prefix exists before SparkContext is built.
31+
32+
This is only attempted when:
33+
- spark.eventLog.enabled == "true"
34+
- spark.eventLog.dir starts with "s3a://"
35+
Failures are non-fatal: Spark will surface its own error if the dir is still missing.
36+
"""
37+
if spark_config.get("spark.eventLog.enabled", "false").lower() != "true":
38+
return
39+
event_dir = spark_config.get("spark.eventLog.dir", "")
40+
if not event_dir.startswith("s3a://"):
41+
return
42+
43+
path = event_dir[len("s3a://") :]
44+
bucket, _, prefix = path.partition("/")
45+
prefix = prefix.rstrip("/")
46+
prefix = (prefix + "/") if prefix else prefix
47+
placeholder_key = prefix + ".keep"
48+
49+
endpoint = spark_config.get(
50+
"spark.hadoop.fs.s3a.endpoint",
51+
os.environ.get("AWS_ENDPOINT_URL", ""),
52+
)
53+
access_key = spark_config.get(
54+
"spark.hadoop.fs.s3a.access.key",
55+
os.environ.get("AWS_ACCESS_KEY_ID", ""),
56+
)
57+
secret_key = spark_config.get(
58+
"spark.hadoop.fs.s3a.secret.key",
59+
os.environ.get("AWS_SECRET_ACCESS_KEY", ""),
60+
)
61+
session_token = (
62+
spark_config.get(
63+
"spark.hadoop.fs.s3a.session.token",
64+
os.environ.get("AWS_SESSION_TOKEN", ""),
65+
)
66+
or None
67+
)
68+
69+
try:
70+
if boto3 is None:
71+
raise ImportError("boto3 is not installed")
72+
73+
addressing_style = (
74+
"path"
75+
if spark_config.get(
76+
"spark.hadoop.fs.s3a.path.style.access", "false"
77+
).lower()
78+
== "true"
79+
else "auto"
80+
)
81+
82+
s3 = boto3.client(
83+
"s3",
84+
endpoint_url=endpoint if endpoint else None,
85+
aws_access_key_id=access_key or None,
86+
aws_secret_access_key=secret_key or None,
87+
aws_session_token=session_token,
88+
config=BotoConfig(
89+
signature_version="s3v4",
90+
s3={"addressing_style": addressing_style},
91+
),
92+
)
93+
resp = s3.list_objects_v2(Bucket=bucket, Prefix=prefix, MaxKeys=1)
94+
if resp.get("KeyCount", 0) == 0:
95+
s3.put_object(Bucket=bucket, Key=placeholder_key, Body=b"")
96+
logger.debug(
97+
"Created S3A event log dir placeholder: s3a://%s/%s",
98+
bucket,
99+
placeholder_key,
100+
)
101+
except Exception as exc:
102+
logger.warning(
103+
"Could not pre-create S3A event log dir s3a://%s/%s — "
104+
"SparkContext may fail if the path still doesn't exist: %s",
105+
bucket,
106+
prefix,
107+
exc,
108+
)
109+
12110

13111
def get_or_create_new_spark_session(
14112
spark_config: Optional[Dict[str, str]] = None,
@@ -17,6 +115,7 @@ def get_or_create_new_spark_session(
17115
if not spark_session:
18116
spark_builder = SparkSession.builder
19117
if spark_config:
118+
_ensure_s3a_event_log_dir(spark_config)
20119
spark_builder = spark_builder.config(
21120
conf=SparkConf().setAll([(k, v) for k, v in spark_config.items()])
22121
)
Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
from unittest.mock import MagicMock, patch
2+
3+
from feast.infra.compute_engines.spark.utils import _ensure_s3a_event_log_dir
4+
5+
BOTO3_PATH = "feast.infra.compute_engines.spark.utils.boto3"
6+
BOTOCONFIG_PATH = "feast.infra.compute_engines.spark.utils.BotoConfig"
7+
8+
9+
def _base_conf(event_log_dir: str) -> dict:
10+
return {
11+
"spark.eventLog.enabled": "true",
12+
"spark.eventLog.dir": event_log_dir,
13+
"spark.hadoop.fs.s3a.endpoint": "http://minio:9000",
14+
}
15+
16+
17+
@patch(BOTOCONFIG_PATH, MagicMock())
18+
@patch(BOTO3_PATH)
19+
def test_ensure_s3a_event_log_dir_creates_placeholder_when_empty(mock_boto3):
20+
"""S3A prefix doesn't exist -> placeholder object is written."""
21+
s3 = MagicMock()
22+
mock_boto3.client.return_value = s3
23+
s3.list_objects_v2.return_value = {"KeyCount": 0}
24+
25+
_ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/spark-events/"))
26+
27+
s3.list_objects_v2.assert_called_once_with(
28+
Bucket="my-bucket", Prefix="spark-events/", MaxKeys=1
29+
)
30+
s3.put_object.assert_called_once_with(
31+
Bucket="my-bucket", Key="spark-events/.keep", Body=b""
32+
)
33+
34+
35+
@patch(BOTOCONFIG_PATH, MagicMock())
36+
@patch(BOTO3_PATH)
37+
def test_ensure_s3a_event_log_dir_skips_when_prefix_exists(mock_boto3):
38+
"""S3A prefix already has objects -> no placeholder written."""
39+
s3 = MagicMock()
40+
mock_boto3.client.return_value = s3
41+
s3.list_objects_v2.return_value = {"KeyCount": 3}
42+
43+
_ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/spark-events/"))
44+
45+
s3.put_object.assert_not_called()
46+
47+
48+
@patch(BOTOCONFIG_PATH, MagicMock())
49+
@patch(BOTO3_PATH)
50+
def test_ensure_s3a_event_log_dir_noop_when_event_log_disabled(mock_boto3):
51+
"""spark.eventLog.enabled != true -> boto3 never called."""
52+
_ensure_s3a_event_log_dir(
53+
{"spark.eventLog.enabled": "false", "spark.eventLog.dir": "s3a://b/p/"}
54+
)
55+
mock_boto3.client.assert_not_called()
56+
57+
58+
@patch(BOTOCONFIG_PATH, MagicMock())
59+
@patch(BOTO3_PATH)
60+
def test_ensure_s3a_event_log_dir_noop_for_non_s3a_path(mock_boto3):
61+
"""Non-S3A paths (hdfs://, file://, etc.) are left untouched."""
62+
_ensure_s3a_event_log_dir(
63+
{"spark.eventLog.enabled": "true", "spark.eventLog.dir": "hdfs:///spark-logs"}
64+
)
65+
mock_boto3.client.assert_not_called()
66+
67+
68+
@patch(BOTOCONFIG_PATH, MagicMock())
69+
@patch(BOTO3_PATH)
70+
def test_ensure_s3a_event_log_dir_non_fatal_on_s3_error(mock_boto3):
71+
"""boto3 errors are swallowed -> SparkContext will surface its own error."""
72+
s3 = MagicMock()
73+
mock_boto3.client.return_value = s3
74+
s3.list_objects_v2.side_effect = Exception("connection refused")
75+
76+
_ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/spark-events/"))
77+
78+
79+
# ---------------------------------------------------------------------------
80+
# Bucket-root edge cases (s3a://bucket, s3a://bucket/)
81+
# ---------------------------------------------------------------------------
82+
83+
84+
@patch(BOTOCONFIG_PATH, MagicMock())
85+
@patch(BOTO3_PATH)
86+
def test_ensure_s3a_event_log_dir_bucket_root_no_trailing_slash(mock_boto3):
87+
"""s3a://bucket (no path) -> .keep at bucket root, not /.keep."""
88+
s3 = MagicMock()
89+
mock_boto3.client.return_value = s3
90+
s3.list_objects_v2.return_value = {"KeyCount": 0}
91+
92+
_ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket"))
93+
94+
s3.list_objects_v2.assert_called_once_with(Bucket="my-bucket", Prefix="", MaxKeys=1)
95+
s3.put_object.assert_called_once_with(Bucket="my-bucket", Key=".keep", Body=b"")
96+
97+
98+
@patch(BOTOCONFIG_PATH, MagicMock())
99+
@patch(BOTO3_PATH)
100+
def test_ensure_s3a_event_log_dir_bucket_root_trailing_slash(mock_boto3):
101+
"""s3a://bucket/ (trailing slash, empty prefix) -> .keep at bucket root."""
102+
s3 = MagicMock()
103+
mock_boto3.client.return_value = s3
104+
s3.list_objects_v2.return_value = {"KeyCount": 0}
105+
106+
_ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/"))
107+
108+
s3.list_objects_v2.assert_called_once_with(Bucket="my-bucket", Prefix="", MaxKeys=1)
109+
s3.put_object.assert_called_once_with(Bucket="my-bucket", Key=".keep", Body=b"")
110+
111+
112+
# ---------------------------------------------------------------------------
113+
# Credentials from spark config / env var fallback
114+
# ---------------------------------------------------------------------------
115+
116+
117+
@patch.dict(
118+
"os.environ",
119+
{
120+
"AWS_ACCESS_KEY_ID": "env-ak",
121+
"AWS_SECRET_ACCESS_KEY": "env-sk", # pragma: allowlist secret
122+
"AWS_SESSION_TOKEN": "env-st",
123+
},
124+
)
125+
@patch(BOTOCONFIG_PATH, MagicMock())
126+
@patch(BOTO3_PATH)
127+
def test_ensure_s3a_event_log_dir_uses_spark_config_credentials(mock_boto3):
128+
"""Credentials in spark config take precedence over env vars."""
129+
s3 = MagicMock()
130+
mock_boto3.client.return_value = s3
131+
s3.list_objects_v2.return_value = {"KeyCount": 1}
132+
133+
conf = {
134+
**_base_conf("s3a://my-bucket/logs/"),
135+
"spark.hadoop.fs.s3a.access.key": "spark-ak",
136+
"spark.hadoop.fs.s3a.secret.key": "spark-sk", # pragma: allowlist secret
137+
"spark.hadoop.fs.s3a.session.token": "spark-st",
138+
}
139+
_ensure_s3a_event_log_dir(conf)
140+
141+
mock_boto3.client.assert_called_once()
142+
kw = mock_boto3.client.call_args.kwargs
143+
assert kw["aws_access_key_id"] == "spark-ak"
144+
assert kw["aws_secret_access_key"] == "spark-sk" # pragma: allowlist secret
145+
assert kw["aws_session_token"] == "spark-st"
146+
147+
148+
@patch.dict(
149+
"os.environ",
150+
{
151+
"AWS_ACCESS_KEY_ID": "env-ak",
152+
"AWS_SECRET_ACCESS_KEY": "env-sk", # pragma: allowlist secret
153+
"AWS_SESSION_TOKEN": "env-st",
154+
},
155+
)
156+
@patch(BOTOCONFIG_PATH, MagicMock())
157+
@patch(BOTO3_PATH)
158+
def test_ensure_s3a_event_log_dir_falls_back_to_env_credentials(mock_boto3):
159+
"""Without spark config keys, env vars are used."""
160+
s3 = MagicMock()
161+
mock_boto3.client.return_value = s3
162+
s3.list_objects_v2.return_value = {"KeyCount": 1}
163+
164+
_ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/logs/"))
165+
166+
mock_boto3.client.assert_called_once()
167+
kw = mock_boto3.client.call_args.kwargs
168+
assert kw["aws_access_key_id"] == "env-ak"
169+
assert kw["aws_secret_access_key"] == "env-sk" # pragma: allowlist secret
170+
assert kw["aws_session_token"] == "env-st"
171+
172+
173+
@patch.dict("os.environ", {}, clear=True)
174+
@patch(BOTOCONFIG_PATH, MagicMock())
175+
@patch(BOTO3_PATH)
176+
def test_ensure_s3a_event_log_dir_no_credentials_passes_none(mock_boto3):
177+
"""No credentials anywhere -> None passed to boto3 (anonymous / instance role)."""
178+
s3 = MagicMock()
179+
mock_boto3.client.return_value = s3
180+
s3.list_objects_v2.return_value = {"KeyCount": 1}
181+
182+
conf = {
183+
"spark.eventLog.enabled": "true",
184+
"spark.eventLog.dir": "s3a://my-bucket/logs/",
185+
}
186+
_ensure_s3a_event_log_dir(conf)
187+
188+
mock_boto3.client.assert_called_once()
189+
kw = mock_boto3.client.call_args.kwargs
190+
assert kw["aws_access_key_id"] is None
191+
assert kw["aws_secret_access_key"] is None
192+
assert kw["aws_session_token"] is None
193+
194+
195+
# ---------------------------------------------------------------------------
196+
# Path-style addressing (MinIO / S3-compatible)
197+
# ---------------------------------------------------------------------------
198+
199+
200+
@patch(BOTOCONFIG_PATH)
201+
@patch(BOTO3_PATH)
202+
def test_ensure_s3a_event_log_dir_path_style_when_enabled(mock_boto3, mock_config_cls):
203+
"""spark.hadoop.fs.s3a.path.style.access=true -> addressing_style='path'."""
204+
s3 = MagicMock()
205+
mock_boto3.client.return_value = s3
206+
s3.list_objects_v2.return_value = {"KeyCount": 1}
207+
208+
conf = {
209+
**_base_conf("s3a://my-bucket/logs/"),
210+
"spark.hadoop.fs.s3a.path.style.access": "true",
211+
}
212+
_ensure_s3a_event_log_dir(conf)
213+
214+
mock_config_cls.assert_called_once()
215+
config_kwargs = mock_config_cls.call_args
216+
assert config_kwargs.kwargs["s3"] == {"addressing_style": "path"}
217+
218+
219+
@patch(BOTOCONFIG_PATH)
220+
@patch(BOTO3_PATH)
221+
def test_ensure_s3a_event_log_dir_virtual_hosted_style_by_default(
222+
mock_boto3, mock_config_cls
223+
):
224+
"""No path.style.access config -> addressing_style='auto'."""
225+
s3 = MagicMock()
226+
mock_boto3.client.return_value = s3
227+
s3.list_objects_v2.return_value = {"KeyCount": 1}
228+
229+
_ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/logs/"))
230+
231+
mock_config_cls.assert_called_once()
232+
config_kwargs = mock_config_cls.call_args
233+
assert config_kwargs.kwargs["s3"] == {"addressing_style": "auto"}
234+
235+
236+
# ---------------------------------------------------------------------------
237+
# Endpoint env var fallback (AWS_ENDPOINT_URL)
238+
# ---------------------------------------------------------------------------
239+
240+
241+
@patch.dict("os.environ", {"AWS_ENDPOINT_URL": "http://localhost:9000"}, clear=True)
242+
@patch(BOTOCONFIG_PATH, MagicMock())
243+
@patch(BOTO3_PATH)
244+
def test_ensure_s3a_event_log_dir_endpoint_from_env(mock_boto3):
245+
"""AWS_ENDPOINT_URL env var is used when spark config has no endpoint."""
246+
s3 = MagicMock()
247+
mock_boto3.client.return_value = s3
248+
s3.list_objects_v2.return_value = {"KeyCount": 1}
249+
250+
conf = {
251+
"spark.eventLog.enabled": "true",
252+
"spark.eventLog.dir": "s3a://my-bucket/logs/",
253+
}
254+
_ensure_s3a_event_log_dir(conf)
255+
256+
mock_boto3.client.assert_called_once()
257+
kw = mock_boto3.client.call_args.kwargs
258+
assert kw["endpoint_url"] == "http://localhost:9000"
259+
260+
261+
@patch.dict("os.environ", {"AWS_ENDPOINT_URL": "http://env-endpoint:9000"}, clear=True)
262+
@patch(BOTOCONFIG_PATH, MagicMock())
263+
@patch(BOTO3_PATH)
264+
def test_ensure_s3a_event_log_dir_spark_endpoint_over_env(mock_boto3):
265+
"""spark.hadoop.fs.s3a.endpoint takes precedence over AWS_ENDPOINT_URL."""
266+
s3 = MagicMock()
267+
mock_boto3.client.return_value = s3
268+
s3.list_objects_v2.return_value = {"KeyCount": 1}
269+
270+
_ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/logs/"))
271+
272+
mock_boto3.client.assert_called_once()
273+
kw = mock_boto3.client.call_args.kwargs
274+
assert kw["endpoint_url"] == "http://minio:9000"

0 commit comments

Comments
 (0)