Skip to content

Commit 305a362

Browse files
[Feature]: Allow listing available key resources such as gpu, region, and backends #2142
Merged `dstack gpu` to `dstack offer`
1 parent fdf6da6 commit 305a362

4 files changed

Lines changed: 314 additions & 63 deletions

File tree

Lines changed: 68 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,20 @@
11
import argparse
2-
import contextlib
3-
import json
42
from pathlib import Path
3+
from typing import List
54

65
from dstack._internal.cli.commands import APIBaseCommand
7-
from dstack._internal.cli.services.configurators.run import (
8-
BaseRunConfigurator,
9-
)
6+
from dstack._internal.cli.services.configurators.run import BaseRunConfigurator
107
from dstack._internal.cli.utils.common import console
11-
from dstack._internal.cli.utils.run import print_run_plan
12-
from dstack._internal.core.models.configurations import (
13-
ApplyConfigurationType,
14-
TaskConfiguration,
15-
)
8+
from dstack._internal.cli.utils.gpu import print_gpu_json, print_gpu_table
9+
from dstack._internal.cli.utils.run import print_offers_json, print_run_plan
10+
from dstack._internal.core.errors import CLIError
11+
from dstack._internal.core.models.configurations import ApplyConfigurationType, TaskConfiguration
1612
from dstack._internal.core.models.runs import RunSpec
13+
from dstack._internal.server.schemas.gpus import GpuGroup
1714
from dstack.api.utils import load_profile
1815

1916

2017
class OfferConfigurator(BaseRunConfigurator):
21-
# TODO: The command currently uses `BaseRunConfigurator` to register arguments.
22-
# This includes --env, --retry-policy, and other arguments that are unnecessary for this command.
23-
# Eventually, we should introduce a base `OfferConfigurator` that doesn't include those arguments—
24-
# `BaseRunConfigurator` will inherit from `OfferConfigurator`.
25-
#
26-
# Additionally, it should have its own type: `ApplyConfigurationType.OFFER`.
2718
TYPE = ApplyConfigurationType.TASK
2819

2920
@classmethod
@@ -32,10 +23,18 @@ def register_args(
3223
parser: argparse.ArgumentParser,
3324
):
3425
super().register_args(parser, default_max_offers=50)
26+
parser.add_argument(
27+
"--group-by",
28+
action="append",
29+
help=(
30+
"Group results by fields ([code]gpu[/code], [code]backend[/code], [code]region[/code], [code]count[/code]). "
31+
"Optional, but if used, must include [code]gpu[/code]. "
32+
"The use of [code]region[/code] also requires [code]backend[/code]. "
33+
"Can be repeated or comma-separated (e.g. [code]--group-by gpu,backend[/code])."
34+
),
35+
)
3536

3637

