Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pkg/csi/blockstorage/controllerserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,10 @@ func (cs *controllerServer) ControllerPublishVolume(ctx context.Context, req *cs

_, err = cloud.AttachVolume(ctx, instanceID, volumeID)
if err != nil {
// Trigger's an immediate `NodeGetInfo` RPC call when MutableCSINodeAllocatableCount is enabled
if stackiterrors.IsTooManyDevicesError(err) {
return nil, status.Errorf(codes.ResourceExhausted, "[ControllerPublishVolume] Node can't accept any more volumes %v. All PCIe lanes are exhausted!", err)
}
klog.Errorf("Failed to AttachVolume: %v", err)
return nil, status.Errorf(codes.Internal, "[ControllerPublishVolume] Attach Volume failed with error %v", err)
}
Expand Down
28 changes: 17 additions & 11 deletions pkg/csi/blockstorage/nodeserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,19 +302,9 @@ func (ns *nodeServer) NodeGetInfo(ctx context.Context, _ *csi.NodeGetInfoRequest
return nil, status.Errorf(codes.Internal, "[NodeGetInfo] unable to retrieve instance id of node %v", err)
}

flavor, err := ns.Metadata.GetFlavor(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "[NodeGetInfo] unable to retrieve flavor of node %v", err)
}

maxVolumesPerNode := DetermineMaxVolumesByFlavor(flavor)
// Subtract 1 for root disk and another for configDrive/spare
maxVolumesPerNode -= 2
klog.V(4).Infof("Determined node to support %d volumes", maxVolumesPerNode)

nodeInfo := &csi.NodeGetInfoResponse{
NodeId: nodeID,
MaxVolumesPerNode: maxVolumesPerNode,
MaxVolumesPerNode: ns.calculateMaxVolumesPerNode(),
}

zone, err := ns.Metadata.GetAvailabilityZone(ctx)
Expand All @@ -332,6 +322,22 @@ func (ns *nodeServer) NodeGetInfo(ctx context.Context, _ *csi.NodeGetInfoRequest
return nodeInfo, nil
}

func (ns *nodeServer) calculateMaxVolumesPerNode() int64 {
freePCIeRootPorts, err := mount.CountFreePCIeSlots()
if err != nil {
klog.Errorf("[NodeGetInfo] unable to retrieve PCIe root ports: %v", err)
freePCIeRootPorts = 0
}

mountedCSIVolumes, err := mount.CountLocalCSIVolumes(driverName)
if err != nil {
klog.Errorf("[NodeGetInfo] unable to retrieve volume count: %v", err)
mountedCSIVolumes = 0
}

return freePCIeRootPorts + mountedCSIVolumes
}

