Skip to content

Commit 3d8c9ba

Browse files
Set explicit GPU defaults in ResourcesSpec and improve default GPU vendor selection (#3573)
* Set explicit GPU default (`0..`) in `ResourcesSpec` and minor improvements in resource pretty-printing * Change how GPU vendor default is set to make it more explicit - Default to NVIDIA only if user has no image - Keep backward compatibility with old/new server/CLI - Make `dstack offer` consistent with `dstack apply`
1 parent 5cc60b7 commit 3d8c9ba

File tree

11 files changed

+219
-37
lines changed

11 files changed

+219
-37
lines changed

src/dstack/_internal/cli/commands/offer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,10 @@ def _register(self):
7474

7575
def _command(self, args: argparse.Namespace):
7676
super()._command(args)
77-
conf = TaskConfiguration(commands=[":"])
77+
# Set image and user so that the server (a) does not default gpu.vendor
78+
# to nvidia — `dstack offer` should show all vendors, and (b) does not
79+
# attempt to pull image config from the Docker registry.
80+
conf = TaskConfiguration(commands=[":"], image="scratch", user="root")
7881

7982
configurator = OfferConfigurator(api_client=self.api)
8083
configurator.apply_args(conf, args)

src/dstack/_internal/cli/services/configurators/run.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,13 @@ def interpolate_env(self, conf: RunConfigurationT):
383383

384384
def validate_gpu_vendor_and_image(self, conf: RunConfigurationT) -> None:
385385
"""
386-
Infers and sets `resources.gpu.vendor` if not set, requires `image` if the vendor is AMD.
386+
Infers GPU vendor if not set. Defaults to Nvidia when using the default
387+
CUDA image. Requires explicit `image` if the vendor is AMD or Tenstorrent.
388+
389+
NOTE: We don't set the inferred vendor on gpu_spec for compatibility with
390+
older servers. Servers set the vendor using the same logic in
391+
set_resources_defaults(). The inferred vendor is used here only for
392+
validation and display (see _infer_gpu_vendor).
387393
"""
388394
gpu_spec = conf.resources.gpu
389395
if gpu_spec is None:
@@ -425,12 +431,18 @@ def validate_gpu_vendor_and_image(self, conf: RunConfigurationT) -> None:
425431
# CUDA image, not a big deal.
426432
has_amd_gpu = gpuhunt.AcceleratorVendor.AMD in vendors
427433
has_tt_gpu = gpuhunt.AcceleratorVendor.TENSTORRENT in vendors
434+
# Set vendor inferred from name on the spec (server needs it for filtering).
435+
gpu_spec.vendor = vendor
428436
else:
429-
# If neither gpu.vendor nor gpu.name is set, assume Nvidia.
430-
vendor = gpuhunt.AcceleratorVendor.NVIDIA
437+
# No vendor or name specified. Default to Nvidia if using the default
438+
# CUDA image, since it's only compatible with Nvidia GPUs.
439+
# We don't set the inferred vendor on the spec — the server does the
440+
# same inference in set_resources_defaults() for compatibility with
441+
# older servers that don't handle vendor + count.min=0 correctly.
442+
if conf.image is None and conf.docker is not True:
443+
vendor = gpuhunt.AcceleratorVendor.NVIDIA
431444
has_amd_gpu = False
432445
has_tt_gpu = False
433-
gpu_spec.vendor = vendor
434446
else:
435447
has_amd_gpu = vendor == gpuhunt.AcceleratorVendor.AMD
436448
has_tt_gpu = vendor == gpuhunt.AcceleratorVendor.TENSTORRENT

src/dstack/_internal/core/models/resources.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,9 @@ def _vendor_from_string(cls, v: str) -> gpuhunt.AcceleratorVendor:
319319
return gpuhunt.AcceleratorVendor.cast(v)
320320

321321

322+
DEFAULT_GPU_SPEC = GPUSpec(count=Range[int](min=0, max=None))
323+
324+
322325
class DiskSpecConfig(CoreConfig):
323326
@staticmethod
324327
def schema_extra(schema: Dict[str, Any]):
@@ -387,7 +390,8 @@ class ResourcesSpec(generate_dual_core_model(ResourcesSpecConfig)):
387390
"you may need to configure this"
388391
),
389392
] = None
390-
gpu: Annotated[Optional[GPUSpec], Field(description="The GPU requirements")] = None
393+
# Optional for backward compatibility
394+
gpu: Annotated[Optional[GPUSpec], Field(description="The GPU requirements")] = DEFAULT_GPU_SPEC
391395
disk: Annotated[Optional[DiskSpec], Field(description="The disk resources")] = DEFAULT_DISK
392396

