@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
1414limitations under the License.
1515*/
1616
17- package controllers
17+ package nvidiadriver
1818
1919import (
2020 "context"
@@ -29,25 +29,18 @@ import (
2929 "github.com/NVIDIA/gpu-operator/internal/consts"
3030)
3131
32- // isDefaultNVIDIADriver returns true when the NVIDIADriver is marked as the fallback driver.
33- func isDefaultNVIDIADriver (driver * nvidiav1alpha1.NVIDIADriver ) bool {
32+ // IsDefault returns true when the NVIDIADriver is marked as the fallback driver.
33+ func IsDefault (driver * nvidiav1alpha1.NVIDIADriver ) bool {
3434 return driver != nil && driver .Spec .Default
3535}
3636
37- // nvidiaDriverCRDEnabled returns true when ClusterPolicy driver management is enabled through NVIDIADriver CRs.
38- func nvidiaDriverCRDEnabled (clusterPolicy * gpuv1.ClusterPolicy ) bool {
39- return clusterPolicy != nil &&
40- clusterPolicy .Spec .Driver .IsEnabled () &&
41- clusterPolicy .Spec .Driver .UseNvidiaDriverCRDType ()
42- }
43-
44- // validateNVIDIADriverNodeSelector rejects selectors that use operator-managed routing labels
37+ // ValidateNodeSelector rejects selectors that use operator-managed routing labels
4538// or scope the default fallback driver.
46- func validateNVIDIADriverNodeSelector (driver * nvidiav1alpha1.NVIDIADriver ) error {
39+ func ValidateNodeSelector (driver * nvidiav1alpha1.NVIDIADriver ) error {
4740 if driver == nil || driver .Spec .NodeSelector == nil {
4841 return nil
4942 }
50- if isDefaultNVIDIADriver (driver ) && len (driver .Spec .NodeSelector ) > 0 {
43+ if IsDefault (driver ) && len (driver .Spec .NodeSelector ) > 0 {
5144 return fmt .Errorf ("default NVIDIADriver %q cannot use nodeSelector" , driver .Name )
5245 }
5346 if _ , ok := driver .Spec .NodeSelector [consts .NVIDIADriverOwnerLabel ]; ok {
@@ -56,59 +49,77 @@ func validateNVIDIADriverNodeSelector(driver *nvidiav1alpha1.NVIDIADriver) error
5649 return nil
5750}
5851
59- // assignNVIDIADriverOwners labels GPU nodes with the NVIDIADriver that should manage their driver pods.
52+ // CRDEnabled returns true when ClusterPolicy driver management is enabled through NVIDIADriver CRs.
53+ func CRDEnabled (clusterPolicy * gpuv1.ClusterPolicy ) bool {
54+ return clusterPolicy != nil &&
55+ clusterPolicy .Spec .Driver .IsEnabled () &&
56+ clusterPolicy .Spec .Driver .UseNvidiaDriverCRDType ()
57+ }
58+
59+ // NodeMatchesSelector returns true when all selector labels are present on the node.
60+ func NodeMatchesSelector (node * corev1.Node , selector map [string ]string ) bool {
61+ for key , value := range selector {
62+ if node .Labels [key ] != value {
63+ return false
64+ }
65+ }
66+ return true
67+ }
68+
69+ // AssignOwners labels GPU nodes with the NVIDIADriver that should manage their driver pods.
6070// Non-default NVIDIADrivers take precedence over the default fallback, and conflicts fail closed before
61- // node owner labels are changed.
62- func assignNVIDIADriverOwners (ctx context.Context , c client.Client ) error {
71+ // node owner labels are changed. It returns true when any node label was changed.
72+ func AssignOwners (ctx context.Context , c client.Client ) ( bool , error ) {
6373 drivers := & nvidiav1alpha1.NVIDIADriverList {}
6474 if err := c .List (ctx , drivers ); err != nil {
65- return fmt .Errorf ("failed to list NVIDIADriver CRs: %w" , err )
75+ return false , fmt .Errorf ("failed to list NVIDIADriver CRs: %w" , err )
6676 }
6777
6878 var defaultDriver * nvidiav1alpha1.NVIDIADriver
6979 defaultDriverNames := []string {}
7080 specificDrivers := make ([]nvidiav1alpha1.NVIDIADriver , 0 , len (drivers .Items ))
7181 for i := range drivers .Items {
72- if err := validateNVIDIADriverNodeSelector (& drivers .Items [i ]); err != nil {
73- return err
82+ if err := ValidateNodeSelector (& drivers .Items [i ]); err != nil {
83+ return false , err
7484 }
75- if isDefaultNVIDIADriver (& drivers .Items [i ]) {
85+ if IsDefault (& drivers .Items [i ]) {
7686 defaultDriverNames = append (defaultDriverNames , drivers .Items [i ].Name )
7787 defaultDriver = & drivers .Items [i ]
7888 continue
7989 }
8090 specificDrivers = append (specificDrivers , drivers .Items [i ])
8191 }
8292 if len (defaultDriverNames ) > 1 {
83- return fmt .Errorf ("multiple default NVIDIADrivers found: %s" , strings .Join (defaultDriverNames , ", " ))
93+ return false , fmt .Errorf ("multiple default NVIDIADrivers found: %s" , strings .Join (defaultDriverNames , ", " ))
8494 }
8595 nodes := & corev1.NodeList {}
8696 if err := c .List (ctx , nodes , client.MatchingLabels {consts .GPUPresentLabel : "true" }); err != nil {
87- return fmt .Errorf ("failed to list GPU nodes: %w" , err )
97+ return false , fmt .Errorf ("failed to list GPU nodes: %w" , err )
8898 }
8999
90100 for i := range nodes .Items {
91101 matchingDrivers := []string {}
92102 for _ , driver := range specificDrivers {
93- if nodeMatchesSelector (& nodes .Items [i ], driver .GetNodeSelector ()) {
103+ if NodeMatchesSelector (& nodes .Items [i ], driver .GetNodeSelector ()) {
94104 matchingDrivers = append (matchingDrivers , driver .Name )
95105 }
96106 }
97107 if len (matchingDrivers ) > 1 {
98- return fmt .Errorf ("conflicting NVIDIADriver NodeSelectors found for node %s: %s " , nodes .Items [i ].Name , strings . Join ( matchingDrivers , ", " ) )
108+ return false , fmt .Errorf ("multiple NVIDIADrivers match the same node %s: %v " , nodes .Items [i ].Name , matchingDrivers )
99109 }
100110 }
101111
112+ changed := false
102113 for i := range nodes .Items {
103114 node := & nodes .Items [i ]
104115 originalNode := node .DeepCopy ()
105116 owner := ""
106117 for _ , driver := range specificDrivers {
107- if nodeMatchesSelector (node , driver .GetNodeSelector ()) {
118+ if NodeMatchesSelector (node , driver .GetNodeSelector ()) {
108119 owner = driver .Name
109120 }
110121 }
111- if owner == "" && defaultDriver != nil && nodeMatchesSelector (node , defaultDriver .GetNodeSelector ()) {
122+ if owner == "" && defaultDriver != nil && NodeMatchesSelector (node , defaultDriver .GetNodeSelector ()) {
112123 owner = defaultDriver .Name
113124 }
114125 if owner == "" {
@@ -120,8 +131,9 @@ func assignNVIDIADriverOwners(ctx context.Context, c client.Client) error {
120131 }
121132 delete (node .Labels , consts .NVIDIADriverOwnerLabel )
122133 if err := c .Patch (ctx , node , client .MergeFrom (originalNode )); err != nil {
123- return fmt .Errorf ("failed to remove NVIDIADriver owner label for node %s: %w" , node .Name , err )
134+ return false , fmt .Errorf ("failed to remove NVIDIADriver owner label for node %s: %w" , node .Name , err )
124135 }
136+ changed = true
125137 continue
126138 }
127139 if node .Labels != nil && node .Labels [consts .NVIDIADriverOwnerLabel ] == owner {
@@ -132,9 +144,10 @@ func assignNVIDIADriverOwners(ctx context.Context, c client.Client) error {
132144 }
133145 node .Labels [consts .NVIDIADriverOwnerLabel ] = owner
134146 if err := c .Patch (ctx , node , client .MergeFrom (originalNode )); err != nil {
135- return fmt .Errorf ("failed to update NVIDIADriver owner label for node %s: %w" , node .Name , err )
147+ return false , fmt .Errorf ("failed to update NVIDIADriver owner label for node %s: %w" , node .Name , err )
136148 }
149+ changed = true
137150 }
138151
139- return nil
152+ return changed , nil
140153}
0 commit comments