Skip to content

Commit ca3eb14

Browse files
committed
refactor: split scaling group component fixtures
1 parent 88cdf47 commit ca3eb14

30 files changed

Lines changed: 385 additions & 366 deletions

src/ai/backend/testutils/fixtures.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,10 @@
44
from dataclasses import dataclass
55

66
from ai.backend.common.identifier.domain import DomainID, DomainName
7-
from ai.backend.common.identifier.resource_group import ResourceGroupID, ResourceGroupName
87

98
__all__ = (
109
"DomainFactory",
1110
"DomainFixtureData",
12-
"ScalingGroupFixtureData",
1311
)
1412

1513

@@ -19,10 +17,4 @@ class DomainFixtureData:
1917
domain_id: DomainID
2018

2119

22-
@dataclass(frozen=True)
23-
class ScalingGroupFixtureData:
24-
scaling_group_name: ResourceGroupName
25-
scaling_group_id: ResourceGroupID
26-
27-
2820
DomainFactory = Callable[..., Awaitable[DomainFixtureData]]

tests/component/agent_api/conftest.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from ai.backend.common.etcd import AsyncEtcd, ConfigScopes
1313
from ai.backend.common.events.dispatcher import EventProducer
14+
from ai.backend.common.identifier.resource_group import ResourceGroupName
1415
from ai.backend.common.plugin.hook import HookPluginContext
1516
from ai.backend.common.types import HostPortPair, ResourceSlot
1617
from ai.backend.manager.actions.validators import ActionValidators
@@ -28,7 +29,6 @@
2829
from ai.backend.manager.repositories.scheduler.repository import SchedulerRepository
2930
from ai.backend.manager.services.agent.processors import AgentProcessors
3031
from ai.backend.manager.services.agent.service import AgentService
31-
from ai.backend.testutils.fixtures import ScalingGroupFixtureData
3232

3333

