Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/renovate.json
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,16 @@
"enabled": true,
"ignoreUnstable": false
},
{
"matchPackageNames": [
"aks/aks-gpu-grid-v20"
],
"groupName": "nvidia-gpu-grid-v20",
"versioning": "regex:^(?<major>\\d+)\\.(?<minor>\\d+)\\.(?<patch>\\d+)-(?<prerelease>\\d{14})$",
"automerge": false,
"enabled": true,
"ignoreUnstable": false
},
{
"matchPackageNames": [
"azuremonitor/containerinsights/ciprod/prometheus-collector/images"
Expand Down
7 changes: 7 additions & 0 deletions parts/common/components.json
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,13 @@
"renovateTag": "registry=https://mcr.microsoft.com, name=aks/aks-gpu-grid",
"latestVersion": "570.211.01-20260522192315"
}
},
{
"downloadURL": "mcr.microsoft.com/aks/aks-gpu-grid-v20:*",
"gpuVersion": {
"renovateTag": "registry=https://mcr.microsoft.com, name=aks/aks-gpu-grid-v20",
"latestVersion": "595.58.03-20260101000000"
}
}
],
"Packages": [
Expand Down
15 changes: 15 additions & 0 deletions pkg/agent/baker.go
Original file line number Diff line number Diff line change
Expand Up @@ -1439,6 +1439,9 @@ func getPortRangeEndValue(portRange string) int {
// NVv1 seems to run with CUDA, NVv5 requires GRID.
// NVv3 is untested on AKS, NVv4 is AMD so n/a, and NVv2 no longer seems to exist (?).
func GetGPUDriverVersion(size string) string {
if useGridV20Drivers(size) {
return datamodel.NvidiaGridV20DriverVersion
}
if useGridDrivers(size) {
return datamodel.NvidiaGridDriverVersion
}
Expand All @@ -1457,14 +1460,26 @@ func useGridDrivers(size string) bool {
return datamodel.ConvergedGPUDriverSizes[strings.ToLower(size)]
}

// useGridV20Drivers reports whether the SKU needs the GRID v20 (595.x) driver
// image (aks-gpu-grid-v20) rather than the standard GRID image (aks-gpu-grid).
func useGridV20Drivers(size string) bool {
return datamodel.RTXPro6000GPUDriverSizes[strings.ToLower(size)]
}

func GetAKSGPUImageSHA(size string) string {
if useGridV20Drivers(size) {
return datamodel.AKSGPUGridV20VersionSuffix
}
if useGridDrivers(size) {
return datamodel.AKSGPUGridVersionSuffix
}
return datamodel.AKSGPUCudaVersionSuffix
}

func GetGPUDriverType(size string) string {
if useGridV20Drivers(size) {
return "grid-v20"
}
if useGridDrivers(size) {
return "grid"
}
Expand Down
11 changes: 11 additions & 0 deletions pkg/agent/baker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,10 @@ var _ = Describe("GetGPUDriverVersion", func() {
Expect(GetGPUDriverVersion("standard_nv6ads_a10_v5")).To(Equal(datamodel.NvidiaGridDriverVersion))
Expect(GetGPUDriverVersion("Standard_nv36adms_A10_V5")).To(Equal(datamodel.NvidiaGridDriverVersion))
})
It("should use grid v20 with rtx pro 6000 bse v6", func() {
Expect(GetGPUDriverVersion("standard_nc128ds_xl_rtxpro6000bse_v6")).To(Equal(datamodel.NvidiaGridV20DriverVersion))
Expect(GetGPUDriverVersion("Standard_NC320ds_xl_RTXPRO6000BSE_v6")).To(Equal(datamodel.NvidiaGridV20DriverVersion))
})
// NV V1 SKUs were retired in September 2023, leaving this test just for safety
It("should use cuda with nv v1", func() {
Expect(GetGPUDriverVersion("standard_nv6")).To(Equal(datamodel.NvidiaCudaDriverVersion))
Expand All @@ -967,6 +971,10 @@ var _ = Describe("GetGPUDriverType", func() {
Expect(GetGPUDriverType("standard_nv6ads_a10_v5")).To(Equal("grid"))
Expect(GetGPUDriverType("Standard_nv36adms_A10_V5")).To(Equal("grid"))
})
It("should use grid-v20 with rtx pro 6000 bse v6", func() {
Expect(GetGPUDriverType("standard_nc128ds_xl_rtxpro6000bse_v6")).To(Equal("grid-v20"))
Expect(GetGPUDriverType("Standard_NC320ds_xl_RTXPRO6000BSE_v6")).To(Equal("grid-v20"))
})
// NV V1 SKUs were retired in September 2023, leaving this test just for safety
It("should use cuda with nv v1", func() {
Expect(GetGPUDriverType("standard_nv6")).To(Equal("cuda"))
Expand All @@ -977,6 +985,9 @@ var _ = Describe("GetAKSGPUImageSHA", func() {
It("should use newest AKSGPUGridVersionSuffix with nv v5", func() {
Expect(GetAKSGPUImageSHA("standard_nv6ads_a10_v5")).To(Equal(datamodel.AKSGPUGridVersionSuffix))
})
It("should use newest AKSGPUGridV20VersionSuffix with rtx pro 6000 bse v6", func() {
Expect(GetAKSGPUImageSHA("standard_nc128ds_xl_rtxpro6000bse_v6")).To(Equal(datamodel.AKSGPUGridV20VersionSuffix))
})
It("should use newest AKSGPUCudaVersionSuffix with non grid SKU", func() {
Expect(GetAKSGPUImageSHA("standard_nc6_v3")).To(Equal(datamodel.AKSGPUCudaVersionSuffix))
})
Expand Down
45 changes: 39 additions & 6 deletions pkg/agent/datamodel/gpu_components.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ const Nvidia470CudaDriverVersion = "cuda-470.82.01"

//nolint:gochecknoglobals
var (
NvidiaCudaDriverVersion string
NvidiaGridDriverVersion string
AKSGPUCudaVersionSuffix string
AKSGPUGridVersionSuffix string
NvidiaCudaDriverVersion string
NvidiaGridDriverVersion string
NvidiaGridV20DriverVersion string
AKSGPUCudaVersionSuffix string
AKSGPUGridVersionSuffix string
AKSGPUGridV20VersionSuffix string
)

type gpuVersion struct {
Expand Down Expand Up @@ -55,17 +57,37 @@ func LoadConfig() error {
}
version, suffix := parts[driverIndex], parts[suffixIndex]

if strings.Contains(image.DownloadURL, "aks-gpu-cuda") {
// Match on the exact repo name (final path segment, tag stripped) so that
// repos sharing a prefix (e.g. "aks-gpu-grid" vs "aks-gpu-grid-v20") are not
// confused by substring matching.
switch gpuImageRepo(image.DownloadURL) {
case "aks-gpu-cuda":
NvidiaCudaDriverVersion = version
AKSGPUCudaVersionSuffix = suffix
} else if strings.Contains(image.DownloadURL, "aks-gpu-grid") {
case "aks-gpu-grid":
NvidiaGridDriverVersion = version
AKSGPUGridVersionSuffix = suffix
case "aks-gpu-grid-v20":
NvidiaGridV20DriverVersion = version
AKSGPUGridV20VersionSuffix = suffix
}
}
return nil
}

// gpuImageRepo extracts the bare repo name from a download URL such as
// "mcr.microsoft.com/aks/aks-gpu-grid-v20:*" -> "aks-gpu-grid-v20".
func gpuImageRepo(downloadURL string) string {
repo := downloadURL
if idx := strings.LastIndex(repo, "/"); idx != -1 {
repo = repo[idx+1:]
}
if idx := strings.Index(repo, ":"); idx != -1 {
repo = repo[:idx]
}
return repo
}

//nolint:gochecknoinits
func init() {
if err := LoadConfig(); err != nil {
Expand Down Expand Up @@ -93,6 +115,17 @@ var ConvergedGPUDriverSizes = map[string]bool{
"standard_nc32ads_a10_v4": true,
}

/* RTXPro6000GPUDriverSizes : NC_RTXPRO6000BSE_v6 (RTX PRO 6000 Blackwell Server
Edition) SKUs require the GRID v20 (595.x) driver, published as the
aks-gpu-grid-v20 image. All other GRID SKUs continue to use aks-gpu-grid.
*/
//nolint:gochecknoglobals
var RTXPro6000GPUDriverSizes = map[string]bool{
"standard_nc128ds_xl_rtxpro6000bse_v6": true,
"standard_nc256ds_xl_rtxpro6000bse_v6": true,
"standard_nc320ds_xl_rtxpro6000bse_v6": true,
}

//nolint:gochecknoglobals
var FabricManagerGPUSizes = map[string]bool{
// A100
Expand Down
18 changes: 17 additions & 1 deletion pkg/agent/datamodel/gpu_components_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,20 @@ func TestLoadConfig(t *testing.T) {
t.Error("NvidiaGridDriverVersion is empty")
}

if NvidiaGridV20DriverVersion == "" {
t.Error("NvidiaGridV20DriverVersion is empty")
}

if AKSGPUCudaVersionSuffix == "" {
t.Error("NvidiaCudaDriverVersion is empty")
}

if AKSGPUGridVersionSuffix == "" {
t.Error(("AKSGPUGridVersionSuffix is empty"))
t.Error("AKSGPUGridVersionSuffix is empty")
}

if AKSGPUGridV20VersionSuffix == "" {
t.Error("AKSGPUGridV20VersionSuffix is empty")
}
Comment thread
ganeshkumarashok marked this conversation as resolved.

// Define regular expressions for expected formats
Expand All @@ -40,11 +48,19 @@ func TestLoadConfig(t *testing.T) {
t.Errorf("NvidiaGridDriverVersion '%s' does not match expected format", NvidiaGridDriverVersion)
}

if !versionPattern.MatchString(NvidiaGridV20DriverVersion) {
t.Errorf("NvidiaGridV20DriverVersion '%s' does not match expected format", NvidiaGridV20DriverVersion)
}

if !suffixPattern.MatchString(AKSGPUCudaVersionSuffix) {
t.Errorf("AKSGPUCudaVersionSuffix '%s' does not match expected format", AKSGPUCudaVersionSuffix)
}

if !suffixPattern.MatchString(AKSGPUGridVersionSuffix) {
t.Errorf("AKSGPUGridVersionSuffix '%s' does not match expected format", AKSGPUGridVersionSuffix)
}

if !suffixPattern.MatchString(AKSGPUGridV20VersionSuffix) {
t.Errorf("AKSGPUGridV20VersionSuffix '%s' does not match expected format", AKSGPUGridV20VersionSuffix)
}
}
Loading