Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
26 changes: 16 additions & 10 deletions cmd/auth-provider-gcp/app/getcredentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ const (

// CredentialOptions contains a representation of the options passed to the credential provider.
type CredentialOptions struct {
AuthFlow string
AuthFlow string
WIFProjectNumber string
WIFPoolID string
WIFProviderID string
}

// AuthFlowFlagError represents an error that occurred during flag validation.
Expand Down Expand Up @@ -82,7 +85,7 @@ func NewGetCredentialsCommand() (*cobra.Command, error) {
Use: "get-credentials",
Short: "Get authentication credentials",
RunE: func(cmd *cobra.Command, args []string) error {
return getCredentials(options.AuthFlow)
return getCredentials(options)
},
}
defineFlags(cmd, &options)
Expand All @@ -92,23 +95,23 @@ func NewGetCredentialsCommand() (*cobra.Command, error) {
return cmd, nil
}

func providerFromFlow(flow string) (credentialconfig.DockerConfigProvider, error) {
func makeProvider(options CredentialOptions) (credentialconfig.DockerConfigProvider, error) {
transport := utilnet.SetTransportDefaults(&http.Transport{})
switch flow {
switch options.AuthFlow {
case gcrAuthFlow:
return provider.MakeRegistryProvider(transport), nil
return provider.MakeRegistryProvider(transport, options.WIFProjectNumber, options.WIFPoolID, options.WIFProviderID)
case dockerConfigAuthFlow:
return provider.MakeDockerConfigProvider(transport), nil
case dockerConfigURLAuthFlow:
return provider.MakeDockerConfigURLProvider(transport), nil
default:
return nil, &AuthFlowTypeError{requestedFlow: flow}
return nil, &AuthFlowTypeError{requestedFlow: options.AuthFlow}
}
}

func getCredentials(authFlow string) error {
klog.V(2).Infof("get-credentials (authFlow %s)", authFlow)
authProvider, err := providerFromFlow(authFlow)
func getCredentials(options CredentialOptions) error {
klog.V(2).Infof("get-credentials (authFlow %s)", options.AuthFlow)
authProvider, err := makeProvider(options)
if err != nil {
return err
}
Expand All @@ -121,7 +124,7 @@ func getCredentials(authFlow string) error {
if err != nil {
return fmt.Errorf("error unmarshaling auth credential request: %w", err)
}
authCredentials, err := provider.GetResponse(authRequest.Image, authProvider)
authCredentials, err := provider.GetResponse(authRequest, authProvider)
if err != nil {
return fmt.Errorf("error getting authentication response from provider: %w", err)
}
Expand All @@ -137,6 +140,9 @@ func getCredentials(authFlow string) error {

func defineFlags(credCmd *cobra.Command, options *CredentialOptions) {
credCmd.Flags().StringVarP(&options.AuthFlow, "authFlow", "a", gcrAuthFlow, fmt.Sprintf("authentication flow (valid values are %q, %q, and %q)", gcrAuthFlow, dockerConfigAuthFlow, dockerConfigURLAuthFlow))
credCmd.Flags().StringVarP(&options.WIFProjectNumber, "gcpWIFProjectNumber", "", "", fmt.Sprintf("Number of GCP project used for Workload Identity Federation (required when using Service Account Token Integration for image pulls)."))
credCmd.Flags().StringVarP(&options.WIFPoolID, "gcpWIFPoolID", "", "", fmt.Sprintf("ID of the Workload Identity Pool used for Workload Identity Federation (required when using Service Account Token Integration for image pulls)."))
credCmd.Flags().StringVarP(&options.WIFProviderID, "gcpWIFProviderID", "", "", fmt.Sprintf("ID of the Workload Identity Provider used for Workload Identity Federation (required when using Service Account Token Integration for image pulls)."))
}

func validateFlags(options *CredentialOptions) error {
Expand Down
2 changes: 2 additions & 0 deletions cmd/auth-provider-gcp/app/getcredentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ func TestValidateAuthFlow(t *testing.T) {
{Name: "validate gcr auth flow", Flow: gcrAuthFlow},
{Name: "validate docker-cfg auth flow option", Flow: dockerConfigAuthFlow},
{Name: "validate docker-cfg-url auth flow option", Flow: dockerConfigURLAuthFlow},
{Name: "validate k8s-sa-wif auth flow option", Flow: k8sSAWIFAuthFlow},
{Name: "bad auth flow option", Flow: "bad-flow", Error: &AuthFlowFlagError{flagValue: "bad-flow"}},
{Name: "empty auth flow option", Flow: "", Error: &AuthFlowFlagError{flagValue: ""}},
{Name: "case-sensitive auth flow", Flow: "Gcrauthflow", Error: &AuthFlowFlagError{flagValue: "Gcrauthflow"}},
Expand Down Expand Up @@ -67,6 +68,7 @@ func TestProviderFromFlow(t *testing.T) {
{Name: "gcr auth provider selection", Flow: gcrAuthFlow, Type: "ContainerRegistryProvider"},
{Name: "docker-cfg auth provider selection", Flow: dockerConfigAuthFlow, Type: "DockerConfigKeyProvider"},
{Name: "docker-cfg-url auth provider selection", Flow: dockerConfigURLAuthFlow, Type: "DockerConfigURLKeyProvider"},
{Name: "k8s-sa-wif auth provider selection", Flow: k8sSAWIFAuthFlow, Type: "K8sSAWIFProvider"},
{Name: "non-existent auth provider request", Flow: "bad-flow", Type: "", Error: &AuthFlowTypeError{requestedFlow: "bad-flow"}},
{Name: "empty auth provider request", Flow: "", Type: "", Error: &AuthFlowTypeError{requestedFlow: ""}},
}
Expand Down
23 changes: 19 additions & 4 deletions cmd/auth-provider-gcp/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@ limitations under the License.
package provider

import (
"context"
"fmt"
"net/http"
"os"
"time"

"google.golang.org/api/option"
"google.golang.org/api/sts/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/cloud-provider-gcp/pkg/credentialconfig"
"k8s.io/cloud-provider-gcp/pkg/gcpcredential"
Expand All @@ -40,13 +43,25 @@ const (
)

// MakeRegistryProvider returns a ContainerRegistryProvider with the given transport.
func MakeRegistryProvider(transport *http.Transport) *gcpcredential.ContainerRegistryProvider {
func MakeRegistryProvider(transport *http.Transport, projectNumber, poolID, providerID string) (*gcpcredential.ContainerRegistryProvider, error) {
httpClient := makeHTTPClient(transport)
provider := &gcpcredential.ContainerRegistryProvider{
MetadataProvider: gcpcredential.MetadataProvider{Client: httpClient},
UseRegistryFromImage: true,
}
return provider
if projectNumber != "" && poolID != "" && providerID != "" {
stsService, err := sts.NewService(context.Background(), option.WithHTTPClient(httpClient))
if err != nil {
return nil, fmt.Errorf("failed to create sts service: %w", err)
}
provider.WIFProvider = gcpcredential.WIFProvider{
StsService: stsService,
ProjectNumber: projectNumber,
PoolId: poolID,
ProviderId: providerID,
}
}
return provider, nil
}

// MakeDockerConfigProvider returns a DockerConfigKeyProvider with the given transport.
Expand Down Expand Up @@ -109,8 +124,8 @@ func getCacheKeyType() (credentialproviderapi.PluginCacheKeyType, error) {
}

// GetResponse queries the given provider for credentials.
func GetResponse(image string, provider credentialconfig.DockerConfigProvider) (*credentialproviderapi.CredentialProviderResponse, error) {
cfg := provider.Provide(image)
func GetResponse(authRequest credentialproviderapi.CredentialProviderRequest, provider credentialconfig.DockerConfigProvider) (*credentialproviderapi.CredentialProviderResponse, error) {
cfg := provider.Provide(authRequest)
response := &credentialproviderapi.CredentialProviderResponse{Auth: make(map[string]credentialproviderapi.AuthConfig)}
for url, dockerConfig := range cfg {
response.Auth[url] = credentialproviderapi.AuthConfig{Username: dockerConfig.Username, Password: dockerConfig.Password}
Expand Down
117 changes: 114 additions & 3 deletions cmd/auth-provider-gcp/provider/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@ limitations under the License.
package provider

import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"testing"

Expand Down Expand Up @@ -82,7 +85,7 @@ func TestContainerRegistry(t *testing.T) {
},
})
provider := MakeRegistryProvider(transport)
response, err := GetResponse(dummyImage, provider)
response, err := GetResponse(credentialproviderapi.CredentialProviderRequest{Image: dummyImage}, provider)
if err != nil {
t.Fatalf("Unexpected error while getting response: %s", err.Error())
}
Expand Down Expand Up @@ -144,7 +147,7 @@ func TestConfigProvider(t *testing.T) {
},
})
provider := MakeDockerConfigProvider(transport)
response, err := GetResponse(dummyImage, provider)
response, err := GetResponse(credentialproviderapi.CredentialProviderRequest{Image: dummyImage}, provider)
if err != nil {
t.Fatalf("Unexpected error while getting response: %s", err.Error())
}
Expand Down Expand Up @@ -201,7 +204,7 @@ func TestConfigURLProvider(t *testing.T) {
})

provider := MakeDockerConfigURLProvider(transport)
response, err := GetResponse(dummyImage, provider)
response, err := GetResponse(credentialproviderapi.CredentialProviderRequest{Image: dummyImage}, provider)
if err != nil {
t.Fatalf("Unexpected error while getting response: %s", err.Error())
}
Expand All @@ -217,3 +220,111 @@ func TestConfigURLProvider(t *testing.T) {
}
}
}

func TestK8sSAWIFProvider(t *testing.T) {
registryURL := strings.Split(dummyImage, "/")[0]
gcpRegistryURL := "container.cloud.google.com"

projectNum := "123456789"
poolId := "test-pool"
providerId := "test-provider"

os.Setenv("GCP_WIF_PROJECT_NUMBER", projectNum)
os.Setenv("GCP_WIF_POOL_ID", poolId)
os.Setenv("GCP_WIF_PROVIDER_ID", providerId)
defer os.Unsetenv("GCP_WIF_PROJECT_NUMBER")
defer os.Unsetenv("GCP_WIF_POOL_ID")
defer os.Unsetenv("GCP_WIF_PROVIDER_ID")

server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/v1/token") {
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
tokenResponse := map[string]interface{}{
"access_token": dummyToken,
"expires_in": 3600,
"issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
}
resp, _ := json.Marshal(tokenResponse)
if _, err := w.Write(resp); err != nil {
t.Fatalf("write token response: %v", err)
}
} else {
w.WriteHeader(http.StatusNotFound)
}
}))
defer server.Close()

serverURL, err := url.Parse(server.URL)
if err != nil {
t.Fatal(err)
}

transport := server.Client().Transport.(*http.Transport).Clone()
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
if strings.HasPrefix(addr, "sts.googleapis.com:") {
var d net.Dialer
return d.DialContext(ctx, network, serverURL.Host)
}
var d net.Dialer
return d.DialContext(ctx, network, addr)
}
transport.TLSClientConfig.ServerName = "127.0.0.1"

provider, err := MakeK8sSAWIFProvider(transport)
if err != nil {
t.Fatalf("Unexpected error while creating provider: %v", err)
}

if provider == nil {
t.Fatalf("Expected K8sSAWIFProvider but got nil")
}

if provider.WIFConfig.ProjectNumber != projectNum {
t.Errorf("Expected project number %s, got %s", projectNum, provider.WIFConfig.ProjectNumber)
}
if provider.WIFConfig.PoolId != poolId {
t.Errorf("Expected pool ID %s, got %s", poolId, provider.WIFConfig.PoolId)
}
if provider.WIFConfig.ProviderId != providerId {
t.Errorf("Expected provider ID %s, got %s", providerId, provider.WIFConfig.ProviderId)
}
if !provider.UseRegistryFromImage {
t.Errorf("Expected UseRegistryFromImage to be true")
}
if provider.StsService == nil {
t.Errorf("Expected StsService to be configured")
}

response, err := GetResponse(credentialproviderapi.CredentialProviderRequest{Image: dummyImage}, provider)
if err != nil {
t.Fatalf("Unexpected error while getting response: %v", err)
}

if !hasURL(registryURL, response) || !hasURL(gcpRegistryURL, response) {
if !hasURL(registryURL, response) {
t.Errorf("URL %s expected in response, not found (response: %s)", registryURL, response.Auth)
}
if !hasURL(gcpRegistryURL, response) {
t.Errorf("URL %s expected in response, not found (response: %s)", gcpRegistryURL, response.Auth)
}
}

if apiKind != response.TypeMeta.Kind {
t.Errorf("Expected Kind %s, got %s", apiKind, response.TypeMeta.Kind)
}
if apiVersion != response.TypeMeta.APIVersion {
t.Errorf("Expected APIVersion %s, got %s", apiVersion, response.TypeMeta.APIVersion)
}
if expectedCacheKey != response.CacheKeyType {
t.Errorf("Expected %s as cache key (found %s instead)", expectedCacheKey, response.CacheKeyType)
}
for _, auth := range response.Auth {
if expectedUsername != auth.Username {
t.Errorf("Expected username %s not found (username: %s)", expectedUsername, auth.Username)
}
if dummyToken != auth.Password {
t.Errorf("Expected password %s not found (password: %s)", dummyToken, auth.Password)
}
}
}
Loading