Skip to content

Commit b0e159b

Browse files
[Feature]: Support dstack offer #2142 (#2540)
1 parent 4b4c55e commit b0e159b

9 files changed

Lines changed: 187 additions & 44 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ dependencies = [
3232
"python-multipart>=0.0.16",
3333
"filelock",
3434
"psutil",
35-
"gpuhunt>=0.1.2,<0.2.0",
35+
"gpuhunt>=0.1.3,<0.2.0",
3636
"argcomplete>=3.5.0",
3737
]
3838

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import argparse
2+
import contextlib
3+
import json
4+
from pathlib import Path
5+
6+
from dstack._internal.cli.commands import APIBaseCommand
7+
from dstack._internal.cli.services.configurators.run import (
8+
BaseRunConfigurator,
9+
)
10+
from dstack._internal.cli.utils.common import console
11+
from dstack._internal.cli.utils.run import print_run_plan
12+
from dstack._internal.core.models.configurations import (
13+
ApplyConfigurationType,
14+
TaskConfiguration,
15+
)
16+
from dstack._internal.core.models.runs import RunSpec
17+
from dstack.api.utils import load_profile
18+
19+
20+
class OfferConfigurator(BaseRunConfigurator):
21+
# TODO: The command currently uses `BaseRunConfigurator` to register arguments.
22+
# This includes --env, --retry-policy, and other arguments that are unnecessary for this command.
23+
# Eventually, we should introduce a base `OfferConfigurator` that doesn't include those arguments—
24+
# `BaseRunConfigurator` will inherit from `OfferConfigurator`.
25+
#
26+
# Additionally, it should have its own type: `ApplyConfigurationType.OFFER`.
27+
TYPE = ApplyConfigurationType.TASK
28+
29+
@classmethod
30+
def register_args(
31+
cls,
32+
parser: argparse.ArgumentParser,
33+
):
34+
super().register_args(parser, default_max_offers=50)
35+
36+
37+
# TODO: Support aggregated offers
38+
# TODO: Add tests
39+
class OfferCommand(APIBaseCommand):
40+
NAME = "offer"
41+
DESCRIPTION = "List offers"
42+
43+
def _register(self):
44+
super()._register()
45+
self._parser.add_argument(
46+
"--format",
47+
choices=["plain", "json"],
48+
default="plain",
49+
help="Output format (default: plain)",
50+
)
51+
self._parser.add_argument(
52+
"--json",
53+
action="store_const",
54+
const="json",
55+
dest="format",
56+
help="Output in JSON format (equivalent to --format json)",
57+
)
58+
OfferConfigurator.register_args(self._parser)
59+
60+
def _command(self, args: argparse.Namespace):
61+
super()._command(args)
62+
conf = TaskConfiguration(commands=[":"])
63+
64+
configurator = OfferConfigurator(api_client=self.api)
65+
configurator.apply_args(conf, args, [])
66+
profile = load_profile(Path.cwd(), profile_name=args.profile)
67+
68+
run_spec = RunSpec(
69+
configuration=conf,
70+
ssh_key_pub="(dummy)",
71+
profile=profile,
72+
)
73+
if args.format == "plain":
74+
status = console.status("Getting offers...")
75+
else:
76+
status = contextlib.nullcontext()
77+
with status:
78+
run_plan = self.api.client.runs.get_plan(
79+
self.api.project,
80+
run_spec,
81+
max_offers=args.max_offers,
82+
)
83+
84+
job_plan = run_plan.job_plans[0]
85+
86+
if args.format == "json":
87+
output = {
88+
"project": run_plan.project_name,
89+
"user": run_plan.user,
90+
"resources": job_plan.job_spec.requirements.resources.dict(),
91+
"max_price": (job_plan.job_spec.requirements.max_price),
92+
"spot": run_spec.configuration.spot_policy,
93+
"reservation": run_plan.run_spec.configuration.reservation,
94+
"offers": [],
95+
"total_offers": job_plan.total_offers,
96+
}
97+
98+
for offer in job_plan.offers:
99+
output["offers"].append(
100+
{
101+
"backend": (
102+
"ssh" if offer.backend.value == "remote" else offer.backend.value
103+
),
104+
"region": offer.region,
105+
"instance_type": offer.instance.name,
106+
"resources": offer.instance.resources.dict(),
107+
"spot": offer.instance.resources.spot,
108+
"price": float(offer.price),
109+
"availability": offer.availability.value,
110+
}
111+
)
112+
113+
print(json.dumps(output, indent=2))
114+
return
115+
else:
116+
print_run_plan(run_plan, include_run_properties=False)

