Skip to content

Commit 6901a76

Browse files
committed
refactor(client): bind OpenFeatureClient to its API instance
1 parent d8b805c commit 6901a76

File tree

1 file changed

+31
-22
lines changed

1 file changed

+31
-22
lines changed

openfeature/client.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
from __future__ import annotations
2+
13
import logging
24
import typing
35
from collections.abc import Awaitable, Mapping, Sequence
46
from dataclasses import dataclass
57
from itertools import chain
68

7-
from openfeature import _event_support
8-
from openfeature.evaluation_context import EvaluationContext, get_evaluation_context
9+
from openfeature.evaluation_context import EvaluationContext
910
from openfeature.event import EventHandler, ProviderEvent
1011
from openfeature.exception import (
1112
ErrorCode,
@@ -23,17 +24,18 @@
2324
FlagValueType,
2425
Reason,
2526
)
26-
from openfeature.hook import Hook, HookContext, HookHints, get_hooks
27+
from openfeature.hook import Hook, HookContext, HookHints
2728
from openfeature.hook._hook_support import (
2829
after_all_hooks,
2930
after_hooks,
3031
before_hooks,
3132
error_hooks,
3233
)
3334
from openfeature.provider import FeatureProvider, ProviderStatus
34-
from openfeature.provider._registry import provider_registry
3535
from openfeature.track import TrackingEventDetails
36-
from openfeature.transaction_context import get_transaction_context
36+
37+
if typing.TYPE_CHECKING:
38+
from openfeature._api import OpenFeatureAPI
3739

3840
__all__ = [
3941
"ClientMetadata",
@@ -81,18 +83,25 @@ def __init__(
8183
version: str | None,
8284
context: EvaluationContext | None = None,
8385
hooks: list[Hook] | None = None,
86+
api: OpenFeatureAPI | None = None,
8487
) -> None:
8588
self.domain = domain
8689
self.version = version
8790
self.context = context or EvaluationContext()
8891
self.hooks = hooks or []
92+
if api is not None:
93+
self._api = api
94+
else:
95+
from openfeature._api import _default_api # noqa: PLC0415
96+
97+
self._api = _default_api
8998

9099
@property
91100
def provider(self) -> FeatureProvider:
92-
return provider_registry.get_provider(self.domain)
101+
return self._api._provider_registry.get_provider(self.domain)
93102

94103
def get_provider_status(self) -> ProviderStatus:
95-
return provider_registry.get_provider_status(self.provider)
104+
return self._api._provider_registry.get_provider_status(self.provider)
96105

97106
def get_metadata(self) -> ClientMetadata:
98107
return ClientMetadata(domain=self.domain)
@@ -422,8 +431,8 @@ def _establish_hooks_and_provider(
422431
# Merge transaction context into evaluation context before creating hook_context
423432
# This ensures hooks have access to the complete context including transaction context
424433
merged_eval_context = (
425-
get_evaluation_context()
426-
.merge(get_transaction_context())
434+
self._api.get_evaluation_context()
435+
.merge(self._api.get_transaction_context())
427436
.merge(self.context)
428437
.merge(evaluation_context)
429438
)
@@ -448,7 +457,7 @@ def _establish_hooks_and_provider(
448457
),
449458
)
450459
for hook in chain(
451-
get_hooks(),
460+
self._api.get_hooks(),
452461
self.hooks,
453462
evaluation_hooks,
454463
provider.get_provider_hooks(),
@@ -540,20 +549,20 @@ async def evaluate_flag_details_async(
540549
self,
541550
flag_type: FlagType,
542551
flag_key: str,
543-
default_value: Sequence["FlagValueType"],
552+
default_value: Sequence[FlagValueType],
544553
evaluation_context: EvaluationContext | None = None,
545554
flag_evaluation_options: FlagEvaluationOptions | None = None,
546-
) -> FlagEvaluationDetails[Sequence["FlagValueType"]]: ...
555+
) -> FlagEvaluationDetails[Sequence[FlagValueType]]: ...
547556

548557
@typing.overload
549558
async def evaluate_flag_details_async(
550559
self,
551560
flag_type: FlagType,
552561
flag_key: str,
553-
default_value: Mapping[str, "FlagValueType"],
562+
default_value: Mapping[str, FlagValueType],
554563
evaluation_context: EvaluationContext | None = None,
555564
flag_evaluation_options: FlagEvaluationOptions | None = None,
556-
) -> FlagEvaluationDetails[Mapping[str, "FlagValueType"]]: ...
565+
) -> FlagEvaluationDetails[Mapping[str, FlagValueType]]: ...
557566

558567
async def evaluate_flag_details_async(
559568
self,
@@ -716,20 +725,20 @@ def evaluate_flag_details(
716725
self,
717726
flag_type: FlagType,
718727
flag_key: str,
719-
default_value: Sequence["FlagValueType"],
728+
default_value: Sequence[FlagValueType],
720729
evaluation_context: EvaluationContext | None = None,
721730
flag_evaluation_options: FlagEvaluationOptions | None = None,
722-
) -> FlagEvaluationDetails[Sequence["FlagValueType"]]: ...
731+
) -> FlagEvaluationDetails[Sequence[FlagValueType]]: ...
723732

724733
@typing.overload
725734
def evaluate_flag_details(
726735
self,
727736
flag_type: FlagType,
728737
flag_key: str,
729-
default_value: Mapping[str, "FlagValueType"],
738+
default_value: Mapping[str, FlagValueType],
730739
evaluation_context: EvaluationContext | None = None,
731740
flag_evaluation_options: FlagEvaluationOptions | None = None,
732-
) -> FlagEvaluationDetails[Mapping[str, "FlagValueType"]]: ...
741+
) -> FlagEvaluationDetails[Mapping[str, FlagValueType]]: ...
733742

734743
def evaluate_flag_details(
735744
self,
@@ -951,10 +960,10 @@ def _create_provider_evaluation(
951960
return resolution.to_flag_evaluation_details(flag_key)
952961

953962
def add_handler(self, event: ProviderEvent, handler: EventHandler) -> None:
954-
_event_support.add_client_handler(self, event, handler)
963+
self._api._event_support.add_client_handler(self, event, handler)
955964

956965
def remove_handler(self, event: ProviderEvent, handler: EventHandler) -> None:
957-
_event_support.remove_client_handler(self, event, handler)
966+
self._api._event_support.remove_client_handler(self, event, handler)
958967

959968
def track(
960969
self,
@@ -974,8 +983,8 @@ def track(
974983
evaluation_context = EvaluationContext()
975984

976985
merged_eval_context = (
977-
get_evaluation_context()
978-
.merge(get_transaction_context())
986+
self._api.get_evaluation_context()
987+
.merge(self._api.get_transaction_context())
979988
.merge(self.context)
980989
.merge(evaluation_context)
981990
)

0 commit comments

Comments
 (0)