Skip to content

Commit ba07858

Browse files
committed
feat(azure/rhel-ai): add GPU guardrails for instance selection
- Add isGPUCapableSize helper matching ND/NC series (NV excluded) - Shallow-copy ComputeRequestArgs before mutation to avoid caller side-effects - Default ComputeRequest.GPUs to 1 so filterCPUsAndMemory auto-selects only GPU-capable instance types when no explicit GPU count is set - Warn when caller explicitly provides compute sizes that are not GPU-capable (expected ND/NC series; vllm requires a GPU device)
1 parent 44ef648 commit ba07858

3 files changed

Lines changed: 70 additions & 4 deletions

File tree

pkg/provider/azure/action/rhel-ai/rhelai.go

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,23 +34,51 @@ func imageId(accelerator, version string) string {
3434
return imageIdFromName(fmt.Sprintf(imageNameRegex, accelerator, version))
3535
}
3636

37+
// isGPUCapableSize returns true for ND-series and NC-series Azure VM sizes,
38+
// which are the compute GPU families supported for RHEL AI workloads.
39+
// NV-series (visualization GPUs) is intentionally excluded.
40+
func isGPUCapableSize(vmSize string) bool {
41+
lower := strings.ToLower(vmSize)
42+
return strings.HasPrefix(lower, "standard_nd") || strings.HasPrefix(lower, "standard_nc")
43+
}
44+
3745
func Create(mCtxArgs *maptContext.ContextArgs, args *apiRHELAI.RHELAIArgs) (err error) {
38-
logging.Debug("Creating RHEL Server")
46+
if args == nil || args.ComputeRequest == nil {
47+
return fmt.Errorf("RHEL AI: args and ComputeRequest must not be nil")
48+
}
49+
logging.Debug("Creating RHEL AI Server")
3950
sharedImageID := imageId(args.Accelerator, args.Version)
4051
if args.CustomImage != "" {
4152
sharedImageID = imageIdFromName(args.CustomImage)
4253
}
54+
// Shallow-copy to avoid mutating the caller's ComputeRequestArgs.
55+
computeReq := *args.ComputeRequest
56+
// Ensure GPU-capable instance selection for auto-selection paths.
57+
if computeReq.GPUs == 0 {
58+
logging.Debug("RHEL AI: GPUs not set, defaulting to 1 for GPU-capable instance selection")
59+
computeReq.GPUs = 1
60+
}
61+
// All explicitly specified sizes must be GPU-capable; a single non-GPU entry
62+
// could get allocated and vllm would fail silently.
63+
for _, s := range computeReq.ComputeSizes {
64+
if !isGPUCapableSize(s) {
65+
return fmt.Errorf("RHEL AI: %q is not GPU-capable (expected ND-series or NC-series for vllm)", s)
66+
}
67+
}
4368
azureLinuxRequest :=
4469
&azureLinux.LinuxArgs{
4570
Prefix: args.Prefix,
46-
ComputeRequest: args.ComputeRequest,
71+
ComputeRequest: &computeReq,
4772
Spot: args.Spot,
4873
ImageRef: &data.ImageReference{
4974
SharedImageID: sharedImageID,
5075
},
5176
Username: username,
5277
ReadinessCommand: command.CommandPing}
53-
return azureLinux.Create(mCtxArgs, azureLinuxRequest)
78+
if err = azureLinux.Create(mCtxArgs, azureLinuxRequest); err != nil && len(computeReq.ComputeSizes) == 0 {
79+
return fmt.Errorf("RHEL AI: failed to provision a GPU-capable instance (ND/NC-series required for vllm); verify GPU quota in the target location/subscription: %w", err)
80+
}
81+
return err
5482
}
5583

5684
func Destroy(mCtxArgs *maptContext.ContextArgs) error {
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package rhelai
2+
3+
import "testing"
4+
5+
func TestIsGPUCapableSize(t *testing.T) {
6+
cases := []struct {
7+
size string
8+
expected bool
9+
}{
10+
{"Standard_ND96asr_v4", true},
11+
{"Standard_ND40rs_v2", true},
12+
{"Standard_NC6s_v3", true},
13+
{"Standard_NC24rs_v3", true},
14+
{"standard_nd96asr_v4", true},
15+
{"standard_nc6s_v3", true},
16+
{"Standard_D8as_v5", false},
17+
{"Standard_E16as_v5", false},
18+
{"Standard_F32s_v2", false},
19+
{"Standard_NV6", false},
20+
{"Standard_NV36ads_A10_v5", false},
21+
{"", false},
22+
}
23+
for _, tc := range cases {
24+
got := isGPUCapableSize(tc.size)
25+
if got != tc.expected {
26+
t.Errorf("isGPUCapableSize(%q) = %v, want %v", tc.size, got, tc.expected)
27+
}
28+
}
29+
}

pkg/provider/azure/data/compute-request.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,10 +251,19 @@ func filterCPUsAndMemory(args *cr.ComputeRequestArgs) filterFunc {
251251
if args.GPUs > 0 && vm.GPUs < args.GPUs {
252252
return
253253
}
254+
// GPU VMs (ND/NC-series) have large temp disks, so skip the
255+
// local-storage check that would otherwise reject them.
256+
featuresOK := false
257+
if args.GPUs > 0 {
258+
featuresOK = vm.AcceleratedNetworkingEnabled && vm.PremiumIO &&
259+
vm.EncryptionAtHostSupported && vm.hypervGen2Supported()
260+
} else {
261+
featuresOK = vm.baseFeaturesSupported()
262+
}
254263
if vm.VCPUs >= args.CPUs &&
255264
vm.Memory >= args.MemoryGib &&
256265
vm.Arch == args.Arch.String() &&
257-
vm.baseFeaturesSupported() {
266+
featuresOK {
258267
dSeries := regexp.MustCompile(lowerCpuPattern)
259268
if !dSeries.Match([]byte(vm.Name)) {
260269
vmCh <- vm.Name

0 commit comments

Comments
 (0)