diff --git a/pkg/provider/azure/action/rhel-ai/rhelai.go b/pkg/provider/azure/action/rhel-ai/rhelai.go index 3d2cd8e59..73e14159a 100644 --- a/pkg/provider/azure/action/rhel-ai/rhelai.go +++ b/pkg/provider/azure/action/rhel-ai/rhelai.go @@ -1,6 +1,7 @@ package rhelai import ( + "context" "fmt" "strings" @@ -35,7 +36,22 @@ func imageId(accelerator, version string) string { } func Create(mCtxArgs *maptContext.ContextArgs, args *apiRHELAI.RHELAIArgs) (err error) { - logging.Debug("Creating RHEL Server") + logging.Debug("Creating RHEL AI Server") + computeReq := *args.ComputeRequest + if len(computeReq.ComputeSizes) > 0 { + ctx := mCtxArgs.Context + if ctx == nil { + ctx = context.Background() + } + computeReq.ComputeSizes, err = data.FilterNoLocalStorageSizes( + ctx, computeReq.ComputeSizes) + if err != nil { + return err + } + if len(computeReq.ComputeSizes) == 0 { + return fmt.Errorf("no valid compute sizes: all provided sizes have NVMe-only local storage, incompatible with RHEL AI") + } + } sharedImageID := imageId(args.Accelerator, args.Version) if args.CustomImage != "" { sharedImageID = imageIdFromName(args.CustomImage) @@ -43,7 +59,7 @@ func Create(mCtxArgs *maptContext.ContextArgs, args *apiRHELAI.RHELAIArgs) (err azureLinuxRequest := &azureLinux.LinuxArgs{ Prefix: args.Prefix, - ComputeRequest: args.ComputeRequest, + ComputeRequest: &computeReq, Spot: args.Spot, ImageRef: &data.ImageReference{ SharedImageID: sharedImageID, diff --git a/pkg/provider/azure/data/compute-request.go b/pkg/provider/azure/data/compute-request.go index 0b50e61a9..760a12007 100644 --- a/pkg/provider/azure/data/compute-request.go +++ b/pkg/provider/azure/data/compute-request.go @@ -2,7 +2,6 @@ package data import ( "context" - "os" "regexp" "slices" "strconv" @@ -12,6 +11,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v7" cr "github.com/redhat-developer/mapt/pkg/provider/api/compute-request" + "github.com/redhat-developer/mapt/pkg/util/logging" ) const ( @@ -66,12 +66,77 @@ func FilterComputeSizesByLocation(ctx context.Context, location *string, compute return supportedSizes, nil } +// FilterNoLocalStorageSizes returns only the sizes from computeSizes that have no +// NVMe-only local storage (L-series). Temp disks (MaxResourceVolumeMB > 0) are allowed +// — they are ephemeral scratch space that does not interfere with RHEL AI's OS disk. +// Sizes not found in the Azure SKU catalog (typo or restricted SKU) are logged as +// warnings and excluded. +func FilterNoLocalStorageSizes(ctx context.Context, computeSizes []string) ([]string, error) { + creds, subscriptionID, err := getCredentials() + if err != nil { + return nil, err + } + client, err := armcompute.NewResourceSKUsClient(*subscriptionID, creds, nil) + if err != nil { + return nil, err + } + pager := client.NewListPager(nil) + capabilities := make(map[string]*virtualMachine, len(computeSizes)) + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, err + } + for _, sku := range page.Value { + if sku.ResourceType == nil || *sku.ResourceType != string(RTVirtualMachines) { + continue + } + if sku.Name == nil || !slices.Contains(computeSizes, *sku.Name) { + continue + } + if _, seen := capabilities[*sku.Name]; seen { + continue + } + if vm := resourceSKUToVirtualMachine(sku); vm != nil { + capabilities[*sku.Name] = vm + } + } + } + valid, dropped, unknown := filterNVMeStorage(computeSizes, capabilities) + for _, size := range dropped { + logging.Warnf("dropping compute size %q: has NVMe-only local storage, incompatible with RHEL AI", size) + } + for _, size := range unknown { + logging.Warnf("dropping compute size %q: not found in Azure SKU catalog (typo or restricted SKU)", size) + } + return valid, nil +} + +// filterNVMeStorage classifies each size into valid (no NVMe-only local storage), +// dropped (has NVMe local storage — e.g. L-series), or unknown (absent from capabilities). +func filterNVMeStorage(computeSizes []string, capabilities map[string]*virtualMachine) (valid, dropped, unknown []string) { + for _, size := range computeSizes { + vm, ok := capabilities[size] + if !ok { + unknown = append(unknown, size) + continue + } + if vm.NvmeDiskSizeInMiB > 0 { + dropped = append(dropped, size) + } else { + valid = append(valid, size) + } + } + return valid, dropped, unknown +} + func getAzureVMSKUs(ctx context.Context, args *cr.ComputeRequestArgs) ([]string, error) { + ensureAzureEnvs() cred, err := azidentity.NewDefaultAzureCredential(nil) if err != nil { return nil, err } - subscriptionId := os.Getenv("AZURE_SUBSCRIPTION_ID") + subscriptionId := SubscriptionID() clientFactory, err := armcompute.NewClientFactory( subscriptionId, cred, nil) if err != nil { @@ -109,6 +174,10 @@ type virtualMachine struct { // Spot capable LowPriorityCapable bool MaxResourceVolumeMB int32 + // L-series VMs expose NVMe storage separately from the temp disk + NvmeDiskSizeInMiB int32 + // Used by the disk-controller-type fix (PR #823) to cross-reference SKU capabilities + DiskControllerTypes []string // IaaS or PaaS VMDeploymentTypes []string // Fast SSD @@ -144,17 +213,17 @@ func (vm *virtualMachine) hypervGen2Supported() bool { return slices.Contains(vm.HyperVGenerations, "V2") } -func (vm *virtualMachine) emptyDiskSupported() bool { - return vm.MaxResourceVolumeMB == 0 +func (vm *virtualMachine) noLocalStorageAttached() bool { + return vm.MaxResourceVolumeMB == 0 && vm.NvmeDiskSizeInMiB == 0 } func (vm *virtualMachine) baseFeaturesSupported() bool { return vm.AcceleratedNetworkingEnabled && vm.PremiumIO && vm.EncryptionAtHostSupported && - vm.emptyDiskSupported() && vm.hypervGen2Supported() + vm.noLocalStorageAttached() && vm.hypervGen2Supported() } func resourceSKUToVirtualMachine(res *armcompute.ResourceSKU) *virtualMachine { - if res.ResourceType != nil && *res.ResourceType != "virtualMachines" { + if res.ResourceType != nil && *res.ResourceType != string(RTVirtualMachines) { return nil } // If Machine type has any type of restriccions discard @@ -219,6 +288,14 @@ func resourceSKUToVirtualMachine(res *armcompute.ResourceSKU) *virtualMachine { return nil } vm.MaxResourceVolumeMB = int32(disk) + case "NvmeDiskSizeInMiB": + nvme, err := strconv.ParseUint(*capability.Value, 10, 32) + if err != nil { + return nil + } + vm.NvmeDiskSizeInMiB = int32(nvme) + case "DiskControllerTypes": + vm.DiskControllerTypes = strings.Split(*capability.Value, ",") case "VMDeploymentTypes": vm.VMDeploymentTypes = strings.Split(*capability.Value, ",") default: diff --git a/pkg/provider/azure/data/compute-request_test.go b/pkg/provider/azure/data/compute-request_test.go new file mode 100644 index 000000000..86f6db043 --- /dev/null +++ b/pkg/provider/azure/data/compute-request_test.go @@ -0,0 +1,183 @@ +package data + +import ( + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v7" +) + +func ptr[T any](v T) *T { return &v } + +// noLocalStorageAttached tests + +func TestNoLocalStorageAttached_NoTempDiskNoNvme(t *testing.T) { + vm := &virtualMachine{MaxResourceVolumeMB: 0, NvmeDiskSizeInMiB: 0} + if !vm.noLocalStorageAttached() { + t.Error("expected true: VM with no temp disk and no NVMe should have no local storage") + } +} + +func TestNoLocalStorageAttached_HasTempDisk(t *testing.T) { + vm := &virtualMachine{MaxResourceVolumeMB: 512, NvmeDiskSizeInMiB: 0} + if vm.noLocalStorageAttached() { + t.Error("expected false: VM with temp disk should have local storage") + } +} + +func TestNoLocalStorageAttached_HasNvmeDisk(t *testing.T) { + // L-series bug case: MaxResourceVolumeMB=0 but NvmeDiskSizeInMiB>0 + vm := &virtualMachine{MaxResourceVolumeMB: 0, NvmeDiskSizeInMiB: 5492736} + if vm.noLocalStorageAttached() { + t.Error("expected false: L-series VM with NVMe storage should have local storage") + } +} + +func TestNoLocalStorageAttached_HasBoth(t *testing.T) { + vm := &virtualMachine{MaxResourceVolumeMB: 512, NvmeDiskSizeInMiB: 5492736} + if vm.noLocalStorageAttached() { + t.Error("expected false: VM with both temp disk and NVMe should have local storage") + } +} + +// resourceSKUToVirtualMachine parsing tests + +func TestResourceSKUToVirtualMachine_ParsesNvmeDiskSizeInMiB(t *testing.T) { + sku := &armcompute.ResourceSKU{ + ResourceType: ptr("virtualMachines"), + Name: ptr("Standard_L8aos_v4"), + Family: ptr("standardLasv4Family"), + Capabilities: []*armcompute.ResourceSKUCapabilities{ + {Name: ptr("NvmeDiskSizeInMiB"), Value: ptr("5492736")}, + }, + } + vm := resourceSKUToVirtualMachine(sku) + if vm == nil { + t.Fatal("expected non-nil virtualMachine") + } + if vm.NvmeDiskSizeInMiB != 5492736 { + t.Errorf("NvmeDiskSizeInMiB: got %d, want 5492736", vm.NvmeDiskSizeInMiB) + } +} + +func TestResourceSKUToVirtualMachine_ParsesDiskControllerTypes(t *testing.T) { + sku := &armcompute.ResourceSKU{ + ResourceType: ptr("virtualMachines"), + Name: ptr("Standard_L8aos_v4"), + Family: ptr("standardLasv4Family"), + Capabilities: []*armcompute.ResourceSKUCapabilities{ + {Name: ptr("DiskControllerTypes"), Value: ptr("NVMe,SCSI")}, + }, + } + vm := resourceSKUToVirtualMachine(sku) + if vm == nil { + t.Fatal("expected non-nil virtualMachine") + } + if len(vm.DiskControllerTypes) != 2 { + t.Fatalf("DiskControllerTypes: got %v, want [NVMe SCSI]", vm.DiskControllerTypes) + } + if vm.DiskControllerTypes[0] != "NVMe" || vm.DiskControllerTypes[1] != "SCSI" { + t.Errorf("DiskControllerTypes: got %v, want [NVMe SCSI]", vm.DiskControllerTypes) + } +} + +func TestResourceSKUToVirtualMachine_NvmeDiskSizeDefaultsToZero(t *testing.T) { + sku := &armcompute.ResourceSKU{ + ResourceType: ptr("virtualMachines"), + Name: ptr("Standard_D8as_v5"), + Family: ptr("standardDasv5Family"), + Capabilities: []*armcompute.ResourceSKUCapabilities{ + {Name: ptr("MaxResourceVolumeMB"), Value: ptr("307200")}, + }, + } + vm := resourceSKUToVirtualMachine(sku) + if vm == nil { + t.Fatal("expected non-nil virtualMachine") + } + if vm.NvmeDiskSizeInMiB != 0 { + t.Errorf("NvmeDiskSizeInMiB: got %d, want 0 for non-NVMe SKU", vm.NvmeDiskSizeInMiB) + } +} + +// filterNVMeStorage tests + +func TestFilterNVMeStorage_DropsNvmeSizes(t *testing.T) { + capabilities := map[string]*virtualMachine{ + "Standard_D8as_v5": {MaxResourceVolumeMB: 0, NvmeDiskSizeInMiB: 0}, + "Standard_L8aos_v4": {MaxResourceVolumeMB: 0, NvmeDiskSizeInMiB: 5492736}, + } + got, dropped, unknown := filterNVMeStorage([]string{"Standard_D8as_v5", "Standard_L8aos_v4"}, capabilities) + if len(got) != 1 || got[0] != "Standard_D8as_v5" { + t.Errorf("filtered: got %v, want [Standard_D8as_v5]", got) + } + if len(dropped) != 1 || dropped[0] != "Standard_L8aos_v4" { + t.Errorf("dropped: got %v, want [Standard_L8aos_v4]", dropped) + } + if len(unknown) != 0 { + t.Errorf("unknown: got %v, want []", unknown) + } +} + +func TestFilterNVMeStorage_AllowsTempDiskSizes(t *testing.T) { + capabilities := map[string]*virtualMachine{ + "Standard_NC64as_T4_v3": {MaxResourceVolumeMB: 32768, NvmeDiskSizeInMiB: 0}, + } + got, dropped, unknown := filterNVMeStorage([]string{"Standard_NC64as_T4_v3"}, capabilities) + if len(got) != 1 { + t.Errorf("filtered: got %v, want [Standard_NC64as_T4_v3]", got) + } + if len(dropped) != 0 { + t.Errorf("dropped: got %v, want []", dropped) + } + if len(unknown) != 0 { + t.Errorf("unknown: got %v, want []", unknown) + } +} + +func TestFilterNVMeStorage_PassesCleanSizes(t *testing.T) { + capabilities := map[string]*virtualMachine{ + "Standard_D8as_v5": {MaxResourceVolumeMB: 0, NvmeDiskSizeInMiB: 0}, + } + got, dropped, unknown := filterNVMeStorage([]string{"Standard_D8as_v5"}, capabilities) + if len(got) != 1 { + t.Errorf("filtered: got %v, want [Standard_D8as_v5]", got) + } + if len(dropped) != 0 { + t.Errorf("dropped: got %v, want []", dropped) + } + if len(unknown) != 0 { + t.Errorf("unknown: got %v, want []", unknown) + } +} + +func TestFilterNVMeStorage_ReportsUnknownSizes(t *testing.T) { + capabilities := map[string]*virtualMachine{ + "Standard_D8as_v5": {MaxResourceVolumeMB: 0, NvmeDiskSizeInMiB: 0}, + } + got, dropped, unknown := filterNVMeStorage( + []string{"Standard_D8as_v5", "Standard_Typo_v99"}, capabilities) + if len(got) != 1 || got[0] != "Standard_D8as_v5" { + t.Errorf("filtered: got %v, want [Standard_D8as_v5]", got) + } + if len(dropped) != 0 { + t.Errorf("dropped: got %v, want []", dropped) + } + if len(unknown) != 1 || unknown[0] != "Standard_Typo_v99" { + t.Errorf("unknown: got %v, want [Standard_Typo_v99]", unknown) + } +} + +// baseFeaturesSupported regression: L-series must be rejected + +func TestBaseFeaturesSupported_LSeriesWithNvmeIsRejected(t *testing.T) { + vm := &virtualMachine{ + MaxResourceVolumeMB: 0, + NvmeDiskSizeInMiB: 5492736, + AcceleratedNetworkingEnabled: true, + PremiumIO: true, + EncryptionAtHostSupported: true, + HyperVGenerations: []string{"V1", "V2"}, + } + if vm.baseFeaturesSupported() { + t.Error("expected false: L-series VM with NVMe storage must not pass baseFeaturesSupported") + } +} diff --git a/pkg/provider/azure/data/images.go b/pkg/provider/azure/data/images.go index aced2ddac..690dbe7ef 100644 --- a/pkg/provider/azure/data/images.go +++ b/pkg/provider/azure/data/images.go @@ -3,7 +3,6 @@ package data import ( "context" "fmt" - "os" "strings" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" @@ -16,11 +15,12 @@ type ImageRequest struct { } func IsImageOffered(ctx context.Context, req ImageRequest) error { + ensureAzureEnvs() cred, err := azidentity.NewDefaultAzureCredential(nil) if err != nil { return err } - subscriptionId := os.Getenv("AZURE_SUBSCRIPTION_ID") + subscriptionId := SubscriptionID() clientFactory, err := armcompute.NewClientFactory(subscriptionId, cred, nil) if err != nil { return err @@ -53,11 +53,12 @@ func getCommunityImage(ctx context.Context, c *armcompute.ClientFactory, id, reg } func GetSharedImage(ctx context.Context, id *string) (*armcompute.GalleryImageVersionsClientGetResponse, error) { + ensureAzureEnvs() cred, err := azidentity.NewDefaultAzureCredential(nil) if err != nil { return nil, err } - subscriptionId := os.Getenv("AZURE_SUBSCRIPTION_ID") + subscriptionId := SubscriptionID() c, err := armcompute.NewClientFactory(subscriptionId, cred, nil) if err != nil { return nil, err @@ -86,11 +87,12 @@ func getSharedImage(ctx context.Context, c *armcompute.ClientFactory, id *string } func SkuG2Support(ctx context.Context, location string, publisher string, offer string, sku string) (string, error) { + ensureAzureEnvs() cred, err := azidentity.NewDefaultAzureCredential(nil) if err != nil { return "", err } - subscriptionId := os.Getenv("AZURE_SUBSCRIPTION_ID") + subscriptionId := SubscriptionID() clientFactory, err := armcompute.NewClientFactory(subscriptionId, cred, nil) if err != nil { diff --git a/pkg/provider/azure/data/util.go b/pkg/provider/azure/data/util.go index cbb9aa20a..a119099d6 100644 --- a/pkg/provider/azure/data/util.go +++ b/pkg/provider/azure/data/util.go @@ -2,26 +2,54 @@ package data import ( "os" + "strings" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resourcegraph/armresourcegraph" ) -const ( - ENV_AZURE_SUBSCRIPTION_ID = "AZURE_SUBSCRIPTION_ID" -) +var azureIdentityEnvs = []string{ + "AZURE_TENANT_ID", + "AZURE_SUBSCRIPTION_ID", + "AZURE_CLIENT_ID", + "AZURE_CLIENT_SECRET", +} + +// ensureAzureEnvs maps ARM_* env vars to AZURE_* if the AZURE_* vars are unset. +// Safe to call multiple times — only sets vars that are currently empty. +func ensureAzureEnvs() { + for _, e := range azureIdentityEnvs { + if os.Getenv(e) == "" { + armKey := strings.ReplaceAll(e, "AZURE", "ARM") + if v := os.Getenv(armKey); v != "" { + os.Setenv(e, v) + } + } + } +} + +// SubscriptionID returns the Azure subscription ID, checking AZURE_SUBSCRIPTION_ID +// first, then falling back to ARM_SUBSCRIPTION_ID (Pulumi/Terraform convention). +func SubscriptionID() string { + if v := os.Getenv("AZURE_SUBSCRIPTION_ID"); v != "" { + return v + } + return os.Getenv("ARM_SUBSCRIPTION_ID") +} func getCredentials() (cred *azidentity.DefaultAzureCredential, subscriptionID *string, err error) { + ensureAzureEnvs() cred, err = azidentity.NewDefaultAzureCredential(nil) if err != nil { return } - azSubsID := os.Getenv(ENV_AZURE_SUBSCRIPTION_ID) + azSubsID := SubscriptionID() subscriptionID = &azSubsID return } func getGraphClientFactory() (*armresourcegraph.ClientFactory, error) { + ensureAzureEnvs() cred, err := azidentity.NewDefaultAzureCredential(nil) if err != nil { return nil, err diff --git a/pkg/provider/azure/data/vmsize.go b/pkg/provider/azure/data/vmsize.go index 1a101f4fe..6e67ed253 100644 --- a/pkg/provider/azure/data/vmsize.go +++ b/pkg/provider/azure/data/vmsize.go @@ -2,7 +2,6 @@ package data import ( "context" - "os" "slices" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" @@ -17,12 +16,12 @@ func IsVMSizeOfferedByLocation(ctx context.Context, vmSize, location string) (bo // Get InstanceTypes offerings on current location func FilterVMSizeOfferedByLocation(ctx context.Context, vmSizes []string, location string) ([]string, error) { - // Create a new Azure credential (uses environment variables or managed identity) + ensureAzureEnvs() cred, err := azidentity.NewDefaultAzureCredential(nil) if err != nil { return nil, err } - subscriptionId := os.Getenv("AZURE_SUBSCRIPTION_ID") + subscriptionId := SubscriptionID() clientFactory, err := armcompute.NewClientFactory(subscriptionId, cred, nil) if err != nil { return nil, err diff --git a/pkg/provider/azure/modules/virtual-machine/virtual-machine.go b/pkg/provider/azure/modules/virtual-machine/virtual-machine.go index c700b1603..3441760fe 100644 --- a/pkg/provider/azure/modules/virtual-machine/virtual-machine.go +++ b/pkg/provider/azure/modules/virtual-machine/virtual-machine.go @@ -2,7 +2,6 @@ package virtualmachine import ( "fmt" - "os" "strings" "github.com/pulumi/pulumi-azure-native-sdk/compute/v3" @@ -161,5 +160,5 @@ func convertImageRef(mCtx *mc.Context, i data.ImageReference, location string) ( func isSelfOwned(sharedImageId *string) bool { sharedImageParams := strings.Split(*sharedImageId, "/") - return os.Getenv("AZURE_SUBSCRIPTION_ID") == sharedImageParams[2] + return data.SubscriptionID() == sharedImageParams[2] }