Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion packages/sdk/server-ai/src/ldai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,10 +832,12 @@ def __evaluate(
if 'model' in variation and isinstance(variation['model'], dict):
parameters = variation['model'].get('parameters', None)
custom = variation['model'].get('custom', None)
region = variation['model'].get('region', None)
model = ModelConfig(
name=variation['model']['name'],
parameters=parameters,
custom=custom
custom=custom,
region=region,
)

variation_key = variation.get('_ldMeta', {}).get('variationKey', '')
Expand Down
12 changes: 11 additions & 1 deletion packages/sdk/server-ai/src/ldai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,17 @@ class ModelConfig:
Configuration related to the model.
"""

def __init__(self, name: str, parameters: Optional[Dict[str, Any]] = None, custom: Optional[Dict[str, Any]] = None):
def __init__(self, name: str, parameters: Optional[Dict[str, Any]] = None, custom: Optional[Dict[str, Any]] = None, region: Optional[str] = None):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like a lint issue with the length of this line.

"""
:param name: The name of the model.
:param parameters: Additional model-specific parameters.
:param custom: Additional customer provided data.
:param region: The region the model is deployed in.
"""
self._name = name
self._parameters = parameters
self._custom = custom
self._region = region

@property
def name(self) -> str:
Expand Down Expand Up @@ -93,6 +95,13 @@ def get_custom(self, key: str) -> Any:

return self._custom.get(key)

@property
def region(self) -> Optional[str]:
"""
The region the model is deployed in.
"""
return self._region

def to_dict(self) -> dict:
"""
Render the given model config as a dictionary object.
Expand All @@ -101,6 +110,7 @@ def to_dict(self) -> dict:
'name': self._name,
'parameters': self._parameters,
'custom': self._custom,
'region': self._region,
}


Expand Down
47 changes: 47 additions & 0 deletions packages/sdk/server-ai/tests/test_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,23 @@ def td() -> TestData:
.variation_for_all(0)
)

td.update(
td.flag('model-config-with-region')
.variations(
{
'model': {
'name': 'anthropic.claude-opus-4-7',
'parameters': {},
'region': 'us',
},
'provider': {'name': 'Bedrock'},
'messages': [{'role': 'system', 'content': 'Hello!'}],
'_ldMeta': {'enabled': True, 'variationKey': 'us-variation', 'version': 1},
},
)
.variation_for_all(0)
)

td.update(
td.flag('multiple-messages')
.variations(
Expand Down Expand Up @@ -482,6 +499,36 @@ def test_create_tracker_preserves_config_metadata():
assert 'runId' in track_data


def test_model_config_region():
model = ModelConfig('fakeModel', region='us')
assert model.region == 'us'


def test_model_config_region_defaults_to_none():
model = ModelConfig('fakeModel')
assert model.region is None


def test_model_config_region_from_flag(ldai_client: LDAIClient):
context = Context.create('user-key')
default = AICompletionConfigDefault(enabled=True, model=ModelConfig('fake-model'), messages=[])

config = ldai_client.completion_config('model-config-with-region', context, default)

assert config.model is not None
assert config.model.region == 'us'


def test_model_config_no_region_is_none(ldai_client: LDAIClient):
context = Context.create('user-key')
default = AICompletionConfigDefault(enabled=True, model=ModelConfig('fake-model'), messages=[])

config = ldai_client.completion_config('model-config', context, default)

assert config.model is not None
assert config.model.region is None


def test_create_tracker_each_call_has_different_run_id():
from unittest.mock import Mock

Expand Down
Loading