Skip to content

Commit 71e880a

Browse files
committed
feat(azure/rhel-ai): add GPU guardrails for instance selection
- Add isGPUCapableSize helper matching ND/NC series (NV excluded) - Shallow-copy ComputeRequestArgs before mutation to avoid caller side-effects - Default ComputeRequest.GPUs to 1 so filterCPUsAndMemory auto-selects only GPU-capable instance types when no explicit GPU count is set - Warn when caller explicitly provides compute sizes that are not GPU-capable (expected ND/NC series; vllm requires a GPU device)
1 parent 44ef648 commit 71e880a

2 files changed

Lines changed: 64 additions & 3 deletions

File tree

pkg/provider/azure/action/rhel-ai/rhelai.go

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,23 +34,55 @@ func imageId(accelerator, version string) string {
3434
return imageIdFromName(fmt.Sprintf(imageNameRegex, accelerator, version))
3535
}
3636

37+
// isGPUCapableSize returns true for ND-series and NC-series Azure VM sizes,
38+
// which are the compute GPU families supported for RHEL AI workloads.
39+
// NV-series (visualization GPUs) is intentionally excluded.
40+
func isGPUCapableSize(vmSize string) bool {
41+
lower := strings.ToLower(vmSize)
42+
return strings.HasPrefix(lower, "standard_nd") || strings.HasPrefix(lower, "standard_nc")
43+
}
44+
3745
func Create(mCtxArgs *maptContext.ContextArgs, args *apiRHELAI.RHELAIArgs) (err error) {
38-
logging.Debug("Creating RHEL Server")
46+
logging.Debug("Creating RHEL AI Server")
3947
sharedImageID := imageId(args.Accelerator, args.Version)
4048
if args.CustomImage != "" {
4149
sharedImageID = imageIdFromName(args.CustomImage)
4250
}
51+
// Shallow-copy to avoid mutating the caller's ComputeRequestArgs.
52+
computeReq := *args.ComputeRequest
53+
// Ensure GPU-capable instance selection for auto-selection paths.
54+
if computeReq.GPUs == 0 {
55+
logging.Debug("RHEL AI: GPUs not set, defaulting to 1 for GPU-capable instance selection")
56+
computeReq.GPUs = 1
57+
}
58+
// Warn when the caller explicitly specifies compute sizes that are not GPU-capable.
59+
if len(computeReq.ComputeSizes) > 0 {
60+
allNonGPU := true
61+
for _, s := range computeReq.ComputeSizes {
62+
if isGPUCapableSize(s) {
63+
allNonGPU = false
64+
break
65+
}
66+
}
67+
if allNonGPU {
68+
return fmt.Errorf("RHEL AI: none of the specified compute sizes %v are GPU-capable "+
69+
"(expected ND-series or NC-series for vllm)", computeReq.ComputeSizes)
70+
}
71+
}
4372
azureLinuxRequest :=
4473
&azureLinux.LinuxArgs{
4574
Prefix: args.Prefix,
46-
ComputeRequest: args.ComputeRequest,
75+
ComputeRequest: &computeReq,
4776
Spot: args.Spot,
4877
ImageRef: &data.ImageReference{
4978
SharedImageID: sharedImageID,
5079
},
5180
Username: username,
5281
ReadinessCommand: command.CommandPing}
53-
return azureLinux.Create(mCtxArgs, azureLinuxRequest)
82+
if err = azureLinux.Create(mCtxArgs, azureLinuxRequest); err != nil && len(computeReq.ComputeSizes) == 0 {
83+
return fmt.Errorf("RHEL AI: failed to provision a GPU-capable instance (ND/NC-series required for vllm); verify GPU quota in the target location/subscription: %w", err)
84+
}
85+
return err
5486
}
5587

5688
func Destroy(mCtxArgs *maptContext.ContextArgs) error {
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package rhelai
2+
3+
import "testing"
4+
5+
func TestIsGPUCapableSize(t *testing.T) {
6+
cases := []struct {
7+
size string
8+
expected bool
9+
}{
10+
{"Standard_ND96asr_v4", true},
11+
{"Standard_ND40rs_v2", true},
12+
{"Standard_NC6s_v3", true},
13+
{"Standard_NC24rs_v3", true},
14+
{"standard_nd96asr_v4", true},
15+
{"standard_nc6s_v3", true},
16+
{"Standard_D8as_v5", false},
17+
{"Standard_E16as_v5", false},
18+
{"Standard_F32s_v2", false},
19+
{"Standard_NV6", false},
20+
{"Standard_NV36ads_A10_v5", false},
21+
{"", false},
22+
}
23+
for _, tc := range cases {
24+
got := isGPUCapableSize(tc.size)
25+
if got != tc.expected {
26+
t.Errorf("isGPUCapableSize(%q) = %v, want %v", tc.size, got, tc.expected)
27+
}
28+
}
29+
}

0 commit comments

Comments
 (0)