Skip to content

Commit 0bc8cc2

Browse files
authored
Add CatalogItem.provider_data + reserved GCP A4 (#184)
Add a catalog item field for holding arbitrary provider-specific data. As an example, use this field for reserved GCP A4 instances. Below is a comparison of different field types that could be used for this field in both gpuhunt and dstack. This commit suggests using typed dicts. JSON string (parsing into a Pydantic model on the provider/backend side) Cons: - Inefficient, inconvenient, and error-prone for writing - unnecessary serialization and deserialization when updating an attribute of an already serialized object. - Possibility of incorrect usage (writing non-JSON data). `dict` / `TypedDict` Cons: - No validation, which means potential errors occur at attribute access time rather than at model loading time. Pydantic models with a `type` discriminator Cons: - Extra difficulties maintaining backward compatibility, as the models are passed from gpuhunt to dstack server, from server to client, and from client to server, all with validation. - Duplication of backend type in the backend-specific field and in other fields of the offer or catalog (e.g., `InstanceOffer.backend` and `InstanceOffer.backend_data.type`). - Discriminators require declaring all possible discriminator values, which in the future will hinder the transition to a more modular architecture with backend plugins. - Backward compatibility issues when a new discriminator value (a new backend) is introduced. Pydantic models + custom deserialization logic (e.g., custom `InstanceOffer` deserializer that determines the `InstanceOffer.backend_data` model based on InstanceOffer.backend) Cons: - Extra difficulties maintaining backward compatibility as the models are passed from gpuhunt to dstack server, from server to client, and from client to server, all with validation. - The need to duplicate deserialization logic in all model that hold the backend-specific field - at least in RawCatalogItem and `InstanceOffer`.
1 parent 95122e6 commit 0bc8cc2

3 files changed

Lines changed: 76 additions & 25 deletions

File tree

src/gpuhunt/_internal/models.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import enum
2+
import json
23
from collections.abc import Container
34
from dataclasses import asdict, dataclass, field, fields
45
from typing import (
@@ -9,6 +10,17 @@
910

1011
from gpuhunt._internal.utils import empty_as_none
1112

13+
JSONType = Union[
14+
None,
15+
bool,
16+
int,
17+
float,
18+
str,
19+
list["JSONType"],
20+
"JSONObject",
21+
]
22+
JSONObject = dict[str, JSONType]
23+
1224

1325
def bool_loader(x: Union[bool, str]) -> bool:
1426
if isinstance(x, bool):
@@ -73,6 +85,7 @@ class RawCatalogItem:
7385
gpu_vendor: Optional[str] = None
7486
flags: list[str] = field(default_factory=list)
7587
cpu_arch: Optional[str] = None
88+
provider_data: JSONObject = field(default_factory=dict)
7689

7790
def __post_init__(self) -> None:
7891
self._process_gpu_vendor()
@@ -121,12 +134,14 @@ def from_dict(v: dict) -> "RawCatalogItem":
121134
spot=empty_as_none(v.get("spot"), loader=bool_loader),
122135
disk_size=empty_as_none(v.get("disk_size"), loader=float),
123136
flags=v.get("flags", "").split(),
137+
provider_data=json.loads(v.get("provider_data", "{}")),
124138
)
125139

126140
def dict(self) -> dict[str, Union[str, int, float, bool, None]]:
127141
return {
128142
**asdict(self),
129143
"flags": " ".join(self.flags),
144+
"provider_data": json.dumps(self.provider_data),
130145
}
131146

132147

@@ -153,6 +168,8 @@ class CatalogItem:
153168
will have to request this flag explicitly to get the catalog item.
154169
If you are adding a new provider, leave the flags empty.
155170
Flag names should be in kebab-case.
171+
provider_data: dict with provider-specific properties.
172+
Prefer defining a TypedDict within provider implementation.
156173
"""
157174

158175
instance_name: str
@@ -169,6 +186,7 @@ class CatalogItem:
169186
gpu_vendor: Optional[AcceleratorVendor] = None
170187
flags: list[str] = field(default_factory=list)
171188
cpu_arch: Optional[CPUArchitecture] = None
189+
provider_data: JSONObject = field(default_factory=dict)
172190

173191
def __post_init__(self) -> None:
174192
self._process_gpu_vendor()

src/gpuhunt/providers/gcp.py

Lines changed: 56 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import enum
23
import importlib.resources
34
import json
45
import logging
@@ -7,14 +8,15 @@
78
from collections.abc import Iterable
89
from concurrent.futures import ThreadPoolExecutor, as_completed
910
from dataclasses import dataclass
10-
from typing import Optional
11+
from typing import Optional, cast
1112

1213
import google.cloud.billing_v1 as billing_v1
1314
import google.cloud.compute_v1 as compute_v1
1415
from google.cloud import tpu_v2
1516
from google.cloud.billing_v1 import CloudCatalogClient, ListSkusRequest
1617
from google.cloud.billing_v1.types.cloud_catalog import Sku
1718
from google.cloud.location import locations_pb2
19+
from typing_extensions import NotRequired, TypedDict
1820

1921
from gpuhunt._internal.models import AcceleratorVendor, QueryFilter, RawCatalogItem
2022
from gpuhunt.providers import AbstractProvider
@@ -243,14 +245,17 @@ def fill_prices(self, instances: list[RawCatalogItem]) -> list[RawCatalogItem]:
243245

244246
offers = []
245247
for instance in instances:
246-
for spot in (False, True):
247-
price = prices.get_instance_price(instance, spot)
248+
for capacity_type in CapacityType:
249+
price = prices.get_instance_price(instance, capacity_type)
248250
if price is None:
249251
continue
250252

251253
offer = copy.deepcopy(instance)
252254
offer.price = round(price, 6)
253-
offer.spot = spot
255+
offer.spot = capacity_type is CapacityType.SPOT
256+
cast(GCPCatalogItemProviderData, offer.provider_data)["is_dws_calendar_mode"] = (
257+
capacity_type is CapacityType.DWS_CALENDAR_MODE
258+
)
254259
offers.append(offer)
255260
return offers
256261

@@ -307,6 +312,9 @@ def filter(cls, offers: list[RawCatalogItem]) -> list[RawCatalogItem]:
307312
# Filter out on-demand offers that are not actually available on demand.
308313
# https://cloud.google.com/compute/docs/accelerator-optimized-machines#consumption_option_availability_by_machine_type
309314
i.spot == False
315+
and not cast(GCPCatalogItemProviderData, i.provider_data).get(
316+
"is_dws_calendar_mode"
317+
)
310318
and (
311319
i.instance_name.startswith("a4x-")
312320
or i.instance_name.startswith("a4-")
@@ -320,17 +328,27 @@ def filter(cls, offers: list[RawCatalogItem]) -> list[RawCatalogItem]:
320328
]
321329

322330

323-
RegionSpot = tuple[str, bool]
324-
PricePerRegionSpot = dict[RegionSpot, float]
331+
class GCPCatalogItemProviderData(TypedDict):
332+
is_dws_calendar_mode: NotRequired[bool]
333+
334+
335+
class CapacityType(enum.Enum):
336+
ON_DEMAND = enum.auto()
337+
SPOT = enum.auto()
338+
DWS_CALENDAR_MODE = enum.auto()
339+
340+
341+
RegionCapacityType = tuple[str, CapacityType]
342+
PricePerRegionCapacityType = dict[RegionCapacityType, float]
325343

326344

327345
class Prices:
328346
def __init__(self):
329-
self.cpu: defaultdict[str, PricePerRegionSpot] = defaultdict(dict)
330-
self.gpu: defaultdict[str, PricePerRegionSpot] = defaultdict(dict)
331-
self.ram: defaultdict[str, PricePerRegionSpot] = defaultdict(dict)
332-
self.local_ssd: PricePerRegionSpot = dict()
333-
self.gpu_slice: defaultdict[str, PricePerRegionSpot] = defaultdict(dict)
347+
self.cpu: defaultdict[str, PricePerRegionCapacityType] = defaultdict(dict)
348+
self.gpu: defaultdict[str, PricePerRegionCapacityType] = defaultdict(dict)
349+
self.ram: defaultdict[str, PricePerRegionCapacityType] = defaultdict(dict)
350+
self.local_ssd: PricePerRegionCapacityType = dict()
351+
self.gpu_slice: defaultdict[str, PricePerRegionCapacityType] = defaultdict(dict)
334352

335353
def add_skus(self, skus: Iterable[Sku]) -> None:
336354
for sku in skus:
@@ -395,7 +413,9 @@ def add_compute_sku(self, sku: Sku) -> None:
395413
self._add_price(sku, resource_prices[family], price)
396414

397415
def add_compute_gpu_slice_sku(self, sku: Sku) -> None:
398-
if sku.description.startswith("Spot Preemptible A4 Nvidia B200"):
416+
if sku.description.startswith(
417+
"Spot Preemptible A4 Nvidia B200"
418+
) or sku.description.startswith("DWS Calendar Mode A4 Nvidia B200"):
399419
gpu = "nvidia-b200"
400420
else:
401421
return
@@ -413,34 +433,43 @@ def _calculate_sku_price(sku: Sku) -> float:
413433
return price.units + price.nanos / 1e9
414434

415435
@staticmethod
416-
def _add_price(sku: Sku, family_prices: PricePerRegionSpot, price: float) -> None:
417-
spot = sku.category.usage_type == "Preemptible"
436+
def _add_price(sku: Sku, family_prices: PricePerRegionCapacityType, price: float) -> None:
437+
if sku.category.usage_type == "Preemptible":
438+
capacity_type = CapacityType.SPOT
439+
elif "DWS Calendar Mode" in sku.description:
440+
capacity_type = CapacityType.DWS_CALENDAR_MODE
441+
else:
442+
capacity_type = CapacityType.ON_DEMAND
418443
for region in sku.service_regions:
419-
family_prices[(region, spot)] = price
444+
family_prices[(region, capacity_type)] = price
420445

421-
def get_instance_price(self, instance: RawCatalogItem, spot: bool) -> Optional[float]:
446+
def get_instance_price(
447+
self, instance: RawCatalogItem, capacity_type: CapacityType
448+
) -> Optional[float]:
422449
vm_family = self.get_vm_family(instance.instance_name)
423450
if vm_family in ["g1", "f1", "m2"]: # shared-core and reservation-only
424451
return None
425452

426-
region_spot = (instance.location[:-2], spot)
453+
region_capacity_type = (instance.location[:-2], capacity_type)
427454

428455
# For some instances, the price is proportional to the number of GPUs
429-
if instance.gpu_name and region_spot in self.gpu_slice[instance.gpu_name]:
430-
return instance.gpu_count * self.gpu_slice[instance.gpu_name][region_spot]
456+
if instance.gpu_name and region_capacity_type in self.gpu_slice[instance.gpu_name]:
457+
return instance.gpu_count * self.gpu_slice[instance.gpu_name][region_capacity_type]
431458

432459
# For others, the price consists of several components
433460
price = 0
434-
if region_spot not in self.cpu[vm_family]:
461+
if region_capacity_type not in self.cpu[vm_family]:
435462
return None
436-
price += instance.cpu * self.cpu[vm_family][region_spot]
437-
price += instance.memory * self.ram[vm_family][region_spot]
463+
price += instance.cpu * self.cpu[vm_family][region_capacity_type]
464+
price += instance.memory * self.ram[vm_family][region_capacity_type]
438465
if instance.gpu_name:
439-
if region_spot not in self.gpu[instance.gpu_name]:
466+
if region_capacity_type not in self.gpu[instance.gpu_name]:
440467
return None
441-
price += instance.gpu_count * self.gpu[instance.gpu_name][region_spot]
468+
price += instance.gpu_count * self.gpu[instance.gpu_name][region_capacity_type]
442469
if instance.instance_name in local_ssd_sizes_gib:
443-
price += local_ssd_sizes_gib[instance.instance_name] * self.local_ssd[region_spot]
470+
price += (
471+
local_ssd_sizes_gib[instance.instance_name] * self.local_ssd[region_capacity_type]
472+
)
444473

445474
return price
446475

@@ -454,6 +483,8 @@ def get_vm_family(instance_name: str) -> str:
454483

455484
def set_flags(catalog_items: list[RawCatalogItem]) -> None:
456485
for item in catalog_items:
486+
if cast(GCPCatalogItemProviderData, item.provider_data).get("is_dws_calendar_mode"):
487+
item.flags.append("gcp-dws-calendar-mode")
457488
if item.instance_name.startswith("a4-"):
458489
item.flags.append("gcp-a4")
459490
elif item.instance_name.startswith("g4-standard-") and item.price == 0:

src/tests/_internal/test_models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def test_raw_catalog_item_to_from_dict() -> None:
151151
spot=False,
152152
disk_size=100.0,
153153
flags=["f1", "f2", "f3"],
154+
provider_data={"custom_prop": 42},
154155
)
155156
item_dict = item.dict()
156157
assert item_dict == {
@@ -167,5 +168,6 @@ def test_raw_catalog_item_to_from_dict() -> None:
167168
"spot": False,
168169
"disk_size": 100.0,
169170
"flags": "f1 f2 f3",
171+
"provider_data": '{"custom_prop": 42}',
170172
}
171173
assert RawCatalogItem.from_dict(item_dict) == item

0 commit comments

Comments
 (0)