diff --git a/config/webhook/manifests.yaml b/config/webhook/manifests.yaml index f534bd66b..8d6ea21fa 100644 --- a/config/webhook/manifests.yaml +++ b/config/webhook/manifests.yaml @@ -28,4 +28,6 @@ webhooks: - clustermanagers - licensemanagers - monitoringconsoles + - postgresclusters + - postgresclusterclasses sideEffects: None diff --git a/pkg/postgresql/cluster/adapter/webhook/postgres_webhook_integration_test.go b/pkg/postgresql/cluster/adapter/webhook/postgres_webhook_integration_test.go new file mode 100644 index 000000000..c8f2eaf3d --- /dev/null +++ b/pkg/postgresql/cluster/adapter/webhook/postgres_webhook_integration_test.go @@ -0,0 +1,573 @@ +/* +Copyright 2026. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package webhook_test + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + admissionv1 "k8s.io/api/admission/v1" + authenticationv1 "k8s.io/api/authentication/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + + enterpriseApi "github.com/splunk/splunk-operator/api/v4" + "github.com/splunk/splunk-operator/pkg/splunk/enterprise/validation" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func mustMarshal(t *testing.T, obj interface{}) []byte { + t.Helper() + data, err := json.Marshal(obj) + if err != nil { + t.Fatalf("failed to marshal object: %v", err) + } + return data +} + +func newPostgresClusterAdmissionReview(t *testing.T, uid string, op admissionv1.Operation, obj *enterpriseApi.PostgresCluster, oldObj *enterpriseApi.PostgresCluster) *admissionv1.AdmissionReview { + t.Helper() + ar := &admissionv1.AdmissionReview{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "admission.k8s.io/v1", + Kind: "AdmissionReview", + }, + Request: &admissionv1.AdmissionRequest{ + UID: types.UID(uid), + Kind: metav1.GroupVersionKind{ + Group: "enterprise.splunk.com", + Version: "v4", + Kind: "PostgresCluster", + }, + Resource: metav1.GroupVersionResource{ + Group: "enterprise.splunk.com", + Version: "v4", + Resource: "postgresclusters", + }, + Name: obj.Name, + Namespace: obj.Namespace, + Operation: op, + Object: runtime.RawExtension{ + Raw: mustMarshal(t, obj), + }, + UserInfo: authenticationv1.UserInfo{Username: "test-user"}, + }, + } + if oldObj != nil { + ar.Request.OldObject = runtime.RawExtension{ + Raw: mustMarshal(t, oldObj), + } + } + return ar +} + +func newPostgresClusterClassAdmissionReview(t *testing.T, uid string, op admissionv1.Operation, obj *enterpriseApi.PostgresClusterClass, oldObj *enterpriseApi.PostgresClusterClass) *admissionv1.AdmissionReview { + t.Helper() + ar := &admissionv1.AdmissionReview{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "admission.k8s.io/v1", + Kind: "AdmissionReview", + }, + Request: &admissionv1.AdmissionRequest{ + UID: types.UID(uid), + Kind: metav1.GroupVersionKind{ + Group: "enterprise.splunk.com", + Version: "v4", + Kind: "PostgresClusterClass", + }, + Resource: metav1.GroupVersionResource{ + Group: "enterprise.splunk.com", + Version: "v4", + Resource: "postgresclusterclasses", + }, + Name: obj.Name, + Operation: op, + Object: runtime.RawExtension{ + Raw: mustMarshal(t, obj), + }, + UserInfo: authenticationv1.UserInfo{Username: "test-user"}, + }, + } + if oldObj != nil { + ar.Request.OldObject = runtime.RawExtension{ + Raw: mustMarshal(t, oldObj), + } + } + return ar +} + +func sendAdmissionReview(t *testing.T, server *validation.WebhookServer, ar *admissionv1.AdmissionReview) *admissionv1.AdmissionResponse { + t.Helper() + body, err := json.Marshal(ar) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/validate", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + + server.HandleValidate(rr, req) + require.Equal(t, http.StatusOK, rr.Code) + + var response admissionv1.AdmissionReview + require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &response)) + require.NotNil(t, response.Response) + return response.Response +} + +func TestPostgresClusterPgHBAIntegration(t *testing.T) { + server := validation.NewWebhookServer(validation.WebhookServerOptions{ + Port: 9443, + Validators: validation.DefaultValidators, + }) + + tests := []struct { + name string + obj *enterpriseApi.PostgresCluster + wantAllowed bool + wantMessage string + wantMessages []string + }{ + { + name: "valid - no pgHBA rules", + obj: &enterpriseApi.PostgresCluster{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "enterprise.splunk.com/v4", + Kind: "PostgresCluster", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster", + Namespace: "default", + }, + Spec: enterpriseApi.PostgresClusterSpec{ + Class: "dev", + }, + }, + wantAllowed: true, + }, + { + name: "valid - correct pgHBA rules", + obj: &enterpriseApi.PostgresCluster{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "enterprise.splunk.com/v4", + Kind: "PostgresCluster", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster", + Namespace: "default", + }, + Spec: enterpriseApi.PostgresClusterSpec{ + Class: "dev", + PgHBA: []string{ + "hostnossl all all 0.0.0.0/0 reject", + "hostssl all all 0.0.0.0/0 scram-sha-256", + "local all all peer", + }, + }, + }, + wantAllowed: true, + }, + { + name: "rejected - bad connection type", + obj: &enterpriseApi.PostgresCluster{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "enterprise.splunk.com/v4", + Kind: "PostgresCluster", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster", + Namespace: "default", + }, + Spec: enterpriseApi.PostgresClusterSpec{ + Class: "dev", + PgHBA: []string{ + "hostx all all 0.0.0.0/0 md5", + }, + }, + }, + wantAllowed: false, + wantMessage: "unknown connection type", + }, + { + name: "rejected - bad CIDR", + obj: &enterpriseApi.PostgresCluster{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "enterprise.splunk.com/v4", + Kind: "PostgresCluster", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster", + Namespace: "default", + }, + Spec: enterpriseApi.PostgresClusterSpec{ + Class: "dev", + PgHBA: []string{ + "host all all 192.168.0.0/33 md5", + }, + }, + }, + wantAllowed: false, + wantMessage: "invalid CIDR", + }, + { + name: "rejected - unknown auth method", + obj: &enterpriseApi.PostgresCluster{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "enterprise.splunk.com/v4", + Kind: "PostgresCluster", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster", + Namespace: "default", + }, + Spec: enterpriseApi.PostgresClusterSpec{ + Class: "dev", + PgHBA: []string{ + "host all all 0.0.0.0/0 bogus", + }, + }, + }, + wantAllowed: false, + wantMessage: "unknown auth method", + }, + { + name: "rejected - too few fields", + obj: &enterpriseApi.PostgresCluster{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "enterprise.splunk.com/v4", + Kind: "PostgresCluster", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster", + Namespace: "default", + }, + Spec: enterpriseApi.PostgresClusterSpec{ + Class: "dev", + PgHBA: []string{ + "host all all", + }, + }, + }, + wantAllowed: false, + wantMessage: "too few fields", + }, + { + name: "rejected - multiple bad rules reports all errors", + obj: &enterpriseApi.PostgresCluster{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "enterprise.splunk.com/v4", + Kind: "PostgresCluster", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster", + Namespace: "default", + }, + Spec: enterpriseApi.PostgresClusterSpec{ + Class: "dev", + PgHBA: []string{ + "hostssl all all 0.0.0.0/0 scram-sha-256", + "hostx all all 0.0.0.0/0 md5", + "host all all 10.0.0.0/8 bogus", + }, + }, + }, + wantAllowed: false, + wantMessages: []string{"spec.pgHBA[1]", "spec.pgHBA[2]"}, + }, + { + name: "valid - rules with auth options and comments", + obj: &enterpriseApi.PostgresCluster{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "enterprise.splunk.com/v4", + Kind: "PostgresCluster", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster", + Namespace: "default", + }, + Spec: enterpriseApi.PostgresClusterSpec{ + Class: "dev", + PgHBA: []string{ + "host all all 0.0.0.0/0 ldap ldapserver=ldap.example.com ldapport=389", + "host all all 0.0.0.0/0 md5 # office access", + }, + }, + }, + wantAllowed: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ar := newPostgresClusterAdmissionReview(t, "uid-"+tt.name, admissionv1.Create, tt.obj, nil) + resp := sendAdmissionReview(t, server, ar) + + assert.Equal(t, tt.wantAllowed, resp.Allowed, "unexpected admission result") + if !tt.wantAllowed { + require.NotNil(t, resp.Result) + assert.Equal(t, metav1.StatusReasonInvalid, resp.Result.Reason) + assert.Equal(t, int32(http.StatusUnprocessableEntity), resp.Result.Code) + } + if tt.wantMessage != "" { + require.NotNil(t, resp.Result) + assert.Contains(t, resp.Result.Message, tt.wantMessage) + } + for _, msg := range tt.wantMessages { + require.NotNil(t, resp.Result) + assert.Contains(t, resp.Result.Message, msg) + } + }) + } +} + +func TestPostgresClusterPgHBAUpdateIntegration(t *testing.T) { + server := validation.NewWebhookServer(validation.WebhookServerOptions{ + Port: 9443, + Validators: validation.DefaultValidators, + }) + + oldObj := &enterpriseApi.PostgresCluster{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "enterprise.splunk.com/v4", + Kind: "PostgresCluster", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster", + Namespace: "default", + }, + Spec: enterpriseApi.PostgresClusterSpec{ + Class: "dev", + PgHBA: []string{ + "host all all 0.0.0.0/0 scram-sha-256", + }, + }, + } + + t.Run("valid update - change rules", func(t *testing.T) { + newObj := oldObj.DeepCopy() + newObj.Spec.PgHBA = []string{ + "hostssl all all 0.0.0.0/0 scram-sha-256", + "local all all peer", + } + ar := newPostgresClusterAdmissionReview(t, "uid-update-valid", admissionv1.Update, newObj, oldObj) + resp := sendAdmissionReview(t, server, ar) + assert.True(t, resp.Allowed) + }) + + t.Run("rejected update - invalid new rules", func(t *testing.T) { + newObj := oldObj.DeepCopy() + newObj.Spec.PgHBA = []string{ + "hostx all all 0.0.0.0/0 md5", + } + ar := newPostgresClusterAdmissionReview(t, "uid-update-invalid", admissionv1.Update, newObj, oldObj) + resp := sendAdmissionReview(t, server, ar) + assert.False(t, resp.Allowed) + assert.Equal(t, metav1.StatusReasonInvalid, resp.Result.Reason) + assert.Equal(t, int32(http.StatusUnprocessableEntity), resp.Result.Code) + assert.Contains(t, resp.Result.Message, "unknown connection type") + }) +} + +func TestPostgresClusterClassPgHBAIntegration(t *testing.T) { + server := validation.NewWebhookServer(validation.WebhookServerOptions{ + Port: 9443, + Validators: validation.DefaultValidators, + }) + + tests := []struct { + name string + obj *enterpriseApi.PostgresClusterClass + wantAllowed bool + wantMessage string + }{ + { + name: "valid - no pgHBA rules", + obj: &enterpriseApi.PostgresClusterClass{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "enterprise.splunk.com/v4", + Kind: "PostgresClusterClass", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "dev", + }, + Spec: enterpriseApi.PostgresClusterClassSpec{ + Provisioner: "postgresql.cnpg.io", + }, + }, + wantAllowed: true, + }, + { + name: "valid - correct pgHBA rules", + obj: &enterpriseApi.PostgresClusterClass{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "enterprise.splunk.com/v4", + Kind: "PostgresClusterClass", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "dev", + }, + Spec: enterpriseApi.PostgresClusterClassSpec{ + Provisioner: "postgresql.cnpg.io", + Config: &enterpriseApi.PostgresClusterClassConfig{ + PgHBA: []string{ + "hostnossl all all 0.0.0.0/0 reject", + "hostssl all all 0.0.0.0/0 scram-sha-256", + }, + }, + }, + }, + wantAllowed: true, + }, + { + name: "rejected - bad connection type", + obj: &enterpriseApi.PostgresClusterClass{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "enterprise.splunk.com/v4", + Kind: "PostgresClusterClass", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "dev", + }, + Spec: enterpriseApi.PostgresClusterClassSpec{ + Provisioner: "postgresql.cnpg.io", + Config: &enterpriseApi.PostgresClusterClassConfig{ + PgHBA: []string{ + "hostx all all 0.0.0.0/0 md5", + }, + }, + }, + }, + wantAllowed: false, + wantMessage: "unknown connection type", + }, + { + name: "rejected - invalid CIDR in class", + obj: &enterpriseApi.PostgresClusterClass{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "enterprise.splunk.com/v4", + Kind: "PostgresClusterClass", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "dev", + }, + Spec: enterpriseApi.PostgresClusterClassSpec{ + Provisioner: "postgresql.cnpg.io", + Config: &enterpriseApi.PostgresClusterClassConfig{ + PgHBA: []string{ + "host all all 256.1.1.1/24 md5", + }, + }, + }, + }, + wantAllowed: false, + wantMessage: "invalid CIDR", + }, + { + name: "rejected - unknown auth method in class", + obj: &enterpriseApi.PostgresClusterClass{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "enterprise.splunk.com/v4", + Kind: "PostgresClusterClass", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "dev", + }, + Spec: enterpriseApi.PostgresClusterClassSpec{ + Provisioner: "postgresql.cnpg.io", + Config: &enterpriseApi.PostgresClusterClassConfig{ + PgHBA: []string{ + "host all all 0.0.0.0/0 fake-method", + }, + }, + }, + }, + wantAllowed: false, + wantMessage: "unknown auth method", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ar := newPostgresClusterClassAdmissionReview(t, "uid-"+tt.name, admissionv1.Create, tt.obj, nil) + resp := sendAdmissionReview(t, server, ar) + + assert.Equal(t, tt.wantAllowed, resp.Allowed, "unexpected admission result") + if !tt.wantAllowed { + require.NotNil(t, resp.Result) + assert.Equal(t, metav1.StatusReasonInvalid, resp.Result.Reason) + assert.Equal(t, int32(http.StatusUnprocessableEntity), resp.Result.Code) + } + if tt.wantMessage != "" { + require.NotNil(t, resp.Result) + assert.Contains(t, resp.Result.Message, tt.wantMessage) + } + }) + } +} + +func TestPostgresClusterClassPgHBAUpdateIntegration(t *testing.T) { + server := validation.NewWebhookServer(validation.WebhookServerOptions{ + Port: 9443, + Validators: validation.DefaultValidators, + }) + + oldObj := &enterpriseApi.PostgresClusterClass{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "enterprise.splunk.com/v4", + Kind: "PostgresClusterClass", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "dev", + }, + Spec: enterpriseApi.PostgresClusterClassSpec{ + Provisioner: "postgresql.cnpg.io", + Config: &enterpriseApi.PostgresClusterClassConfig{ + PgHBA: []string{ + "host all all 0.0.0.0/0 scram-sha-256", + }, + }, + }, + } + + t.Run("valid update - change rules", func(t *testing.T) { + newObj := oldObj.DeepCopy() + newObj.Spec.Config.PgHBA = []string{ + "hostssl all all 0.0.0.0/0 scram-sha-256", + "hostnossl all all 0.0.0.0/0 reject", + } + ar := newPostgresClusterClassAdmissionReview(t, "uid-class-update-valid", admissionv1.Update, newObj, oldObj) + resp := sendAdmissionReview(t, server, ar) + assert.True(t, resp.Allowed) + }) + + t.Run("rejected update - invalid new rules", func(t *testing.T) { + newObj := oldObj.DeepCopy() + newObj.Spec.Config.PgHBA = []string{ + "host all all 0.0.0.0/0 bogus", + } + ar := newPostgresClusterClassAdmissionReview(t, "uid-class-update-invalid", admissionv1.Update, newObj, oldObj) + resp := sendAdmissionReview(t, server, ar) + assert.False(t, resp.Allowed) + assert.Equal(t, metav1.StatusReasonInvalid, resp.Result.Reason) + assert.Equal(t, int32(http.StatusUnprocessableEntity), resp.Result.Code) + assert.Contains(t, resp.Result.Message, "unknown auth method") + }) +} diff --git a/pkg/postgresql/cluster/adapter/webhook/postgrescluster_validation.go b/pkg/postgresql/cluster/adapter/webhook/postgrescluster_validation.go new file mode 100644 index 000000000..7fc724c4f --- /dev/null +++ b/pkg/postgresql/cluster/adapter/webhook/postgrescluster_validation.go @@ -0,0 +1,56 @@ +/* +Copyright 2026. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package webhook + +import ( + "k8s.io/apimachinery/pkg/util/validation/field" + + enterpriseApi "github.com/splunk/splunk-operator/api/v4" + hba "github.com/splunk/splunk-operator/pkg/postgresql/cluster/core" +) + +// ValidatePostgresClusterCreate validates a PostgresCluster on CREATE. +func ValidatePostgresClusterCreate(obj *enterpriseApi.PostgresCluster) field.ErrorList { + var allErrs field.ErrorList + + if len(obj.Spec.PgHBA) > 0 { + pgHBAPath := field.NewPath("spec").Child("pgHBA") + for _, re := range hba.ValidateRules(obj.Spec.PgHBA) { + allErrs = append(allErrs, field.Invalid( + pgHBAPath.Index(re.Index), + obj.Spec.PgHBA[re.Index], + re.Message)) + } + } + + return allErrs +} + +// ValidatePostgresClusterUpdate validates a PostgresCluster on UPDATE. +func ValidatePostgresClusterUpdate(obj, oldObj *enterpriseApi.PostgresCluster) field.ErrorList { + return ValidatePostgresClusterCreate(obj) +} + +// GetPostgresClusterWarningsOnCreate returns warnings for PostgresCluster CREATE. +func GetPostgresClusterWarningsOnCreate(obj *enterpriseApi.PostgresCluster) []string { + return nil +} + +// GetPostgresClusterWarningsOnUpdate returns warnings for PostgresCluster UPDATE. +func GetPostgresClusterWarningsOnUpdate(obj, oldObj *enterpriseApi.PostgresCluster) []string { + return nil +} diff --git a/pkg/postgresql/cluster/adapter/webhook/postgrescluster_validation_test.go b/pkg/postgresql/cluster/adapter/webhook/postgrescluster_validation_test.go new file mode 100644 index 000000000..56ff34c9c --- /dev/null +++ b/pkg/postgresql/cluster/adapter/webhook/postgrescluster_validation_test.go @@ -0,0 +1,192 @@ +/* +Copyright 2026. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package webhook + +import ( + "testing" + + enterpriseApi "github.com/splunk/splunk-operator/api/v4" + "github.com/stretchr/testify/assert" +) + +func TestValidatePostgresClusterCreate(t *testing.T) { + tests := []struct { + name string + obj *enterpriseApi.PostgresCluster + wantErrCount int + wantErrField string + }{ + { + name: "valid - no pgHBA rules", + obj: &enterpriseApi.PostgresCluster{ + Spec: enterpriseApi.PostgresClusterSpec{ + Class: "dev", + }, + }, + wantErrCount: 0, + }, + { + name: "valid - empty pgHBA", + obj: &enterpriseApi.PostgresCluster{ + Spec: enterpriseApi.PostgresClusterSpec{ + Class: "dev", + PgHBA: []string{}, + }, + }, + wantErrCount: 0, + }, + { + name: "valid - correct pgHBA rules", + obj: &enterpriseApi.PostgresCluster{ + Spec: enterpriseApi.PostgresClusterSpec{ + Class: "dev", + PgHBA: []string{ + "hostnossl all all 0.0.0.0/0 reject", + "hostssl all all 0.0.0.0/0 scram-sha-256", + }, + }, + }, + wantErrCount: 0, + }, + { + name: "invalid - bad connection type", + obj: &enterpriseApi.PostgresCluster{ + Spec: enterpriseApi.PostgresClusterSpec{ + Class: "dev", + PgHBA: []string{ + "hostx all all 0.0.0.0/0 md5", + }, + }, + }, + wantErrCount: 1, + wantErrField: "spec.pgHBA[0]", + }, + { + name: "invalid - bad CIDR", + obj: &enterpriseApi.PostgresCluster{ + Spec: enterpriseApi.PostgresClusterSpec{ + Class: "dev", + PgHBA: []string{ + "host all all 192.168.0.0/33 md5", + }, + }, + }, + wantErrCount: 1, + wantErrField: "spec.pgHBA[0]", + }, + { + name: "invalid - bad auth method", + obj: &enterpriseApi.PostgresCluster{ + Spec: enterpriseApi.PostgresClusterSpec{ + Class: "dev", + PgHBA: []string{ + "host all all 0.0.0.0/0 bogus-auth", + }, + }, + }, + wantErrCount: 1, + wantErrField: "spec.pgHBA[0]", + }, + { + name: "invalid - missing fields", + obj: &enterpriseApi.PostgresCluster{ + Spec: enterpriseApi.PostgresClusterSpec{ + Class: "dev", + PgHBA: []string{ + "host all all", + }, + }, + }, + wantErrCount: 1, + wantErrField: "spec.pgHBA[0]", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := ValidatePostgresClusterCreate(tt.obj) + assert.Len(t, errs, tt.wantErrCount, "unexpected error count") + if tt.wantErrField != "" && len(errs) > 0 { + assert.Equal(t, tt.wantErrField, errs[0].Field, "unexpected error field") + } + }) + } +} + +func TestValidatePostgresClusterUpdate(t *testing.T) { + tests := []struct { + name string + obj *enterpriseApi.PostgresCluster + oldObj *enterpriseApi.PostgresCluster + wantErrCount int + }{ + { + name: "valid update - add pgHBA rules", + obj: &enterpriseApi.PostgresCluster{ + Spec: enterpriseApi.PostgresClusterSpec{ + Class: "dev", + PgHBA: []string{"host all all 0.0.0.0/0 scram-sha-256"}, + }, + }, + oldObj: &enterpriseApi.PostgresCluster{ + Spec: enterpriseApi.PostgresClusterSpec{ + Class: "dev", + }, + }, + wantErrCount: 0, + }, + { + name: "invalid update - bad pgHBA", + obj: &enterpriseApi.PostgresCluster{ + Spec: enterpriseApi.PostgresClusterSpec{ + Class: "dev", + PgHBA: []string{"hostx all all 0.0.0.0/0 md5"}, + }, + }, + oldObj: &enterpriseApi.PostgresCluster{ + Spec: enterpriseApi.PostgresClusterSpec{ + Class: "dev", + }, + }, + wantErrCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := ValidatePostgresClusterUpdate(tt.obj, tt.oldObj) + assert.Len(t, errs, tt.wantErrCount, "unexpected error count") + }) + } +} + +func TestGetPostgresClusterWarningsOnCreate(t *testing.T) { + obj := &enterpriseApi.PostgresCluster{ + Spec: enterpriseApi.PostgresClusterSpec{Class: "dev"}, + } + assert.Empty(t, GetPostgresClusterWarningsOnCreate(obj)) +} + +func TestGetPostgresClusterWarningsOnUpdate(t *testing.T) { + obj := &enterpriseApi.PostgresCluster{ + Spec: enterpriseApi.PostgresClusterSpec{Class: "dev"}, + } + oldObj := &enterpriseApi.PostgresCluster{ + Spec: enterpriseApi.PostgresClusterSpec{Class: "dev"}, + } + assert.Empty(t, GetPostgresClusterWarningsOnUpdate(obj, oldObj)) +} diff --git a/pkg/postgresql/cluster/adapter/webhook/postgresclusterclass_validation.go b/pkg/postgresql/cluster/adapter/webhook/postgresclusterclass_validation.go new file mode 100644 index 000000000..28246cba4 --- /dev/null +++ b/pkg/postgresql/cluster/adapter/webhook/postgresclusterclass_validation.go @@ -0,0 +1,56 @@ +/* +Copyright 2026. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package webhook + +import ( + "k8s.io/apimachinery/pkg/util/validation/field" + + enterpriseApi "github.com/splunk/splunk-operator/api/v4" + hba "github.com/splunk/splunk-operator/pkg/postgresql/cluster/core" +) + +// ValidatePostgresClusterClassCreate validates a PostgresClusterClass on CREATE. +func ValidatePostgresClusterClassCreate(obj *enterpriseApi.PostgresClusterClass) field.ErrorList { + var allErrs field.ErrorList + + if obj.Spec.Config != nil && len(obj.Spec.Config.PgHBA) > 0 { + pgHBAPath := field.NewPath("spec").Child("config").Child("pgHBA") + for _, re := range hba.ValidateRules(obj.Spec.Config.PgHBA) { + allErrs = append(allErrs, field.Invalid( + pgHBAPath.Index(re.Index), + obj.Spec.Config.PgHBA[re.Index], + re.Message)) + } + } + + return allErrs +} + +// ValidatePostgresClusterClassUpdate validates a PostgresClusterClass on UPDATE. +func ValidatePostgresClusterClassUpdate(obj, oldObj *enterpriseApi.PostgresClusterClass) field.ErrorList { + return ValidatePostgresClusterClassCreate(obj) +} + +// GetPostgresClusterClassWarningsOnCreate returns warnings for PostgresClusterClass CREATE. +func GetPostgresClusterClassWarningsOnCreate(obj *enterpriseApi.PostgresClusterClass) []string { + return nil +} + +// GetPostgresClusterClassWarningsOnUpdate returns warnings for PostgresClusterClass UPDATE. +func GetPostgresClusterClassWarningsOnUpdate(obj, oldObj *enterpriseApi.PostgresClusterClass) []string { + return nil +} diff --git a/pkg/postgresql/cluster/adapter/webhook/postgresclusterclass_validation_test.go b/pkg/postgresql/cluster/adapter/webhook/postgresclusterclass_validation_test.go new file mode 100644 index 000000000..5f0bef95c --- /dev/null +++ b/pkg/postgresql/cluster/adapter/webhook/postgresclusterclass_validation_test.go @@ -0,0 +1,191 @@ +/* +Copyright 2026. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package webhook + +import ( + "testing" + + enterpriseApi "github.com/splunk/splunk-operator/api/v4" + "github.com/stretchr/testify/assert" +) + +func TestValidatePostgresClusterClassCreate(t *testing.T) { + tests := []struct { + name string + obj *enterpriseApi.PostgresClusterClass + wantErrCount int + wantErrField string + }{ + { + name: "valid - no config", + obj: &enterpriseApi.PostgresClusterClass{ + Spec: enterpriseApi.PostgresClusterClassSpec{ + Provisioner: "postgresql.cnpg.io", + }, + }, + wantErrCount: 0, + }, + { + name: "valid - config without pgHBA", + obj: &enterpriseApi.PostgresClusterClass{ + Spec: enterpriseApi.PostgresClusterClassSpec{ + Provisioner: "postgresql.cnpg.io", + Config: &enterpriseApi.PostgresClusterClassConfig{}, + }, + }, + wantErrCount: 0, + }, + { + name: "valid - correct pgHBA rules", + obj: &enterpriseApi.PostgresClusterClass{ + Spec: enterpriseApi.PostgresClusterClassSpec{ + Provisioner: "postgresql.cnpg.io", + Config: &enterpriseApi.PostgresClusterClassConfig{ + PgHBA: []string{ + "hostnossl all all 0.0.0.0/0 reject", + "hostssl all all 0.0.0.0/0 scram-sha-256", + }, + }, + }, + }, + wantErrCount: 0, + }, + { + name: "invalid - bad connection type", + obj: &enterpriseApi.PostgresClusterClass{ + Spec: enterpriseApi.PostgresClusterClassSpec{ + Provisioner: "postgresql.cnpg.io", + Config: &enterpriseApi.PostgresClusterClassConfig{ + PgHBA: []string{ + "hostx all all 0.0.0.0/0 md5", + }, + }, + }, + }, + wantErrCount: 1, + wantErrField: "spec.config.pgHBA[0]", + }, + { + name: "invalid - bad CIDR in class", + obj: &enterpriseApi.PostgresClusterClass{ + Spec: enterpriseApi.PostgresClusterClassSpec{ + Provisioner: "postgresql.cnpg.io", + Config: &enterpriseApi.PostgresClusterClassConfig{ + PgHBA: []string{ + "host all all 256.1.1.1/24 md5", + }, + }, + }, + }, + wantErrCount: 1, + wantErrField: "spec.config.pgHBA[0]", + }, + { + name: "invalid - unknown auth method in class", + obj: &enterpriseApi.PostgresClusterClass{ + Spec: enterpriseApi.PostgresClusterClassSpec{ + Provisioner: "postgresql.cnpg.io", + Config: &enterpriseApi.PostgresClusterClassConfig{ + PgHBA: []string{ + "host all all 0.0.0.0/0 bogus", + }, + }, + }, + }, + wantErrCount: 1, + wantErrField: "spec.config.pgHBA[0]", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := ValidatePostgresClusterClassCreate(tt.obj) + assert.Len(t, errs, tt.wantErrCount, "unexpected error count") + if tt.wantErrField != "" && len(errs) > 0 { + assert.Equal(t, tt.wantErrField, errs[0].Field, "unexpected error field") + } + }) + } +} + +func TestValidatePostgresClusterClassUpdate(t *testing.T) { + tests := []struct { + name string + obj *enterpriseApi.PostgresClusterClass + oldObj *enterpriseApi.PostgresClusterClass + wantErrCount int + }{ + { + name: "valid update", + obj: &enterpriseApi.PostgresClusterClass{ + Spec: enterpriseApi.PostgresClusterClassSpec{ + Provisioner: "postgresql.cnpg.io", + Config: &enterpriseApi.PostgresClusterClassConfig{ + PgHBA: []string{"host all all 0.0.0.0/0 scram-sha-256"}, + }, + }, + }, + oldObj: &enterpriseApi.PostgresClusterClass{ + Spec: enterpriseApi.PostgresClusterClassSpec{ + Provisioner: "postgresql.cnpg.io", + }, + }, + wantErrCount: 0, + }, + { + name: "invalid update - bad pgHBA", + obj: &enterpriseApi.PostgresClusterClass{ + Spec: enterpriseApi.PostgresClusterClassSpec{ + Provisioner: "postgresql.cnpg.io", + Config: &enterpriseApi.PostgresClusterClassConfig{ + PgHBA: []string{"host all all 0.0.0.0/0 fake-method"}, + }, + }, + }, + oldObj: &enterpriseApi.PostgresClusterClass{ + Spec: enterpriseApi.PostgresClusterClassSpec{ + Provisioner: "postgresql.cnpg.io", + }, + }, + wantErrCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := ValidatePostgresClusterClassUpdate(tt.obj, tt.oldObj) + assert.Len(t, errs, tt.wantErrCount, "unexpected error count") + }) + } +} + +func TestGetPostgresClusterClassWarningsOnCreate(t *testing.T) { + obj := &enterpriseApi.PostgresClusterClass{ + Spec: enterpriseApi.PostgresClusterClassSpec{Provisioner: "postgresql.cnpg.io"}, + } + assert.Empty(t, GetPostgresClusterClassWarningsOnCreate(obj)) +} + +func TestGetPostgresClusterClassWarningsOnUpdate(t *testing.T) { + obj := &enterpriseApi.PostgresClusterClass{ + Spec: enterpriseApi.PostgresClusterClassSpec{Provisioner: "postgresql.cnpg.io"}, + } + oldObj := &enterpriseApi.PostgresClusterClass{ + Spec: enterpriseApi.PostgresClusterClassSpec{Provisioner: "postgresql.cnpg.io"}, + } + assert.Empty(t, GetPostgresClusterClassWarningsOnUpdate(obj, oldObj)) +} diff --git a/pkg/postgresql/cluster/core/hba.go b/pkg/postgresql/cluster/core/hba.go new file mode 100644 index 000000000..099597696 --- /dev/null +++ b/pkg/postgresql/cluster/core/hba.go @@ -0,0 +1,253 @@ +/* +Copyright 2026. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package core + +import ( + "fmt" + "net" + "regexp" + "strings" +) + +var hbaConnectionTypes = map[string]bool{ + "local": true, + "host": true, + "hostssl": true, + "hostnossl": true, + "hostgssenc": true, + "hostnogssenc": true, +} + +var hbaAuthMethods = map[string]bool{ + "trust": true, + "reject": true, + "scram-sha-256": true, + "md5": true, + "password": true, + "gss": true, + "sspi": true, + "ident": true, + "peer": true, + "pam": true, + "ldap": true, + "radius": true, + "cert": true, + "oauth": true, +} + +var hbaSpecialAddresses = map[string]bool{ + "all": true, + "samehost": true, + "samenet": true, +} + +// tokenPattern splits on whitespace while keeping double-quoted strings intact. +var hbaTokenPattern = regexp.MustCompile(`(?:"+.*?"+|\S)+`) + +// hbaLabelPattern matches a valid DNS label sequence (hostname or domain suffix). +var hbaLabelPattern = regexp.MustCompile(`^[a-zA-Z0-9]([a-zA-Z0-9-]*[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9-]*[a-zA-Z0-9])?)*$`) + +// RuleError describes a validation error for a single pg_hba.conf rule. +type RuleError struct { + Index int + Message string +} + +// ValidateRules validates a slice of pg_hba.conf rule strings. +func ValidateRules(rules []string) []RuleError { + var errs []RuleError + for i, rule := range rules { + for _, msg := range validateRule(rule) { + errs = append(errs, RuleError{Index: i, Message: msg}) + } + } + return errs +} + +// validateRule validates a single pg_hba rule using positional parsing. +// pg_hba.conf has two formats: +// +// local DATABASE USER METHOD [OPTIONS] +// host* DATABASE USER ADDRESS METHOD [OPTIONS] +// host* DATABASE USER IP-ADDRESS NETMASK METHOD [OPTIONS] +// +// Validation order: connection type → minimum field count → auth method +// (at a fixed positional index) → address for host* types. The IP+netmask +// form is detected by checking whether tokens[4] parses as a valid IP. +func validateRule(rule string) []string { + trimmed := strings.TrimSpace(rule) + if trimmed == "" { + return nil + } + + tokens := tokenize(trimmed) + if len(tokens) == 0 { + return nil + } + + var errs []string + + connType := tokens[0] + if !hbaConnectionTypes[connType] { + return []string{fmt.Sprintf("unknown connection type %q", connType)} + } + + isLocal := connType == "local" + minFields := 5 // TYPE DATABASE USER ADDRESS METHOD + if isLocal { + minFields = 4 // local DATABASE USER METHOD + } + if len(tokens) < minFields { + return []string{fmt.Sprintf("too few fields: expected at least %d (%s DATABASE USER %sMETHOD), got %d", + minFields, connType, map[bool]string{true: "", false: "ADDRESS "}[isLocal], len(tokens))} + } + + methodIdx := 3 // local: tokens[3] + if !isLocal { + if len(tokens) > 5 && net.ParseIP(tokens[4]) != nil { + methodIdx = 5 + } else { + methodIdx = 4 + } + } + if methodIdx >= len(tokens) { + return []string{fmt.Sprintf("too few fields: missing auth method")} + } + method := tokens[methodIdx] + if !hbaAuthMethods[method] { + errs = append(errs, fmt.Sprintf("unknown auth method %q", method)) + } + + if !isLocal { + address := tokens[3] + if methodIdx == 5 { + if addrErr := validateIPNetmask(tokens[3], tokens[4]); addrErr != "" { + errs = append(errs, addrErr) + } + } else { + if addrErr := validateAddress(address); addrErr != "" { + errs = append(errs, addrErr) + } + } + } + + return errs +} + +// stripComment removes pg_hba.conf comments: a # outside double quotes starts +// a comment that runs to the end of the line. +func stripComment(line string) string { + inQuotes := false + for i, ch := range line { + switch ch { + case '"': + inQuotes = !inQuotes + case '#': + if !inQuotes { + return line[:i] + } + } + } + return line +} + +// tokenize splits a rule string on whitespace, keeping double-quoted strings intact. +// Comments (# to end of line, outside quotes) are stripped first. +func tokenize(line string) []string { + stripped := stripComment(line) + matches := hbaTokenPattern.FindAllString(stripped, -1) + var tokens []string + for _, m := range matches { + if s := strings.TrimSpace(m); s != "" { + tokens = append(tokens, s) + } + } + return tokens +} + +// validateAddress validates the address field for host* connection types. +func validateAddress(address string) string { + if hbaSpecialAddresses[address] { + return "" + } + + // Domain suffix match: .example.com + if strings.HasPrefix(address, ".") && len(address) > 1 { + suffix := address[1:] + if hbaLabelPattern.MatchString(suffix) { + return "" + } + return fmt.Sprintf("invalid domain suffix %q", address) + } + + // CIDR notation + if strings.Contains(address, "/") { + if _, _, err := net.ParseCIDR(address); err != nil { + return fmt.Sprintf("invalid CIDR address %q: %v", address, err) + } + return "" + } + + // IP address without CIDR (used with separate netmask field) + if ip := net.ParseIP(address); ip != nil { + return "" + } + + // Hostname + if hbaLabelPattern.MatchString(address) { + return "" + } + + return fmt.Sprintf("invalid address %q: expected CIDR, IP, hostname, or special keyword (all, samehost, samenet)", address) +} + +// validateIPNetmask validates the IP + netmask form (two separate fields). +func validateIPNetmask(ip, mask string) string { + parsedIP := net.ParseIP(ip) + if parsedIP == nil { + return fmt.Sprintf("invalid IP address %q in IP/netmask pair", ip) + } + + parsedMask := net.ParseIP(mask) + if parsedMask == nil { + return fmt.Sprintf("invalid netmask %q: not a valid IP address", mask) + } + + // Verify the mask is a valid contiguous subnet mask. + // Convert to 4 or 16 bytes depending on IPv4/IPv6. + var maskBytes net.IPMask + if v4 := parsedMask.To4(); v4 != nil { + maskBytes = net.IPMask(v4) + } else { + maskBytes = net.IPMask(parsedMask.To16()) + } + + // net.IPMask.Size() returns (ones, bits); ones == 0 && bits == 0 means invalid mask + ones, bits := maskBytes.Size() + if ones == 0 && bits == 0 { + return fmt.Sprintf("invalid netmask %q: not a contiguous subnet mask", mask) + } + + // IP and mask must be the same address family + ipIs4 := parsedIP.To4() != nil + maskIs4 := parsedMask.To4() != nil + if ipIs4 != maskIs4 { + return fmt.Sprintf("IP %q and netmask %q are not the same address family", ip, mask) + } + + return "" +} diff --git a/pkg/postgresql/cluster/core/hba_unit_test.go b/pkg/postgresql/cluster/core/hba_unit_test.go new file mode 100644 index 000000000..dffcb1cec --- /dev/null +++ b/pkg/postgresql/cluster/core/hba_unit_test.go @@ -0,0 +1,364 @@ +/* +Copyright 2026. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package core + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestValidateRules(t *testing.T) { + t.Run("nil slice returns empty", func(t *testing.T) { + assert.Empty(t, ValidateRules(nil)) + }) + + t.Run("empty slice returns empty", func(t *testing.T) { + assert.Empty(t, ValidateRules([]string{})) + }) + + t.Run("all valid rules returns empty", func(t *testing.T) { + rules := []string{ + "local all all trust", + "host all all 0.0.0.0/0 scram-sha-256", + "hostssl all all 192.168.1.0/24 md5", + } + assert.Empty(t, ValidateRules(rules)) + }) + + t.Run("mixed valid and invalid returns correct indices", func(t *testing.T) { + rules := []string{ + "host all all 0.0.0.0/0 scram-sha-256", + "hostx all all 0.0.0.0/0 md5", + "host all all 0.0.0.0/0 md5", + } + errs := ValidateRules(rules) + require.Len(t, errs, 1) + assert.Equal(t, 1, errs[0].Index) + assert.Contains(t, errs[0].Message, "unknown connection type") + }) + + t.Run("multiple errors in different rules", func(t *testing.T) { + rules := []string{ + "hostx all all 0.0.0.0/0 md5", + "host all all 192.168.0.0/33 bogus", + } + errs := ValidateRules(rules) + require.Len(t, errs, 3) + assert.Equal(t, 0, errs[0].Index) + assert.Equal(t, 1, errs[1].Index) + assert.Equal(t, 1, errs[2].Index) + }) +} + +func TestValidateRule(t *testing.T) { + // === Valid rules === + + validRules := []struct { + name string + rule string + }{ + {"local basic", "local all all trust"}, + {"local with peer", "local postgres postgres peer"}, + {"host CIDR IPv4", "host all all 0.0.0.0/0 scram-sha-256"}, + {"hostssl CIDR", "hostssl all all 192.168.1.0/24 md5"}, + {"hostnossl reject", "hostnossl all all 0.0.0.0/0 reject"}, + {"hostgssenc", "hostgssenc all all 0.0.0.0/0 gss"}, + {"hostnogssenc", "hostnogssenc all all 0.0.0.0/0 scram-sha-256"}, + {"host replication", "host replication all 10.0.0.0/8 password"}, + {"host samehost", "host all all samehost trust"}, + {"host samenet", "host all all samenet trust"}, + {"host address all", "host all all all scram-sha-256"}, + {"host domain suffix", "host all all .example.com cert"}, + {"host IPv6", "host all all ::1/128 scram-sha-256"}, + {"host IPv6 all", "host all all ::0/0 md5"}, + {"host IP+netmask", "host all all 192.168.1.1 255.255.255.0 md5"}, + {"host IP+netmask /8", "host all all 10.0.0.0 255.0.0.0 md5"}, + {"inline comment", "host all all 0.0.0.0/0 md5 # office access"}, + {"inline comment with spaces", "host all all 0.0.0.0/0 md5 # allow all"}, + {"full-line comment", "# this is a comment"}, + {"comment-only with spaces", " # indented comment"}, + {"host auth options", "host all all 0.0.0.0/0 ldap ldapserver=ldap.example.com ldapport=389"}, + {"host quoted option", "host all all 0.0.0.0/0 ident map=omicron"}, + {"host quoted value", `host all all 0.0.0.0/0 ldap ldapprefix="cn="`}, + {"quoted db with equals", `host "db=name" all 0.0.0.0/0 md5`}, + {"comma-separated db", "host db1,db2 all 0.0.0.0/0 md5"}, + {"comma-separated user", "host all user1,user2 0.0.0.0/0 md5"}, + {"host hostname", "host all all myhost.example.com md5"}, + {"host with sspi", "host all all 0.0.0.0/0 sspi"}, + {"host with ident", "host all all 0.0.0.0/0 ident"}, + {"host with pam", "host all all 0.0.0.0/0 pam"}, + {"host with radius", "host all all 0.0.0.0/0 radius"}, + {"host with oauth (PG18)", "host all all 0.0.0.0/0 oauth"}, + {"empty string", ""}, + {"whitespace only", " "}, + } + + for _, tc := range validRules { + t.Run("valid/"+tc.name, func(t *testing.T) { + errs := validateRule(tc.rule) + assert.Empty(t, errs, "expected no errors for rule %q, got: %v", tc.rule, errs) + }) + } + + // === Layer 0: connection type errors === + + t.Run("layer0/unknown connection type", func(t *testing.T) { + errs := validateRule("hostx all all 0.0.0.0/0 md5") + require.Len(t, errs, 1) + assert.Contains(t, errs[0], `unknown connection type "hostx"`) + }) + + t.Run("layer0/uppercase connection type", func(t *testing.T) { + errs := validateRule("HOST all all 0.0.0.0/0 md5") + require.Len(t, errs, 1) + assert.Contains(t, errs[0], `unknown connection type "HOST"`) + }) + + // === Layer 1: field count errors === + + t.Run("layer1/host missing method", func(t *testing.T) { + errs := validateRule("host all all 0.0.0.0/0") + require.Len(t, errs, 1) + assert.Contains(t, errs[0], "too few fields") + }) + + t.Run("layer1/host only three fields", func(t *testing.T) { + errs := validateRule("host all all") + require.Len(t, errs, 1) + assert.Contains(t, errs[0], "too few fields") + }) + + t.Run("layer1/local missing user and method", func(t *testing.T) { + errs := validateRule("local all") + require.Len(t, errs, 1) + assert.Contains(t, errs[0], "too few fields") + }) + + t.Run("layer1/local missing method", func(t *testing.T) { + errs := validateRule("local all all") + require.Len(t, errs, 1) + assert.Contains(t, errs[0], "too few fields") + }) + + // === Layer 2: auth method errors === + + t.Run("layer2/unknown auth method", func(t *testing.T) { + errs := validateRule("host all all 0.0.0.0/0 bogus") + require.Len(t, errs, 1) + assert.Contains(t, errs[0], `unknown auth method "bogus"`) + }) + + t.Run("layer2/typo scram-sha256", func(t *testing.T) { + errs := validateRule("host all all 0.0.0.0/0 scram-sha256") + require.Len(t, errs, 1) + assert.Contains(t, errs[0], `unknown auth method "scram-sha256"`) + }) + + t.Run("layer2/local unknown method", func(t *testing.T) { + errs := validateRule("local all all unknown") + require.Len(t, errs, 1) + assert.Contains(t, errs[0], `unknown auth method "unknown"`) + }) + + // === Layer 3: address errors === + + t.Run("layer3/invalid CIDR mask too large", func(t *testing.T) { + errs := validateRule("host all all 192.168.0.0/33 md5") + require.Len(t, errs, 1) + assert.Contains(t, errs[0], "invalid CIDR") + }) + + t.Run("layer3/invalid IP in CIDR", func(t *testing.T) { + errs := validateRule("host all all 256.1.1.1/24 md5") + require.Len(t, errs, 1) + assert.Contains(t, errs[0], "invalid CIDR") + }) + + t.Run("layer3/garbage address", func(t *testing.T) { + errs := validateRule("host all all not@valid md5") + require.Len(t, errs, 1) + assert.Contains(t, errs[0], "invalid address") + }) + + // === Layer 3: netmask errors === + + t.Run("layer3/non-contiguous netmask", func(t *testing.T) { + errs := validateRule("host all all 10.0.0.1 255.0.255.0 md5") + require.Len(t, errs, 1) + assert.Contains(t, errs[0], "not a contiguous subnet mask") + }) + + t.Run("layer3/invalid IP in netmask pair", func(t *testing.T) { + errs := validateRule("host all all 999.0.0.1 255.255.255.0 md5") + require.Len(t, errs, 1) + assert.Contains(t, errs[0], "invalid IP address") + }) + + t.Run("layer2/garbage where netmask expected", func(t *testing.T) { + errs := validateRule("host all all 10.0.0.1 notamask md5") + require.Len(t, errs, 1) + assert.Contains(t, errs[0], "unknown auth method") + }) + + // === Multiple errors in one rule === + + t.Run("multiple/bad method and bad address", func(t *testing.T) { + errs := validateRule("host all all 192.168.0.0/33 bogus") + assert.Len(t, errs, 2) + }) +} + +func TestTokenize(t *testing.T) { + t.Run("simple fields", func(t *testing.T) { + tokens := tokenize("host all all 0.0.0.0/0 md5") + assert.Equal(t, []string{"host", "all", "all", "0.0.0.0/0", "md5"}, tokens) + }) + + t.Run("multiple spaces", func(t *testing.T) { + tokens := tokenize("host all all 0.0.0.0/0 md5") + assert.Equal(t, []string{"host", "all", "all", "0.0.0.0/0", "md5"}, tokens) + }) + + t.Run("quoted string preserved", func(t *testing.T) { + tokens := tokenize(`host all all 0.0.0.0/0 ldap ldapprefix="cn="`) + assert.Equal(t, []string{"host", "all", "all", "0.0.0.0/0", "ldap", `ldapprefix="cn="`}, tokens) + }) + + t.Run("auth option with equals", func(t *testing.T) { + tokens := tokenize("host all all 0.0.0.0/0 ident map=omicron") + assert.Equal(t, []string{"host", "all", "all", "0.0.0.0/0", "ident", "map=omicron"}, tokens) + }) + + t.Run("empty string", func(t *testing.T) { + tokens := tokenize("") + assert.Empty(t, tokens) + }) + + t.Run("inline comment stripped", func(t *testing.T) { + tokens := tokenize("host all all 0.0.0.0/0 md5 # office access") + assert.Equal(t, []string{"host", "all", "all", "0.0.0.0/0", "md5"}, tokens) + }) + + t.Run("full-line comment", func(t *testing.T) { + tokens := tokenize("# this is a comment") + assert.Empty(t, tokens) + }) + + t.Run("hash inside quotes not treated as comment", func(t *testing.T) { + tokens := tokenize(`host all all 0.0.0.0/0 ldap ldapprefix="cn=#test"`) + assert.Equal(t, []string{"host", "all", "all", "0.0.0.0/0", "ldap", `ldapprefix="cn=#test"`}, tokens) + }) +} + +func TestStripComment(t *testing.T) { + t.Run("no comment", func(t *testing.T) { + assert.Equal(t, "host all all 0.0.0.0/0 md5", stripComment("host all all 0.0.0.0/0 md5")) + }) + + t.Run("inline comment", func(t *testing.T) { + assert.Equal(t, "host all all 0.0.0.0/0 md5 ", stripComment("host all all 0.0.0.0/0 md5 # comment")) + }) + + t.Run("full-line comment", func(t *testing.T) { + assert.Equal(t, "", stripComment("# full line comment")) + }) + + t.Run("hash inside quotes preserved", func(t *testing.T) { + assert.Equal(t, `host all all 0.0.0.0/0 ldap ldapprefix="cn=#x"`, stripComment(`host all all 0.0.0.0/0 ldap ldapprefix="cn=#x"`)) + }) + + t.Run("hash after closing quote", func(t *testing.T) { + assert.Equal(t, `host all all 0.0.0.0/0 ldap ldapprefix="cn" `, stripComment(`host all all 0.0.0.0/0 ldap ldapprefix="cn" # comment`)) + }) +} + +func TestValidateIPNetmask(t *testing.T) { + t.Run("valid IPv4", func(t *testing.T) { + assert.Empty(t, validateIPNetmask("192.168.1.0", "255.255.255.0")) + }) + + t.Run("valid /8", func(t *testing.T) { + assert.Empty(t, validateIPNetmask("10.0.0.0", "255.0.0.0")) + }) + + t.Run("invalid IP", func(t *testing.T) { + result := validateIPNetmask("999.0.0.1", "255.255.255.0") + assert.Contains(t, result, "invalid IP address") + }) + + t.Run("invalid mask not an IP", func(t *testing.T) { + result := validateIPNetmask("10.0.0.1", "notamask") + assert.Contains(t, result, "invalid netmask") + }) + + t.Run("non-contiguous mask", func(t *testing.T) { + result := validateIPNetmask("10.0.0.1", "255.0.255.0") + assert.Contains(t, result, "not a contiguous subnet mask") + }) +} + +func TestValidateAddress(t *testing.T) { + validAddresses := []string{ + "0.0.0.0/0", + "192.168.1.0/24", + "10.0.0.0/8", + "::1/128", + "::0/0", + "all", + "samehost", + "samenet", + ".example.com", + ".sub.domain.com", + "192.168.1.1", + "myhost.example.com", + "my-host", + "localhost", + } + + for _, addr := range validAddresses { + t.Run("valid/"+addr, func(t *testing.T) { + assert.Empty(t, validateAddress(addr)) + }) + } + + invalidAddresses := []struct { + name string + address string + errMsg string + }{ + {"CIDR mask too large", "192.168.0.0/33", "invalid CIDR"}, + {"invalid IP in CIDR", "256.1.1.1/24", "invalid CIDR"}, + {"bad CIDR format", "999.999.999.999/32", "invalid CIDR"}, + {"special chars", "host@name", "invalid address"}, + {"spaces in addr", "my host", "invalid address"}, + {"double dot hostname", "myhost..example.com", "invalid address"}, + {"leading dash hostname", "-myhost", "invalid address"}, + {"trailing dash hostname", "myhost-", "invalid address"}, + {"double dot domain suffix", ".foo..bar", "invalid domain suffix"}, + {"dash-prefixed domain suffix", ".-bad", "invalid domain suffix"}, + {"trailing dash domain suffix", ".bad-", "invalid domain suffix"}, + } + + for _, tc := range invalidAddresses { + t.Run("invalid/"+tc.name, func(t *testing.T) { + result := validateAddress(tc.address) + assert.Contains(t, result, tc.errMsg) + }) + } +} diff --git a/pkg/splunk/enterprise/validation/registry.go b/pkg/splunk/enterprise/validation/registry.go index 98b386f18..5eab98402 100644 --- a/pkg/splunk/enterprise/validation/registry.go +++ b/pkg/splunk/enterprise/validation/registry.go @@ -20,6 +20,7 @@ import ( "k8s.io/apimachinery/pkg/runtime/schema" enterpriseApi "github.com/splunk/splunk-operator/api/v4" + pgwebhook "github.com/splunk/splunk-operator/pkg/postgresql/cluster/adapter/webhook" ) // GVR constants for all Splunk Enterprise CRDs @@ -71,6 +72,18 @@ var ( Version: "v4", Resource: "monitoringconsoles", } + + PostgresClusterGVR = schema.GroupVersionResource{ + Group: "enterprise.splunk.com", + Version: "v4", + Resource: "postgresclusters", + } + + PostgresClusterClassGVR = schema.GroupVersionResource{ + Group: "enterprise.splunk.com", + Version: "v4", + Resource: "postgresclusterclasses", + } ) // DefaultValidators is the registry of validators for all Splunk Enterprise CRDs @@ -180,4 +193,26 @@ var DefaultValidators = map[schema.GroupVersionResource]Validator{ Kind: "MonitoringConsole", }, }, + + PostgresClusterGVR: &GenericValidator[*enterpriseApi.PostgresCluster]{ + ValidateCreateFunc: pgwebhook.ValidatePostgresClusterCreate, + ValidateUpdateFunc: pgwebhook.ValidatePostgresClusterUpdate, + WarningsOnCreateFunc: pgwebhook.GetPostgresClusterWarningsOnCreate, + WarningsOnUpdateFunc: pgwebhook.GetPostgresClusterWarningsOnUpdate, + GroupKind: schema.GroupKind{ + Group: "enterprise.splunk.com", + Kind: "PostgresCluster", + }, + }, + + PostgresClusterClassGVR: &GenericValidator[*enterpriseApi.PostgresClusterClass]{ + ValidateCreateFunc: pgwebhook.ValidatePostgresClusterClassCreate, + ValidateUpdateFunc: pgwebhook.ValidatePostgresClusterClassUpdate, + WarningsOnCreateFunc: pgwebhook.GetPostgresClusterClassWarningsOnCreate, + WarningsOnUpdateFunc: pgwebhook.GetPostgresClusterClassWarningsOnUpdate, + GroupKind: schema.GroupKind{ + Group: "enterprise.splunk.com", + Kind: "PostgresClusterClass", + }, + }, } diff --git a/pkg/splunk/enterprise/validation/server.go b/pkg/splunk/enterprise/validation/server.go index c94e03f45..882f89878 100644 --- a/pkg/splunk/enterprise/validation/server.go +++ b/pkg/splunk/enterprise/validation/server.go @@ -80,7 +80,7 @@ func (s *WebhookServer) Start(ctx context.Context) error { mux := http.NewServeMux() // Register validation endpoint - mux.HandleFunc("/validate", s.handleValidate) + mux.HandleFunc("/validate", s.HandleValidate) // Register health check endpoint mux.HandleFunc("/readyz", s.handleReadyz) @@ -140,8 +140,8 @@ func (s *WebhookServer) Start(ctx context.Context) error { } } -// handleValidate handles validation requests -func (s *WebhookServer) handleValidate(w http.ResponseWriter, r *http.Request) { +// HandleValidate handles validation requests +func (s *WebhookServer) HandleValidate(w http.ResponseWriter, r *http.Request) { reqLog := log.FromContext(r.Context()).WithName("webhook-server") reqLog.V(1).Info("Received validation request", "method", r.Method, "path", r.URL.Path) diff --git a/pkg/splunk/enterprise/validation/server_test.go b/pkg/splunk/enterprise/validation/server_test.go index 0b2543014..020825041 100644 --- a/pkg/splunk/enterprise/validation/server_test.go +++ b/pkg/splunk/enterprise/validation/server_test.go @@ -253,7 +253,7 @@ func TestHandleValidate(t *testing.T) { req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() - server.handleValidate(rr, req) + server.HandleValidate(rr, req) if rr.Code != tt.wantStatusCode { t.Errorf("expected status code %d, got %d", tt.wantStatusCode, rr.Code) @@ -382,7 +382,7 @@ func TestHandleValidateWithWarnings(t *testing.T) { req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() - server.handleValidate(rr, req) + server.HandleValidate(rr, req) if rr.Code != http.StatusOK { t.Errorf("expected status code %d, got %d", http.StatusOK, rr.Code)