src/dstack/_internal/cli/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from dstack._internal.cli.commands.init import InitCommand
1515
from dstack._internal.cli.commands.logs import LogsCommand
1616
from dstack._internal.cli.commands.metrics import MetricsCommand
17+
from dstack._internal.cli.commands.offer import OfferCommand
1718
from dstack._internal.cli.commands.ps import PsCommand
1819
from dstack._internal.cli.commands.server import ServerCommand
1920
from dstack._internal.cli.commands.stats import StatsCommand
@@ -65,6 +66,7 @@ def main():
6566
FleetCommand.register(subparsers)
6667
GatewayCommand.register(subparsers)
6768
InitCommand.register(subparsers)
69+
OfferCommand.register(subparsers)
6870
LogsCommand.register(subparsers)
6971
MetricsCommand.register(subparsers)
7072
PsCommand.register(subparsers)

src/dstack/_internal/cli/services/configurators/run.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def apply_configuration(
9292
profile=profile,
9393
)
9494

95-
print_run_plan(run_plan, offers_limit=configurator_args.max_offers)
95+
print_run_plan(run_plan, max_offers=configurator_args.max_offers)
9696

9797
confirm_message = "Submit a new run?"
9898
stop_run_name = None
@@ -274,7 +274,7 @@ def delete_configuration(
274274
console.print(f"Run [code]{conf.name}[/] deleted")
275275

276276
@classmethod
277-
def register_args(cls, parser: argparse.ArgumentParser):
277+
def register_args(cls, parser: argparse.ArgumentParser, default_max_offers: int = 3):
278278
configuration_group = parser.add_argument_group(f"{cls.TYPE.value} Options")
279279
configuration_group.add_argument(
280280
"-n",
@@ -286,7 +286,7 @@ def register_args(cls, parser: argparse.ArgumentParser):
286286
"--max-offers",
287287
help="Number of offers to show in the run plan",
288288
type=int,
289-
default=3,
289+
default=default_max_offers,
290290
)
291291
cls.register_env_args(configuration_group)
292292
configuration_group.add_argument(

src/dstack/_internal/cli/utils/run.py

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Union
1+
from typing import Any, Dict, List, Optional, Union
22

33
from rich.markup import escape
44
from rich.table import Table
@@ -24,7 +24,9 @@
2424
from dstack.api import Run
2525

2626

27-
def print_run_plan(run_plan: RunPlan, offers_limit: int = 3):
27+
def print_run_plan(
28+
run_plan: RunPlan, max_offers: Optional[int] = None, include_run_properties: bool = True
29+
):
2830
job_plan = run_plan.job_plans[0]
2931

3032
props = Table(box=None, show_header=False)
@@ -39,29 +41,30 @@ def print_run_plan(run_plan: RunPlan, offers_limit: int = 3):
3941
if job_plan.job_spec.max_duration
4042
else "-"
4143
)
42-
inactivity_duration = None
43-
if isinstance(run_plan.run_spec.configuration, DevEnvironmentConfiguration):
44-
inactivity_duration = "-"
45-
if isinstance(run_plan.run_spec.configuration.inactivity_duration, int):
46-
inactivity_duration = format_pretty_duration(
47-
run_plan.run_spec.configuration.inactivity_duration
48-
)
49-
if job_plan.job_spec.retry is None:
50-
retry = "-"
51-
else:
52-
retry = escape(job_plan.job_spec.retry.pretty_format())
53-
54-
profile = run_plan.run_spec.merged_profile
55-
creation_policy = profile.creation_policy
56-
# FIXME: This assumes the default idle_duration is the same for client and server.
57-
# If the server changes idle_duration, old clients will see incorrect value.
58-
termination_policy, termination_idle_time = get_termination(
59-
profile, DEFAULT_RUN_TERMINATION_IDLE_TIME
60-
)
61-
if termination_policy == TerminationPolicy.DONT_DESTROY:
62-
idle_duration = "-"
63-
else:
64-
idle_duration = format_pretty_duration(termination_idle_time)
44+
if include_run_properties:
45+
inactivity_duration = None
46+
if isinstance(run_plan.run_spec.configuration, DevEnvironmentConfiguration):
47+
inactivity_duration = "-"
48+
if isinstance(run_plan.run_spec.configuration.inactivity_duration, int):
49+
inactivity_duration = format_pretty_duration(
50+
run_plan.run_spec.configuration.inactivity_duration
51+
)
52+
if job_plan.job_spec.retry is None:
53+
retry = "-"
54+
else:
55+
retry = escape(job_plan.job_spec.retry.pretty_format())
56+
57+
profile = run_plan.run_spec.merged_profile
58+
creation_policy = profile.creation_policy
59+
# FIXME: This assumes the default idle_duration is the same for client and server.
60+
# If the server changes idle_duration, old clients will see incorrect value.
61+
termination_policy, termination_idle_time = get_termination(
62+
profile, DEFAULT_RUN_TERMINATION_IDLE_TIME
63+
)
64+
if termination_policy == TerminationPolicy.DONT_DESTROY:
65+
idle_duration = "-"
66+
else:
67+
idle_duration = format_pretty_duration(termination_idle_time)
6568

6669
if req.spot is None:
6770
spot_policy = "auto"
@@ -75,30 +78,32 @@ def th(s: str) -> str:
7578

7679
props.add_row(th("Project"), run_plan.project_name)
7780
props.add_row(th("User"), run_plan.user)
78-
props.add_row(th("Configuration"), run_plan.run_spec.configuration_path)
79-
props.add_row(th("Type"), run_plan.run_spec.configuration.type)
81+
if include_run_properties:
82+
props.add_row(th("Configuration"), run_plan.run_spec.configuration_path)
83+
props.add_row(th("Type"), run_plan.run_spec.configuration.type)
8084
props.add_row(th("Resources"), pretty_req)
81-
props.add_row(th("Max price"), max_price)
82-
props.add_row(th("Max duration"), max_duration)
83-
if inactivity_duration is not None: # None means n/a
84-
props.add_row(th("Inactivity duration"), inactivity_duration)
8585
props.add_row(th("Spot policy"), spot_policy)
86-
props.add_row(th("Retry policy"), retry)
87-
props.add_row(th("Creation policy"), creation_policy)
88-
props.add_row(th("Idle duration"), idle_duration)
86+
props.add_row(th("Max price"), max_price)
87+
if include_run_properties:
88+
props.add_row(th("Retry policy"), retry)
89+
props.add_row(th("Creation policy"), creation_policy)
90+
props.add_row(th("Idle duration"), idle_duration)
91+
props.add_row(th("Max duration"), max_duration)
92+
if inactivity_duration is not None: # None means n/a
93+
props.add_row(th("Inactivity duration"), inactivity_duration)
8994
props.add_row(th("Reservation"), run_plan.run_spec.configuration.reservation or "-")
9095

9196
offers = Table(box=None)
9297
offers.add_column("#")
9398
offers.add_column("BACKEND")
9499
offers.add_column("REGION")
95-
offers.add_column("INSTANCE")
100+
offers.add_column("INSTANCE TYPE")
96101
offers.add_column("RESOURCES")
97102
offers.add_column("SPOT")
98103
offers.add_column("PRICE")
99104
offers.add_column()
100105

101-
job_plan.offers = job_plan.offers[:offers_limit]
106+
job_plan.offers = job_plan.offers[:max_offers] if max_offers else job_plan.offers
102107

103108
for i, offer in enumerate(job_plan.offers, start=1):
104109
r = offer.instance.resources

src/dstack/_internal/server/routers/runs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ async def get_plan(
100100
project=project,
101101
user=user,
102102
run_spec=body.run_spec,
103+
max_offers=body.max_offers,
103104
)
104105
return run_plan
105106

src/dstack/_internal/server/schemas/runs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ class GetRunRequest(CoreModel):
2626

2727
class GetRunPlanRequest(CoreModel):
2828
run_spec: RunSpec
29+
max_offers: Optional[int] = Field(
30+
description="The maximum number of offers to return", ge=1, le=10000
31+
)
2932

3033

3134
class SubmitRunRequest(CoreModel):

src/dstack/_internal/server/services/runs.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@
9292
JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY,
9393
}
9494

