Skip to content

Commit 724c82d

Browse files
rishupkclaude
andcommitted
feat(azure/rhel-ai): add GPU guardrails for instance selection
- Add GPU capability detection to VM SKU filter - Validate that selected compute sizes are GPU-capable (ND/NC-series) - Default GPUs=1 when unset so spot allocator targets GPU instances Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Rishabh Kothari <rkothari@redhat.com>
1 parent 8d0f721 commit 724c82d

3 files changed

Lines changed: 80 additions & 4 deletions

File tree

pkg/provider/azure/action/rhel-ai/rhelai.go

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,41 @@ func imageId(accelerator, version string) string {
3434
return imageIdFromName(fmt.Sprintf(imageNameRegex, accelerator, version))
3535
}
3636

37+
// isGPUCapableSize returns true for ND-series and NC-series Azure VM sizes,
38+
// which are the compute GPU families supported for RHEL AI workloads.
39+
// NV-series (visualization GPUs) is intentionally excluded.
40+
func isGPUCapableSize(vmSize string) bool {
41+
lower := strings.ToLower(vmSize)
42+
return strings.HasPrefix(lower, "standard_nd") || strings.HasPrefix(lower, "standard_nc")
43+
}
44+
3745
func Create(mCtxArgs *maptContext.ContextArgs, args *apiRHELAI.RHELAIArgs) (err error) {
38-
logging.Debug("Creating RHEL Server")
46+
if args == nil || args.ComputeRequest == nil {
47+
return fmt.Errorf("RHEL AI: args and ComputeRequest must not be nil")
48+
}
49+
logging.Debug("Creating RHEL AI Server")
3950
sharedImageID := imageId(args.Accelerator, args.Version)
4051
if args.CustomImage != "" {
4152
sharedImageID = imageIdFromName(args.CustomImage)
4253
}
54+
// Shallow-copy to avoid mutating the caller's ComputeRequestArgs.
55+
computeReq := *args.ComputeRequest
56+
// Ensure GPU-capable instance selection for auto-selection paths.
57+
if computeReq.GPUs == 0 {
58+
logging.Debug("RHEL AI: GPUs not set, defaulting to 1 for GPU-capable instance selection")
59+
computeReq.GPUs = 1
60+
}
61+
// All explicitly specified sizes must be GPU-capable; a single non-GPU entry
62+
// could get allocated and vllm would fail silently.
63+
for _, s := range computeReq.ComputeSizes {
64+
if !isGPUCapableSize(s) {
65+
return fmt.Errorf("RHEL AI: %q is not GPU-capable (expected ND-series or NC-series for vllm)", s)
66+
}
67+
}
4368
azureLinuxRequest :=
4469
&azureLinux.LinuxArgs{
4570
Prefix: args.Prefix,
46-
ComputeRequest: args.ComputeRequest,
71+
ComputeRequest: &computeReq,
4772
Spot: args.Spot,
4873
ImageRef: &data.ImageReference{
4974
SharedImageID: sharedImageID,
@@ -55,7 +80,10 @@ func Create(mCtxArgs *maptContext.ContextArgs, args *apiRHELAI.RHELAIArgs) (err
5580
},
5681
Username: username,
5782
ReadinessCommand: command.CommandPing}
58-
return azureLinux.Create(mCtxArgs, azureLinuxRequest)
83+
if err = azureLinux.Create(mCtxArgs, azureLinuxRequest); err != nil && len(computeReq.ComputeSizes) == 0 {
84+
return fmt.Errorf("RHEL AI: failed to provision a GPU-capable instance (ND/NC-series required for vllm); verify GPU quota in the target location/subscription: %w", err)
85+
}
86+
return err
5987
}
6088

6189
func Destroy(mCtxArgs *maptContext.ContextArgs) error {
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package rhelai
2+
3+
import "testing"
4+
5+
func TestIsGPUCapableSize(t *testing.T) {
6+
cases := []struct {
7+
size string
8+
expected bool
9+
}{
10+
{"Standard_ND96asr_v4", true},
11+
{"Standard_ND40rs_v2", true},
12+
{"Standard_NC6s_v3", true},
13+
{"Standard_NC24rs_v3", true},
14+
{"standard_nd96asr_v4", true},
15+
{"standard_nc6s_v3", true},
16+
{"Standard_D8as_v5", false},
17+
{"Standard_E16as_v5", false},
18+
{"Standard_F32s_v2", false},
19+
{"Standard_NV6", false},
20+
{"Standard_NV36ads_A10_v5", false},
21+
{"", false},
22+
}
23+
for _, tc := range cases {
24+
got := isGPUCapableSize(tc.size)
25+
if got != tc.expected {
26+
t.Errorf("isGPUCapableSize(%q) = %v, want %v", tc.size, got, tc.expected)
27+
}
28+
}
29+
}

pkg/provider/azure/data/compute-request.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ type virtualMachine struct {
151151
// Spot capable
152152
LowPriorityCapable bool
153153
MaxResourceVolumeMB int32
154+
GPUs int32
154155
// IaaS or PaaS
155156
VMDeploymentTypes []string
156157
// Fast SSD
@@ -261,6 +262,12 @@ func resourceSKUToVirtualMachine(res *armcompute.ResourceSKU) *virtualMachine {
261262
return nil
262263
}
263264
vm.MaxResourceVolumeMB = int32(disk)
265+
case "GPUs":
266+
gpus, err := strconv.ParseInt(*capability.Value, 10, 32)
267+
if err != nil {
268+
return nil
269+
}
270+
vm.GPUs = int32(gpus)
264271
case "VMDeploymentTypes":
265272
vm.VMDeploymentTypes = strings.Split(*capability.Value, ",")
266273
default:
@@ -283,10 +290,22 @@ func filterCPUsAndMemory(args *cr.ComputeRequestArgs) filterFunc {
283290
if args.NestedVirt && !vm.nestedVirtSupported() {
284291
return
285292
}
293+
if args.GPUs > 0 && vm.GPUs < args.GPUs {
294+
return
295+
}
296+
// GPU VMs (ND/NC-series) have large temp disks, so skip the
297+
// local-storage check that would otherwise reject them.
298+
featuresOK := false
299+
if args.GPUs > 0 {
300+
featuresOK = vm.AcceleratedNetworkingEnabled && vm.PremiumIO &&
301+
vm.EncryptionAtHostSupported && vm.hypervGen2Supported()
302+
} else {
303+
featuresOK = vm.baseFeaturesSupported()
304+
}
286305
if vm.VCPUs >= args.CPUs &&
287306
vm.Memory >= args.MemoryGib &&
288307
vm.Arch == args.Arch.String() &&
289-
vm.baseFeaturesSupported() {
308+
featuresOK {
290309
dSeries := regexp.MustCompile(lowerCpuPattern)
291310
if !dSeries.Match([]byte(vm.Name)) {
292311
vmCh <- vm.Name

0 commit comments

Comments
 (0)