@@ -1535,22 +1535,146 @@ def _get_gpus_grouped_by_backend_and_region(backend_gpus: List[BackendGpus]) ->
15351535 )
15361536
15371537
1538+ def _get_gpus_grouped_by_count (backend_gpus : List [BackendGpus ]) -> List [GpuGroup ]:
1539+ """Aggregates GPU specs, grouping them by GPU count."""
1540+ gpu_rows : Dict [Tuple , GpuGroup ] = {}
1541+ for backend in backend_gpus :
1542+ for gpu in backend .gpus :
1543+ key = (gpu .name , gpu .memory_mib , gpu .vendor , gpu .count )
1544+ if key not in gpu_rows :
1545+ gpu_rows [key ] = GpuGroup (
1546+ name = gpu .name ,
1547+ memory_mib = gpu .memory_mib ,
1548+ vendor = gpu .vendor ,
1549+ availability = [gpu .availability ],
1550+ spot = ["spot" if gpu .spot else "on-demand" ],
1551+ count = Range [int ](min = gpu .count , max = gpu .count ),
1552+ price = Range [float ](min = gpu .price , max = gpu .price ),
1553+ backends = [backend .backend_type ],
1554+ )
1555+ else :
1556+ _update_gpu_group (gpu_rows [key ], gpu , backend .backend_type )
1557+
1558+ return sorted (
1559+ list (gpu_rows .values ()),
1560+ key = lambda g : (
1561+ not any (av .is_available () for av in g .availability ),
1562+ g .price .min ,
1563+ g .price .max ,
1564+ g .count .min ,
1565+ g .name ,
1566+ g .memory_mib ,
1567+ ),
1568+ )
1569+
1570+
1571+ def _get_gpus_grouped_by_backend_and_count (backend_gpus : List [BackendGpus ]) -> List [GpuGroup ]:
1572+ """Aggregates GPU specs, grouping them by backend and GPU count."""
1573+ gpu_rows : Dict [Tuple , GpuGroup ] = {}
1574+ for backend in backend_gpus :
1575+ for gpu in backend .gpus :
1576+ key = (gpu .name , gpu .memory_mib , gpu .vendor , backend .backend_type , gpu .count )
1577+ if key not in gpu_rows :
1578+ gpu_rows [key ] = GpuGroup (
1579+ name = gpu .name ,
1580+ memory_mib = gpu .memory_mib ,
1581+ vendor = gpu .vendor ,
1582+ availability = [gpu .availability ],
1583+ spot = ["spot" if gpu .spot else "on-demand" ],
1584+ count = Range [int ](min = gpu .count , max = gpu .count ),
1585+ price = Range [float ](min = gpu .price , max = gpu .price ),
1586+ backend = backend .backend_type ,
1587+ regions = backend .regions .copy (),
1588+ )
1589+ else :
1590+ _update_gpu_group (gpu_rows [key ], gpu , backend .backend_type )
1591+
1592+ return sorted (
1593+ list (gpu_rows .values ()),
1594+ key = lambda g : (
1595+ not any (av .is_available () for av in g .availability ),
1596+ g .price .min ,
1597+ g .price .max ,
1598+ g .backend .value ,
1599+ g .count .min ,
1600+ g .name ,
1601+ g .memory_mib ,
1602+ ),
1603+ )
1604+
1605+
1606+ def _get_gpus_grouped_by_backend_region_and_count (
1607+ backend_gpus : List [BackendGpus ],
1608+ ) -> List [GpuGroup ]:
1609+ """Aggregates GPU specs, grouping them by backend, region, and GPU count."""
1610+ gpu_rows : Dict [Tuple , GpuGroup ] = {}
1611+ for backend in backend_gpus :
1612+ for region in backend .regions :
1613+ for gpu in backend .gpus :
1614+ key = (
1615+ gpu .name ,
1616+ gpu .memory_mib ,
1617+ gpu .vendor ,
1618+ backend .backend_type ,
1619+ region ,
1620+ gpu .count ,
1621+ )
1622+ if key not in gpu_rows :
1623+ gpu_rows [key ] = GpuGroup (
1624+ name = gpu .name ,
1625+ memory_mib = gpu .memory_mib ,
1626+ vendor = gpu .vendor ,
1627+ availability = [gpu .availability ],
1628+ spot = ["spot" if gpu .spot else "on-demand" ],
1629+ count = Range [int ](min = gpu .count , max = gpu .count ),
1630+ price = Range [float ](min = gpu .price , max = gpu .price ),
1631+ backend = backend .backend_type ,
1632+ region = region ,
1633+ )
1634+ else :
1635+ _update_gpu_group (gpu_rows [key ], gpu , backend .backend_type )
1636+
1637+ return sorted (
1638+ list (gpu_rows .values ()),
1639+ key = lambda g : (
1640+ not any (av .is_available () for av in g .availability ),
1641+ g .price .min ,
1642+ g .price .max ,
1643+ g .backend .value ,
1644+ g .region ,
1645+ g .count .min ,
1646+ g .name ,
1647+ g .memory_mib ,
1648+ ),
1649+ )
1650+
1651+
15381652async def get_run_gpus_grouped (
15391653 session : AsyncSession ,
15401654 project : ProjectModel ,
15411655 run_spec : RunSpec ,
1542- group_by : Optional [List [Literal ["backend" , "region" ]]] = None ,
1656+ group_by : Optional [List [Literal ["backend" , "region" , "count" ]]] = None ,
15431657) -> RunGpusResponse :
15441658 """Retrieves available GPU specifications based on a run spec, with optional grouping."""
15451659 offers = await _get_gpu_offers (session , project , run_spec )
15461660 backend_gpus = _process_offers_into_backend_gpus (offers )
15471661
15481662 group_by_set = set (group_by ) if group_by else set ()
15491663
1550- if "backend" in group_by_set and "region" in group_by_set :
1664+ # Determine grouping strategy based on combination
1665+ has_backend = "backend" in group_by_set
1666+ has_region = "region" in group_by_set
1667+ has_count = "count" in group_by_set
1668+ if has_backend and has_region and has_count :
1669+ gpus = _get_gpus_grouped_by_backend_region_and_count (backend_gpus )
1670+ elif has_backend and has_count :
1671+ gpus = _get_gpus_grouped_by_backend_and_count (backend_gpus )
1672+ elif has_backend and has_region :
15511673 gpus = _get_gpus_grouped_by_backend_and_region (backend_gpus )
1552- elif "backend" in group_by_set :
1674+ elif has_backend :
15531675 gpus = _get_gpus_grouped_by_backend (backend_gpus )
1676+ elif has_count :
1677+ gpus = _get_gpus_grouped_by_count (backend_gpus )
15541678 else :
15551679 gpus = _get_gpus_with_no_grouping (backend_gpus )
15561680
0 commit comments