Skip to content

Commit 2a0d7f5

Browse files
authored
Allow configuring Vast.ai offer order (#234)
1 parent 69496b8 commit 2a0d7f5

2 files changed

Lines changed: 9 additions & 11 deletions

File tree

src/gpuhunt/providers/vastai.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
import re
44
from collections import defaultdict
5+
from collections.abc import Iterable
56
from typing import Any, Literal
67

78
import requests
@@ -25,16 +26,16 @@ def __init__(
2526
self,
2627
extra_filters: dict[str, dict[Operators, FilterValue]] | None = None,
2728
community_cloud: bool = True,
29+
order: Iterable[tuple[str, str]] = [("score", "desc")],
2830
):
2931
self.extra_filters = extra_filters
3032
self.community_cloud = community_cloud
33+
self.order = list(order)
3134

3235
def get(
3336
self, query_filter: QueryFilter | None = None, balance_resources: bool = True
3437
) -> list[RawCatalogItem]:
35-
filters: dict[str, Any] = self.make_filters(
36-
query_filter or QueryFilter(), community_cloud=self.community_cloud
37-
)
38+
filters: dict[str, Any] = self.make_filters(query_filter or QueryFilter())
3839
if self.extra_filters:
3940
for key, constraints in self.extra_filters.items():
4041
for op, value in constraints.items():
@@ -85,10 +86,7 @@ def get(
8586
instance_offers.append(spot_offer)
8687
return instance_offers
8788

88-
@staticmethod
89-
def make_filters(
90-
q: QueryFilter, community_cloud: bool = True
91-
) -> dict[str, dict[Operators, FilterValue]]:
89+
def make_filters(self, q: QueryFilter) -> dict[str, dict[Operators, FilterValue]]:
9290
filters = defaultdict(dict)
9391
if q.min_cpu is not None:
9492
filters["cpu_cores"]["gte"] = q.min_cpu
@@ -132,11 +130,11 @@ def make_filters(
132130
# Datacenter offers map to Vast's "server cloud" scope.
133131
# When community_cloud is enabled, keep scope unfiltered so both
134132
# server and community offers are returned.
135-
if not community_cloud:
133+
if not self.community_cloud:
136134
filters["datacenter"]["eq"] = True
137135
filters["rentable"]["eq"] = True
138136
filters["rented"]["eq"] = False
139-
filters["order"] = [["score", "desc"]]
137+
filters["order"] = self.order
140138
return filters
141139

142140
@staticmethod

src/tests/providers/test_vastai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33

44

55
def test_make_filters_defaults_to_datacenter_only():
6-
filters = VastAIProvider.make_filters(QueryFilter(), community_cloud=False)
6+
filters = VastAIProvider(community_cloud=False).make_filters(QueryFilter())
77
assert filters["datacenter"]["eq"] is True
88
assert "external" not in filters
99

1010

1111
def test_make_filters_does_not_constrain_scope_when_community_cloud_enabled():
12-
filters = VastAIProvider.make_filters(QueryFilter(), community_cloud=True)
12+
filters = VastAIProvider(community_cloud=True).make_filters(QueryFilter())
1313
assert "datacenter" not in filters
1414
assert "external" not in filters

0 commit comments

Comments
 (0)