37-
# TODO: Support aggregated offers
38-
# TODO: Add tests
3938
class OfferCommand(APIBaseCommand):
4039
NAME = "offer"
4140
DESCRIPTION = "List offers"
@@ -70,49 +69,58 @@ def _command(self, args: argparse.Namespace):
7069
ssh_key_pub="(dummy)",
7170
profile=profile,
7271
)
72+
73+
if args.group_by:
74+
args.group_by = self._process_group_by_args(args.group_by)
75+
76+
if args.group_by and "gpu" not in args.group_by:
77+
group_values = ", ".join(args.group_by)
78+
raise CLIError(f"Cannot group by '{group_values}' without also grouping by 'gpu'")
79+
7380
if args.format == "plain":
74-
status = console.status("Getting offers...")
81+
with console.status("Getting offers..."):
82+
if args.group_by:
83+
gpus = self._list_gpus(args, run_spec)
84+
print_gpu_table(gpus, run_spec, args.group_by, self.api.project)
85+
else:
86+
run_plan = self.api.client.runs.get_plan(
87+
self.api.project,
88+
run_spec,
89+
max_offers=args.max_offers,
90+
)
91+
print_run_plan(run_plan, include_run_properties=False)
7592
else:
76-
status = contextlib.nullcontext()
77-
with status:
78-
run_plan = self.api.client.runs.get_plan(
79-
self.api.project,
80-
run_spec,
81-
max_offers=args.max_offers,
82-
)
83-
84-
job_plan = run_plan.job_plans[0]
85-
86-
if args.format == "json":
87-
# FIXME: Should use effective_run_spec from run_plan,
88-
# since the spec can be changed by the server and plugins
89-
output = {
90-
"project": run_plan.project_name,
91-
"user": run_plan.user,
92-
"resources": job_plan.job_spec.requirements.resources.dict(),
93-
"max_price": (job_plan.job_spec.requirements.max_price),
94-
"spot": run_spec.configuration.spot_policy,
95-
"reservation": run_plan.run_spec.configuration.reservation,
96-
"offers": [],
97-
"total_offers": job_plan.total_offers,
98-
}
99-
100-
for offer in job_plan.offers:
101-
output["offers"].append(
102-
{
103-
"backend": (
104-
"ssh" if offer.backend.value == "remote" else offer.backend.value
105-
),
106-
"region": offer.region,
107-
"instance_type": offer.instance.name,
108-
"resources": offer.instance.resources.dict(),
109-
"spot": offer.instance.resources.spot,
110-
"price": float(offer.price),
111-
"availability": offer.availability.value,
112-
}
93+
if args.group_by:
94+
gpus = self._list_gpus(args, run_spec)
95+
print_gpu_json(gpus, run_spec, args.group_by, self.api.project)
96+
else:
97+
run_plan = self.api.client.runs.get_plan(
98+
self.api.project,
99+
run_spec,
100+
max_offers=args.max_offers,
113101
)
102+
print_offers_json(run_plan, run_spec)
114103

115-
print(json.dumps(output, indent=2))
116-
return
117-
else:
118-
print_run_plan(run_plan, include_run_properties=False)
104+
def _process_group_by_args(self, group_by_args: List[str]) -> List[str]:
105+
valid_choices = {"gpu", "backend", "region", "count"}
106+
processed = []
107+
108+
for arg in group_by_args:
109+
values = [v.strip() for v in arg.split(",") if v.strip()]
110+
for value in values:
111+
if value in valid_choices:
112+
processed.append(value)
113+
else:
114+
raise CLIError(
115+
f"Invalid group-by value: '{value}'. Valid choices are: {', '.join(sorted(valid_choices))}"
116+
)
117+
118+
return processed
119+
120+
def _list_gpus(self, args: List[str], run_spec: RunSpec) -> List[GpuGroup]:
121+
group_by = [g for g in args.group_by if g != "gpu"] or None
122+
return self.api.client.gpus.list_gpus(
123+
self.api.project,
124+
run_spec,
125+
group_by=group_by,
126+
)
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
import shutil
2+
from typing import List
3+
4+
from rich.table import Table
5+
6+
from dstack._internal.cli.utils.common import console
7+
from dstack._internal.core.models.profiles import SpotPolicy
8+
from dstack._internal.core.models.runs import Requirements, RunSpec, get_policy_map
9+
from dstack._internal.server.schemas.gpus import GpuGroup
10+
11+
12+
def print_gpu_json(gpu_response, run_spec, group_by_cli, api_project):
13+
"""Print GPU information in JSON format."""
14+
req = Requirements(
15+
resources=run_spec.configuration.resources,
16+
max_price=run_spec.merged_profile.max_price,
17+
spot=get_policy_map(run_spec.merged_profile.spot_policy, default=SpotPolicy.AUTO),
18+
reservation=run_spec.configuration.reservation,
19+
)
20+
21+
if req.spot is None:
22+
spot_policy = "auto"
23+
elif req.spot:
24+
spot_policy = "spot"
25+
else:
26+
spot_policy = "on-demand"
27+
28+
output = {
29+
"project": api_project,
30+
"user": "admin", # TODO: Get actual user name
31+
"resources": req.resources.dict(),
32+
"spot_policy": spot_policy,
33+
"max_price": req.max_price,
34+
"reservation": run_spec.configuration.reservation,
35+
"group_by": group_by_cli,
36+
"gpus": [],
37+
}
38+
39+
for gpu_group in gpu_response.gpus:
40+
gpu_data = {
41+
"name": gpu_group.name,
42+
"memory_mib": gpu_group.memory_mib,
43+
"vendor": gpu_group.vendor.value,
44+
"availability": [av.value for av in gpu_group.availability],
45+
"spot": gpu_group.spot,
46+
"count": {"min": gpu_group.count.min, "max": gpu_group.count.max},
47+
"price": {"min": gpu_group.price.min, "max": gpu_group.price.max},
48+
}
49+
50+
if gpu_group.backend:
51+
gpu_data["backend"] = gpu_group.backend.value
52+
if gpu_group.backends:
53+
gpu_data["backends"] = [b.value for b in gpu_group.backends]
54+
if gpu_group.region:
55+
gpu_data["region"] = gpu_group.region
56+
if gpu_group.regions:
57+
gpu_data["regions"] = gpu_group.regions
58+
59+
output["gpus"].append(gpu_data)
60+
61+
import json
62+
63+
print(json.dumps(output, indent=2))
64+
65+
66+
def print_gpu_table(gpus: List[GpuGroup], run_spec: RunSpec, group_by: List[str], project: str):
67+
"""Print GPU information in a formatted table."""
68+
print_filter_info(run_spec, group_by, project)
69+
70+
has_single_backend = any(gpu_group.backend for gpu_group in gpus)
71+
has_single_region = any(gpu_group.region for gpu_group in gpus)
72+
has_multiple_regions = any(gpu_group.regions for gpu_group in gpus)
73+
74+
if has_single_backend and has_single_region:
75+
backend_column = "BACKEND"
76+
region_column = "REGION"
77+
elif has_single_backend and has_multiple_regions:
78+
backend_column = "BACKEND"
79+
region_column = "REGIONS"
80+
else:
81+
backend_column = "BACKENDS"
82+
region_column = None
83+
84+
table = Table(box=None, expand=shutil.get_terminal_size(fallback=(120, 40)).columns <= 110)
85+
table.add_column("#")
86+
table.add_column("GPU", no_wrap=True, ratio=2)
87+
table.add_column("SPOT", style="grey58", ratio=1)
88+
table.add_column("$/GPU", style="grey58", ratio=1)
89+
table.add_column(backend_column, style="grey58", ratio=2)
90+
if region_column:
91+
table.add_column(region_column, style="grey58", ratio=2)
92+
table.add_column()
93+
94+
for i, gpu_group in enumerate(gpus, start=1):
95+
backend_text = ""
96+
if gpu_group.backend:
97+
backend_text = gpu_group.backend.value
98+
elif gpu_group.backends:
99+
backend_text = ", ".join(b.value for b in gpu_group.backends)
100+
101+
region_text = ""
102+
if gpu_group.region:
103+
region_text = gpu_group.region
104+
elif gpu_group.regions:
105+
if len(gpu_group.regions) <= 3:
106+
region_text = ", ".join(gpu_group.regions)
107+
else:
108+
region_text = f"{len(gpu_group.regions)} regions"
109+
110+
if not region_column:
111+
if gpu_group.regions and len(gpu_group.regions) > 3:
112+
shortened_region_text = f"{len(gpu_group.regions)} regions"
113+
backends_display = (
114+
f"{backend_text} ({shortened_region_text})"
115+
if shortened_region_text
116+
else backend_text
117+
)
118+
else:
119+
backends_display = (
120+
f"{backend_text} ({region_text})" if region_text else backend_text
121+
)
122+
else:
123+
backends_display = backend_text
124+
125+
memory_gb = f"{gpu_group.memory_mib // 1024}GB"
126+
if gpu_group.count.min == gpu_group.count.max:
127+
count_range = str(gpu_group.count.min)
128+
else:
129+
count_range = f"{gpu_group.count.min}..{gpu_group.count.max}"
130+
131+
gpu_spec = f"{gpu_group.name}:{memory_gb}:{count_range}"
132+
133+
spot_types = []
134+
if "spot" in gpu_group.spot:
135+
spot_types.append("spot")
136+
if "on-demand" in gpu_group.spot:
137+
spot_types.append("on-demand")
138+
spot_display = ", ".join(spot_types)
139+
140+
if gpu_group.price.min == gpu_group.price.max:
141+
price_display = f"{gpu_group.price.min:.4f}".rstrip("0").rstrip(".")
142+
else:
143+
min_formatted = f"{gpu_group.price.min:.4f}".rstrip("0").rstrip(".")
144+
max_formatted = f"{gpu_group.price.max:.4f}".rstrip("0").rstrip(".")
145+
price_display = f"{min_formatted}..{max_formatted}"
146+
147+
availability = ""
148+
has_available = any(av.is_available() for av in gpu_group.availability)
149+
has_unavailable = any(not av.is_available() for av in gpu_group.availability)
150+
151+
if has_unavailable and not has_available:
152+
for av in gpu_group.availability:
153+
if av.value in {"not_available", "no_quota", "idle", "busy"}:
154+
availability = av.value.replace("_", " ").lower()
155+
break
156+
157+
secondary_style = "grey58"
158+
row_data = [
159+
f"[{secondary_style}]{i}[/]",
160+
gpu_spec,
161+
f"[{secondary_style}]{spot_display}[/]",
162+
f"[{secondary_style}]{price_display}[/]",
163+
f"[{secondary_style}]{backends_display}[/]",
164+
]
165+
if region_column:
166+
row_data.append(f"[{secondary_style}]{region_text}[/]")
167+
row_data.append(f"[{secondary_style}]{availability}[/]")
168+
169+
table.add_row(*row_data)
170+
171+
console.print(table)
172+
173+
174+
def print_filter_info(run_spec: RunSpec, group_by: List[str], project: str):
175+
"""Print filter information for GPU display."""
176+
props = Table(box=None, show_header=False)
177+
props.add_column(no_wrap=True)
178+
props.add_column()
179+
180+
req = Requirements(
181+
resources=run_spec.configuration.resources,
182+
max_price=run_spec.merged_profile.max_price,
183+
spot=get_policy_map(run_spec.merged_profile.spot_policy, default=SpotPolicy.AUTO),
184+
reservation=run_spec.merged_profile.reservation,
185+
)
186+
187+
pretty_req = req.pretty_format(resources_only=True)
188+
max_price = f"${req.max_price:3f}".rstrip("0").rstrip(".") if req.max_price else "-"
189+
190+
if req.spot is None:
191+
spot_policy = "auto"
192+
elif req.spot:
193+
spot_policy = "spot"
194+
else:
195+
spot_policy = "on-demand"
196+
197+
def th(s: str) -> str:
198+
return f"[bold]{s}[/bold]"
199+
200+
props.add_row(th("Project"), project)
201+
props.add_row(th("User"), "admin") # TODO: Get actual user name
202+
props.add_row(th("Resources"), pretty_req)
203+
props.add_row(th("Spot policy"), spot_policy)
204+
props.add_row(th("Max price"), max_price)
205+
props.add_row(th("Reservation"), run_spec.configuration.reservation or "-")
206+
if group_by:
207+
props.add_row(th("Group by"), ", ".join(group_by))
208+
209+
console.print(props)
210+
console.print()

0 commit comments

Comments
 (0)