diff --git a/pkg/internal/utils/basic_test.go b/pkg/internal/utils/basic_test.go new file mode 100644 index 0000000..1dc8f5e --- /dev/null +++ b/pkg/internal/utils/basic_test.go @@ -0,0 +1,183 @@ +/* +Copyright 2025 The OpenCIDN Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package utils + +import ( + "net/url" + "testing" +) + +func TestParseBasicAuth(t *testing.T) { + tests := []struct { + name string + auth string + wantUser string + wantPass string + wantNil bool + }{ + { + name: "valid auth with password", + auth: "Basic dXNlcjpwYXNzd29yZA==", // user:password + wantUser: "user", + wantPass: "password", + wantNil: false, + }, + { + name: "valid auth without password", + auth: "Basic dXNlcg==", // user + wantUser: "user", + wantPass: "", + wantNil: false, + }, + { + name: "invalid prefix", + auth: "Bearer token123", + wantNil: true, + }, + { + name: "missing prefix", + auth: "dXNlcjpwYXNzd29yZA==", + wantNil: true, + }, + { + name: "invalid base64", + auth: "Basic invalid!!!", + wantNil: true, + }, + { + name: "empty string", + auth: "", + wantNil: true, + }, + { + name: "case insensitive prefix", + auth: "basic dXNlcjpwYXNzd29yZA==", // lowercase basic + wantUser: "user", + wantPass: "password", + wantNil: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ParseBasicAuth(tt.auth) + if tt.wantNil { + if got != nil { + t.Errorf("ParseBasicAuth() = %v, want nil", got) + } + return + } + if got == nil { + t.Errorf("ParseBasicAuth() = nil, want non-nil") + return + } + if got.Username() != tt.wantUser { + t.Errorf("ParseBasicAuth() username = %v, want %v", got.Username(), tt.wantUser) + } + gotPass, _ := got.Password() + if gotPass != tt.wantPass { + t.Errorf("ParseBasicAuth() password = %v, want %v", gotPass, tt.wantPass) + } + }) + } +} + +func TestFormathBasicAuth(t *testing.T) { + tests := []struct { + name string + ui *url.Userinfo + want string + }{ + { + name: "with password", + ui: url.UserPassword("user", "password"), + want: "Basic dXNlcjpwYXNzd29yZA==", + }, + { + name: "without password", + ui: url.User("user"), + want: "Basic dXNlcg==", + }, + { + name: "nil userinfo", + ui: nil, + want: "", + }, + { + name: "special characters in username", + ui: url.UserPassword("admin@domain.com", "pass123"), + want: "Basic YWRtaW4lNDBkb21haW4uY29tOnBhc3MxMjM=", // @ is URL-encoded to %40 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := FormathBasicAuth(tt.ui); got != tt.want { + t.Errorf("FormathBasicAuth() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestParseAndFormatBasicAuth_RoundTrip(t *testing.T) { + tests := []struct { + name string + username string + password string + }{ + { + name: "simple credentials", + username: "user", + password: "password", + }, + { + name: "username with numbers", + username: "admin123", + password: "secure123", + }, + { + name: "alphanumeric only", + username: "testuser", + password: "Passw0rd123", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Format credentials + ui := url.UserPassword(tt.username, tt.password) + formatted := FormathBasicAuth(ui) + + // Parse them back + parsed := ParseBasicAuth(formatted) + if parsed == nil { + t.Fatal("ParseBasicAuth() returned nil") + } + + // For alphanumeric credentials without special chars, + // the round trip should preserve the values exactly + if parsed.Username() != tt.username { + t.Errorf("Round trip username = %v, want %v", parsed.Username(), tt.username) + } + + parsedPass, _ := parsed.Password() + if parsedPass != tt.password { + t.Errorf("Round trip password = %v, want %v", parsedPass, tt.password) + } + }) + } +} diff --git a/pkg/internal/utils/identity_test.go b/pkg/internal/utils/identity_test.go new file mode 100644 index 0000000..551dc3c --- /dev/null +++ b/pkg/internal/utils/identity_test.go @@ -0,0 +1,106 @@ +/* +Copyright 2025 The OpenCIDN Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package utils + +import ( + "strings" + "testing" +) + +func TestIdentity(t *testing.T) { + // Test that Identity() returns a non-empty string + id, err := Identity() + if err != nil { + t.Fatalf("Identity() error = %v, want nil", err) + } + if id == "" { + t.Error("Identity() returned empty string") + } + + // Test format: should be "hexstring-timestamp" + parts := strings.Split(id, "-") + if len(parts) != 2 { + t.Errorf("Identity() format = %v, want 'hex-timestamp' with 2 parts", id) + } + + // Test that hex part is 16 characters + hexPart := parts[0] + if len(hexPart) != 16 { + t.Errorf("Identity() hex part length = %d, want 16", len(hexPart)) + } + + // Test that hex part is valid hex + for _, c := range hexPart { + if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) { + t.Errorf("Identity() hex part contains invalid character: %c", c) + } + } + + // Test that timestamp part is numeric + timestampPart := parts[1] + for _, c := range timestampPart { + if c < '0' || c > '9' { + t.Errorf("Identity() timestamp part contains non-digit: %c", c) + } + } +} + +func TestIdentity_Multiple(t *testing.T) { + // Test that calling Identity() multiple times returns different values + // (due to different timestamps) + id1, err := Identity() + if err != nil { + t.Fatalf("Identity() first call error = %v, want nil", err) + } + + id2, err := Identity() + if err != nil { + t.Fatalf("Identity() second call error = %v, want nil", err) + } + + // The IDs should be different because timestamps are different + // Note: In very rare cases on very fast systems, timestamps might be the same + // but that's acceptable for this test + if id1 == id2 { + t.Logf("Identity() returned same value twice: %v (acceptable if timestamps are equal)", id1) + } + + // The hex parts should be the same (same hostname) + parts1 := strings.Split(id1, "-") + parts2 := strings.Split(id2, "-") + if parts1[0] != parts2[0] { + t.Errorf("Identity() hex parts differ: %v vs %v (should be same for same hostname)", parts1[0], parts2[0]) + } +} + +func TestIdentity_Format(t *testing.T) { + id, err := Identity() + if err != nil { + t.Fatalf("Identity() error = %v, want nil", err) + } + + // Verify the identity matches expected pattern + if !strings.Contains(id, "-") { + t.Errorf("Identity() = %v, want to contain '-'", id) + } + + // Verify overall length is reasonable + // 16 hex chars + 1 dash + at least 10 digit timestamp = at least 27 chars + if len(id) < 27 { + t.Errorf("Identity() length = %d, want at least 27", len(id)) + } +} diff --git a/pkg/internal/utils/network_test.go b/pkg/internal/utils/network_test.go new file mode 100644 index 0000000..078ed88 --- /dev/null +++ b/pkg/internal/utils/network_test.go @@ -0,0 +1,291 @@ +/* +Copyright 2025 The OpenCIDN Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package utils + +import ( + "errors" + "io" + "net" + "net/http" + "testing" + "time" +) + +func TestIsNetWorkError(t *testing.T) { + tests := []struct { + name string + err error + wantIsNet bool + }{ + { + name: "nil error", + err: nil, + wantIsNet: false, + }, + { + name: "io.EOF", + err: io.EOF, + wantIsNet: true, + }, + { + name: "io.ErrUnexpectedEOF", + err: io.ErrUnexpectedEOF, + wantIsNet: true, + }, + { + name: "network timeout error", + err: &net.DNSError{IsTimeout: true}, + wantIsNet: true, + }, + { + name: "network temporary error", + err: &net.DNSError{IsTemporary: true}, + wantIsNet: true, + }, + { + name: "generic error", + err: errors.New("some error"), + wantIsNet: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotIsNet, gotErr := IsNetWorkError(tt.err) + if gotIsNet != tt.wantIsNet { + t.Errorf("IsNetWorkError() isNet = %v, want %v", gotIsNet, tt.wantIsNet) + } + if tt.err == nil { + if gotErr != nil { + t.Errorf("IsNetWorkError() error = %v, want nil", gotErr) + } + } else { + if gotErr == nil { + t.Errorf("IsNetWorkError() error = nil, want non-nil") + } + } + }) + } +} + +func TestIsHTTPResponseError(t *testing.T) { + tests := []struct { + name string + resp *http.Response + err error + wantIsErr bool + wantErrNil bool + }{ + { + name: "successful response", + resp: &http.Response{ + StatusCode: http.StatusOK, + }, + err: nil, + wantIsErr: false, + wantErrNil: true, + }, + { + name: "server error", + resp: &http.Response{ + StatusCode: http.StatusInternalServerError, + }, + err: nil, + wantIsErr: true, + wantErrNil: false, + }, + { + name: "bad gateway", + resp: &http.Response{ + StatusCode: http.StatusBadGateway, + }, + err: nil, + wantIsErr: true, + wantErrNil: false, + }, + { + name: "service unavailable", + resp: &http.Response{ + StatusCode: http.StatusServiceUnavailable, + }, + err: nil, + wantIsErr: true, + wantErrNil: false, + }, + { + name: "too many requests", + resp: &http.Response{ + StatusCode: http.StatusTooManyRequests, + }, + err: nil, + wantIsErr: true, + wantErrNil: false, + }, + { + name: "client error - not retryable", + resp: &http.Response{ + StatusCode: http.StatusBadRequest, + }, + err: nil, + wantIsErr: false, + wantErrNil: true, + }, + { + name: "not found - not retryable", + resp: &http.Response{ + StatusCode: http.StatusNotFound, + }, + err: nil, + wantIsErr: false, + wantErrNil: true, + }, + { + name: "network error with nil response", + resp: nil, + err: io.EOF, + wantIsErr: true, + wantErrNil: false, + }, + { + name: "nil response and nil error", + resp: nil, + err: nil, + wantIsErr: false, + wantErrNil: false, + }, + { + name: "generic error", + resp: nil, + err: errors.New("some error"), + wantIsErr: false, + wantErrNil: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotIsErr, gotErr := IsHTTPResponseError(tt.resp, tt.err) + if gotIsErr != tt.wantIsErr { + t.Errorf("IsHTTPResponseError() isErr = %v, want %v", gotIsErr, tt.wantIsErr) + } + if tt.wantErrNil { + if gotErr != nil { + t.Errorf("IsHTTPResponseError() error = %v, want nil", gotErr) + } + } else { + if gotErr == nil { + t.Errorf("IsHTTPResponseError() error = nil, want non-nil") + } + } + }) + } +} + +// mockNetError implements net.Error interface for testing +type mockNetError struct { + timeout bool + temporary bool + msg string +} + +func (m *mockNetError) Error() string { + return m.msg +} + +func (m *mockNetError) Timeout() bool { + return m.timeout +} + +func (m *mockNetError) Temporary() bool { + return m.temporary +} + +func TestIsNetWorkError_WithMockNetError(t *testing.T) { + tests := []struct { + name string + err error + wantIsNet bool + }{ + { + name: "timeout error", + err: &mockNetError{ + timeout: true, + msg: "timeout", + }, + wantIsNet: true, + }, + { + name: "temporary error", + err: &mockNetError{ + temporary: true, + msg: "temporary", + }, + wantIsNet: true, + }, + { + name: "wrapped EOF", + err: &wrappedError{ + msg: "connection failed", + err: io.EOF, + }, + wantIsNet: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotIsNet, gotErr := IsNetWorkError(tt.err) + if gotIsNet != tt.wantIsNet { + t.Errorf("IsNetWorkError() isNet = %v, want %v", gotIsNet, tt.wantIsNet) + } + if gotErr == nil { + t.Errorf("IsNetWorkError() error = nil, want non-nil") + } + }) + } +} + +// wrappedError wraps an error for testing +type wrappedError struct { + msg string + err error +} + +func (w *wrappedError) Error() string { + return w.msg +} + +func (w *wrappedError) Unwrap() error { + return w.err +} + +// Test real network timeout scenario +func TestIsNetWorkError_RealNetworkTimeout(t *testing.T) { + // Create a real network timeout error + _, err := net.DialTimeout("tcp", "192.0.2.1:80", 1*time.Nanosecond) + if err == nil { + t.Skip("Expected timeout error, got nil") + } + + isNet, gotErr := IsNetWorkError(err) + if !isNet { + t.Errorf("IsNetWorkError() isNet = false, want true for real network timeout") + } + if gotErr == nil { + t.Errorf("IsNetWorkError() error = nil, want non-nil") + } +}