func (ns *nodeServer) NodeGetCapabilities(_ context.Context, req *csi.NodeGetCapabilitiesRequest) (*csi.NodeGetCapabilitiesResponse, error) {
klog.V(5).Infof("NodeGetCapabilities called with req: %#v", req)

Expand Down
17 changes: 0 additions & 17 deletions pkg/csi/blockstorage/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,23 +84,6 @@ func ParseEndpoint(ep string) (proto, addr string, err error) {
return "", "", fmt.Errorf("invalid endpoint: %v", ep)
}

func DetermineMaxVolumesByFlavor(flavor string) int64 {
flavorParts := strings.Split(flavor, ".")

// The following numbers were specified by the IaaS team. They are based on actual tests.
switch {
case strings.HasPrefix(flavor, "n"):
// Flavors starting with 'n' are nvidia GPU flavors, all GPU VM's can only mount 10 volumes
return 10
case strings.HasSuffix(flavorParts[0], "2a"):
// AMD 2nd Gen
return 159
default:
// All other flavors can mount 28 volumes
return 25
}
}

func logGRPC(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
callID := serverGRPCEndpointCallCounter.Add(1)

Expand Down
25 changes: 0 additions & 25 deletions pkg/csi/blockstorage/utils_test.go

This file was deleted.

9 changes: 9 additions & 0 deletions pkg/csi/util/mount/mount_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,12 @@ func newDeviceStats(statfs *unix.Statfs_t) *DeviceStats {
UsedInodes: int64(statfs.Files) - int64(statfs.Ffree),
}
}

func CountLocalCSIVolumes(_ string) (int64, error) {
// not implemented
return 0, nil
}

func CountFreePCIeSlots() (int64, error) {
return 0, nil
}
60 changes: 60 additions & 0 deletions pkg/csi/util/mount/mount_helper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package mount

import (
"fmt"
"os"
"path/filepath"
"strings"

"k8s.io/klog/v2"
)

const (
// pciClassBridgePCI matches the Linux PCI-to-PCI bridge class prefix.
pciClassBridgePCI = "0x0604"
globalMountDir = "globalmount"
)

func countFreePCIeSlotsAt(devicesPath string) (int64, error) {
devices, err := os.ReadDir(devicesPath)
if err != nil {
return 0, fmt.Errorf("failed to read PCI bus: %w", err)
}

var freePCIeSlots int64

for _, dev := range devices {
devPath := filepath.Join(devicesPath, dev.Name())

classBuf, err := os.ReadFile(filepath.Join(devPath, "class"))
if err != nil {
klog.Errorf("failed to read PCI device class %s: %v", devPath, err)
continue
}

class := strings.TrimSpace(string(classBuf))
if !strings.HasPrefix(class, pciClassBridgePCI) {
continue
}

children, err := filepath.Glob(filepath.Join(devPath, "????:??:??.?"))
if err != nil {
return 0, fmt.Errorf("failed to glob PCI children for %s: %w", devPath, err)
}

if len(children) == 0 {
freePCIeSlots++
}
}

return freePCIeSlots, nil
}

func countLocalCSIVolumesAt(driverPluginDir string) (int64, error) {
volumeMounts, err := filepath.Glob(filepath.Join(driverPluginDir, "*", globalMountDir))
if err != nil {
return 0, fmt.Errorf("failed to glob CSI volume mounts in %s: %w", driverPluginDir, err)
}

return int64(len(volumeMounts)), nil
}
162 changes: 162 additions & 0 deletions pkg/csi/util/mount/mount_helper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
package mount

import (
"os"
"path/filepath"
"testing"
)

func TestCountFreePCIeSlotsAtMissingRoot(t *testing.T) {
t.Parallel()

_, err := countFreePCIeSlotsAt(filepath.Join(t.TempDir(), "missing"))
if err == nil {
t.Fatal("countFreePCIeSlotsAt() error = nil, want error")
}
}

func TestCountFreePCIeSlotsAtCountsOnlyFreeBridgeSlots(t *testing.T) {
t.Parallel()

devicesPath := t.TempDir()

createPCIDevice(t, devicesPath, "0000:00:00.0", "0x060400")
createPCIDevice(t, devicesPath, "0000:00:01.0", "0x060400", "0000:01:00.0")
createPCIDevice(t, devicesPath, "0000:00:02.0", "0x010000", "0000:02:00.0")

count, err := countFreePCIeSlotsAt(devicesPath)
if err != nil {
t.Fatalf("countFreePCIeSlotsAt() error = %v", err)
}

if count != 1 {
t.Fatalf("countFreePCIeSlotsAt() = %d, want 1", count)
}
}

func TestCountFreePCIeSlotsAtSkipsDevicesWithoutClass(t *testing.T) {
t.Parallel()

devicesPath := t.TempDir()

createPCIDevice(t, devicesPath, "0000:00:00.0", "0x060400")
mustMkdirAll(t, filepath.Join(devicesPath, "0000:00:01.0"))

count, err := countFreePCIeSlotsAt(devicesPath)
if err != nil {
t.Fatalf("countFreePCIeSlotsAt() error = %v", err)
}

if count != 1 {
t.Fatalf("countFreePCIeSlotsAt() = %d, want 1", count)
}
}

func TestCountFreePCIeSlotsAtIgnoresNonPCIChildren(t *testing.T) {
t.Parallel()

devicesPath := t.TempDir()
devPath := filepath.Join(devicesPath, "0000:00:00.0")
mustMkdirAll(t, devPath)
mustWriteFile(t, filepath.Join(devPath, "class"), "0x060400")
mustMkdirAll(t, filepath.Join(devPath, "driver"))
mustMkdirAll(t, filepath.Join(devPath, "not-a-pci-child"))

count, err := countFreePCIeSlotsAt(devicesPath)
if err != nil {
t.Fatalf("countFreePCIeSlotsAt() error = %v", err)
}

if count != 1 {
t.Fatalf("countFreePCIeSlotsAt() = %d, want 1", count)
}
}

func TestCountLocalCSIVolumesAtMissingDir(t *testing.T) {
t.Parallel()

count, err := countLocalCSIVolumesAt(filepath.Join(t.TempDir(), "missing"))
if err != nil {
t.Fatalf("countLocalCSIVolumesAt() error = %v", err)
}

if count != 0 {
t.Fatalf("countLocalCSIVolumesAt() = %d, want 0", count)
}
}

func TestCountLocalCSIVolumesAtCountsOnlyGlobalMountDirs(t *testing.T) {
t.Parallel()

driverPluginDir := t.TempDir()

mustMkdirAll(t, filepath.Join(driverPluginDir, "volume-a", globalMountDir))
mustMkdirAll(t, filepath.Join(driverPluginDir, "volume-b", globalMountDir))
mustMkdirAll(t, filepath.Join(driverPluginDir, "volume-c", "not-a-globalmount"))

count, err := countLocalCSIVolumesAt(driverPluginDir)
if err != nil {
t.Fatalf("countLocalCSIVolumesAt() error = %v", err)
}

if count != 2 {
t.Fatalf("countLocalCSIVolumesAt() = %d, want 2", count)
}
}

func TestCountLocalCSIVolumesAtEmptyDir(t *testing.T) {
t.Parallel()

count, err := countLocalCSIVolumesAt(t.TempDir())
if err != nil {
t.Fatalf("countLocalCSIVolumesAt() error = %v", err)
}

if count != 0 {
t.Fatalf("countLocalCSIVolumesAt() = %d, want 0", count)
}
}

func TestCountLocalCSIVolumesAtReturnsZeroWhenDriverPathIsFile(t *testing.T) {
t.Parallel()

driverPluginDir := filepath.Join(t.TempDir(), "driver")
mustWriteFile(t, driverPluginDir, "not a directory")

count, err := countLocalCSIVolumesAt(driverPluginDir)
if err != nil {
t.Fatalf("countLocalCSIVolumesAt() error = %v", err)
}

if count != 0 {
t.Fatalf("countLocalCSIVolumesAt() = %d, want 0", count)
}
}

func createPCIDevice(t *testing.T, rootPath, deviceName, class string, children ...string) {
t.Helper()

devPath := filepath.Join(rootPath, deviceName)
mustMkdirAll(t, devPath)
mustWriteFile(t, filepath.Join(devPath, "class"), class)

for _, child := range children {
mustMkdirAll(t, filepath.Join(devPath, child))
}
}

func mustMkdirAll(t *testing.T, path string) {
t.Helper()

if err := os.MkdirAll(path, 0o755); err != nil {
t.Fatalf("MkdirAll(%q) error = %v", path, err)
}
}

func mustWriteFile(t *testing.T, path string, content string) {
t.Helper()

if err := os.WriteFile(path, []byte(content), 0o644); err != nil {
t.Fatalf("WriteFile(%q) error = %v", path, err)
}
}
22 changes: 21 additions & 1 deletion pkg/csi/util/mount/mount_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@

package mount

import "golang.org/x/sys/unix"
import (
"path/filepath"

"golang.org/x/sys/unix"
)

const (
pciDevicesPath = "/sys/bus/pci/devices"
kubeletDir = "/var/lib/kubelet"
)

func newDeviceStats(statfs *unix.Statfs_t) *DeviceStats {
return &DeviceStats{
Expand All @@ -17,3 +26,14 @@ func newDeviceStats(statfs *unix.Statfs_t) *DeviceStats {
UsedInodes: int64(statfs.Files) - int64(statfs.Ffree),
}
}

// CountFreePCIeSlots returns the number of PCIe root ports that are not occupied.
func CountFreePCIeSlots() (int64, error) {
return countFreePCIeSlotsAt(pciDevicesPath)
}

// CountLocalCSIVolumes counts staged CSI volumes for the given driver.
func CountLocalCSIVolumes(driverName string) (int64, error) {
driverPluginDir := filepath.Join(kubeletDir, "plugins", "kubernetes.io", "csi", driverName)
return countLocalCSIVolumesAt(driverPluginDir)
}
Loading