Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 56 additions & 19 deletions src/go/pt-mongodb-summary/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ const (
toolname = "pt-mongodb-summary"

DefaultAuthDB = "admin"
DefaultHost = "mongodb://localhost:27017"
DefaultHost = "localhost"
DefaultPort = "27017"
DefaultLogLevel = "warn"
DefaultRunningOpsInterval = 1000 // milliseconds
DefaultRunningOpsSamples = 5
Expand Down Expand Up @@ -161,6 +162,7 @@ type clusterwideInfo struct {

type cliOptions struct {
Host string
Port string
User string
Password string
AuthDB string
Expand All @@ -170,6 +172,7 @@ type cliOptions struct {
SSLPEMKeyFile string
RunningOpsSamples int
RunningOpsInterval int
URI string
Help bool
Comment thread
svetasmirnova marked this conversation as resolved.
Version bool
NoVersionCheck bool
Expand Down Expand Up @@ -252,13 +255,11 @@ func main() {
log.Errorf("Cannot get hostnames: %s", err)
}

log.Debugf("hostnames: %v", hostnames)

ci := &collectedInfo{}

ci.HostInfo, err = getHostInfo(ctx, client)
if err != nil {
log.Errorf("Cannot get host info for %q: %s", opts.Host, err)
log.Errorf("Cannot get host info for %q: %s", clientOptions.Hosts, err)
os.Exit(cannotGetHostInfo) //nolint:gocritic
}

Expand Down Expand Up @@ -920,7 +921,6 @@ func externalIP() (string, error) {

func parseFlags() (*cliOptions, error) {
opts := &cliOptions{
Host: DefaultHost,
LogLevel: DefaultLogLevel,
RunningOpsSamples: DefaultRunningOpsSamples,
RunningOpsInterval: DefaultRunningOpsInterval, // milliseconds
Expand All @@ -933,6 +933,10 @@ func parseFlags() (*cliOptions, error) {
gop.BoolVarLong(&opts.Version, "version", 'v', "", "Show version & exit")
gop.BoolVarLong(&opts.NoVersionCheck, "no-version-check", 'c', "", "Default: Don't check for updates")

gop.StringVarLong(&opts.URI, "uri", 0, "URI describes the hosts to be used and options. Flags has higher priority.")
Comment thread
BON4 marked this conversation as resolved.
Outdated
gop.StringVarLong(&opts.Host, "host", 0, "Host")
gop.StringVarLong(&opts.Port, "port", 0, "Port")

gop.StringVarLong(&opts.User, "username", 'u', "", "Username to use for optional MongoDB authentication")
gop.StringVarLong(&opts.Password, "password", 'p', "", "Password to use for optional MongoDB authentication").
SetOptional()
Expand All @@ -958,7 +962,14 @@ func parseFlags() (*cliOptions, error) {
gop.Parse(os.Args)

if gop.NArgs() > 0 {
opts.Host = gop.Arg(0)
if gop.IsSet("host") || gop.IsSet("port") {
Comment thread
BON4 marked this conversation as resolved.
Outdated
return nil, fmt.Errorf(`parameter host[:port] is not compatible with "--host" and "--port" flags set`)
}
var err error
opts.Host, opts.Port, err = net.SplitHostPort(gop.Arg(0))
if err != nil {
return nil, err
}
gop.Parse(gop.Args())
}

Expand All @@ -973,8 +984,8 @@ func parseFlags() (*cliOptions, error) {
opts.Password = string(pass)
}

if !strings.HasPrefix(opts.Host, "mongodb://") {
opts.Host = "mongodb://" + opts.Host
if gop.IsSet("uri") && (gop.IsSet("host") || gop.IsSet("port")) {
return nil, fmt.Errorf("If a full URI is provided, you cannot also specify --host or --port")
}

if opts.Help {
Expand Down Expand Up @@ -1015,20 +1026,38 @@ func getChunksCount(ctx context.Context, client *mongo.Client) ([]proto.ChunksBy
}

func getClientOptions(opts *cliOptions) (*options.ClientOptions, error) {
clientOptions := options.Client().ApplyURI(opts.Host)
var clientOptions *options.ClientOptions

clientOptions.ServerSelectionTimeout = &defaultConnectionTimeout
clientOptions.Direct = &directConnection
credential := options.Credential{}
if opts.User != "" {
credential.Username = opts.User
clientOptions.SetAuth(credential)
if opts.URI != "" {
clientOptions = options.Client().ApplyURI(opts.URI)
} else {
host := opts.Host
if host == "" {
host = DefaultHost
}
port := opts.Port
if port == "" {
port = DefaultPort
}

clientOptions = options.Client().ApplyURI("mongodb://" + net.JoinHostPort(host, port))
}

auth := clientOptions.Auth
if auth == nil {
auth = &options.Credential{}
}

if opts.User != "" {
auth.Username = opts.User
}
if opts.Password != "" {
credential.Password = opts.Password
credential.PasswordSet = true
clientOptions.SetAuth(credential)
auth.Password = opts.Password
auth.PasswordSet = true
}

if auth.Username != "" {
clientOptions.SetAuth(*auth)
}

if opts.SSLPEMKeyFile != "" || opts.SSLCAFile != "" {
Expand All @@ -1040,7 +1069,15 @@ func getClientOptions(opts *cliOptions) (*options.ClientOptions, error) {
clientOptions.TLSConfig = tlsConfig
}

return clientOptions, nil
// Defaults
if clientOptions.ServerSelectionTimeout == nil {
clientOptions.ServerSelectionTimeout = &defaultConnectionTimeout
}
if clientOptions.Direct == nil {
clientOptions.Direct = &directConnection
}

return clientOptions, clientOptions.Validate()
}

func getTLSConfig(sslPEMKeyFile, sslCAFile string) (*tls.Config, error) {
Expand Down
184 changes: 163 additions & 21 deletions src/go/pt-mongodb-summary/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@
package main

import (
"testing"
"time"

"context"
"os"
"reflect"
"testing"
"time"

"github.com/stretchr/testify/assert"
"go.mongodb.org/mongo-driver/mongo/options"

"github.com/pborman/getopt"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -110,15 +114,43 @@ func TestClusterWideInfo(t *testing.T) {
}
}

func TestParseArgs(t *testing.T) {
func TestParseFlags(t *testing.T) {
tests := []struct {
args []string
want *cliOptions
name string
args []string
want *cliOptions
wantErr bool
}{
{
args: []string{toolname}, // arg[0] is the command itself
name: "Default values",
args: []string{toolname},
want: &cliOptions{
Host: "",
LogLevel: DefaultLogLevel,
AuthDB: DefaultAuthDB,
RunningOpsSamples: DefaultRunningOpsSamples,
RunningOpsInterval: DefaultRunningOpsInterval,
OutputFormat: "text",
},
},
{
name: "URI only",
args: []string{toolname, "--uri", "mongodb://test:27017"},
want: &cliOptions{
URI: "mongodb://test:27017",
LogLevel: DefaultLogLevel,
AuthDB: DefaultAuthDB,
RunningOpsSamples: DefaultRunningOpsSamples,
RunningOpsInterval: DefaultRunningOpsInterval,
OutputFormat: "text",
},
},
{
name: "Legacy positional host:port",
args: []string{toolname, "test.example.com:27019"},
want: &cliOptions{
Host: DefaultHost,
Host: "test.example.com",
Port: "27019",
LogLevel: DefaultLogLevel,
AuthDB: DefaultAuthDB,
RunningOpsSamples: DefaultRunningOpsSamples,
Expand All @@ -127,27 +159,137 @@ func TestParseArgs(t *testing.T) {
},
},
{
args: []string{toolname, "zapp.brannigan.net:27018/samples", "--help"},
name: "Error: URI and Host together",
args: []string{toolname, "--uri", "mongodb://test", "--host", "localhost"},
wantErr: true,
},
{
name: "Error: Positional arg and Host flag together",
args: []string{toolname, "--host", "newhost", "legacy:27017"},
wantErr: true,
},
{
name: "Help flag returns nil options",
args: []string{toolname, "--help"},
want: nil,
},
}

// Capture stdout to not to show help
old := os.Stdout // keep backup of the real stdout
// Backup and silence stdout
oldStdout := os.Stdout
_, w, _ := os.Pipe()
os.Stdout = w
defer func() { os.Stdout = oldStdout }()

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
getopt.Reset()
os.Args = tt.args

got, err := parseFlags()

if tt.wantErr {
if err == nil {
t.Errorf("expected error but got none")
}
return
}

if err != nil {
t.Errorf("unexpected error: %v", err)
return
}

if !reflect.DeepEqual(got, tt.want) {
t.Errorf("mismatch:\ngot: %+v\nwant: %+v", got, tt.want)
}
})
}
}

for i, test := range tests {
getopt.Reset()
os.Args = test.args
got, err := parseFlags()
if err != nil {
t.Errorf("error parsing command line arguments: %s", err.Error())
}
if !reflect.DeepEqual(got, test.want) {
t.Errorf("invalid command line options test %d\ngot %+v\nwant %+v\n", i, got, test.want)
}
func TestGetClientOptions(t *testing.T) {
tests := []struct {
name string
opts *cliOptions
wantErr bool
validate func(*testing.T, *options.ClientOptions)
}{
{
name: "Default values when everything is empty",
opts: &cliOptions{},
validate: func(t *testing.T, co *options.ClientOptions) {
assert.Equal(t, []string{"localhost:27017"}, co.Hosts)
assert.Nil(t, co.Auth)
},
},
{
name: "Priority to URI",
opts: &cliOptions{
URI: "mongodb://remote-host:28000",
},
validate: func(t *testing.T, co *options.ClientOptions) {
assert.Equal(t, []string{"remote-host:28000"}, co.Hosts)
},
},
{
name: "Flags override Auth in URI",
opts: &cliOptions{
URI: "mongodb://old-user:old-pass@localhost:27017",
User: "new-user",
Password: "new-password",
},
validate: func(t *testing.T, co *options.ClientOptions) {
assert.Equal(t, "new-user", co.Auth.Username)
assert.Equal(t, "new-password", co.Auth.Password)
},
},
{
name: "Only host and port flags",
opts: &cliOptions{
Host: "127.0.0.1",
Port: "27019",
},
validate: func(t *testing.T, co *options.ClientOptions) {
assert.Equal(t, []string{"127.0.0.1:27019"}, co.Hosts)
},
},
{
name: "Invalid URI should return error",
opts: &cliOptions{
URI: "not-a-valid-uri",
},
wantErr: true,
},
{
name: "AuthDB via URI (check if preserved)",
opts: &cliOptions{
URI: "mongodb://user@localhost:27017/admin?authSource=custom_db",
},
validate: func(t *testing.T, co *options.ClientOptions) {
assert.Equal(t, "user", co.Auth.Username)
assert.Equal(t, "custom_db", co.Auth.AuthSource)
},
},
}

os.Stdout = old
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := getClientOptions(tt.opts)

if tt.wantErr {
assert.Error(t, err)
return
}

assert.NoError(t, err)
assert.NotNil(t, got)

if tt.validate != nil {
tt.validate(t, got)
}

assert.NotNil(t, got.ServerSelectionTimeout)
assert.True(t, *got.Direct)
})
}
}
Loading