Skip to content

Commit d277b49

Browse files
committed
Addressed PR feedback; fixed race condition
1 parent 00a265e commit d277b49

2 files changed

Lines changed: 61 additions & 13 deletions

File tree

packages/sdk/server-ai/src/ldai/client.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def completion_config(
4040
"""
4141
self._client.track('$ld:ai:config:function:single', context, key, 1)
4242

43-
model, provider, messages, instructions, tracker, enabled, judge_configuration = self.__evaluate(
43+
model, provider, messages, instructions, tracker, enabled, judge_configuration, _ = self.__evaluate(
4444
key, context, default_value.to_dict(), variables
4545
)
4646

@@ -94,31 +94,29 @@ def judge_config(
9494
"""
9595
self._client.track('$ld:ai:judge:function:single', context, key, 1)
9696

97-
model, provider, messages, instructions, tracker, enabled, judge_configuration = self.__evaluate(
97+
model, provider, messages, instructions, tracker, enabled, judge_configuration, variation = self.__evaluate(
9898
key, context, default_value.to_dict(), variables
9999
)
100100

101-
variation = self._client.variation(key, context, default_value.to_dict())
102-
103101
def _extract_evaluation_metric_key(
104102
variation: Dict[str, Any], default_value: AIJudgeConfigDefault
105103
) -> Optional[str]:
106104
"""
107105
Extract evaluation_metric_key with backward compatibility.
108106
109-
Priority: 1) evaluationMetricKey from variation, 2) evaluation_metric_key from default,
110-
3) first from evaluationMetricKeys in variation, 4) first from evaluation_metric_keys in default
107+
Priority: 1) evaluationMetricKey from variation, 2) evaluationMetricKeys from variation,
108+
3) evaluation_metric_key from default, 4) evaluation_metric_keys from default
111109
"""
112110
if evaluation_metric_key := variation.get('evaluationMetricKey'):
113111
return evaluation_metric_key
114112

115-
if default_value.evaluation_metric_key:
116-
return default_value.evaluation_metric_key
117-
118113
variation_keys = variation.get('evaluationMetricKeys')
119114
if isinstance(variation_keys, list) and variation_keys:
120115
return variation_keys[0]
121116

117+
if default_value.evaluation_metric_key:
118+
return default_value.evaluation_metric_key
119+
122120
if default_value.evaluation_metric_keys:
123121
return default_value.evaluation_metric_keys[0]
124122

@@ -458,7 +456,7 @@ def __evaluate(
458456
variables: Optional[Dict[str, Any]] = None,
459457
) -> Tuple[
460458
Optional[ModelConfig], Optional[ProviderConfig], Optional[List[LDMessage]],
461-
Optional[str], LDAIConfigTracker, bool, Optional[Any]
459+
Optional[str], LDAIConfigTracker, bool, Optional[Any], Dict[str, Any]
462460
]:
463461
"""
464462
Internal method to evaluate a configuration and extract components.
@@ -467,7 +465,7 @@ def __evaluate(
467465
:param context: The evaluation context.
468466
:param default_dict: Default configuration as dictionary.
469467
:param variables: Variables for interpolation.
470-
:return: Tuple of (model, provider, messages, instructions, tracker, enabled).
468+
:return: Tuple of (model, provider, messages, instructions, tracker, enabled, judge_configuration, variation).
471469
"""
472470
variation = self._client.variation(key, context, default_dict)
473471

@@ -536,7 +534,7 @@ def __evaluate(
536534
if judges:
537535
judge_configuration = JudgeConfiguration(judges=judges)
538536

539-
return model, provider_config, messages, instructions, tracker, enabled, judge_configuration
537+
return model, provider_config, messages, instructions, tracker, enabled, judge_configuration, variation
540538

541539
def __evaluate_agent(
542540
self,
@@ -554,7 +552,7 @@ def __evaluate_agent(
554552
:param variables: Variables for interpolation.
555553
:return: Configured AIAgentConfig instance.
556554
"""
557-
model, provider, messages, instructions, tracker, enabled, judge_configuration = self.__evaluate(
555+
model, provider, messages, instructions, tracker, enabled, judge_configuration, _ = self.__evaluate(
558556
key, context, default_value.to_dict(), variables
559557
)
560558

packages/sdk/server-ai/tests/test_judge.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,3 +595,53 @@ def test_judge_config_prefers_evaluation_metric_key_over_keys(
595595

596596
assert config is not None
597597
assert config.evaluation_metric_key == '$ld:ai:judge:preferred'
598+
599+
def test_judge_config_uses_same_variation_for_consistency(
600+
self, context: Context
601+
):
602+
"""judge_config should use the same variation from __evaluate to avoid race conditions."""
603+
from ldai import LDAIClient
604+
from ldclient import Config, LDClient
605+
from ldclient.integrations.test_data import TestData
606+
from unittest.mock import patch
607+
608+
td = TestData.data_source()
609+
td.update(
610+
td.flag('judge-consistency-test')
611+
.variations(
612+
{
613+
'model': {'name': 'gpt-4'},
614+
'provider': {'name': 'openai'},
615+
'messages': [{'role': 'system', 'content': 'You are a judge.'}],
616+
'evaluationMetricKey': '$ld:ai:judge:from-flag',
617+
'_ldMeta': {'enabled': True, 'variationKey': 'judge-v1', 'version': 1},
618+
}
619+
)
620+
.variation_for_all(0)
621+
)
622+
623+
test_client = LDClient(Config('sdk-key', update_processor_class=td, send_events=False))
624+
ldai_client = LDAIClient(test_client)
625+
626+
default_value = AIJudgeConfigDefault(
627+
enabled=True,
628+
evaluation_metric_key='$ld:ai:judge:from-default',
629+
messages=[LDMessage(role='system', content='You are a judge.')],
630+
model=ModelConfig('gpt-4'),
631+
provider=ProviderConfig('openai'),
632+
)
633+
634+
variation_calls = []
635+
original_variation = test_client.variation
636+
637+
def tracked_variation(key, context, default):
638+
result = original_variation(key, context, default)
639+
variation_calls.append((key, result.get('evaluationMetricKey')))
640+
return result
641+
642+
with patch.object(test_client, 'variation', side_effect=tracked_variation):
643+
config = ldai_client.judge_config('judge-consistency-test', context, default_value)
644+
645+
assert len(variation_calls) == 1, f"Expected 1 variation call, got {len(variation_calls)}"
646+
assert config is not None
647+
assert config.evaluation_metric_key == '$ld:ai:judge:from-flag'

0 commit comments

Comments
 (0)