diff --git a/github/provider.go b/github/provider.go index 2d18019542..2d5d6dc33a 100644 --- a/github/provider.go +++ b/github/provider.go @@ -497,6 +497,19 @@ func providerConfigure(p *schema.Provider) schema.ConfigureContextFunc { } } +// ghCLIHostFromAPIHost maps an API hostname to the corresponding +// gh-CLI --hostname value. For example api.github.com -> github.com +// and api..ghe.com -> .ghe.com. +// for unrecognized hostnames, input is returned unmodified. +func ghCLIHostFromAPIHost(host string) string { + if host == DotComAPIHost { + return DotComHost + } else if GHECAPIHostMatch.MatchString(host) { + return strings.TrimPrefix(host, "api.") + } + return host +} + // See https://github.com/integrations/terraform-provider-github/issues/1822 func tokenFromGHCLI(u *url.URL) string { ghCliPath := os.Getenv("GH_PATH") @@ -504,10 +517,7 @@ func tokenFromGHCLI(u *url.URL) string { ghCliPath = "gh" } - host := u.Host - if host == DotComAPIHost { - host = DotComHost - } + host := ghCLIHostFromAPIHost(u.Host) out, err := exec.Command(ghCliPath, "auth", "token", "--hostname", host).Output() if err != nil { diff --git a/github/provider_test.go b/github/provider_test.go index 4f022ce4d4..bb9ab01400 100644 --- a/github/provider_test.go +++ b/github/provider_test.go @@ -259,3 +259,46 @@ data "github_ip_ranges" "test" {} }) }) } + +func Test_ghCLIHostFromAPIHost(t *testing.T) { + testCases := []struct { + name string + host string + expectedHost string + }{ + { + name: "dotcom API host is mapped to dotcom host", + host: "api.github.com", + expectedHost: "github.com", + }, + { + name: "ghec API host has api. prefix stripped", + host: "api.my-enterprise.ghe.com", + expectedHost: "my-enterprise.ghe.com", + }, + { + name: "ghec API host with numbers has api. prefix stripped", + host: "api.customer-123.ghe.com", + expectedHost: "customer-123.ghe.com", + }, + { + name: "ghes host is passed through unchanged", + host: "github.example.com", + expectedHost: "github.example.com", + }, + { + name: "ghes host with port is passed through unchanged", + host: "github.example.com:8443", + expectedHost: "github.example.com:8443", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := ghCLIHostFromAPIHost(tc.host) + if got != tc.expectedHost { + t.Errorf("ghCLIHostFromAPIHost(%q) = %q, want %q", tc.host, got, tc.expectedHost) + } + }) + } +}