Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/pkg/api/handleCheckin.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ func (ct *CheckinT) handleCheckin(zlog zerolog.Logger, w http.ResponseWriter, r
ctx := zlog.WithContext(r.Context())
r = r.WithContext(ctx)

ver, err := validateUserAgent(r.Context(), zlog, userAgent, ct.verCon)
ver, err := validateUserAgent(r.Context(), zlog, userAgent, r.Header.Get("Elastic-Agent-Version"), ct.verCon)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add this into the openapi spec?

if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion internal/pkg/api/handleEnroll.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (et *EnrollerT) handleEnroll(zlog zerolog.Logger, w http.ResponseWriter, r
ctx := zlog.WithContext(r.Context())
r = r.WithContext(ctx)

ver, err := validateUserAgent(r.Context(), zlog, userAgent, et.verCon)
ver, err := validateUserAgent(r.Context(), zlog, userAgent, r.Header.Get("Elastic-Agent-Version"), et.verCon)
if err != nil {
return err
}
Expand Down
22 changes: 19 additions & 3 deletions internal/pkg/api/userAgent.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func maximizePatch(ver *version.Version) string {

// validateUserAgent validates that the User-Agent of the connecting Elastic Agent is valid and that the version is
// supported for this Fleet Server.
func validateUserAgent(ctx context.Context, zlog zerolog.Logger, userAgent string, verConst version.Constraints) (string, error) {
func validateUserAgent(ctx context.Context, zlog zerolog.Logger, userAgent string, elasticAgentVersion string, verConst version.Constraints) (string, error) {
span, _ := apm.StartSpan(ctx, "userAgent", "validate")
defer span.End()
zlog = zlog.With().Str("userAgent", userAgent).Logger()
Expand All @@ -60,6 +60,7 @@ func validateUserAgent(ctx context.Context, zlog zerolog.Logger, userAgent strin
}

userAgent = strings.ToLower(userAgent)

if !strings.HasPrefix(userAgent, userAgentPrefix) {
zlog.Info().
Err(ErrInvalidUserAgent).
Expand All @@ -77,22 +78,37 @@ func validateUserAgent(ctx context.Context, zlog zerolog.Logger, userAgent strin
// Trim leading and traling spaces
verStr := strings.TrimSpace(verSep[0])

ver, err := version.NewVersion(verStr)
ver, err := getVersion(verStr, elasticAgentVersion)
if err != nil {
zlog.Info().
Err(err).
Str("verStr", verStr).
Msg("invalid user agent version string")
return "", ErrInvalidUserAgent
}

if !verConst.Check(ver) {
zlog.Info().
Err(ErrUnsupportedVersion).
Str("verStr", verStr).
Str("verStr", ver.String()).
Str("constraints", verConst.String()).
Msg("unsuported user agent version")
return "", ErrUnsupportedVersion
}

return ver.String(), nil
}

func getVersion(userAgentVersion string, elasticAgentVersion string) (*version.Version, error) {
ver, err := version.NewVersion(userAgentVersion)
if err == nil {
return ver, nil
}

ver, err = version.NewVersion(elasticAgentVersion)
if err == nil {
return ver, nil
}

return nil, ErrInvalidUserAgent
Comment on lines +108 to +113
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ErrInvalidUserAgent seems overly broad for what this function is checking. Also, that error is being returned from the error checking being done at this function's call site. So maybe this function should just return the raw error at the end?

Suggested change
ver, err = version.NewVersion(elasticAgentVersion)
if err == nil {
return ver, nil
}
return nil, ErrInvalidUserAgent
// Fallback
return version.NewVersion(elasticAgentVersion)

}
86 changes: 81 additions & 5 deletions internal/pkg/api/userAgent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@ import (

"github.com/hashicorp/go-version"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
)

func TestValidateUserAgent(t *testing.T) {
tests := []struct {
userAgent string
verCon version.Constraints
err error
userAgent string
verCon version.Constraints
err error
elasticAgentVersion string
}{
{
userAgent: "",
Expand Down Expand Up @@ -111,13 +113,87 @@ func TestValidateUserAgent(t *testing.T) {
verCon: mustBuildConstraints("8.0.0-beta1"),
err: nil,
},
{
userAgent: "Elastic Agent Agentless",
verCon: mustBuildConstraints("8.0.0"),
err: nil,
elasticAgentVersion: "v8.0.0",
},
{
userAgent: "Elastic Agent Agentless",
verCon: nil,
err: ErrInvalidUserAgent,
},
}
for _, tr := range tests {
t.Run(tr.userAgent, func(t *testing.T) {
_, res := validateUserAgent(context.Background(), zerolog.Nop(), tr.userAgent, tr.verCon)
t.Run(tr.userAgent+tr.elasticAgentVersion, func(t *testing.T) {
_, res := validateUserAgent(context.Background(), zerolog.Nop(), tr.userAgent, tr.elasticAgentVersion, tr.verCon)
if !errors.Is(tr.err, res) {
t.Fatalf("err mismatch: %v != %v", tr.err, res)
}
})
}
}

func TestGetVersion(t *testing.T) {
t.Parallel()

tests := []struct {
name string
userAgentVersion string
elasticAgentVersion string
want string
wantErr error
}{
{
name: "parses_user_agent_version_first",
userAgentVersion: "8.0.0",
elasticAgentVersion: "9.0.0",
want: "8.0.0",
},
{
name: "strips_v_prefix_on_user_agent_version",
userAgentVersion: "v8.1.2",
elasticAgentVersion: "9.0.0",
want: "8.1.2",
},
{
name: "falls_back_to_elastic_agent_version_header",
userAgentVersion: "agentless",
elasticAgentVersion: "v8.0.0",
want: "8.0.0",
},
{
name: "empty_user_agent_version_uses_header",
userAgentVersion: "",
elasticAgentVersion: "7.14.0",
want: "7.14.0",
},
{
name: "invalid_user_agent_version_invalid_header",
userAgentVersion: "not-a-semver",
elasticAgentVersion: "also-invalid",
wantErr: ErrInvalidUserAgent,
},
{
name: "invalid_user_empty_header",
userAgentVersion: "nope",
elasticAgentVersion: "",
wantErr: ErrInvalidUserAgent,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := getVersion(tt.userAgentVersion, tt.elasticAgentVersion)
if tt.wantErr != nil {
require.ErrorIs(t, err, tt.wantErr)
require.Nil(t, got)
return
}
require.NoError(t, err)
require.NotNil(t, got)
require.Equal(t, tt.want, got.String())
})
}
}
Loading