Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/stackit_auth_login.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ stackit auth login [flags]
### Options

```
-h, --help Help for "stackit auth login"
-h, --help Help for "stackit auth login"
--port int The port on which the callback server will listen to. By default, it tries to bind a port between 8000 and 8020.
When a value is specified, it will only try to use the specified port. Valid values are within the range of 8000 to 8020.
```

### Options inherited from parent commands
Expand Down
48 changes: 44 additions & 4 deletions internal/cmd/auth/login/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,24 @@ package login
import (
"fmt"

"github.com/stackitcloud/stackit-cli/internal/pkg/types"

"github.com/stackitcloud/stackit-cli/internal/pkg/args"
"github.com/stackitcloud/stackit-cli/internal/pkg/auth"
"github.com/stackitcloud/stackit-cli/internal/pkg/examples"
"github.com/stackitcloud/stackit-cli/internal/pkg/flags"
"github.com/stackitcloud/stackit-cli/internal/pkg/print"
"github.com/stackitcloud/stackit-cli/internal/pkg/types"

"github.com/spf13/cobra"
)

const (
portFlag = "port"
)

type inputModel struct {
Port *int
}

func NewCmd(params *types.CmdParams) *cobra.Command {
cmd := &cobra.Command{
Use: "login",
Expand All @@ -25,8 +34,16 @@ func NewCmd(params *types.CmdParams) *cobra.Command {
`Login to the STACKIT CLI. This command will open a browser window where you can login to your STACKIT account`,
"$ stackit auth login"),
),
RunE: func(_ *cobra.Command, _ []string) error {
err := auth.AuthorizeUser(params.Printer, false)
RunE: func(cmd *cobra.Command, args []string) error {
model, err := parseInput(params.Printer, cmd, args)
if err != nil {
return err
}

err = auth.AuthorizeUser(params.Printer, auth.UserAuthConfig{
IsReauthentication: false,
Port: model.Port,
})
if err != nil {
return fmt.Errorf("authorization failed: %w", err)
}
Expand All @@ -36,5 +53,28 @@ func NewCmd(params *types.CmdParams) *cobra.Command {
return nil
},
}
configureFlags(cmd)
return cmd
}

func configureFlags(cmd *cobra.Command) {
cmd.Flags().Int(portFlag, 0,
"The port on which the callback server will listen to. By default, it tries to bind a port between 8000 and 8020.\n"+
"When a value is specified, it will only try to use the specified port. Valid values are within the range of 8000 to 8020.",
)
}

func parseInput(p *print.Printer, cmd *cobra.Command, _ []string) (*inputModel, error) {
port := flags.FlagToIntPointer(p, cmd, portFlag)
// For the CLI client only callback URLs with localhost:[8000-8020] are valid. Additional callbacks must be enabled in the backend.
if port != nil && (*port < 8000 || *port > 8020) {
Comment thread
marceljk marked this conversation as resolved.
Outdated
return nil, fmt.Errorf("port must be between 8000 and 8020")
}

model := inputModel{
Port: port,
}

p.DebugInputModel(model)
return &model, nil
}
93 changes: 93 additions & 0 deletions internal/cmd/auth/login/login_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package login

import (
"testing"

"github.com/stackitcloud/stackit-cli/internal/pkg/testutils"
"github.com/stackitcloud/stackit-cli/internal/pkg/utils"
)

func fixtureFlagValues(mods ...func(flagValues map[string]string)) map[string]string {
flagValues := map[string]string{
portFlag: "8010",
}
for _, mod := range mods {
mod(flagValues)
}
return flagValues
}

func fixtureInputModel(mods ...func(model *inputModel)) *inputModel {
model := &inputModel{
Port: utils.Ptr(8010),
}
for _, mod := range mods {
mod(model)
}
return model
}

func TestParseInput(t *testing.T) {
tests := []struct {
description string
flagValues map[string]string
argValues []string
isValid bool
expectedModel *inputModel
}{
{
description: "base",
flagValues: fixtureFlagValues(),
isValid: true,
expectedModel: fixtureInputModel(),
},
{
description: "no values",
flagValues: map[string]string{},
isValid: true,
expectedModel: &inputModel{
Port: nil,
},
},
{
description: "lower limit",
flagValues: map[string]string{
portFlag: "8000",
},
isValid: true,
expectedModel: &inputModel{
Port: utils.Ptr(8000),
},
},
{
description: "below lower limit is not valid ",
flagValues: map[string]string{
portFlag: "7999",
},
isValid: false,
},
{
description: "upper limit",
flagValues: map[string]string{
portFlag: "8020",
},
isValid: true,
expectedModel: &inputModel{
Port: utils.Ptr(8020),
},
},
{
description: "above upper limit is not valid ",
flagValues: map[string]string{
portFlag: "8021",
},
isValid: false,
},
}

for _, tt := range tests {
t.Run(tt.description, func(t *testing.T) {
testutils.TestParseInput(t, NewCmd, parseInput, tt.expectedModel, tt.argValues, tt.flagValues, tt.isValid)
})
}
}
7 changes: 5 additions & 2 deletions internal/pkg/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type tokenClaims struct {
//
// If the user was logged in and the user session expired, reauthorizeUserRoutine is called to reauthenticate the user again.
// If the environment variable STACKIT_ACCESS_TOKEN is set this token is used instead.
func AuthenticationConfig(p *print.Printer, reauthorizeUserRoutine func(p *print.Printer, _ bool) error) (authCfgOption sdkConfig.ConfigurationOption, err error) {
func AuthenticationConfig(p *print.Printer, reauthorizeUserRoutine func(p *print.Printer, _ UserAuthConfig) error) (authCfgOption sdkConfig.ConfigurationOption, err error) {
// Get access token from env and use this if present
accessToken := os.Getenv(envAccessTokenName)
if accessToken != "" {
Expand Down Expand Up @@ -70,7 +70,10 @@ func AuthenticationConfig(p *print.Printer, reauthorizeUserRoutine func(p *print
case AUTH_FLOW_USER_TOKEN:
p.Debug(print.DebugLevel, "authenticating using user token")
if userSessionExpired {
err = reauthorizeUserRoutine(p, true)
err = reauthorizeUserRoutine(p, UserAuthConfig{
IsReauthentication: true,
Port: nil,
})
if err != nil {
return nil, fmt.Errorf("user login: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/pkg/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ func TestAuthenticationConfig(t *testing.T) {
}

reauthorizeUserCalled := false
reauthenticateUser := func(_ *print.Printer, _ bool) error {
reauthenticateUser := func(_ *print.Printer, _ UserAuthConfig) error {
if reauthorizeUserCalled {
t.Errorf("user reauthorized more than once")
}
Expand Down
23 changes: 18 additions & 5 deletions internal/pkg/auth/user_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,19 @@ type InputValues struct {
Logo string
}

type UserAuthConfig struct {
// IsReauthentication defines if an expired user session should be renewed
IsReauthentication bool
// Port defines which port should be used for the UserAuthFlow callback
Port *int
}

type apiClient interface {
Do(req *http.Request) (*http.Response, error)
}

// AuthorizeUser implements the PKCE OAuth2 flow.
func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
func AuthorizeUser(p *print.Printer, authConfig UserAuthConfig) error {
idpWellKnownConfig, err := retrieveIDPWellKnownConfig(p)
if err != nil {
return err
Expand All @@ -68,7 +75,7 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
}
}

if isReauthentication {
if authConfig.IsReauthentication {
err := p.PromptForEnter("Your session has expired, press Enter to login again...")
if err != nil {
return err
Expand All @@ -79,8 +86,14 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
var listener net.Listener
var listenerErr error
var port int
for i := range configuredPortRange {
port = defaultPort + i
startingPort := defaultPort
portRange := configuredPortRange
if authConfig.Port != nil {
startingPort = *authConfig.Port
portRange = 1
}
for i := range portRange {
port = startingPort + i
portString := fmt.Sprintf(":%s", strconv.Itoa(port))
p.Debug(print.DebugLevel, "trying to bind port %d for login redirect", port)
listener, listenerErr = net.Listen("tcp", portString)
Expand All @@ -92,7 +105,7 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
p.Debug(print.DebugLevel, "unable to bind port %d for login redirect: %s", port, listenerErr)
}
if listenerErr != nil {
return fmt.Errorf("unable to bind port for login redirect, tried from port %d to %d: %w", defaultPort, port, err)
return fmt.Errorf("unable to bind port for login redirect, tried from port %d to %d: %w", defaultPort, port, listenerErr)
}

conf := &oauth2.Config{
Expand Down
9 changes: 7 additions & 2 deletions internal/pkg/auth/user_token_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (

type userTokenFlow struct {
printer *print.Printer
reauthorizeUserRoutine func(p *print.Printer, isReauthentication bool) error // Called if the user needs to login again
reauthorizeUserRoutine func(p *print.Printer, isReauthentication UserAuthConfig) error // Called if the user needs to login again
client *http.Client
authFlow AuthFlow
accessToken string
Expand Down Expand Up @@ -95,7 +95,12 @@ func loadVarsFromStorage(utf *userTokenFlow) error {
}

func reauthenticateUser(utf *userTokenFlow) error {
err := utf.reauthorizeUserRoutine(utf.printer, true)
err := utf.reauthorizeUserRoutine(
utf.printer,
UserAuthConfig{
IsReauthentication: true,
},
)
if err != nil {
return fmt.Errorf("authenticate user: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/pkg/auth/user_token_flow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ func TestRoundTrip(t *testing.T) {
authorizeUserCalled: &authorizeUserCalled,
tokensRefreshed: &tokensRefreshed,
}
authorizeUserRoutine := func(_ *print.Printer, _ bool) error {
authorizeUserRoutine := func(_ *print.Printer, _ UserAuthConfig) error {
return reauthorizeUser(authorizeUserContext)
}

Expand Down
14 changes: 14 additions & 0 deletions internal/pkg/flags/flag_to_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,20 @@ func FlagToStringToStringPointer(p *print.Printer, cmd *cobra.Command, flag stri
return nil
}

// Returns a pointer to the flag's value.
// Returns nil if the flag is not set, if its value can not be converted to int, or if the flag does not exist.
func FlagToIntPointer(p *print.Printer, cmd *cobra.Command, flag string) *int {
value, err := cmd.Flags().GetInt(flag)
if err != nil {
p.Debug(print.ErrorLevel, "convert flag to Uint64 pointer: %v", err)
return nil
}
if cmd.Flag(flag).Changed {
return &value
}
return nil
}

// Returns a pointer to the flag's value.
// Returns nil if the flag is not set, if its value can not be converted to int64, or if the flag does not exist.
func FlagToInt64Pointer(p *print.Printer, cmd *cobra.Command, flag string) *int64 {
Expand Down
Loading