3434
@pytest.fixture()
@@ -106,11 +106,11 @@ def server_module_registries(
106106
@pytest.fixture()
107107
async def agent_fixture(
108108
db_engine: SAEngine,
109-
scaling_group_fixture: ScalingGroupFixtureData,
109+
scaling_group_name: ResourceGroupName,
110110
) -> AsyncIterator[str]:
111111
"""Insert a test agent row and yield its ID.
112112
113-
The agent references the scaling_group_fixture via FK.
113+
The agent references the scaling_group_name via FK.
114114
Teardown deletes the agent row (cascade deletes agent_resources).
115115
"""
116116
agent_id = f"i-test-agent-{secrets.token_hex(6)}"
@@ -120,7 +120,7 @@ async def agent_fixture(
120120
id=agent_id,
121121
status=AgentStatus.ALIVE,
122122
region="local",
123-
scaling_group=scaling_group_fixture.scaling_group_name,
123+
scaling_group=scaling_group_name,
124124
schedulable=True,
125125
available_slots=ResourceSlot({"cpu": "4", "mem": "8589934592"}),
126126
occupied_slots=ResourceSlot(),

tests/component/agent_api/test_agent_api.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
OrderDirection,
2727
)
2828
from ai.backend.common.dto.manager.query import StringFilter
29-
from ai.backend.testutils.fixtures import ScalingGroupFixtureData
29+
from ai.backend.common.identifier.resource_group import ResourceGroupName
3030

3131

3232
class TestSearchAgents:
@@ -84,10 +84,10 @@ async def test_admin_searches_agents_with_resource_group_filter(
8484
self,
8585
admin_registry: BackendAIClientRegistry,
8686
agent_fixture: str,
87-
scaling_group_fixture: ScalingGroupFixtureData,
87+
scaling_group_name: ResourceGroupName,
8888
) -> None:
8989
"""Filtering by resource_group returns agents in that scaling group."""
90-
sgroup_name = scaling_group_fixture.scaling_group_name
90+
sgroup_name = scaling_group_name
9191
result = await admin_registry.agent.search_agents(
9292
SearchAgentsRequest(
9393
filter=AgentFilter(
@@ -140,11 +140,11 @@ async def test_compound_filters_status_and_resource_group(
140140
self,
141141
admin_registry: BackendAIClientRegistry,
142142
agent_fixture: str,
143-
scaling_group_fixture: ScalingGroupFixtureData,
143+
scaling_group_name: ResourceGroupName,
144144
) -> None:
145145
"""Compound filters (status + resource_group) return intersection of conditions."""
146146
# Search with both status=ALIVE and resource_group filters
147-
sgroup_name = scaling_group_fixture.scaling_group_name
147+
sgroup_name = scaling_group_name
148148
result = await admin_registry.agent.search_agents(
149149
SearchAgentsRequest(
150150
filter=AgentFilter(

tests/component/auto_scaling_rule/conftest.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from ai.backend.common.container_registry import ContainerRegistryType
1313
from ai.backend.common.data.endpoint.types import EndpointLifecycle
14+
from ai.backend.common.identifier.resource_group import ResourceGroupName
1415
from ai.backend.manager.actions.validators import ActionValidators
1516
from ai.backend.manager.actions.validators.rbac import RBACValidators
1617
from ai.backend.manager.actions.validators.rbac.bulk import BulkActionRBACValidator
@@ -37,7 +38,7 @@
3738
)
3839
from ai.backend.manager.services.deployment.processors import DeploymentProcessors
3940
from ai.backend.manager.services.deployment.service import DeploymentService
40-
from ai.backend.testutils.fixtures import DomainFixtureData, ScalingGroupFixtureData
41+
from ai.backend.testutils.fixtures import DomainFixtureData
4142

4243

4344
@dataclass
@@ -125,7 +126,7 @@ async def model_deployment_fixture(
125126
db_engine: SAEngine,
126127
domain_fixture: DomainFixtureData,
127128
group_fixture: uuid.UUID,
128-
scaling_group_fixture: ScalingGroupFixtureData,
129+
scaling_group_name: ResourceGroupName,
129130
admin_user_fixture: UserFixtureData,
130131
) -> AsyncIterator[uuid.UUID]:
131132
"""Insert a minimal EndpointRow (model deployment) and yield its UUID.
@@ -176,7 +177,7 @@ async def model_deployment_fixture(
176177
session_owner=str(admin_user_fixture.user_uuid),
177178
domain=domain_fixture.domain_name,
178179
project=str(group_fixture),
179-
resource_group=scaling_group_fixture.scaling_group_name,
180+
resource_group=scaling_group_name,
180181
lifecycle_stage=EndpointLifecycle.CREATED,
181182
url=None,
182183
)

tests/component/conftest.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@
152152
postgres_container,
153153
redis_container,
154154
)
155-
from ai.backend.testutils.fixtures import DomainFixtureData, ScalingGroupFixtureData
155+
from ai.backend.testutils.fixtures import DomainFixtureData
156156
from ai.backend.testutils.pants import get_parallel_slot
157157

158158
log = logging.getLogger("tests.component.conftest")
@@ -673,16 +673,15 @@ async def resource_policy_fixture(
673673

674674

675675
@pytest.fixture()
676-
async def scaling_group_fixture(
676+
async def scaling_group_name(
677677
db_engine: SAEngine,
678678
domain_fixture: DomainFixtureData,
679-
) -> AsyncIterator[ScalingGroupFixtureData]:
680-
"""Insert a scaling group and its domain association; yield its identifiers."""
679+
) -> AsyncIterator[ResourceGroupName]:
680+
"""Insert a scaling group and its domain association; yield its name."""
681681
sgroup_name = ResourceGroupName(f"sgroup-{secrets.token_hex(6)}")
682682
async with db_engine.begin() as conn:
683-
result = await conn.execute(
684-
sa.insert(scaling_groups)
685-
.values(
683+
await conn.execute(
684+
sa.insert(scaling_groups).values(
686685
name=sgroup_name,
687686
description=f"Test scaling group {sgroup_name}",
688687
is_active=True,
@@ -691,23 +690,34 @@ async def scaling_group_fixture(
691690
scheduler="fifo",
692691
scheduler_opts=ScalingGroupOpts(),
693692
)
694-
.returning(scaling_groups.c.id)
695693
)
696-
sgroup_id = ResourceGroupID(result.scalar_one())
697694
await conn.execute(
698695
sa.insert(sgroups_for_domains).values(
699696
scaling_group=sgroup_name,
700697
domain=domain_fixture.domain_name,
701698
)
702699
)
703-
yield ScalingGroupFixtureData(scaling_group_name=sgroup_name, scaling_group_id=sgroup_id)
700+
yield sgroup_name
704701
async with db_engine.begin() as conn:
705702
await conn.execute(
706703
sgroups_for_domains.delete().where(sgroups_for_domains.c.scaling_group == sgroup_name)
707704
)
708705
await conn.execute(scaling_groups.delete().where(scaling_groups.c.name == sgroup_name))
709706

710707

708+
@pytest.fixture()
709+
async def scaling_group_id(
710+
db_engine: SAEngine,
711+
scaling_group_name: ResourceGroupName,
712+
) -> ResourceGroupID:
713+
"""Return the inserted scaling group's ID."""
714+
async with db_engine.begin() as conn:
715+
result = await conn.execute(
716+
sa.select(scaling_groups.c.id).where(scaling_groups.c.name == scaling_group_name)
717+
)
718+
return ResourceGroupID(result.scalar_one())
719+
720+
711721
@pytest.fixture()
712722
async def group_fixture(
713723
db_engine: SAEngine,
@@ -936,7 +946,7 @@ async def regular_user_fixture(
936946
async def database_fixture(
937947
admin_user_fixture: UserFixtureData,
938948
regular_user_fixture: UserFixtureData,
939-
scaling_group_fixture: ScalingGroupFixtureData,
949+
scaling_group_name: ResourceGroupName,
940950
) -> AsyncIterator[None]:
941951
"""Backward-compatible aggregate: requests all seed data fixtures."""
942952
yield

0 commit comments

Comments
 (0)