diff --git a/src/go/pt-mongodb-summary/README.rst b/src/go/pt-mongodb-summary/README.rst index 1bfc4041f..068e096a6 100644 --- a/src/go/pt-mongodb-summary/README.rst +++ b/src/go/pt-mongodb-summary/README.rst @@ -32,15 +32,33 @@ For better results, host must be a **mongos** server. Options ------- -``-a``, ``--auth-db`` +``-a``, ``--authenticationDatabase`` Specifies the database used to establish credentials and privileges with a MongoDB server. By default, the ``admin`` database is used. +``-c``, ``--no-version-check`` + Disables checking the version of MongoDB before running the report. + ``-f``, ``--output-format`` Specifies the report output format. Valid options are: ``text``, ``json``. The default value is ``text``. +``-h``, ``--help`` + Show help message and exit. + +``--host`` + Specifies the hostname or IP address of the MongoDB server to connect to. + +``-i``, ``--running-ops-interval`` + Interval in milliseconds to wait between samples of running operations. + Default: 1000 milliseconds. + +``-l``, ``--log-level`` + Specifies the logging level. Valid options: ``panic``, ``fatal``, ``error``, + ``warn``, ``info``, ``debug``. + Default: ``error``. + ``-p``, ``--password`` Specifies the password to use when connecting to a server with authentication enabled. @@ -48,12 +66,34 @@ Options Do not add a space between the option and its value: ``-p``. If you specify the option without any value, - ``pt-mongodb-summary`` will ask for password interactively. + ``pt-mongodb-summary`` will ask for the password interactively. + +``--port`` + Specifies the port of the MongoDB server to connect to. -``-u``, ``--user`` - Specifies the user name for connecting to a server +``-s``, ``--running-ops-samples`` + Number of samples to collect for running operations. + Default: 5. + +``--sslCAFile`` + Path to the SSL CA certificate file used for authentication. + +``--sslPEMKeyFile`` + Path to the SSL client PEM file used for authentication. + +``--uri`` + Full MongoDB URI describing hosts and options. + Command-line flags have higher priority than URI settings. + If a full URI is provided, you cannot also specify ``--host`` or ``--port``. + Example: ``mongodb://admin:secret@localhost:27017`` + +``-u``, ``--username`` + Specifies the username to use when connecting to a server with authentication enabled. +``-v``, ``--version`` + Show version information and exit. + Output example ============== diff --git a/src/go/pt-mongodb-summary/main.go b/src/go/pt-mongodb-summary/main.go index 66bff3343..648931819 100644 --- a/src/go/pt-mongodb-summary/main.go +++ b/src/go/pt-mongodb-summary/main.go @@ -51,7 +51,8 @@ const ( toolname = "pt-mongodb-summary" DefaultAuthDB = "admin" - DefaultHost = "mongodb://localhost:27017" + DefaultHost = "localhost" + DefaultPort = "27017" DefaultLogLevel = "warn" DefaultRunningOpsInterval = 1000 // milliseconds DefaultRunningOpsSamples = 5 @@ -161,6 +162,7 @@ type clusterwideInfo struct { type cliOptions struct { Host string + Port string User string Password string AuthDB string @@ -170,6 +172,7 @@ type cliOptions struct { SSLPEMKeyFile string RunningOpsSamples int RunningOpsInterval int + URI string Help bool Version bool NoVersionCheck bool @@ -252,13 +255,11 @@ func main() { log.Errorf("Cannot get hostnames: %s", err) } - log.Debugf("hostnames: %v", hostnames) - ci := &collectedInfo{} ci.HostInfo, err = getHostInfo(ctx, client) if err != nil { - log.Errorf("Cannot get host info for %q: %s", opts.Host, err) + log.Errorf("Cannot get host info for %q: %s", clientOptions.Hosts, err) os.Exit(cannotGetHostInfo) //nolint:gocritic } @@ -920,7 +921,6 @@ func externalIP() (string, error) { func parseFlags() (*cliOptions, error) { opts := &cliOptions{ - Host: DefaultHost, LogLevel: DefaultLogLevel, RunningOpsSamples: DefaultRunningOpsSamples, RunningOpsInterval: DefaultRunningOpsInterval, // milliseconds @@ -933,6 +933,10 @@ func parseFlags() (*cliOptions, error) { gop.BoolVarLong(&opts.Version, "version", 'v', "", "Show version & exit") gop.BoolVarLong(&opts.NoVersionCheck, "no-version-check", 'c', "", "Default: Don't check for updates") + gop.StringVarLong(&opts.URI, "uri", 0, `URI describes the hosts to be used and options. Flags has higher priority. If a full URI is provided, you cannot also specify "--host" or "--port".`) + gop.StringVarLong(&opts.Host, "host", 0, "Host") + gop.StringVarLong(&opts.Port, "port", 0, "Port") + gop.StringVarLong(&opts.User, "username", 'u', "", "Username to use for optional MongoDB authentication") gop.StringVarLong(&opts.Password, "password", 'p', "", "Password to use for optional MongoDB authentication"). SetOptional() @@ -958,7 +962,14 @@ func parseFlags() (*cliOptions, error) { gop.Parse(os.Args) if gop.NArgs() > 0 { - opts.Host = gop.Arg(0) + if gop.IsSet("host") || gop.IsSet("port") || gop.IsSet("uri") { + return nil, fmt.Errorf(`parameter host[:port] is not compatible with "--uri", "--host" and "--port" flags set`) + } + var err error + opts.Host, opts.Port, err = net.SplitHostPort(gop.Arg(0)) + if err != nil { + return nil, err + } gop.Parse(gop.Args()) } @@ -973,8 +984,8 @@ func parseFlags() (*cliOptions, error) { opts.Password = string(pass) } - if !strings.HasPrefix(opts.Host, "mongodb://") { - opts.Host = "mongodb://" + opts.Host + if gop.IsSet("uri") && (gop.IsSet("host") || gop.IsSet("port")) { + return nil, fmt.Errorf("If a full URI is provided, you cannot also specify --host or --port") } if opts.Help { @@ -1015,20 +1026,38 @@ func getChunksCount(ctx context.Context, client *mongo.Client) ([]proto.ChunksBy } func getClientOptions(opts *cliOptions) (*options.ClientOptions, error) { - clientOptions := options.Client().ApplyURI(opts.Host) + var clientOptions *options.ClientOptions - clientOptions.ServerSelectionTimeout = &defaultConnectionTimeout - clientOptions.Direct = &directConnection - credential := options.Credential{} - if opts.User != "" { - credential.Username = opts.User - clientOptions.SetAuth(credential) + if opts.URI != "" { + clientOptions = options.Client().ApplyURI(opts.URI) + } else { + host := opts.Host + if host == "" { + host = DefaultHost + } + port := opts.Port + if port == "" { + port = DefaultPort + } + + clientOptions = options.Client().ApplyURI("mongodb://" + net.JoinHostPort(host, port)) + } + + auth := clientOptions.Auth + if auth == nil { + auth = &options.Credential{} } + if opts.User != "" { + auth.Username = opts.User + } if opts.Password != "" { - credential.Password = opts.Password - credential.PasswordSet = true - clientOptions.SetAuth(credential) + auth.Password = opts.Password + auth.PasswordSet = true + } + + if auth.Username != "" { + clientOptions.SetAuth(*auth) } if opts.SSLPEMKeyFile != "" || opts.SSLCAFile != "" { @@ -1040,7 +1069,15 @@ func getClientOptions(opts *cliOptions) (*options.ClientOptions, error) { clientOptions.TLSConfig = tlsConfig } - return clientOptions, nil + // Defaults + if clientOptions.ServerSelectionTimeout == nil { + clientOptions.ServerSelectionTimeout = &defaultConnectionTimeout + } + if clientOptions.Direct == nil { + clientOptions.Direct = &directConnection + } + + return clientOptions, clientOptions.Validate() } func getTLSConfig(sslPEMKeyFile, sslCAFile string) (*tls.Config, error) { diff --git a/src/go/pt-mongodb-summary/main_test.go b/src/go/pt-mongodb-summary/main_test.go index 0f1a9e69e..296192087 100644 --- a/src/go/pt-mongodb-summary/main_test.go +++ b/src/go/pt-mongodb-summary/main_test.go @@ -14,11 +14,15 @@ package main import ( + "testing" + "time" + "context" "os" "reflect" - "testing" - "time" + + "github.com/stretchr/testify/assert" + "go.mongodb.org/mongo-driver/mongo/options" "github.com/pborman/getopt" "github.com/stretchr/testify/require" @@ -110,15 +114,43 @@ func TestClusterWideInfo(t *testing.T) { } } -func TestParseArgs(t *testing.T) { +func TestParseFlags(t *testing.T) { tests := []struct { - args []string - want *cliOptions + name string + args []string + want *cliOptions + wantErr bool }{ { - args: []string{toolname}, // arg[0] is the command itself + name: "Default values", + args: []string{toolname}, + want: &cliOptions{ + Host: "", + LogLevel: DefaultLogLevel, + AuthDB: DefaultAuthDB, + RunningOpsSamples: DefaultRunningOpsSamples, + RunningOpsInterval: DefaultRunningOpsInterval, + OutputFormat: "text", + }, + }, + { + name: "URI only", + args: []string{toolname, "--uri", "mongodb://test:27017"}, + want: &cliOptions{ + URI: "mongodb://test:27017", + LogLevel: DefaultLogLevel, + AuthDB: DefaultAuthDB, + RunningOpsSamples: DefaultRunningOpsSamples, + RunningOpsInterval: DefaultRunningOpsInterval, + OutputFormat: "text", + }, + }, + { + name: "Legacy positional host:port", + args: []string{toolname, "test.example.com:27019"}, want: &cliOptions{ - Host: DefaultHost, + Host: "test.example.com", + Port: "27019", LogLevel: DefaultLogLevel, AuthDB: DefaultAuthDB, RunningOpsSamples: DefaultRunningOpsSamples, @@ -127,27 +159,137 @@ func TestParseArgs(t *testing.T) { }, }, { - args: []string{toolname, "zapp.brannigan.net:27018/samples", "--help"}, + name: "Error: URI and Host together", + args: []string{toolname, "--uri", "mongodb://test", "--host", "localhost"}, + wantErr: true, + }, + { + name: "Error: Positional arg and Host flag together", + args: []string{toolname, "--host", "newhost", "legacy:27017"}, + wantErr: true, + }, + { + name: "Help flag returns nil options", + args: []string{toolname, "--help"}, want: nil, }, } - // Capture stdout to not to show help - old := os.Stdout // keep backup of the real stdout + // Backup and silence stdout + oldStdout := os.Stdout _, w, _ := os.Pipe() os.Stdout = w + defer func() { os.Stdout = oldStdout }() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + getopt.Reset() + os.Args = tt.args + + got, err := parseFlags() + + if tt.wantErr { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("mismatch:\ngot: %+v\nwant: %+v", got, tt.want) + } + }) + } +} - for i, test := range tests { - getopt.Reset() - os.Args = test.args - got, err := parseFlags() - if err != nil { - t.Errorf("error parsing command line arguments: %s", err.Error()) - } - if !reflect.DeepEqual(got, test.want) { - t.Errorf("invalid command line options test %d\ngot %+v\nwant %+v\n", i, got, test.want) - } +func TestGetClientOptions(t *testing.T) { + tests := []struct { + name string + opts *cliOptions + wantErr bool + validate func(*testing.T, *options.ClientOptions) + }{ + { + name: "Default values when everything is empty", + opts: &cliOptions{}, + validate: func(t *testing.T, co *options.ClientOptions) { + assert.Equal(t, []string{"localhost:27017"}, co.Hosts) + assert.Nil(t, co.Auth) + }, + }, + { + name: "Priority to URI", + opts: &cliOptions{ + URI: "mongodb://remote-host:28000", + }, + validate: func(t *testing.T, co *options.ClientOptions) { + assert.Equal(t, []string{"remote-host:28000"}, co.Hosts) + }, + }, + { + name: "Flags override Auth in URI", + opts: &cliOptions{ + URI: "mongodb://old-user:old-pass@localhost:27017", + User: "new-user", + Password: "new-password", + }, + validate: func(t *testing.T, co *options.ClientOptions) { + assert.Equal(t, "new-user", co.Auth.Username) + assert.Equal(t, "new-password", co.Auth.Password) + }, + }, + { + name: "Only host and port flags", + opts: &cliOptions{ + Host: "127.0.0.1", + Port: "27019", + }, + validate: func(t *testing.T, co *options.ClientOptions) { + assert.Equal(t, []string{"127.0.0.1:27019"}, co.Hosts) + }, + }, + { + name: "Invalid URI should return error", + opts: &cliOptions{ + URI: "not-a-valid-uri", + }, + wantErr: true, + }, + { + name: "AuthDB via URI (check if preserved)", + opts: &cliOptions{ + URI: "mongodb://user@localhost:27017/admin?authSource=custom_db", + }, + validate: func(t *testing.T, co *options.ClientOptions) { + assert.Equal(t, "user", co.Auth.Username) + assert.Equal(t, "custom_db", co.Auth.AuthSource) + }, + }, } - os.Stdout = old + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := getClientOptions(tt.opts) + + if tt.wantErr { + assert.Error(t, err) + return + } + + assert.NoError(t, err) + assert.NotNil(t, got) + + if tt.validate != nil { + tt.validate(t, got) + } + + assert.NotNil(t, got.ServerSelectionTimeout) + assert.True(t, *got.Direct) + }) + } }