Skip to content

Commit df940f3

Browse files
fix: close multi-provider parity gaps
Co-authored-by: jonathan <jonathan@taplytics.com>
1 parent c27cc8f commit df940f3

File tree

5 files changed

+1649
-571
lines changed

5 files changed

+1649
-571
lines changed

openfeature/client.py

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,11 @@ def _establish_hooks_and_provider(
429429

430430
client_metadata = self.get_metadata()
431431
provider_metadata = provider.get_metadata()
432+
provider_hooks = (
433+
[]
434+
if self._provider_uses_internal_hooks(provider)
435+
else provider.get_provider_hooks()
436+
)
432437

433438
# Hooks need to be handled in different orders at different stages
434439
# in the flag evaluation
@@ -450,7 +455,7 @@ def _establish_hooks_and_provider(
450455
get_hooks(),
451456
self.hooks,
452457
evaluation_hooks,
453-
provider.get_provider_hooks(),
458+
provider_hooks,
454459
)
455460
]
456461
# after, error, finally: Provider, Invocation, Client, API
@@ -465,6 +470,36 @@ def _establish_hooks_and_provider(
465470
merged_eval_context,
466471
)
467472

473+
def _provider_uses_internal_hooks(self, provider: FeatureProvider) -> bool:
474+
uses_internal_hooks = getattr(provider, "uses_internal_provider_hooks", None)
475+
return bool(callable(uses_internal_hooks) and uses_internal_hooks())
476+
477+
def _set_internal_provider_hook_runtime(
478+
self,
479+
provider: FeatureProvider,
480+
flag_type: FlagType,
481+
hook_hints: HookHints,
482+
) -> object | None:
483+
if not self._provider_uses_internal_hooks(provider):
484+
return None
485+
set_hook_runtime = getattr(provider, "set_internal_provider_hook_runtime", None)
486+
if not callable(set_hook_runtime):
487+
return None
488+
return set_hook_runtime(
489+
flag_type=flag_type,
490+
client_metadata=self.get_metadata(),
491+
hook_hints=hook_hints,
492+
)
493+
494+
def _reset_internal_provider_hook_runtime(
495+
self, provider: FeatureProvider, runtime_token: object | None
496+
) -> None:
497+
if runtime_token is None:
498+
return
499+
reset_hook_runtime = getattr(provider, "reset_internal_provider_hook_runtime", None)
500+
if callable(reset_hook_runtime):
501+
reset_hook_runtime(runtime_token)
502+
468503
def _assert_provider_status(
469504
self,
470505
) -> OpenFeatureError | None:
@@ -611,13 +646,21 @@ async def evaluate_flag_details_async(
611646
merged_eval_context,
612647
)
613648

614-
flag_evaluation = await self._create_provider_evaluation_async(
649+
runtime_token = self._set_internal_provider_hook_runtime(
615650
provider,
616651
flag_type,
617-
flag_key,
618-
default_value,
619-
merged_context,
652+
hook_hints,
620653
)
654+
try:
655+
flag_evaluation = await self._create_provider_evaluation_async(
656+
provider,
657+
flag_type,
658+
flag_key,
659+
default_value,
660+
merged_context,
661+
)
662+
finally:
663+
self._reset_internal_provider_hook_runtime(provider, runtime_token)
621664
if err := flag_evaluation.get_exception():
622665
error_hooks(
623666
flag_type, err, reversed_merged_hooks_and_context, hook_hints
@@ -787,13 +830,21 @@ def evaluate_flag_details(
787830
merged_eval_context,
788831
)
789832

790-
flag_evaluation = self._create_provider_evaluation(
833+
runtime_token = self._set_internal_provider_hook_runtime(
791834
provider,
792835
flag_type,
793-
flag_key,
794-
default_value,
795-
merged_context,
836+
hook_hints,
796837
)
838+
try:
839+
flag_evaluation = self._create_provider_evaluation(
840+
provider,
841+
flag_type,
842+
flag_key,
843+
default_value,
844+
merged_context,
845+
)
846+
finally:
847+
self._reset_internal_provider_hook_runtime(provider, runtime_token)
797848
if err := flag_evaluation.get_exception():
798849
error_hooks(
799850
flag_type, err, reversed_merged_hooks_and_context, hook_hints

openfeature/provider/__init__.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,17 @@
1111
from openfeature.hook import Hook
1212

1313
from .metadata import Metadata
14-
from .multi_provider import (
15-
EvaluationStrategy,
16-
FirstMatchStrategy,
17-
MultiProvider,
18-
ProviderEntry,
19-
)
2014

2115
if typing.TYPE_CHECKING:
2216
from openfeature.flag_evaluation import FlagValueType
2317

2418
__all__ = [
2519
"AbstractProvider",
20+
"ComparisonStrategy",
2621
"EvaluationStrategy",
2722
"FeatureProvider",
2823
"FirstMatchStrategy",
24+
"FirstSuccessfulStrategy",
2925
"Metadata",
3026
"MultiProvider",
3127
"ProviderEntry",
@@ -262,3 +258,13 @@ def emit_provider_stale(self, details: ProviderEventDetails) -> None:
262258
def emit(self, event: ProviderEvent, details: ProviderEventDetails) -> None:
263259
if hasattr(self, "_on_emit"):
264260
self._on_emit(self, event, details)
261+
262+
263+
from .multi_provider import ( # noqa: E402
264+
ComparisonStrategy,
265+
EvaluationStrategy,
266+
FirstMatchStrategy,
267+
FirstSuccessfulStrategy,
268+
MultiProvider,
269+
ProviderEntry,
270+
)

openfeature/provider/_registry.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -80,23 +80,25 @@ def _initialize_provider(self, provider: FeatureProvider) -> None:
8080
try:
8181
if hasattr(provider, "initialize"):
8282
provider.initialize(self._get_evaluation_context())
83-
self.dispatch_event(
84-
provider, ProviderEvent.PROVIDER_READY, ProviderEventDetails()
85-
)
83+
if self.get_provider_status(provider) == ProviderStatus.NOT_READY:
84+
self.dispatch_event(
85+
provider, ProviderEvent.PROVIDER_READY, ProviderEventDetails()
86+
)
8687
except Exception as err:
8788
error_code = (
8889
err.error_code
8990
if isinstance(err, OpenFeatureError)
9091
else ErrorCode.GENERAL
9192
)
92-
self.dispatch_event(
93-
provider,
94-
ProviderEvent.PROVIDER_ERROR,
95-
ProviderEventDetails(
96-
message=f"Provider initialization failed: {err}",
97-
error_code=error_code,
98-
),
99-
)
93+
if self.get_provider_status(provider) == ProviderStatus.NOT_READY:
94+
self.dispatch_event(
95+
provider,
96+
ProviderEvent.PROVIDER_ERROR,
97+
ProviderEventDetails(
98+
message=f"Provider initialization failed: {err}",
99+
error_code=error_code,
100+
),
101+
)
100102

101103
def _shutdown_provider(self, provider: FeatureProvider) -> None:
102104
try:
@@ -115,6 +117,11 @@ def _shutdown_provider(self, provider: FeatureProvider) -> None:
115117
provider.detach()
116118

117119
def get_provider_status(self, provider: FeatureProvider) -> ProviderStatus:
120+
provider_status_getter = getattr(provider, "get_status", None)
121+
if callable(provider_status_getter):
122+
status = provider_status_getter()
123+
if isinstance(status, ProviderStatus):
124+
return status
118125
return self._provider_status.get(provider, ProviderStatus.NOT_READY)
119126

120127
def dispatch_event(

0 commit comments

Comments
 (0)