Skip to content

Commit 0800f13

Browse files
committed
feat: Implement Databricks Unity Catalog offline store integration
Signed-off-by: Abhishek Shinde <norizzabhii@gmail.com>
1 parent 7619222 commit 0800f13

3 files changed

Lines changed: 494 additions & 0 deletions

File tree

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
import logging
2+
from datetime import date, datetime
3+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4+
5+
import pandas as pd
6+
import pyarrow
7+
import pyspark
8+
from pydantic import StrictStr
9+
from pyspark import SparkConf
10+
from pyspark.sql import SparkSession
11+
12+
from feast import FeatureView
13+
from feast.data_source import DataSource
14+
from feast.infra.offline_stores.contrib.spark_offline_store.spark import (
15+
SparkOfflineStore,
16+
SparkOfflineStoreConfig,
17+
)
18+
from feast.infra.offline_stores.offline_store import RetrievalJob
19+
from feast.infra.registry.base_registry import BaseRegistry
20+
from feast.repo_config import RepoConfig
21+
22+
logger = logging.getLogger(__name__)
23+
24+
25+
class DatabricksUCOfflineStoreConfig(SparkOfflineStoreConfig):
26+
type: StrictStr = "databricks_uc"
27+
"""Offline store type selector"""
28+
29+
workspace_host: Optional[StrictStr] = None
30+
"""Databricks workspace host (e.g. adb-xxxx.azuredatabricks.net)"""
31+
32+
token: Optional[StrictStr] = None
33+
"""Databricks Personal Access Token (PAT)"""
34+
35+
cluster_id: Optional[StrictStr] = None
36+
"""Databricks Cluster ID to connect to for Databricks Connect"""
37+
38+
default_catalog: Optional[StrictStr] = None
39+
"""Default catalog name to use in Unity Catalog"""
40+
41+
default_schema: Optional[StrictStr] = None
42+
"""Default schema name to use in Unity Catalog"""
43+
44+
45+
def get_databricks_session(
46+
store_config: DatabricksUCOfflineStoreConfig,
47+
) -> SparkSession:
48+
# Check if there is already an active session
49+
spark_session = SparkSession.getActiveSession()
50+
if not spark_session:
51+
workspace_host = store_config.workspace_host
52+
token = store_config.token
53+
cluster_id = store_config.cluster_id
54+
55+
# Clean host URL if it starts with https://
56+
if workspace_host:
57+
if workspace_host.startswith("https://"):
58+
workspace_host = workspace_host[8:]
59+
elif workspace_host.startswith("http://"):
60+
workspace_host = workspace_host[7:]
61+
62+
if workspace_host and cluster_id:
63+
# Databricks Connect V2 initialization (Spark Connect URI format)
64+
conn_str = f"sc://{workspace_host}:443/"
65+
params = []
66+
if token:
67+
params.append(f"token={token}")
68+
params.append(f"x-databricks-cluster-id={cluster_id}")
69+
if params:
70+
conn_str = f"{conn_str};{';'.join(params)}"
71+
72+
try:
73+
from databricks.connect import DatabricksSession
74+
75+
builder = DatabricksSession.builder.remote(conn_str)
76+
except ImportError:
77+
# Fallback to standard PySpark remote connect if databricks-connect not installed
78+
builder = SparkSession.builder.remote(conn_str)
79+
else:
80+
try:
81+
from databricks.connect import DatabricksSession
82+
83+
builder = DatabricksSession.builder
84+
except ImportError:
85+
builder = SparkSession.builder
86+
87+
spark_conf = store_config.spark_conf
88+
if spark_conf:
89+
builder = builder.config(
90+
conf=SparkConf().setAll([(k, v) for k, v in spark_conf.items()])
91+
)
92+
93+
spark_session = builder.getOrCreate()
94+
95+
# Apply configuration defaults
96+
spark_session.conf.set("spark.sql.parser.quotedRegexColumnNames", "true")
97+
98+
if store_config.default_catalog:
99+
spark_session.sql(f"USE CATALOG `{store_config.default_catalog}`")
100+
if store_config.default_schema:
101+
spark_session.sql(f"USE SCHEMA `{store_config.default_schema}`")
102+
103+
return spark_session
104+
105+
106+
class DatabricksUCOfflineStore(SparkOfflineStore):
107+
@staticmethod
108+
def pull_latest_from_table_or_query(
109+
config: RepoConfig,
110+
data_source: DataSource,
111+
join_key_columns: List[str],
112+
feature_name_columns: List[str],
113+
timestamp_field: str,
114+
created_timestamp_column: Optional[str],
115+
start_date: datetime,
116+
end_date: datetime,
117+
) -> RetrievalJob:
118+
assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig)
119+
# Initialize/Retrieve the Databricks Spark Session so it's registered as active
120+
get_databricks_session(config.offline_store)
121+
122+
return SparkOfflineStore.pull_latest_from_table_or_query(
123+
config=config,
124+
data_source=data_source,
125+
join_key_columns=join_key_columns,
126+
feature_name_columns=feature_name_columns,
127+
timestamp_field=timestamp_field,
128+
created_timestamp_column=created_timestamp_column,
129+
start_date=start_date,
130+
end_date=end_date,
131+
)
132+
133+
@staticmethod
134+
def get_historical_features(
135+
config: RepoConfig,
136+
feature_views: List[FeatureView],
137+
feature_refs: List[str],
138+
entity_df: Optional[Union[pd.DataFrame, str, pyspark.sql.DataFrame]],
139+
registry: BaseRegistry,
140+
project: str,
141+
full_feature_names: bool = False,
142+
**kwargs,
143+
) -> RetrievalJob:
144+
assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig)
145+
get_databricks_session(config.offline_store)
146+
147+
return SparkOfflineStore.get_historical_features(
148+
config=config,
149+
feature_views=feature_views,
150+
feature_refs=feature_refs,
151+
entity_df=entity_df,
152+
registry=registry,
153+
project=project,
154+
full_feature_names=full_feature_names,
155+
**kwargs,
156+
)
157+
158+
@staticmethod
159+
def pull_all_from_table_or_query(
160+
config: RepoConfig,
161+
data_source: DataSource,
162+
join_key_columns: List[str],
163+
feature_name_columns: List[str],
164+
timestamp_field: str,
165+
created_timestamp_column: Optional[str] = None,
166+
start_date: Optional[datetime] = None,
167+
end_date: Optional[datetime] = None,
168+
) -> RetrievalJob:
169+
assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig)
170+
get_databricks_session(config.offline_store)
171+
172+
return SparkOfflineStore.pull_all_from_table_or_query(
173+
config=config,
174+
data_source=data_source,
175+
join_key_columns=join_key_columns,
176+
feature_name_columns=feature_name_columns,
177+
timestamp_field=timestamp_field,
178+
created_timestamp_column=created_timestamp_column,
179+
start_date=start_date,
180+
end_date=end_date,
181+
)
182+
183+
@staticmethod
184+
def offline_write_batch(
185+
config: RepoConfig,
186+
feature_view: FeatureView,
187+
table: pyarrow.Table,
188+
progress: Optional[Callable[[int], Any]],
189+
):
190+
assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig)
191+
get_databricks_session(config.offline_store)
192+
193+
return SparkOfflineStore.offline_write_batch(
194+
config=config,
195+
feature_view=feature_view,
196+
table=table,
197+
progress=progress,
198+
)
199+
200+
@staticmethod
201+
def compute_monitoring_metrics(
202+
config: RepoConfig,
203+
data_source: DataSource,
204+
feature_columns: List[Tuple[str, str]],
205+
timestamp_field: str,
206+
start_date: Optional[datetime] = None,
207+
end_date: Optional[datetime] = None,
208+
histogram_bins: int = 20,
209+
top_n: int = 10,
210+
) -> List[Dict[str, Any]]:
211+
assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig)
212+
get_databricks_session(config.offline_store)
213+
214+
return SparkOfflineStore.compute_monitoring_metrics(
215+
config=config,
216+
data_source=data_source,
217+
feature_columns=feature_columns,
218+
timestamp_field=timestamp_field,
219+
start_date=start_date,
220+
end_date=end_date,
221+
histogram_bins=histogram_bins,
222+
top_n=top_n,
223+
)
224+
225+
@staticmethod
226+
def get_monitoring_max_timestamp(
227+
config: RepoConfig,
228+
data_source: DataSource,
229+
timestamp_field: str,
230+
) -> Optional[datetime]:
231+
assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig)
232+
get_databricks_session(config.offline_store)
233+
234+
return SparkOfflineStore.get_monitoring_max_timestamp(
235+
config=config,
236+
data_source=data_source,
237+
timestamp_field=timestamp_field,
238+
)
239+
240+
@staticmethod
241+
def ensure_monitoring_tables(config: RepoConfig) -> None:
242+
assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig)
243+
get_databricks_session(config.offline_store)
244+
245+
return SparkOfflineStore.ensure_monitoring_tables(config=config)
246+
247+
@staticmethod
248+
def save_monitoring_metrics(
249+
config: RepoConfig,
250+
metric_type: str,
251+
metrics: List[Dict[str, Any]],
252+
) -> None:
253+
assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig)
254+
get_databricks_session(config.offline_store)
255+
256+
return SparkOfflineStore.save_monitoring_metrics(
257+
config=config,
258+
metric_type=metric_type,
259+
metrics=metrics,
260+
)
261+
262+
@staticmethod
263+
def query_monitoring_metrics(
264+
config: RepoConfig,
265+
project: str,
266+
metric_type: str,
267+
filters: Optional[Dict[str, Any]] = None,
268+
start_date: Optional[date] = None,
269+
end_date: Optional[date] = None,
270+
) -> List[Dict[str, Any]]:
271+
assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig)
272+
get_databricks_session(config.offline_store)
273+
274+
return SparkOfflineStore.query_monitoring_metrics(
275+
config=config,
276+
project=project,
277+
metric_type=metric_type,
278+
filters=filters,
279+
start_date=start_date,
280+
end_date=end_date,
281+
)
282+
283+
@staticmethod
284+
def clear_monitoring_baseline(
285+
config: RepoConfig,
286+
project: str,
287+
feature_view_name: Optional[str] = None,
288+
feature_name: Optional[str] = None,
289+
data_source_type: Optional[str] = None,
290+
) -> None:
291+
assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig)
292+
get_databricks_session(config.offline_store)
293+
294+
return SparkOfflineStore.clear_monitoring_baseline(
295+
config=config,
296+
project=project,
297+
feature_view_name=feature_view_name,
298+
feature_name=feature_name,
299+
data_source_type=data_source_type,
300+
)

sdk/python/feast/repo_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
"redshift": "feast.infra.offline_stores.redshift.RedshiftOfflineStore",
9696
"snowflake.offline": "feast.infra.offline_stores.snowflake.SnowflakeOfflineStore",
9797
"spark": "feast.infra.offline_stores.contrib.spark_offline_store.spark.SparkOfflineStore",
98+
"databricks_uc": "feast.infra.offline_stores.contrib.spark_offline_store.databricks_uc.DatabricksUCOfflineStore",
9899
"trino": "feast.infra.offline_stores.contrib.trino_offline_store.trino.TrinoOfflineStore",
99100
"postgres": "feast.infra.offline_stores.contrib.postgres_offline_store.postgres.PostgreSQLOfflineStore",
100101
"athena": "feast.infra.offline_stores.contrib.athena_offline_store.athena.AthenaOfflineStore",

0 commit comments

Comments
 (0)