95+
DEFAULT_MAX_OFFERS = 50
96+
9597

9698
async def list_user_runs(
9799
session: AsyncSession,
@@ -275,6 +277,7 @@ async def get_plan(
275277
project: ProjectModel,
276278
user: UserModel,
277279
run_spec: RunSpec,
280+
max_offers: Optional[int],
278281
) -> RunPlan:
279282
_validate_run_spec_and_set_defaults(run_spec)
280283

@@ -342,7 +345,7 @@ async def get_plan(
342345

343346
job_plan = JobPlan(
344347
job_spec=job_spec,
345-
offers=job_offers[:50],
348+
offers=job_offers[: (max_offers or DEFAULT_MAX_OFFERS)],
346349
total_offers=len(job_offers),
347350
max_price=max((offer.price for offer in job_offers), default=None),
348351
)

src/dstack/api/server/_runs.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,13 @@ def get(self, project_name: str, run_name: str) -> Run:
5353
resp = self._request(f"/api/project/{project_name}/runs/get", body=json_body)
5454
return parse_obj_as(Run.__response__, resp.json())
5555

56-
def get_plan(self, project_name: str, run_spec: RunSpec) -> RunPlan:
57-
body = GetRunPlanRequest(run_spec=run_spec)
56+
def get_plan(
57+
self, project_name: str, run_spec: RunSpec, max_offers: Optional[int] = None
58+
) -> RunPlan:
59+
body = GetRunPlanRequest(run_spec=run_spec, max_offers=max_offers)
5860
resp = self._request(
5961
f"/api/project/{project_name}/runs/get_plan",
60-
body=body.json(exclude=_get_run_spec_excludes(run_spec)),
62+
body=body.json(exclude=_get_get_plan_excludes(body)),
6163
)
6264
return parse_obj_as(RunPlan.__response__, resp.json())
6365

@@ -96,6 +98,17 @@ def _get_apply_plan_excludes(plan: ApplyRunPlanInput) -> Optional[Dict]:
9698
return None
9799

98100

101+
def _get_get_plan_excludes(request: GetRunPlanRequest) -> Optional[Dict]:
102+
"""
103+
Excludes new fields when they are not set to keep
104+
clients backward-compatibility with older servers.
105+
"""
106+
run_spec_excludes = _get_run_spec_excludes(request.run_spec)
107+
if request.max_offers is None:
108+
run_spec_excludes["max_offers"] = True
109+
return run_spec_excludes
110+
111+
99112
def _get_run_spec_excludes(run_spec: RunSpec) -> Optional[Dict]:
100113
"""
101114
Returns `run_spec` exclude mapping to exclude certain fields from the request.

0 commit comments

Comments
 (0)