152152 postgres_container ,
153153 redis_container ,
154154)
155- from ai .backend .testutils .fixtures import DomainFixtureData , ScalingGroupFixtureData
155+ from ai .backend .testutils .fixtures import DomainFixtureData
156156from ai .backend .testutils .pants import get_parallel_slot
157157
158158log = logging .getLogger ("tests.component.conftest" )
@@ -673,16 +673,15 @@ async def resource_policy_fixture(
673673
674674
675675@pytest .fixture ()
676- async def scaling_group_fixture (
676+ async def scaling_group_name (
677677 db_engine : SAEngine ,
678678 domain_fixture : DomainFixtureData ,
679- ) -> AsyncIterator [ScalingGroupFixtureData ]:
680- """Insert a scaling group and its domain association; yield its identifiers ."""
679+ ) -> AsyncIterator [ResourceGroupName ]:
680+ """Insert a scaling group and its domain association; yield its name ."""
681681 sgroup_name = ResourceGroupName (f"sgroup-{ secrets .token_hex (6 )} " )
682682 async with db_engine .begin () as conn :
683- result = await conn .execute (
684- sa .insert (scaling_groups )
685- .values (
683+ await conn .execute (
684+ sa .insert (scaling_groups ).values (
686685 name = sgroup_name ,
687686 description = f"Test scaling group { sgroup_name } " ,
688687 is_active = True ,
@@ -691,23 +690,34 @@ async def scaling_group_fixture(
691690 scheduler = "fifo" ,
692691 scheduler_opts = ScalingGroupOpts (),
693692 )
694- .returning (scaling_groups .c .id )
695693 )
696- sgroup_id = ResourceGroupID (result .scalar_one ())
697694 await conn .execute (
698695 sa .insert (sgroups_for_domains ).values (
699696 scaling_group = sgroup_name ,
700697 domain = domain_fixture .domain_name ,
701698 )
702699 )
703- yield ScalingGroupFixtureData ( scaling_group_name = sgroup_name , scaling_group_id = sgroup_id )
700+ yield sgroup_name
704701 async with db_engine .begin () as conn :
705702 await conn .execute (
706703 sgroups_for_domains .delete ().where (sgroups_for_domains .c .scaling_group == sgroup_name )
707704 )
708705 await conn .execute (scaling_groups .delete ().where (scaling_groups .c .name == sgroup_name ))
709706
710707
708+ @pytest .fixture ()
709+ async def scaling_group_id (
710+ db_engine : SAEngine ,
711+ scaling_group_name : ResourceGroupName ,
712+ ) -> ResourceGroupID :
713+ """Return the inserted scaling group's ID."""
714+ async with db_engine .begin () as conn :
715+ result = await conn .execute (
716+ sa .select (scaling_groups .c .id ).where (scaling_groups .c .name == scaling_group_name )
717+ )
718+ return ResourceGroupID (result .scalar_one ())
719+
720+
711721@pytest .fixture ()
712722async def group_fixture (
713723 db_engine : SAEngine ,
@@ -936,7 +946,7 @@ async def regular_user_fixture(
936946async def database_fixture (
937947 admin_user_fixture : UserFixtureData ,
938948 regular_user_fixture : UserFixtureData ,
939- scaling_group_fixture : ScalingGroupFixtureData ,
949+ scaling_group_name : ResourceGroupName ,
940950) -> AsyncIterator [None ]:
941951 """Backward-compatible aggregate: requests all seed data fixtures."""
942952 yield
0 commit comments