393397
def pretty_format(self) -> str:
@@ -397,6 +401,7 @@ def pretty_format(self) -> str:
397401
if self.gpu:
398402
gpu = self.gpu
399403
resources.update(
404+
gpu_vendor=gpu.vendor,
400405
gpu_name=",".join(gpu.name) if gpu.name else None,
401406
gpu_count=gpu.count,
402407
gpu_memory=gpu.memory,

src/dstack/_internal/server/services/resources.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
import gpuhunt
24
from pydantic import parse_obj_as
35

@@ -19,3 +21,24 @@ def set_resources_defaults(resources: ResourcesSpec) -> None:
1921
else:
2022
cpu.arch = gpuhunt.CPUArchitecture.X86
2123
resources.cpu = cpu
24+
25+
26+
def set_gpu_vendor_default(
27+
resources: ResourcesSpec,
28+
image: Optional[str],
29+
docker: Optional[bool],
30+
) -> None:
31+
"""Default GPU vendor to Nvidia when using the default CUDA image,
32+
since it's only compatible with Nvidia GPUs.
33+
Mirrors the client-side logic in validate_gpu_vendor_and_image().
34+
Should only be called for runs (not fleets) since fleets don't have image context."""
35+
gpu = resources.gpu
36+
if (
37+
gpu is not None
38+
and gpu.vendor is None
39+
and gpu.name is None
40+
and gpu.count.max != 0
41+
and image is None
42+
and docker is not True
43+
):
44+
gpu.vendor = gpuhunt.AcceleratorVendor.NVIDIA

src/dstack/_internal/server/services/runs/__init__.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,10 @@
6565
from dstack._internal.server.services.plugins import apply_plugin_policies
6666
from dstack._internal.server.services.probes import is_probe_ready
6767
from dstack._internal.server.services.projects import list_user_project_models
68-
from dstack._internal.server.services.resources import set_resources_defaults
68+
from dstack._internal.server.services.resources import (
69+
set_gpu_vendor_default,
70+
set_resources_defaults,
71+
)
6972
from dstack._internal.server.services.runs.plan import get_job_plans
7073
from dstack._internal.server.services.runs.spec import (
7174
can_update_run_spec,
@@ -343,8 +346,8 @@ async def get_plan(
343346
)
344347
if current_resource is not None:
345348
# For backward compatibility (current_resource may has been submitted before
346-
# some fields, e.g., CPUSpec.arch, were added)
347-
set_resources_defaults(current_resource.run_spec.configuration.resources)
349+
# some fields, e.g., CPUSpec.arch, gpu.vendor were added)
350+
_set_run_resources_defaults(current_resource.run_spec)
348351
if not current_resource.status.is_finished() and can_update_run_spec(
349352
current_resource.run_spec, effective_run_spec
350353
):
@@ -354,7 +357,7 @@ async def get_plan(
354357
session=session,
355358
project=project,
356359
profile=profile,
357-
run_spec=run_spec,
360+
run_spec=effective_run_spec,
358361
max_offers=max_offers,
359362
)
360363
run_plan = RunPlan(
@@ -410,8 +413,8 @@ async def apply_plan(
410413
current_resource = run_model_to_run(current_resource_model, return_in_api=True)
411414

412415
# For backward compatibility (current_resource may has been submitted before
413-
# some fields, e.g., CPUSpec.arch, were added)
414-
set_resources_defaults(current_resource.run_spec.configuration.resources)
416+
# some fields, e.g., CPUSpec.arch, gpu.vendor were added)
417+
_set_run_resources_defaults(current_resource.run_spec)
415418
try:
416419
spec_diff = check_can_update_run_spec(current_resource.run_spec, run_spec)
417420
except ServerClientError:
@@ -421,7 +424,7 @@ async def apply_plan(
421424
raise
422425
if not force:
423426
if plan.current_resource is not None:
424-
set_resources_defaults(plan.current_resource.run_spec.configuration.resources)
427+
_set_run_resources_defaults(plan.current_resource.run_spec)
425428
if (
426429
plan.current_resource is None
427430
or plan.current_resource.id != current_resource.id
@@ -782,6 +785,16 @@ def run_model_to_run(
782785
return run
783786

784787

788+
def _set_run_resources_defaults(run_spec: RunSpec) -> None:
789+
"""Apply resource defaults to a run spec, including GPU vendor inference."""
790+
set_resources_defaults(run_spec.configuration.resources)
791+
set_gpu_vendor_default(
792+
run_spec.configuration.resources,
793+
image=run_spec.configuration.image,
794+
docker=getattr(run_spec.configuration, "docker", None),
795+
)
796+
797+
785798
def _get_run_jobs_with_submissions(
786799
run_model: RunModel,
787800
job_submissions_limit: Optional[int],

src/dstack/_internal/server/services/runs/spec.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
from dstack._internal.server import settings
99
from dstack._internal.server.models import UserModel
1010
from dstack._internal.server.services.docker import is_valid_docker_volume_target
11-
from dstack._internal.server.services.resources import set_resources_defaults
11+
from dstack._internal.server.services.resources import (
12+
set_gpu_vendor_default,
13+
set_resources_defaults,
14+
)
1215
from dstack._internal.utils.logging import get_logger
1316

1417
logger = get_logger(__name__)
@@ -108,6 +111,11 @@ def validate_run_spec_and_set_defaults(
108111
if run_spec.configuration.priority is None:
109112
run_spec.configuration.priority = RUN_PRIORITY_DEFAULT
110113
set_resources_defaults(run_spec.configuration.resources)
114+
set_gpu_vendor_default(
115+
run_spec.configuration.resources,
116+
image=run_spec.configuration.image,
117+
docker=getattr(run_spec.configuration, "docker", None),
118+
)
111119
if run_spec.ssh_key_pub is None:
112120
if user.ssh_public_key:
113121
run_spec.ssh_key_pub = user.ssh_public_key

src/dstack/_internal/utils/common.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -91,28 +91,14 @@ def pretty_resources(
9191
cpus: Optional[Any] = None,
9292
memory: Optional[Any] = None,
9393
gpu_count: Optional[Any] = None,
94+
gpu_vendor: Optional[Any] = None,
9495
gpu_name: Optional[Any] = None,
9596
gpu_memory: Optional[Any] = None,
9697
total_gpu_memory: Optional[Any] = None,
9798
compute_capability: Optional[Any] = None,
9899
disk_size: Optional[Any] = None,
99100
) -> str:
100-
"""
101-
>>> pretty_resources(cpus=4, memory="16GB")
102-
'4xCPU, 16GB'
103-
>>> pretty_resources(cpus=4, memory="16GB", gpu_count=1)
104-
'4xCPU, 16GB, 1xGPU'
105-
>>> pretty_resources(cpus=4, memory="16GB", gpu_count=1, gpu_name='A100')
106-
'4xCPU, 16GB, 1xA100'
107-
>>> pretty_resources(cpus=4, memory="16GB", gpu_count=1, gpu_name='A100', gpu_memory="40GB")
108-
'4xCPU, 16GB, 1xA100 (40GB)'
109-
>>> pretty_resources(cpus=4, memory="16GB", gpu_count=1, total_gpu_memory="80GB")
110-
'4xCPU, 16GB, 1xGPU (total 80GB)'
111-
>>> pretty_resources(cpus=4, memory="16GB", gpu_count=2, gpu_name='A100', gpu_memory="40GB", total_gpu_memory="80GB")
112-
'4xCPU, 16GB, 2xA100 (40GB, total 80GB)'
113-
>>> pretty_resources(gpu_count=1, compute_capability="8.0")
114-
'1xGPU (8.0)'
115-
"""
101+
"""Format resource requirements as a human-readable string."""
116102
parts = []
117103
if cpus is not None:
118104
cpu_arch_lower: Optional[str] = None
@@ -131,7 +117,6 @@ def pretty_resources(
131117
parts.append(f"disk={disk_size}")
132118
if gpu_count:
133119
gpu_parts = []
134-
gpu_parts.append(f"{gpu_name or 'gpu'}")
135120
if gpu_memory is not None:
136121
gpu_parts.append(f"{gpu_memory}")
137122
if gpu_count is not None:
@@ -141,8 +126,13 @@ def pretty_resources(
141126
if compute_capability is not None:
142127
gpu_parts.append(f"{compute_capability}")
143128

144-
gpu = ":".join(gpu_parts)
145-
parts.append(gpu)
129+
if gpu_name:
130+
parts.append("gpu=" + ":".join([f"{gpu_name}"] + gpu_parts))
131+
elif gpu_vendor:
132+
vendor_str = gpu_vendor.value if isinstance(gpu_vendor, enum.Enum) else str(gpu_vendor)
133+
parts.append("gpu=" + ":".join([vendor_str] + gpu_parts))
134+
else:
135+
parts.append("gpu=" + ":".join(gpu_parts))
146136
return " ".join(parts)
147137

148138

src/tests/_internal/cli/services/configurators/test_run.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,13 +132,34 @@ def validate(self, conf: BaseRunConfiguration) -> None:
132132
def test_no_gpu(self):
133133
conf = self.prepare_conf()
134134
self.validate(conf)
135-
assert conf.resources.gpu is None
135+
assert conf.resources.gpu is not None
136+
# Vendor is not written to spec for compatibility with older servers.
137+
# The server infers nvidia in set_resources_defaults().
138+
assert conf.resources.gpu.vendor is None
139+
assert conf.resources.gpu.name is None
140+
assert conf.resources.gpu.count.min == 0
136141

137142
def test_zero_gpu(self):
138143
conf = self.prepare_conf(gpu_spec="0")
139144
self.validate(conf)
140145
assert conf.resources.gpu.vendor is None
141146

147+
def test_gpu_no_vendor_no_image_defaults_to_nvidia(self):
148+
"""Vendor is inferred as nvidia for validation but NOT written to spec."""
149+
conf = self.prepare_conf(gpu_spec="1")
150+
self.validate(conf)
151+
assert conf.resources.gpu.vendor is None
152+
153+
def test_gpu_no_vendor_with_image_no_default(self):
154+
conf = self.prepare_conf(gpu_spec="1", image="my-custom-image")
155+
self.validate(conf)
156+
assert conf.resources.gpu.vendor is None
157+
158+
def test_gpu_no_vendor_docker_true_no_default(self):
159+
conf = self.prepare_conf(gpu_spec="1", docker=True)
160+
self.validate(conf)
161+
assert conf.resources.gpu.vendor is None
162+
142163
@pytest.mark.parametrize(
143164
["gpu_spec", "expected_vendor"],
144165
[

src/tests/_internal/server/routers/test_fleets.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,14 @@ async def test_creates_fleet(self, test_db, session: AsyncSession, client: Async
344344
"cpu": {"min": 2, "max": None},
345345
"memory": {"min": 8.0, "max": None},
346346
"shm_size": None,
347-
"gpu": None,
347+
"gpu": {
348+
"vendor": None,
349+
"name": None,
350+
"count": {"min": 0, "max": None},
351+
"memory": None,
352+
"total_memory": None,
353+
"compute_capability": None,
354+
},
348355
"disk": {"size": {"min": 100.0, "max": None}},
349356
},
350357
"backends": None,
@@ -467,7 +474,14 @@ async def test_creates_ssh_fleet(self, test_db, session: AsyncSession, client: A
467474
"cpu": {"min": 2, "max": None},
468475
"memory": {"min": 8.0, "max": None},
469476
"shm_size": None,
470-
"gpu": None,
477+
"gpu": {
478+
"vendor": None,
479+
"name": None,
480+
"count": {"min": 0, "max": None},
481+
"memory": None,
482+
"total_memory": None,
483+
"compute_capability": None,
484+
},
471485
"disk": {"size": {"min": 100.0, "max": None}},
472486
},
473487
"backends": None,
@@ -639,7 +653,14 @@ async def test_updates_ssh_fleet(self, test_db, session: AsyncSession, client: A
639653
"cpu": {"min": 2, "max": None},
640654
"memory": {"min": 8.0, "max": None},
641655
"shm_size": None,
642-
"gpu": None,
656+
"gpu": {
657+
"vendor": None,
658+
"name": None,
659+
"count": {"min": 0, "max": None},
660+
"memory": None,
661+
"total_memory": None,
662+
"compute_capability": None,
663+
},
643664
"disk": {"size": {"min": 100.0, "max": None}},
644665
},
645666
"backends": None,

src/tests/_internal/server/routers/test_runs.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@
4747
from dstack._internal.server.models import JobModel, RunModel
4848
from dstack._internal.server.schemas.runs import ApplyRunPlanRequest
4949
from dstack._internal.server.services.projects import add_project_member
50+
from dstack._internal.server.services.resources import (
51+
set_gpu_vendor_default,
52+
set_resources_defaults,
53+
)
5054
from dstack._internal.server.services.runs import run_model_to_run
5155
from dstack._internal.server.services.runs.spec import validate_run_spec_and_set_defaults
5256
from dstack._internal.server.testing.common import (
@@ -1535,6 +1539,13 @@ async def test_returns_update_or_create_action_on_conf_change(
15351539
run_spec=run_spec,
15361540
)
15371541
run = run_model_to_run(run_model)
1542+
# Apply the same defaults the server applies to current_resource
1543+
set_resources_defaults(run.run_spec.configuration.resources)
1544+
set_gpu_vendor_default(
1545+
run.run_spec.configuration.resources,
1546+
image=run.run_spec.configuration.image,
1547+
docker=getattr(run.run_spec.configuration, "docker", None),
1548+
)
15381549
run_spec.configuration = new_conf
15391550
response = await client.post(
15401551
f"/api/project/{project.name}/runs/get_plan",

0 commit comments

Comments
 (0)