From 1b1baacda7cb0a31819d498bd13e6ca03aa597e7 Mon Sep 17 00:00:00 2001 From: michal Date: Mon, 20 Apr 2026 11:21:18 +0200 Subject: [PATCH] Updated validate to fall back to Elastic-Agent-Version header' --- internal/pkg/api/handleCheckin.go | 2 +- internal/pkg/api/handleEnroll.go | 2 +- internal/pkg/api/userAgent.go | 22 ++++++-- internal/pkg/api/userAgent_test.go | 86 ++++++++++++++++++++++++++++-- 4 files changed, 102 insertions(+), 10 deletions(-) diff --git a/internal/pkg/api/handleCheckin.go b/internal/pkg/api/handleCheckin.go index 3e4f530597..e8bed25f1b 100644 --- a/internal/pkg/api/handleCheckin.go +++ b/internal/pkg/api/handleCheckin.go @@ -144,7 +144,7 @@ func (ct *CheckinT) handleCheckin(zlog zerolog.Logger, w http.ResponseWriter, r ctx := zlog.WithContext(r.Context()) r = r.WithContext(ctx) - ver, err := validateUserAgent(r.Context(), zlog, userAgent, ct.verCon) + ver, err := validateUserAgent(r.Context(), zlog, userAgent, r.Header.Get("Elastic-Agent-Version"), ct.verCon) if err != nil { return err } diff --git a/internal/pkg/api/handleEnroll.go b/internal/pkg/api/handleEnroll.go index 4e9ab2e33b..edcbf8f289 100644 --- a/internal/pkg/api/handleEnroll.go +++ b/internal/pkg/api/handleEnroll.go @@ -92,7 +92,7 @@ func (et *EnrollerT) handleEnroll(zlog zerolog.Logger, w http.ResponseWriter, r ctx := zlog.WithContext(r.Context()) r = r.WithContext(ctx) - ver, err := validateUserAgent(r.Context(), zlog, userAgent, et.verCon) + ver, err := validateUserAgent(r.Context(), zlog, userAgent, r.Header.Get("Elastic-Agent-Version"), et.verCon) if err != nil { return err } diff --git a/internal/pkg/api/userAgent.go b/internal/pkg/api/userAgent.go index 992b98c644..fa388ef47e 100644 --- a/internal/pkg/api/userAgent.go +++ b/internal/pkg/api/userAgent.go @@ -47,7 +47,7 @@ func maximizePatch(ver *version.Version) string { // validateUserAgent validates that the User-Agent of the connecting Elastic Agent is valid and that the version is // supported for this Fleet Server. -func validateUserAgent(ctx context.Context, zlog zerolog.Logger, userAgent string, verConst version.Constraints) (string, error) { +func validateUserAgent(ctx context.Context, zlog zerolog.Logger, userAgent string, elasticAgentVersion string, verConst version.Constraints) (string, error) { span, _ := apm.StartSpan(ctx, "userAgent", "validate") defer span.End() zlog = zlog.With().Str("userAgent", userAgent).Logger() @@ -60,6 +60,7 @@ func validateUserAgent(ctx context.Context, zlog zerolog.Logger, userAgent strin } userAgent = strings.ToLower(userAgent) + if !strings.HasPrefix(userAgent, userAgentPrefix) { zlog.Info(). Err(ErrInvalidUserAgent). @@ -77,7 +78,7 @@ func validateUserAgent(ctx context.Context, zlog zerolog.Logger, userAgent strin // Trim leading and traling spaces verStr := strings.TrimSpace(verSep[0]) - ver, err := version.NewVersion(verStr) + ver, err := getVersion(verStr, elasticAgentVersion) if err != nil { zlog.Info(). Err(err). @@ -85,10 +86,11 @@ func validateUserAgent(ctx context.Context, zlog zerolog.Logger, userAgent strin Msg("invalid user agent version string") return "", ErrInvalidUserAgent } + if !verConst.Check(ver) { zlog.Info(). Err(ErrUnsupportedVersion). - Str("verStr", verStr). + Str("verStr", ver.String()). Str("constraints", verConst.String()). Msg("unsuported user agent version") return "", ErrUnsupportedVersion @@ -96,3 +98,17 @@ func validateUserAgent(ctx context.Context, zlog zerolog.Logger, userAgent strin return ver.String(), nil } + +func getVersion(userAgentVersion string, elasticAgentVersion string) (*version.Version, error) { + ver, err := version.NewVersion(userAgentVersion) + if err == nil { + return ver, nil + } + + ver, err = version.NewVersion(elasticAgentVersion) + if err == nil { + return ver, nil + } + + return nil, ErrInvalidUserAgent +} diff --git a/internal/pkg/api/userAgent_test.go b/internal/pkg/api/userAgent_test.go index 9106d08de5..a16f9e19f7 100644 --- a/internal/pkg/api/userAgent_test.go +++ b/internal/pkg/api/userAgent_test.go @@ -13,13 +13,15 @@ import ( "github.com/hashicorp/go-version" "github.com/rs/zerolog" + "github.com/stretchr/testify/require" ) func TestValidateUserAgent(t *testing.T) { tests := []struct { - userAgent string - verCon version.Constraints - err error + userAgent string + verCon version.Constraints + err error + elasticAgentVersion string }{ { userAgent: "", @@ -111,13 +113,87 @@ func TestValidateUserAgent(t *testing.T) { verCon: mustBuildConstraints("8.0.0-beta1"), err: nil, }, + { + userAgent: "Elastic Agent Agentless", + verCon: mustBuildConstraints("8.0.0"), + err: nil, + elasticAgentVersion: "v8.0.0", + }, + { + userAgent: "Elastic Agent Agentless", + verCon: nil, + err: ErrInvalidUserAgent, + }, } for _, tr := range tests { - t.Run(tr.userAgent, func(t *testing.T) { - _, res := validateUserAgent(context.Background(), zerolog.Nop(), tr.userAgent, tr.verCon) + t.Run(tr.userAgent+tr.elasticAgentVersion, func(t *testing.T) { + _, res := validateUserAgent(context.Background(), zerolog.Nop(), tr.userAgent, tr.elasticAgentVersion, tr.verCon) if !errors.Is(tr.err, res) { t.Fatalf("err mismatch: %v != %v", tr.err, res) } }) } } + +func TestGetVersion(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + userAgentVersion string + elasticAgentVersion string + want string + wantErr error + }{ + { + name: "parses_user_agent_version_first", + userAgentVersion: "8.0.0", + elasticAgentVersion: "9.0.0", + want: "8.0.0", + }, + { + name: "strips_v_prefix_on_user_agent_version", + userAgentVersion: "v8.1.2", + elasticAgentVersion: "9.0.0", + want: "8.1.2", + }, + { + name: "falls_back_to_elastic_agent_version_header", + userAgentVersion: "agentless", + elasticAgentVersion: "v8.0.0", + want: "8.0.0", + }, + { + name: "empty_user_agent_version_uses_header", + userAgentVersion: "", + elasticAgentVersion: "7.14.0", + want: "7.14.0", + }, + { + name: "invalid_user_agent_version_invalid_header", + userAgentVersion: "not-a-semver", + elasticAgentVersion: "also-invalid", + wantErr: ErrInvalidUserAgent, + }, + { + name: "invalid_user_empty_header", + userAgentVersion: "nope", + elasticAgentVersion: "", + wantErr: ErrInvalidUserAgent, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := getVersion(tt.userAgentVersion, tt.elasticAgentVersion) + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + require.Nil(t, got) + return + } + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, tt.want, got.String()) + }) + } +}