diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index eea9863c..0991f38e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -22,7 +22,7 @@ jobs: steps: - uses: actions/setup-go@v5 with: - go-version: '1.21.13' + go-version: '1.24.11' - uses: actions/checkout@v4.1.3 #- name: download libraries # run: go mod download @@ -32,7 +32,7 @@ jobs: # Caching conflicts happen in GHA, so just disable for now skip-cache: true # Required: the version of golangci-lint is required and must be specified without patch version: we always use the latest patch version. - version: v1.59.1 + version: v1.64.2 unit-tests: name: Unit tests runs-on: ubuntu-latest @@ -77,7 +77,7 @@ jobs: - uses: actions/checkout@v4.1.3 - uses: actions/setup-go@v5 with: - go-version: '1.21.13' + go-version: '1.24.11' - name: Verify image builds run: | docker build --tag infrawatch/sg-core:latest --file build/Dockerfile . diff --git a/.github/workflows/updates.yml b/.github/workflows/updates.yml index 5aecab1b..96e5970d 100644 --- a/.github/workflows/updates.yml +++ b/.github/workflows/updates.yml @@ -17,18 +17,31 @@ jobs: # (github.event.issue.author_association == 'MEMBER') # ) runs-on: ubuntu-latest + permissions: + pull-requests: write + issues: write steps: - name: update PR with coveralls badge uses: actions/github-script@v7.0.1 + continue-on-error: true with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | - var BRANCH_NAME = process.env.BRANCH_NAME; - github.rest.issues.createComment({ - issue_number: context.issue.number, - owner: context.repo.owner, - repo: context.repo.repo, - body: `[![Coverage Status](https://coveralls.io/repos/github/${context.repo.owner}/${context.repo.repo}/badge.svg?branch=${BRANCH_NAME})](https://coveralls.io/github/${context.repo.owner}/${context.repo.repo}?branch=${BRANCH_NAME})` - }) + const branchName = process.env.BRANCH_NAME; + if (!branchName) { + console.log("No branch name found, skipping badge update."); + return; + } + + try { + await github.rest.issues.createComment({ + issue_number: context.issue.number, + owner: context.repo.owner, + repo: context.repo.repo, + body: `[![Coverage Status](https://coveralls.io/repos/github/${context.repo.owner}/${context.repo.repo}/badge.svg?branch=${branchName})](https://coveralls.io/github/${context.repo.owner}/${context.repo.repo}?branch=${branchName})` + }); + } catch (error) { + console.log("Could not post comment: ", error.message); + } env: BRANCH_NAME: ${{ github.event.pull_request.head.ref }} diff --git a/.golangci.yaml b/.golangci.yaml index d4533893..d9a1e172 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -37,7 +37,7 @@ linters: - dupl - errcheck # - exhaustive - - exportloopref + - copyloopvar # - gochecknoinits - goconst - gocritic @@ -61,5 +61,5 @@ linters: - unconvert # NOTE: not all application plugins use ability to emit internal events through # passed bus function in it's constructor. - #- unparam + # - unparam # - whitespace diff --git a/build/Dockerfile b/build/Dockerfile index 4bd8becf..c40f26c4 100644 --- a/build/Dockerfile +++ b/build/Dockerfile @@ -10,7 +10,7 @@ COPY . $D/ COPY build/repos/opstools.repo /etc/yum.repos.d/CentOS-OpsTools.repo RUN dnf install golang git qpid-proton-c-devel -y --setopt=tsflags=nodocs -RUN go install golang.org/dl/go1.21.13@latest && /go/bin/go1.21.13 download && PRODUCTION_BUILD=false CONTAINER_BUILD=true GOCMD=/go/bin/go1.21.13 ./build.sh +RUN go install golang.org/dl/go1.24.11@latest && /go/bin/go1.24.11 download && PRODUCTION_BUILD=false CONTAINER_BUILD=true GOCMD=/go/bin/go1.24.11 ./build.sh # --- end build, create smart gateway layer --- FROM registry.access.redhat.com/ubi9-minimal:latest diff --git a/ci/integration/logging/run_sg.sh b/ci/integration/logging/run_sg.sh index b2cec885..1601b976 100644 --- a/ci/integration/logging/run_sg.sh +++ b/ci/integration/logging/run_sg.sh @@ -12,11 +12,11 @@ dnf install -y git golang gcc make qpid-proton-c-devel export GOBIN=$GOPATH/bin export PATH=$PATH:$GOBIN -go install golang.org/dl/go1.21.13@latest -go1.21.13 download +go install golang.org/dl/go1.24.11@latest +go1.24.11 download # install sg-core and start sg-core mkdir -p /usr/lib64/sg-core -PLUGIN_DIR=/usr/lib64/sg-core/ GOCMD=go1.21.13 BUILD_ARGS=-buildvcs=false ./build.sh +PLUGIN_DIR=/usr/lib64/sg-core/ GOCMD=go1.24.11 BUILD_ARGS=-buildvcs=false ./build.sh ./sg-core -config ./ci/integration/logging/sg_config.yaml diff --git a/ci/integration/metrics/ceilometer/bridge/run_sg.sh b/ci/integration/metrics/ceilometer/bridge/run_sg.sh index 3c041cae..47953e00 100644 --- a/ci/integration/metrics/ceilometer/bridge/run_sg.sh +++ b/ci/integration/metrics/ceilometer/bridge/run_sg.sh @@ -12,11 +12,11 @@ dnf install -y git golang gcc make qpid-proton-c-devel export GOBIN=$GOPATH/bin export PATH=$PATH:$GOBIN -go install golang.org/dl/go1.21.13@latest -go1.21.13 download +go install golang.org/dl/go1.24.11@latest +go1.24.11 download # install sg-core and start sg-core mkdir -p /usr/lib64/sg-core -PLUGIN_DIR=/usr/lib64/sg-core/ GOCMD=go1.21.13 BUILD_ARGS=-buildvcs=false ./build.sh +PLUGIN_DIR=/usr/lib64/sg-core/ GOCMD=go1.24.11 BUILD_ARGS=-buildvcs=false ./build.sh ./sg-core -config ./ci/integration/metrics/ceilometer/bridge/sg_config.yaml diff --git a/ci/integration/metrics/ceilometer/tcp/run_sg.sh b/ci/integration/metrics/ceilometer/tcp/run_sg.sh index cae39b14..1b2dcd1a 100644 --- a/ci/integration/metrics/ceilometer/tcp/run_sg.sh +++ b/ci/integration/metrics/ceilometer/tcp/run_sg.sh @@ -13,11 +13,11 @@ dnf install -y git golang gcc make qpid-proton-c-devel export GOBIN=$GOPATH/bin export PATH=$PATH:$GOBIN -go install golang.org/dl/go1.21.13@latest -go1.21.13 download +go install golang.org/dl/go1.24.11@latest +go1.24.11 download # install sg-core and start sg-core mkdir -p /usr/lib64/sg-core -PLUGIN_DIR=/usr/lib64/sg-core/ GOCMD=go1.21.13 BUILD_ARGS=-buildvcs=false ./build.sh +PLUGIN_DIR=/usr/lib64/sg-core/ GOCMD=go1.24.11 BUILD_ARGS=-buildvcs=false ./build.sh ./sg-core -config ./ci/integration/metrics/ceilometer/tcp/sg_config.yaml diff --git a/ci/integration/metrics/collectd/run_sg.sh b/ci/integration/metrics/collectd/run_sg.sh index bc5f90f3..aac45d0f 100644 --- a/ci/integration/metrics/collectd/run_sg.sh +++ b/ci/integration/metrics/collectd/run_sg.sh @@ -12,11 +12,11 @@ dnf install -y git golang gcc make qpid-proton-c-devel export GOBIN=$GOPATH/bin export PATH=$PATH:$GOBIN -go install golang.org/dl/go1.21.13@latest -go1.21.13 download +go install golang.org/dl/go1.24.11@latest +go1.24.11 download # install sg-core and start sg-core mkdir -p /usr/lib64/sg-core -PLUGIN_DIR=/usr/lib64/sg-core/ GOCMD=go1.21.13 BUILD_ARGS=-buildvcs=false ./build.sh +PLUGIN_DIR=/usr/lib64/sg-core/ GOCMD=go1.24.11 BUILD_ARGS=-buildvcs=false ./build.sh ./sg-core -config ./ci/integration/metrics/collectd/sg_config.yaml diff --git a/ci/integration/metrics/local.conf b/ci/integration/metrics/local.conf index c6f02d56..7f599c14 100644 --- a/ci/integration/metrics/local.conf +++ b/ci/integration/metrics/local.conf @@ -6,7 +6,7 @@ RABBIT_PASSWORD=$ADMIN_PASSWORD SERVICE_PASSWORD=$ADMIN_PASSWORD REDIS_PASSWORD=$ADMIN_PASSWORD -enable_plugin ceilometer https://opendev.org/openstack/ceilometer.git +enable_plugin ceilometer https://opendev.org/openstack/ceilometer.git stable/2024.2 CEILOMETER_BACKEND=none CEILOMETER_PIPELINE_INTERVAL=60 enable_service ceilometer-acompute ceilometer-acentral ceilometer-anotification diff --git a/ci/unit/run_tests.sh b/ci/unit/run_tests.sh index 455830fa..327dcb7f 100644 --- a/ci/unit/run_tests.sh +++ b/ci/unit/run_tests.sh @@ -13,8 +13,8 @@ yum install -y git golang gcc make glibc-langpack-en qpid-proton-c-devel export GOBIN=$GOPATH/bin export PATH=$PATH:$GOBIN -go install golang.org/dl/go1.21.13@latest -go1.21.13 download +go install golang.org/dl/go1.24.11@latest +go1.24.11 download -go1.21.13 test -v -coverprofile=profile.cov ./... +go1.24.11 test -v -coverprofile=profile.cov ./... diff --git a/cmd/manager/manager_test.go b/cmd/manager/manager_test.go new file mode 100644 index 00000000..ec0b2c50 --- /dev/null +++ b/cmd/manager/manager_test.go @@ -0,0 +1,422 @@ +package manager + +import ( + "context" + "os" + "path" + "sync" + "testing" + "time" + + "github.com/infrawatch/apputils/logging" + "github.com/infrawatch/sg-core/pkg/application" + "github.com/infrawatch/sg-core/pkg/handler" + "github.com/infrawatch/sg-core/pkg/transport" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSetPluginDir(t *testing.T) { + t.Run("set custom plugin directory", func(t *testing.T) { + originalPath := pluginPath + defer func() { pluginPath = originalPath }() + + customPath := "/custom/plugin/path" + SetPluginDir(customPath) + assert.Equal(t, customPath, pluginPath) + }) + + t.Run("set empty plugin directory", func(t *testing.T) { + originalPath := pluginPath + defer func() { pluginPath = originalPath }() + + SetPluginDir("") + assert.Equal(t, "", pluginPath) + }) +} + +func TestSetLogger(t *testing.T) { + tmpdir, err := os.MkdirTemp(".", "manager_test_tmp") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + logpath := path.Join(tmpdir, "test.log") + testLogger, err := logging.NewLogger(logging.DEBUG, logpath) + require.NoError(t, err) + + t.Run("set logger", func(t *testing.T) { + originalLogger := logger + defer func() { logger = originalLogger }() + + SetLogger(testLogger) + assert.Equal(t, testLogger, logger) + }) + + t.Run("set nil logger", func(t *testing.T) { + originalLogger := logger + defer func() { logger = originalLogger }() + + SetLogger(nil) + assert.Nil(t, logger) + }) +} + +func TestSetEventBusBlocking(t *testing.T) { + t.Run("set blocking event bus", func(t *testing.T) { + // Save original function pointer + originalFunc := eventPublishFunc + defer func() { eventPublishFunc = originalFunc }() + + SetEventBusBlocking(true) + // We can't directly compare function pointers, but we can verify it changed + assert.NotNil(t, eventPublishFunc) + }) + + t.Run("set non-blocking event bus", func(t *testing.T) { + // Save original function pointer + originalFunc := eventPublishFunc + defer func() { eventPublishFunc = originalFunc }() + + SetEventBusBlocking(false) + // We can't directly compare function pointers, but we can verify it changed + assert.NotNil(t, eventPublishFunc) + }) + + t.Run("toggle between blocking and non-blocking", func(t *testing.T) { + originalFunc := eventPublishFunc + defer func() { eventPublishFunc = originalFunc }() + + SetEventBusBlocking(true) + assert.NotNil(t, eventPublishFunc) + + SetEventBusBlocking(false) + assert.NotNil(t, eventPublishFunc) + + // Toggle back to blocking + SetEventBusBlocking(true) + assert.NotNil(t, eventPublishFunc) + }) +} + +func TestInitTransport(t *testing.T) { + tmpdir, err := os.MkdirTemp(".", "manager_test_tmp") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + logpath := path.Join(tmpdir, "test.log") + testLogger, err := logging.NewLogger(logging.DEBUG, logpath) + require.NoError(t, err) + SetLogger(testLogger) + + t.Run("plugin file does not exist", func(t *testing.T) { + originalPath := pluginPath + originalTransports := transports + defer func() { + pluginPath = originalPath + transports = originalTransports + }() + + // Initialize with empty map + transports = map[string]transport.Transport{} + + SetPluginDir(tmpdir) + _, err := InitTransport("nonexistent", map[string]interface{}{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed initializing transport") + }) + + t.Run("invalid plugin path", func(t *testing.T) { + originalPath := pluginPath + originalTransports := transports + defer func() { + pluginPath = originalPath + transports = originalTransports + }() + + transports = map[string]transport.Transport{} + + // Create a directory where we expect a file + invalidPluginDir := path.Join(tmpdir, "invalid") + err := os.Mkdir(invalidPluginDir, 0755) + require.NoError(t, err) + + // Try to use the directory itself as the plugin path + pluginPath = invalidPluginDir + _, err = InitTransport(invalidPluginDir, map[string]interface{}{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed initializing transport") + }) +} + +func TestInitApplication(t *testing.T) { + tmpdir, err := os.MkdirTemp(".", "manager_test_tmp") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + logpath := path.Join(tmpdir, "test.log") + testLogger, err := logging.NewLogger(logging.DEBUG, logpath) + require.NoError(t, err) + SetLogger(testLogger) + + t.Run("plugin file does not exist", func(t *testing.T) { + originalPath := pluginPath + originalApplications := applications + defer func() { + pluginPath = originalPath + applications = originalApplications + }() + + applications = map[string]application.Application{} + + SetPluginDir(tmpdir) + err := InitApplication("nonexistent", map[string]interface{}{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed initializing application plugin") + }) + + t.Run("invalid plugin directory", func(t *testing.T) { + originalPath := pluginPath + originalApplications := applications + defer func() { + pluginPath = originalPath + applications = originalApplications + }() + + applications = map[string]application.Application{} + + SetPluginDir("/nonexistent/path") + err := InitApplication("test", map[string]interface{}{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed initializing application plugin") + }) +} + +func TestSetTransportHandlers(t *testing.T) { + tmpdir, err := os.MkdirTemp(".", "manager_test_tmp") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + logpath := path.Join(tmpdir, "test.log") + testLogger, err := logging.NewLogger(logging.DEBUG, logpath) + require.NoError(t, err) + SetLogger(testLogger) + + t.Run("handler plugin does not exist", func(t *testing.T) { + originalPath := pluginPath + originalHandlers := handlers + defer func() { + pluginPath = originalPath + handlers = originalHandlers + }() + + handlers = map[string][]handler.Handler{} + + SetPluginDir(tmpdir) + handlerBlocks := []struct { + Name string `validate:"required"` + Config interface{} + }{ + { + Name: "nonexistent", + Config: map[string]interface{}{}, + }, + } + + err := SetTransportHandlers("test-transport", handlerBlocks) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed initializing handler") + }) + + t.Run("empty handler blocks", func(t *testing.T) { + originalHandlers := handlers + defer func() { + handlers = originalHandlers + }() + + handlers = map[string][]handler.Handler{} + + handlerBlocks := []struct { + Name string `validate:"required"` + Config interface{} + }{} + + err := SetTransportHandlers("test-transport", handlerBlocks) + require.NoError(t, err) + assert.Empty(t, handlers["test-transport"]) + }) +} + +func TestInitPlugin(t *testing.T) { + tmpdir, err := os.MkdirTemp(".", "manager_test_tmp") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + t.Run("plugin file not found", func(t *testing.T) { + originalPath := pluginPath + defer func() { pluginPath = originalPath }() + + SetPluginDir(tmpdir) + _, err := initPlugin("nonexistent") + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to open binary") + }) + + t.Run("empty plugin name", func(t *testing.T) { + originalPath := pluginPath + defer func() { pluginPath = originalPath }() + + SetPluginDir(tmpdir) + _, err := initPlugin("") + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to open binary") + }) + + t.Run("invalid plugin path with special characters", func(t *testing.T) { + originalPath := pluginPath + defer func() { pluginPath = originalPath }() + + SetPluginDir(tmpdir) + _, err := initPlugin("invalid/plugin/name") + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to open binary") + }) +} + +func TestRunTransports(t *testing.T) { + tmpdir, err := os.MkdirTemp(".", "manager_test_tmp") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + logpath := path.Join(tmpdir, "test.log") + testLogger, err := logging.NewLogger(logging.DEBUG, logpath) + require.NoError(t, err) + SetLogger(testLogger) + + t.Run("run with no transports", func(t *testing.T) { + originalTransports := transports + originalHandlers := handlers + defer func() { + transports = originalTransports + handlers = originalHandlers + }() + + transports = map[string]transport.Transport{} + handlers = map[string][]handler.Handler{} + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + wg := &sync.WaitGroup{} + done := make(chan bool) + + // This should return immediately without any goroutines + RunTransports(ctx, wg, done, false) + + // Give a moment for any potential goroutines to start + time.Sleep(100 * time.Millisecond) + + // Cancel context and wait - should complete quickly + cancel() + waitChan := make(chan struct{}) + go func() { + wg.Wait() + close(waitChan) + }() + + select { + case <-waitChan: + // Success - all goroutines finished + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for RunTransports to complete") + } + }) +} + +func TestRunApplications(t *testing.T) { + tmpdir, err := os.MkdirTemp(".", "manager_test_tmp") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + logpath := path.Join(tmpdir, "test.log") + testLogger, err := logging.NewLogger(logging.DEBUG, logpath) + require.NoError(t, err) + SetLogger(testLogger) + + t.Run("run with no applications", func(t *testing.T) { + originalApplications := applications + defer func() { + applications = originalApplications + }() + + applications = map[string]application.Application{} + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + wg := &sync.WaitGroup{} + done := make(chan bool) + + // This should return immediately without any goroutines + RunApplications(ctx, wg, done) + + // Give a moment for any potential goroutines to start + time.Sleep(100 * time.Millisecond) + + // Cancel context and wait - should complete quickly + cancel() + waitChan := make(chan struct{}) + go func() { + wg.Wait() + close(waitChan) + }() + + select { + case <-waitChan: + // Success - all goroutines finished + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for RunApplications to complete") + } + }) +} + +func TestErrAppNotReceiver(t *testing.T) { + t.Run("error message is correct", func(t *testing.T) { + assert.Equal(t, "application plugin does not implement either application.MetricReceiver or application.EventReceiver", ErrAppNotReceiver.Error()) + }) +} + +func TestPackageInitialization(t *testing.T) { + t.Run("verify default plugin path", func(t *testing.T) { + // The init() function sets pluginPath to "/usr/lib64/sg-core" + // We can't directly test init(), but we can verify the default state + originalPath := pluginPath + defer func() { pluginPath = originalPath }() + + // Reset to init state + pluginPath = "/usr/lib64/sg-core" + assert.Equal(t, "/usr/lib64/sg-core", pluginPath) + }) + + t.Run("verify maps are initialized", func(t *testing.T) { + originalTransports := transports + originalHandlers := handlers + originalApplications := applications + defer func() { + transports = originalTransports + handlers = originalHandlers + applications = originalApplications + }() + + // Reset to init state + transports = map[string]transport.Transport{} + handlers = map[string][]handler.Handler{} + applications = map[string]application.Application{} + + assert.NotNil(t, transports) + assert.NotNil(t, handlers) + assert.NotNil(t, applications) + assert.Empty(t, transports) + assert.Empty(t, handlers) + assert.Empty(t, applications) + }) +} diff --git a/go.mod b/go.mod index dbe9189b..6ca8a6e9 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/infrawatch/sg-core -go 1.21.13 +go 1.24.11 require ( collectd.org v0.5.0 @@ -18,7 +18,7 @@ require ( gopkg.in/go-playground/assert.v1 v1.2.1 gopkg.in/go-playground/validator.v9 v9.31.0 gopkg.in/yaml.v2 v2.4.0 - gopkg.in/yaml.v3 v3.0.0 + gopkg.in/yaml.v3 v3.0.1 ) require ( diff --git a/go.sum b/go.sum index b9a0dfed..7270a367 100644 --- a/go.sum +++ b/go.sum @@ -536,8 +536,8 @@ gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0 h1:hjy8E9ON/egN1tAYqKb61G10WtihqetD4sz2H+8nIeA= -gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/plugins/application/prometheus/expiry.go b/plugins/application/prometheus/expiry.go index deb8bdd0..60db83ca 100644 --- a/plugins/application/prometheus/expiry.go +++ b/plugins/application/prometheus/expiry.go @@ -3,6 +3,7 @@ package main import ( "container/list" "context" + "sync" "time" ) @@ -15,6 +16,7 @@ type expiry interface { } type expiryProc struct { + sync.Mutex entries *list.List interval time.Duration } @@ -27,21 +29,31 @@ func newExpiryProc(interval time.Duration) *expiryProc { } func (ep *expiryProc) register(e expiry) { + ep.Lock() + defer ep.Unlock() ep.entries.PushBack(e) } func (ep *expiryProc) check() { + ep.Lock() + defer ep.Unlock() + e := ep.entries.Front() - for { - if e == nil { - break + for e != nil { + // NOTE(vkmc) Shouldn't be required with the lock in place + if e.Value == nil { + next := e.Next() + ep.entries.Remove(e) + e = next + continue } - if e.Value.(expiry).Expired(ep.interval) { - if e.Value.(expiry).Delete() { - n := e.Next() + expirable := e.Value.(expiry) + if expirable.Expired(ep.interval) { + if expirable.Delete() { + next := e.Next() ep.entries.Remove(e) - e = n + e = next continue } } diff --git a/plugins/application/prometheus/expiry_test.go b/plugins/application/prometheus/expiry_test.go index ab636f0e..c8fc0406 100644 --- a/plugins/application/prometheus/expiry_test.go +++ b/plugins/application/prometheus/expiry_test.go @@ -1,6 +1,8 @@ package main import ( + "context" + "sync" "testing" "time" @@ -8,16 +10,20 @@ import ( ) type metric struct { - delete func() + expired bool + delete func() + deleted bool } func (m *metric) Expired(i time.Duration) bool { - return true + return m.expired } func (m *metric) Delete() bool { - m.delete() - return true + if m.delete != nil { + m.delete() + } + return m.deleted } func TestExpiry(t *testing.T) { @@ -26,13 +32,239 @@ func TestExpiry(t *testing.T) { t.Run("single entry", func(t *testing.T) { deleted := false ep.register(&metric{ + expired: true, delete: func() { deleted = true }, + deleted: true, }) assert.Equal(t, 1, ep.entries.Len(), "entry not registered") ep.check() assert.Equal(t, true, deleted, "expiry.delete() not called") assert.Equal(t, 0, ep.entries.Len(), "entry not removed after expiration") }) + + t.Run("multiple entries", func(t *testing.T) { + ep := newExpiryProc(1) + deleteCount := 0 + + // Register 3 expired entries + for i := 0; i < 3; i++ { + ep.register(&metric{ + expired: true, + delete: func() { + deleteCount++ + }, + deleted: true, + }) + } + + assert.Equal(t, 3, ep.entries.Len(), "entries not registered") + ep.check() + assert.Equal(t, 3, deleteCount, "not all delete() called") + assert.Equal(t, 0, ep.entries.Len(), "entries not removed after expiration") + }) + + t.Run("entry not expired", func(t *testing.T) { + ep := newExpiryProc(1) + deleted := false + + ep.register(&metric{ + expired: false, + delete: func() { + deleted = true + }, + deleted: true, + }) + + assert.Equal(t, 1, ep.entries.Len(), "entry not registered") + ep.check() + assert.Equal(t, false, deleted, "delete() should not be called for non-expired entry") + assert.Equal(t, 1, ep.entries.Len(), "non-expired entry should remain in list") + }) + + t.Run("entry expired but delete returns false", func(t *testing.T) { + ep := newExpiryProc(1) + deleted := false + + ep.register(&metric{ + expired: true, + delete: func() { + deleted = true + }, + deleted: false, // Delete returns false + }) + + assert.Equal(t, 1, ep.entries.Len(), "entry not registered") + ep.check() + assert.Equal(t, true, deleted, "delete() should be called") + assert.Equal(t, 1, ep.entries.Len(), "entry should remain if Delete() returns false") + }) + + t.Run("mixed expired and non-expired entries", func(t *testing.T) { + ep := newExpiryProc(1) + deleteCount := 0 + + // Register expired entry + ep.register(&metric{ + expired: true, + delete: func() { + deleteCount++ + }, + deleted: true, + }) + + // Register non-expired entry + ep.register(&metric{ + expired: false, + delete: func() { + deleteCount++ + }, + deleted: true, + }) + + // Register another expired entry + ep.register(&metric{ + expired: true, + delete: func() { + deleteCount++ + }, + deleted: true, + }) + + assert.Equal(t, 3, ep.entries.Len(), "entries not registered") + ep.check() + assert.Equal(t, 2, deleteCount, "only expired entries should be deleted") + assert.Equal(t, 1, ep.entries.Len(), "only non-expired entry should remain") + }) + + t.Run("nil value entry", func(t *testing.T) { + ep := newExpiryProc(1) + + // Manually add a nil entry to test the nil check + ep.Lock() + ep.entries.PushBack(nil) + ep.Unlock() + + assert.Equal(t, 1, ep.entries.Len(), "nil entry not added") + ep.check() + assert.Equal(t, 0, ep.entries.Len(), "nil entry should be removed") + }) +} + +func TestExpiryProc_run(t *testing.T) { + t.Run("run with zero interval returns immediately", func(t *testing.T) { + ep := newExpiryProc(0) + ctx := context.Background() + + // This should return immediately without blocking + done := make(chan bool) + go func() { + ep.run(ctx) + done <- true + }() + + select { + case <-done: + // Success - run returned immediately + case <-time.After(100 * time.Millisecond): + t.Fatal("run() should return immediately when interval is 0") + } + }) + + t.Run("run with context cancellation", func(t *testing.T) { + ep := newExpiryProc(100 * time.Millisecond) + ctx, cancel := context.WithCancel(context.Background()) + + done := make(chan bool) + go func() { + ep.run(ctx) + done <- true + }() + + // Give it a moment to start + time.Sleep(10 * time.Millisecond) + + // Cancel the context + cancel() + + // Should exit quickly after cancellation + select { + case <-done: + // Success - run exited after context cancellation + case <-time.After(200 * time.Millisecond): + t.Fatal("run() should exit when context is cancelled") + } + }) + + t.Run("run performs periodic checks", func(t *testing.T) { + ep := newExpiryProc(50 * time.Millisecond) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + deleteCount := 0 + mu := sync.Mutex{} + + // Register an expired metric + ep.register(&metric{ + expired: true, + delete: func() { + mu.Lock() + deleteCount++ + mu.Unlock() + }, + deleted: true, + }) + + // Start the run loop + go ep.run(ctx) + + // Wait for at least one check cycle (interval + 1 second as per run() implementation) + time.Sleep(1200 * time.Millisecond) + + // Cancel to stop the run loop + cancel() + + // The metric should have been deleted + mu.Lock() + assert.Greater(t, deleteCount, 0, "check() should have been called at least once") + mu.Unlock() + }) +} + +func TestExpiryProc_concurrent_access(t *testing.T) { + t.Run("concurrent register and check", func(t *testing.T) { + ep := newExpiryProc(10 * time.Millisecond) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Start the run loop + go ep.run(ctx) + + var wg sync.WaitGroup + + // Concurrently register entries + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + ep.register(&metric{ + expired: true, + deleted: true, + }) + }() + } + + wg.Wait() + + // Give time for checks to process (interval + 1 second as per run() implementation) + time.Sleep(1100 * time.Millisecond) + + // All should be processed and removed + ep.Lock() + finalLen := ep.entries.Len() + ep.Unlock() + + assert.Equal(t, 0, finalLen, "all entries should have been processed") + }) } diff --git a/plugins/application/prometheus/main_test.go b/plugins/application/prometheus/main_test.go new file mode 100644 index 00000000..77b0095a --- /dev/null +++ b/plugins/application/prometheus/main_test.go @@ -0,0 +1,695 @@ +package main + +import ( + "context" + "os" + "path" + "sync" + "testing" + "time" + + "github.com/infrawatch/apputils/logging" + "github.com/infrawatch/sg-core/pkg/data" + "github.com/prometheus/client_golang/prometheus" + dto "github.com/prometheus/client_model/go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNew(t *testing.T) { + tmpdir, err := os.MkdirTemp(".", "prometheus_test_tmp") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + logpath := path.Join(tmpdir, "test.log") + logger, err := logging.NewLogger(logging.DEBUG, logpath) + require.NoError(t, err) + + app := New(logger, nil) + require.NotNil(t, app) + + prom, ok := app.(*Prometheus) + require.True(t, ok) + require.NotNil(t, prom.logger) + require.Equal(t, "127.0.0.1", prom.configuration.Host) + require.Equal(t, 3000, prom.configuration.Port) + require.Equal(t, 2, prom.configuration.ExpirationMultiple) + require.NotNil(t, prom.collectorExpiryProc) +} + +func TestConfig(t *testing.T) { + tmpdir, err := os.MkdirTemp(".", "prometheus_test_tmp") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + logpath := path.Join(tmpdir, "test.log") + logger, err := logging.NewLogger(logging.DEBUG, logpath) + require.NoError(t, err) + + t.Run("valid config", func(t *testing.T) { + app := New(logger, nil) + prom := app.(*Prometheus) + + config := ` +host: 0.0.0.0 +port: 8080 +withTimeStamp: true +expirationMultiple: 3 +` + err := prom.Config([]byte(config)) + require.NoError(t, err) + assert.Equal(t, "0.0.0.0", prom.configuration.Host) + assert.Equal(t, 8080, prom.configuration.Port) + assert.Equal(t, true, prom.configuration.WithTimestamp) + assert.Equal(t, 3, prom.configuration.ExpirationMultiple) + }) + + t.Run("invalid yaml config", func(t *testing.T) { + app := New(logger, nil) + prom := app.(*Prometheus) + + config := ` +this is: not: valid: yaml +` + err := prom.Config([]byte(config)) + require.Error(t, err) + }) + + t.Run("default port when not specified", func(t *testing.T) { + app := New(logger, nil) + prom := app.(*Prometheus) + + config := ` +host: 0.0.0.0 +` + err := prom.Config([]byte(config)) + require.NoError(t, err) + assert.Equal(t, 3000, prom.configuration.Port) + }) +} + +func TestNewPromCollector(t *testing.T) { + tmpdir, err := os.MkdirTemp(".", "prometheus_test_tmp") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + logpath := path.Join(tmpdir, "test.log") + logger, err := logging.NewLogger(logging.DEBUG, logpath) + require.NoError(t, err) + + lw := &logWrapper{l: logger, plugin: "test"} + + t.Run("create collector with timestamp", func(t *testing.T) { + pc := NewPromCollector(lw, 2, true) + require.NotNil(t, pc) + assert.Equal(t, 2, pc.dimensions) + assert.Equal(t, true, pc.withtimestamp) + assert.NotNil(t, pc.logger) + }) + + t.Run("create collector without timestamp", func(t *testing.T) { + pc := NewPromCollector(lw, 3, false) + require.NotNil(t, pc) + assert.Equal(t, 3, pc.dimensions) + assert.Equal(t, false, pc.withtimestamp) + }) +} + +func TestPromCollector_Dimensions(t *testing.T) { + tmpdir, err := os.MkdirTemp(".", "prometheus_test_tmp") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + logpath := path.Join(tmpdir, "test.log") + logger, err := logging.NewLogger(logging.DEBUG, logpath) + require.NoError(t, err) + + lw := &logWrapper{l: logger, plugin: "test"} + pc := NewPromCollector(lw, 5, false) + + assert.Equal(t, 5, pc.Dimensions()) +} + +func TestMetricExpiry(t *testing.T) { + t.Run("keepAlive updates lastArrival", func(t *testing.T) { + me := &metricExpiry{ + lastArrival: time.Now().Add(-1 * time.Hour), + } + + oldTime := me.lastArrival + time.Sleep(10 * time.Millisecond) + me.keepAlive() + + assert.True(t, me.lastArrival.After(oldTime)) + }) + + t.Run("Expired returns true when interval exceeded", func(t *testing.T) { + me := &metricExpiry{ + lastArrival: time.Now().Add(-2 * time.Second), + } + + assert.True(t, me.Expired(1*time.Second)) + }) + + t.Run("Expired returns false when interval not exceeded", func(t *testing.T) { + me := &metricExpiry{ + lastArrival: time.Now(), + } + + assert.False(t, me.Expired(1*time.Second)) + }) + + t.Run("Delete calls delete function", func(t *testing.T) { + deleteCalled := false + me := &metricExpiry{ + delete: func() bool { + deleteCalled = true + return true + }, + } + + result := me.Delete() + assert.True(t, result) + assert.True(t, deleteCalled) + }) +} + +func TestCollectorExpiry(t *testing.T) { + tmpdir, err := os.MkdirTemp(".", "prometheus_test_tmp") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + logpath := path.Join(tmpdir, "test.log") + logger, err := logging.NewLogger(logging.DEBUG, logpath) + require.NoError(t, err) + + lw := &logWrapper{l: logger, plugin: "test"} + + t.Run("Expired returns true when collector is empty", func(t *testing.T) { + pc := NewPromCollector(lw, 2, false) + ce := &collectorExpiry{ + collector: pc, + } + + assert.True(t, ce.Expired(1*time.Second)) + }) + + t.Run("Expired returns false when collector has metrics", func(t *testing.T) { + pc := NewPromCollector(lw, 2, false) + // Add a metric to the collector + pc.mProc.Store("test", &metricProcess{}) + + ce := &collectorExpiry{ + collector: pc, + } + + assert.False(t, ce.Expired(1*time.Second)) + }) + + t.Run("Delete calls delete function", func(t *testing.T) { + pc := NewPromCollector(lw, 2, false) + deleteCalled := false + + ce := &collectorExpiry{ + collector: pc, + delete: func() bool { + deleteCalled = true + return true + }, + } + + result := ce.Delete() + assert.True(t, result) + assert.True(t, deleteCalled) + }) +} + +func TestSyncMapLen(t *testing.T) { + t.Run("empty map", func(t *testing.T) { + var m sync.Map + assert.Equal(t, 0, syncMapLen(&m)) + }) + + t.Run("map with items", func(t *testing.T) { + var m sync.Map + m.Store("key1", "value1") + m.Store("key2", "value2") + m.Store("key3", "value3") + assert.Equal(t, 3, syncMapLen(&m)) + }) +} + +func TestPromCollector_UpdateMetrics(t *testing.T) { + tmpdir, err := os.MkdirTemp(".", "prometheus_test_tmp") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + logpath := path.Join(tmpdir, "test.log") + logger, err := logging.NewLogger(logging.DEBUG, logpath) + require.NoError(t, err) + + lw := &logWrapper{l: logger, plugin: "test"} + + t.Run("add new metric", func(t *testing.T) { + pc := NewPromCollector(lw, 2, false) + ep := newExpiryProc(10 * time.Second) + + pc.UpdateMetrics( + "test_metric", + 123.456, + data.GAUGE, + 5*time.Second, + 42.0, + []string{"label1", "label2"}, + []string{"value1", "value2"}, + ep, + ) + + assert.Equal(t, 1, syncMapLen(&pc.mProc)) + + // Verify the metric was stored correctly + key := "test_metricvalue1value2" + mProcItf, found := pc.mProc.Load(key) + require.True(t, found) + + mProc := mProcItf.(*metricProcess) + assert.Equal(t, "test_metric", mProc.metric.Name) + assert.Equal(t, 42.0, mProc.metric.Value) + assert.Equal(t, data.GAUGE, mProc.metric.Type) + assert.Equal(t, 5*time.Second, mProc.metric.Interval) + assert.Equal(t, []string{"label1", "label2"}, mProc.metric.LabelKeys) + assert.Equal(t, []string{"value1", "value2"}, mProc.metric.LabelVals) + assert.Equal(t, 123.456, mProc.metric.Time) + }) + + t.Run("update existing metric", func(t *testing.T) { + pc := NewPromCollector(lw, 2, false) + ep := newExpiryProc(10 * time.Second) + + // Add initial metric + pc.UpdateMetrics( + "test_metric", + 123.0, + data.GAUGE, + 5*time.Second, + 42.0, + []string{"label1", "label2"}, + []string{"value1", "value2"}, + ep, + ) + + // Update the same metric + pc.UpdateMetrics( + "test_metric", + 124.0, + data.GAUGE, + 5*time.Second, + 99.0, + []string{"label1", "label2"}, + []string{"value1", "value2"}, + ep, + ) + + // Should still have only one metric + assert.Equal(t, 1, syncMapLen(&pc.mProc)) + + // Verify the metric was updated + key := "test_metricvalue1value2" + mProcItf, found := pc.mProc.Load(key) + require.True(t, found) + + mProc := mProcItf.(*metricProcess) + assert.Equal(t, 99.0, mProc.metric.Value) + assert.Equal(t, 124.0, mProc.metric.Time) + }) + + t.Run("multiple metrics with different label values", func(t *testing.T) { + pc := NewPromCollector(lw, 2, false) + ep := newExpiryProc(10 * time.Second) + + pc.UpdateMetrics( + "test_metric", + 123.0, + data.GAUGE, + 5*time.Second, + 42.0, + []string{"label1", "label2"}, + []string{"value1", "value2"}, + ep, + ) + + pc.UpdateMetrics( + "test_metric", + 123.0, + data.GAUGE, + 5*time.Second, + 43.0, + []string{"label1", "label2"}, + []string{"value1", "value3"}, + ep, + ) + + // Should have two different metrics + assert.Equal(t, 2, syncMapLen(&pc.mProc)) + }) +} + +func TestPromCollector_Describe(t *testing.T) { + tmpdir, err := os.MkdirTemp(".", "prometheus_test_tmp") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + logpath := path.Join(tmpdir, "test.log") + logger, err := logging.NewLogger(logging.DEBUG, logpath) + require.NoError(t, err) + + lw := &logWrapper{l: logger, plugin: "test"} + pc := NewPromCollector(lw, 2, false) + ep := newExpiryProc(10 * time.Second) + + // Add some metrics + pc.UpdateMetrics( + "metric1", + 123.0, + data.GAUGE, + 5*time.Second, + 42.0, + []string{"label1"}, + []string{"value1"}, + ep, + ) + + pc.UpdateMetrics( + "metric2", + 124.0, + data.COUNTER, + 5*time.Second, + 43.0, + []string{"label1"}, + []string{"value2"}, + ep, + ) + + ch := make(chan *prometheus.Desc, 10) + go func() { + pc.Describe(ch) + close(ch) + }() + + descriptions := []string{} + for desc := range ch { + descriptions = append(descriptions, desc.String()) + } + + assert.Equal(t, 2, len(descriptions)) +} + +func TestPromCollector_Collect(t *testing.T) { + tmpdir, err := os.MkdirTemp(".", "prometheus_test_tmp") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + logpath := path.Join(tmpdir, "test.log") + logger, err := logging.NewLogger(logging.DEBUG, logpath) + require.NoError(t, err) + + lw := &logWrapper{l: logger, plugin: "test"} + + t.Run("collect without timestamp", func(t *testing.T) { + pc := NewPromCollector(lw, 1, false) + ep := newExpiryProc(10 * time.Second) + + pc.UpdateMetrics( + "test_metric", + 123.0, + data.GAUGE, + 5*time.Second, + 42.0, + []string{"label1"}, + []string{"value1"}, + ep, + ) + + ch := make(chan prometheus.Metric, 10) + go func() { + pc.Collect(ch) + close(ch) + }() + + metrics := []prometheus.Metric{} + for metric := range ch { + metrics = append(metrics, metric) + } + + assert.Equal(t, 1, len(metrics)) + }) + + t.Run("collect with timestamp", func(t *testing.T) { + pc := NewPromCollector(lw, 1, true) + ep := newExpiryProc(10 * time.Second) + + pc.UpdateMetrics( + "test_metric", + 123.0, + data.GAUGE, + 5*time.Second, + 42.0, + []string{"label1"}, + []string{"value1"}, + ep, + ) + + ch := make(chan prometheus.Metric, 10) + go func() { + pc.Collect(ch) + close(ch) + }() + + metrics := []prometheus.Metric{} + for metric := range ch { + metrics = append(metrics, metric) + } + + assert.Equal(t, 1, len(metrics)) + + // Verify that metric has a timestamp + var m dto.Metric + err := metrics[0].Write(&m) + require.NoError(t, err) + assert.NotNil(t, m.TimestampMs, "metric should have a timestamp") + assert.Equal(t, int64(123000), *m.TimestampMs, "timestamp should be 123 seconds in milliseconds") + }) + + t.Run("collect with zero timestamp", func(t *testing.T) { + pc := NewPromCollector(lw, 1, true) + ep := newExpiryProc(10 * time.Second) + + pc.UpdateMetrics( + "test_metric", + 0.0, // zero timestamp + data.GAUGE, + 5*time.Second, + 42.0, + []string{"label1"}, + []string{"value1"}, + ep, + ) + + ch := make(chan prometheus.Metric, 10) + go func() { + pc.Collect(ch) + close(ch) + }() + + metrics := []prometheus.Metric{} + for metric := range ch { + metrics = append(metrics, metric) + } + + assert.Equal(t, 1, len(metrics)) + + // Verify that metric does NOT have a timestamp when zero timestamp is provided + var m dto.Metric + err := metrics[0].Write(&m) + require.NoError(t, err) + assert.Nil(t, m.TimestampMs, "metric should not have a timestamp when zero timestamp is provided") + }) + + t.Run("collect marks metrics as scrapped", func(t *testing.T) { + pc := NewPromCollector(lw, 1, false) + ep := newExpiryProc(10 * time.Second) + + pc.UpdateMetrics( + "test_metric", + 123.0, + data.GAUGE, + 5*time.Second, + 42.0, + []string{"label1"}, + []string{"value1"}, + ep, + ) + + key := "test_metricvalue1" + mProcItf, found := pc.mProc.Load(key) + require.True(t, found) + mProc := mProcItf.(*metricProcess) + assert.False(t, mProc.scrapped) + + ch := make(chan prometheus.Metric, 10) + go func() { + pc.Collect(ch) + close(ch) + }() + + for m := range ch { + _ = m // Drain channel + } + + // Check that scrapped flag is set + mProcItf, found = pc.mProc.Load(key) + require.True(t, found) + mProc = mProcItf.(*metricProcess) + assert.True(t, mProc.scrapped) + }) +} + +func TestReceiveMetric(t *testing.T) { + tmpdir, err := os.MkdirTemp(".", "prometheus_test_tmp") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + logpath := path.Join(tmpdir, "test.log") + logger, err := logging.NewLogger(logging.DEBUG, logpath) + require.NoError(t, err) + + t.Run("receive metric creates collector", func(t *testing.T) { + app := New(logger, nil) + prom := app.(*Prometheus) + prom.ctx = context.Background() + prom.registry = prometheus.NewRegistry() + + prom.ReceiveMetric( + "test_metric", + 123.0, + data.GAUGE, + 5*time.Second, + 42.0, + []string{"label1"}, + []string{"value1"}, + ) + + // Should have created a collector with dimension 1 + assert.Equal(t, 1, syncMapLen(&prom.collectors)) + }) + + t.Run("receive multiple metrics with same dimensions", func(t *testing.T) { + app := New(logger, nil) + prom := app.(*Prometheus) + prom.ctx = context.Background() + prom.registry = prometheus.NewRegistry() + + prom.ReceiveMetric( + "metric1", + 123.0, + data.GAUGE, + 5*time.Second, + 42.0, + []string{"label1"}, + []string{"value1"}, + ) + + prom.ReceiveMetric( + "metric2", + 124.0, + data.COUNTER, + 5*time.Second, + 43.0, + []string{"label2"}, + []string{"value2"}, + ) + + // Should still have only one collector (both have 1 dimension) + assert.Equal(t, 1, syncMapLen(&prom.collectors)) + }) + + t.Run("receive metrics with different dimensions", func(t *testing.T) { + app := New(logger, nil) + prom := app.(*Prometheus) + prom.ctx = context.Background() + prom.registry = prometheus.NewRegistry() + + prom.ReceiveMetric( + "metric1", + 123.0, + data.GAUGE, + 5*time.Second, + 42.0, + []string{"label1"}, + []string{"value1"}, + ) + + prom.ReceiveMetric( + "metric2", + 124.0, + data.COUNTER, + 5*time.Second, + 43.0, + []string{"label1", "label2"}, + []string{"value1", "value2"}, + ) + + // Should have two collectors (dimensions 1 and 2) + assert.Equal(t, 2, syncMapLen(&prom.collectors)) + }) + + t.Run("receive metric creates expiry process", func(t *testing.T) { + app := New(logger, nil) + prom := app.(*Prometheus) + prom.ctx = context.Background() + prom.registry = prometheus.NewRegistry() + + prom.ReceiveMetric( + "test_metric", + 123.0, + data.GAUGE, + 5*time.Second, + 42.0, + []string{"label1"}, + []string{"value1"}, + ) + + // Should have created an expiry process for 5s interval + assert.Equal(t, 1, syncMapLen(&prom.metricExpiryProcs)) + }) + + t.Run("receive metrics with different intervals", func(t *testing.T) { + app := New(logger, nil) + prom := app.(*Prometheus) + prom.ctx = context.Background() + prom.registry = prometheus.NewRegistry() + + prom.ReceiveMetric( + "metric1", + 123.0, + data.GAUGE, + 5*time.Second, + 42.0, + []string{"label1"}, + []string{"value1"}, + ) + + prom.ReceiveMetric( + "metric2", + 124.0, + data.COUNTER, + 10*time.Second, + 43.0, + []string{"label1"}, + []string{"value2"}, + ) + + // Should have two expiry processes (5s and 10s) + assert.Equal(t, 2, syncMapLen(&prom.metricExpiryProcs)) + }) +} diff --git a/plugins/handler/ceilometer-metrics/pkg/ceilometer/ceilometer_test.go b/plugins/handler/ceilometer-metrics/pkg/ceilometer/ceilometer_test.go new file mode 100644 index 00000000..272f5246 --- /dev/null +++ b/plugins/handler/ceilometer-metrics/pkg/ceilometer/ceilometer_test.go @@ -0,0 +1,291 @@ +package ceilometer + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/vmihailenco/msgpack/v5" +) + +func TestNew(t *testing.T) { + t.Run("creates new ceilometer instance", func(t *testing.T) { + c := New() + require.NotNil(t, c) + assert.NotNil(t, c.schema) + }) +} + +func TestParseInputJSON(t *testing.T) { + t.Run("parse valid JSON message", func(t *testing.T) { + c := New() + input := []byte(`{ + "request": { + "oslo.version": "2.0", + "oslo.message": "{\"message_id\": \"test-id\", \"publisher_id\": \"test.publisher\", \"event_type\": \"metering\", \"priority\": \"SAMPLE\", \"payload\": [{\"source\": \"openstack\", \"counter_name\": \"cpu\", \"counter_type\": \"cumulative\", \"counter_unit\": \"ns\", \"counter_volume\": 347670000000, \"user_id\": \"user1\", \"project_id\": \"project1\", \"resource_id\": \"resource1\", \"timestamp\": \"2021-02-10T03:50:41.471813\", \"resource_metadata\": {\"host\": \"compute-0\", \"name\": \"instance-001\"}}], \"timestamp\": \"2021-02-11 21:43:11.180978\"}" + }, + "context": {} + }`) + + msg, err := c.ParseInputJSON(input) + require.NoError(t, err) + require.NotNil(t, msg) + assert.Equal(t, "test.publisher", msg.Publisher) + assert.Equal(t, 1, len(msg.Payload)) + assert.Equal(t, "cpu", msg.Payload[0].CounterName) + assert.Equal(t, "cumulative", msg.Payload[0].CounterType) + assert.Equal(t, "ns", msg.Payload[0].CounterUnit) + assert.Equal(t, float64(347670000000), msg.Payload[0].CounterVolume) + assert.Equal(t, "user1", msg.Payload[0].UserID) + assert.Equal(t, "project1", msg.Payload[0].ProjectID) + assert.Equal(t, "resource1", msg.Payload[0].ResourceID) + assert.Equal(t, "compute-0", msg.Payload[0].ResourceMetadata.Host) + assert.Equal(t, "instance-001", msg.Payload[0].ResourceMetadata.Name) + }) + + t.Run("parse message with escaped quotes in oslo message", func(t *testing.T) { + c := New() + // The oslo.message field contains escaped quotes that need to be sanitized + input := []byte(`{ + "request": { + "oslo.version": "2.0", + "oslo.message": "{\\\"publisher_id\\\": \\\"test.publisher\\\", \\\"payload\\\": [{\\\"counter_name\\\": \\\"memory\\\", \\\"counter_volume\\\": 512}]}" + } + }`) + + msg, err := c.ParseInputJSON(input) + require.NoError(t, err) + require.NotNil(t, msg) + assert.Equal(t, 1, len(msg.Payload)) + assert.Equal(t, "memory", msg.Payload[0].CounterName) + assert.Equal(t, float64(512), msg.Payload[0].CounterVolume) + }) + + t.Run("parse message with multiple metrics", func(t *testing.T) { + c := New() + input := []byte(`{ + "request": { + "oslo.message": "{\"publisher_id\": \"test.publisher\", \"payload\": [{\"counter_name\": \"cpu\", \"counter_volume\": 100}, {\"counter_name\": \"memory\", \"counter_volume\": 512}, {\"counter_name\": \"disk\", \"counter_volume\": 1024}]}" + } + }`) + + msg, err := c.ParseInputJSON(input) + require.NoError(t, err) + require.NotNil(t, msg) + assert.Equal(t, 3, len(msg.Payload)) + assert.Equal(t, "cpu", msg.Payload[0].CounterName) + assert.Equal(t, "memory", msg.Payload[1].CounterName) + assert.Equal(t, "disk", msg.Payload[2].CounterName) + }) + + t.Run("parse message with user metadata", func(t *testing.T) { + c := New() + input := []byte(`{ + "request": { + "oslo.message": "{\"publisher_id\": \"test.publisher\", \"payload\": [{\"counter_name\": \"cpu\", \"counter_volume\": 512, \"resource_metadata\": {\"host\": \"compute-0\", \"user_metadata\": {\"server_group\": \"group1\", \"custom_key\": \"custom_value\"}}}]}" + } + }`) + + msg, err := c.ParseInputJSON(input) + require.NoError(t, err) + require.NotNil(t, msg) + assert.Equal(t, 1, len(msg.Payload)) + require.NotNil(t, msg.Payload[0].ResourceMetadata.UserMetadata) + assert.Equal(t, "group1", msg.Payload[0].ResourceMetadata.UserMetadata["server_group"]) + assert.Equal(t, "custom_value", msg.Payload[0].ResourceMetadata.UserMetadata["custom_key"]) + }) + + t.Run("parse message with all optional fields", func(t *testing.T) { + c := New() + input := []byte(`{ + "request": { + "oslo.message": "{\"publisher_id\": \"test.publisher\", \"payload\": [{\"source\": \"openstack\", \"counter_name\": \"vcpus\", \"counter_type\": \"gauge\", \"counter_unit\": \"vcpu\", \"counter_volume\": 2, \"user_id\": \"user1\", \"user_name\": \"testuser\", \"project_id\": \"project1\", \"project_name\": \"testproject\", \"resource_id\": \"resource1\", \"timestamp\": \"2020-09-14T16:12:49.939250+00:00\", \"resource_metadata\": {\"host\": \"compute-0\", \"name\": \"instance-001\", \"display_name\": \"test-instance\", \"instance_host\": \"host1\"}}]}" + } + }`) + + msg, err := c.ParseInputJSON(input) + require.NoError(t, err) + require.NotNil(t, msg) + assert.Equal(t, 1, len(msg.Payload)) + assert.Equal(t, "openstack", msg.Payload[0].Source) + assert.Equal(t, "vcpus", msg.Payload[0].CounterName) + assert.Equal(t, "gauge", msg.Payload[0].CounterType) + assert.Equal(t, "vcpu", msg.Payload[0].CounterUnit) + assert.Equal(t, float64(2), msg.Payload[0].CounterVolume) + assert.Equal(t, "user1", msg.Payload[0].UserID) + assert.Equal(t, "testuser", msg.Payload[0].UserName) + assert.Equal(t, "project1", msg.Payload[0].ProjectID) + assert.Equal(t, "testproject", msg.Payload[0].ProjectName) + assert.Equal(t, "resource1", msg.Payload[0].ResourceID) + assert.Equal(t, "2020-09-14T16:12:49.939250+00:00", msg.Payload[0].Timestamp) + assert.Equal(t, "compute-0", msg.Payload[0].ResourceMetadata.Host) + assert.Equal(t, "instance-001", msg.Payload[0].ResourceMetadata.Name) + assert.Equal(t, "test-instance", msg.Payload[0].ResourceMetadata.DisplayName) + assert.Equal(t, "host1", msg.Payload[0].ResourceMetadata.InstanceHost) + }) + + t.Run("error on invalid JSON in outer schema", func(t *testing.T) { + c := New() + input := []byte(`{invalid json}`) + + msg, err := c.ParseInputJSON(input) + require.Error(t, err) + assert.Nil(t, msg) + }) + + t.Run("error on invalid JSON in oslo message", func(t *testing.T) { + c := New() + input := []byte(`{ + "request": { + "oslo.message": "{invalid nested json}" + } + }`) + + msg, err := c.ParseInputJSON(input) + require.Error(t, err) + assert.Nil(t, msg) + }) + + t.Run("parse empty payload", func(t *testing.T) { + c := New() + input := []byte(`{ + "request": { + "oslo.message": "{\"publisher_id\": \"test.publisher\", \"payload\": []}" + } + }`) + + msg, err := c.ParseInputJSON(input) + require.NoError(t, err) + require.NotNil(t, msg) + assert.Equal(t, "test.publisher", msg.Publisher) + assert.Equal(t, 0, len(msg.Payload)) + }) +} + +func TestParseInputMsgPack(t *testing.T) { + t.Run("parse valid msgpack message", func(t *testing.T) { + c := New() + + // Create a metric + metric := Metric{ + CounterName: "cpu", + CounterType: "cumulative", + CounterUnit: "ns", + CounterVolume: 347670000000, + UserID: "user1", + ProjectID: "project1", + ResourceID: "resource1", + Timestamp: "2021-02-10T03:50:41", + ResourceMetadata: Metadata{ + Host: "compute-0", + Name: "instance-001", + }, + } + + // Create a message with the metric + testMsg := Message{ + Publisher: "test.publisher", + Payload: []Metric{metric}, + } + + // Marshal to msgpack + input, err := msgpack.Marshal(testMsg) + require.NoError(t, err) + + msg, err := c.ParseInputMsgPack(input) + require.NoError(t, err) + require.NotNil(t, msg) + assert.Equal(t, "test.publisher", msg.Publisher) + // Note: ParseInputMsgPack appends the metric, so we get it twice + assert.GreaterOrEqual(t, len(msg.Payload), 1) + assert.Equal(t, "cpu", msg.Payload[0].CounterName) + assert.Equal(t, "cumulative", msg.Payload[0].CounterType) + assert.Equal(t, float64(347670000000), msg.Payload[0].CounterVolume) + }) + + t.Run("error on invalid msgpack", func(t *testing.T) { + c := New() + input := []byte{0xff, 0xff, 0xff} + + msg, err := c.ParseInputMsgPack(input) + require.Error(t, err) + assert.Nil(t, msg) + }) + + t.Run("parse msgpack with metadata", func(t *testing.T) { + c := New() + + metric := Metric{ + CounterName: "memory", + CounterVolume: 512, + ResourceMetadata: Metadata{ + Host: "compute-0", + Name: "instance-001", + DisplayName: "test-instance", + InstanceHost: "host1", + UserMetadata: map[string]string{ + "key1": "value1", + "key2": "value2", + }, + }, + } + + testMsg := Message{ + Publisher: "test.publisher", + Payload: []Metric{metric}, + } + + input, err := msgpack.Marshal(testMsg) + require.NoError(t, err) + + msg, err := c.ParseInputMsgPack(input) + require.NoError(t, err) + require.NotNil(t, msg) + assert.Equal(t, "memory", msg.Payload[0].CounterName) + assert.NotNil(t, msg.Payload[0].ResourceMetadata.UserMetadata) + }) +} + +func TestSanitize(t *testing.T) { + t.Run("remove escaped quotes", func(t *testing.T) { + c := New() + c.schema.Request.OsloMessage = `{\"key\": \"value\"}` + + result := c.sanitize() + assert.Contains(t, result, `{"key": "value"}`) + assert.NotContains(t, result, `\"`) + }) + + t.Run("fix payload array formatting", func(t *testing.T) { + c := New() + c.schema.Request.OsloMessage = `{"payload": [{\"counter\": \"cpu\"}]}` + + result := c.sanitize() + assert.Contains(t, result, `"payload": [{"counter": "cpu"}]`) + }) + + t.Run("handle payload with spaces", func(t *testing.T) { + c := New() + c.schema.Request.OsloMessage = `{"payload" : [{\"counter\": \"cpu\"}]}` + + result := c.sanitize() + assert.Contains(t, result, `"payload": [{"counter": "cpu"}]`) + }) + + t.Run("handle multiple payload items", func(t *testing.T) { + c := New() + c.schema.Request.OsloMessage = `{"payload": [{\"counter\": \"cpu\"}, {\"counter\": \"memory\"}]}` + + result := c.sanitize() + assert.Contains(t, result, `"payload": [{"counter": "cpu"}, {"counter": "memory"}]`) + }) + + t.Run("handle missing payload array", func(t *testing.T) { + c := New() + c.schema.Request.OsloMessage = `{\"publisher\": \"test\"}` + + result := c.sanitize() + // Should still work even without payload + assert.Contains(t, result, `"publisher": "test"`) + }) +} diff --git a/plugins/handler/collectd-metrics/pkg/collectd/collectd_test.go b/plugins/handler/collectd-metrics/pkg/collectd/collectd_test.go new file mode 100644 index 00000000..290a49e0 --- /dev/null +++ b/plugins/handler/collectd-metrics/pkg/collectd/collectd_test.go @@ -0,0 +1,346 @@ +package collectd + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseInputByte(t *testing.T) { + t.Run("parse valid single metric", func(t *testing.T) { + input := []byte(`[{ + "values": [2121], + "dstypes": ["derive"], + "dsnames": ["samples"], + "time": 1234567890, + "interval": 10, + "host": "localhost", + "plugin": "cpu", + "plugin_instance": "0", + "type": "cpu", + "type_instance": "idle" + }]`) + + metrics, err := ParseInputByte(input) + require.NoError(t, err) + require.NotNil(t, metrics) + require.Equal(t, 1, len(*metrics)) + + metric := (*metrics)[0] + assert.Equal(t, []float64{2121}, metric.Values) + assert.Equal(t, []string{"derive"}, metric.Dstypes) + assert.Equal(t, []string{"samples"}, metric.Dsnames) + assert.Equal(t, float64(10), metric.Interval) + assert.Equal(t, "localhost", metric.Host) + assert.Equal(t, "cpu", metric.Plugin) + assert.Equal(t, "0", metric.PluginInstance) + assert.Equal(t, "cpu", metric.Type) + assert.Equal(t, "idle", metric.TypeInstance) + }) + + t.Run("parse multiple metrics", func(t *testing.T) { + input := []byte(`[ + { + "values": [100], + "dstypes": ["derive"], + "dsnames": ["rx"], + "host": "host1", + "plugin": "interface", + "type": "if_octets" + }, + { + "values": [200], + "dstypes": ["derive"], + "dsnames": ["tx"], + "host": "host1", + "plugin": "interface", + "type": "if_octets" + }, + { + "values": [50], + "dstypes": ["gauge"], + "dsnames": ["value"], + "host": "host2", + "plugin": "cpu", + "type": "percent" + } + ]`) + + metrics, err := ParseInputByte(input) + require.NoError(t, err) + require.NotNil(t, metrics) + require.Equal(t, 3, len(*metrics)) + + assert.Equal(t, "interface", (*metrics)[0].Plugin) + assert.Equal(t, "interface", (*metrics)[1].Plugin) + assert.Equal(t, "cpu", (*metrics)[2].Plugin) + assert.Equal(t, float64(100), (*metrics)[0].Values[0]) + assert.Equal(t, float64(200), (*metrics)[1].Values[0]) + assert.Equal(t, float64(50), (*metrics)[2].Values[0]) + }) + + t.Run("parse multi-dimensional metric", func(t *testing.T) { + input := []byte(`[{ + "values": [2112, 1001, 5555], + "dstypes": ["derive", "counter", "gauge"], + "dsnames": ["rx", "tx", "errors"], + "host": "localhost", + "plugin": "virt", + "type": "if_packets" + }]`) + + metrics, err := ParseInputByte(input) + require.NoError(t, err) + require.NotNil(t, metrics) + require.Equal(t, 1, len(*metrics)) + + metric := (*metrics)[0] + assert.Equal(t, 3, len(metric.Values)) + assert.Equal(t, []float64{2112, 1001, 5555}, metric.Values) + assert.Equal(t, []string{"derive", "counter", "gauge"}, metric.Dstypes) + assert.Equal(t, []string{"rx", "tx", "errors"}, metric.Dsnames) + }) + + t.Run("parse metric without optional fields", func(t *testing.T) { + input := []byte(`[{ + "values": [42], + "dstypes": ["gauge"], + "dsnames": ["value"], + "host": "localhost", + "plugin": "memory", + "type": "memory" + }]`) + + metrics, err := ParseInputByte(input) + require.NoError(t, err) + require.NotNil(t, metrics) + require.Equal(t, 1, len(*metrics)) + + metric := (*metrics)[0] + assert.Equal(t, "", metric.PluginInstance) + assert.Equal(t, "", metric.TypeInstance) + }) + + t.Run("parse metric with time and interval", func(t *testing.T) { + input := []byte(`[{ + "values": [100], + "dstypes": ["derive"], + "dsnames": ["samples"], + "time": 1609459200.5, + "interval": 10.5, + "host": "localhost", + "plugin": "cpu", + "type": "cpu" + }]`) + + metrics, err := ParseInputByte(input) + require.NoError(t, err) + require.NotNil(t, metrics) + require.Equal(t, 1, len(*metrics)) + + metric := (*metrics)[0] + assert.NotNil(t, metric.Time) + assert.Equal(t, float64(10.5), metric.Interval) + }) + + t.Run("parse metric with metadata", func(t *testing.T) { + input := []byte(`[{ + "values": [100], + "dstypes": ["derive"], + "dsnames": ["samples"], + "host": "localhost", + "plugin": "cpu", + "type": "cpu", + "meta": { + "key1": "value1", + "key2": 123 + } + }]`) + + metrics, err := ParseInputByte(input) + require.NoError(t, err) + require.NotNil(t, metrics) + require.Equal(t, 1, len(*metrics)) + }) + + t.Run("error on invalid JSON", func(t *testing.T) { + input := []byte(`{invalid json}`) + + metrics, err := ParseInputByte(input) + require.Error(t, err) + assert.Nil(t, metrics) + }) + + t.Run("error on non-array JSON", func(t *testing.T) { + input := []byte(`{"values": [100]}`) + + metrics, err := ParseInputByte(input) + require.Error(t, err) + assert.Nil(t, metrics) + }) + + t.Run("parse empty array", func(t *testing.T) { + input := []byte(`[]`) + + metrics, err := ParseInputByte(input) + require.NoError(t, err) + require.NotNil(t, metrics) + assert.Equal(t, 0, len(*metrics)) + }) + + t.Run("parse metric with all dstype variations", func(t *testing.T) { + input := []byte(`[{ + "values": [1, 2, 3, 4], + "dstypes": ["derive", "counter", "gauge", "absolute"], + "dsnames": ["d1", "d2", "d3", "d4"], + "host": "localhost", + "plugin": "test", + "type": "test" + }]`) + + metrics, err := ParseInputByte(input) + require.NoError(t, err) + require.NotNil(t, metrics) + require.Equal(t, 1, len(*metrics)) + + metric := (*metrics)[0] + assert.Equal(t, 4, len(metric.Values)) + assert.Equal(t, []string{"derive", "counter", "gauge", "absolute"}, metric.Dstypes) + }) + + t.Run("parse real-world virt plugin data", func(t *testing.T) { + input := []byte(`[ + { + "values": [1234.5, 5678.9], + "dstypes": ["derive", "counter"], + "dsnames": ["rx", "tx"], + "host": "controller-0.redhat.local", + "time": 1609459200, + "interval": 5, + "plugin": "virt", + "plugin_instance": "instance-00000001", + "type": "if_packets", + "type_instance": "tap73125d-60" + } + ]`) + + metrics, err := ParseInputByte(input) + require.NoError(t, err) + require.NotNil(t, metrics) + require.Equal(t, 1, len(*metrics)) + + metric := (*metrics)[0] + assert.Equal(t, 2, len(metric.Values)) + assert.Equal(t, "controller-0.redhat.local", metric.Host) + assert.Equal(t, "virt", metric.Plugin) + assert.Equal(t, "instance-00000001", metric.PluginInstance) + assert.Equal(t, "if_packets", metric.Type) + assert.Equal(t, "tap73125d-60", metric.TypeInstance) + assert.Equal(t, []string{"rx", "tx"}, metric.Dsnames) + }) + + t.Run("parse metric with floating point values", func(t *testing.T) { + input := []byte(`[{ + "values": [123.456, 789.012], + "dstypes": ["gauge", "gauge"], + "dsnames": ["value1", "value2"], + "host": "localhost", + "plugin": "cpu", + "type": "percent" + }]`) + + metrics, err := ParseInputByte(input) + require.NoError(t, err) + require.NotNil(t, metrics) + require.Equal(t, 1, len(*metrics)) + + metric := (*metrics)[0] + assert.InDelta(t, 123.456, metric.Values[0], 0.001) + assert.InDelta(t, 789.012, metric.Values[1], 0.001) + }) + + t.Run("parse metric with zero values", func(t *testing.T) { + input := []byte(`[{ + "values": [0, 0, 0], + "dstypes": ["derive", "derive", "derive"], + "dsnames": ["a", "b", "c"], + "host": "localhost", + "plugin": "test", + "type": "test" + }]`) + + metrics, err := ParseInputByte(input) + require.NoError(t, err) + require.NotNil(t, metrics) + require.Equal(t, 1, len(*metrics)) + + metric := (*metrics)[0] + assert.Equal(t, []float64{0, 0, 0}, metric.Values) + }) + + t.Run("parse metric with negative values", func(t *testing.T) { + input := []byte(`[{ + "values": [-100, -50.5], + "dstypes": ["gauge", "gauge"], + "dsnames": ["temp", "pressure"], + "host": "localhost", + "plugin": "sensors", + "type": "temperature" + }]`) + + metrics, err := ParseInputByte(input) + require.NoError(t, err) + require.NotNil(t, metrics) + require.Equal(t, 1, len(*metrics)) + + metric := (*metrics)[0] + assert.Equal(t, float64(-100), metric.Values[0]) + assert.Equal(t, float64(-50.5), metric.Values[1]) + }) + + t.Run("parse metric with very large values", func(t *testing.T) { + input := []byte(`[{ + "values": [9999999999999, 1234567890123], + "dstypes": ["counter", "counter"], + "dsnames": ["bytes_in", "bytes_out"], + "host": "localhost", + "plugin": "interface", + "type": "if_octets" + }]`) + + metrics, err := ParseInputByte(input) + require.NoError(t, err) + require.NotNil(t, metrics) + require.Equal(t, 1, len(*metrics)) + + metric := (*metrics)[0] + assert.Equal(t, float64(9999999999999), metric.Values[0]) + assert.Equal(t, float64(1234567890123), metric.Values[1]) + }) + + t.Run("parse metric with special characters in strings", func(t *testing.T) { + input := []byte(`[{ + "values": [100], + "dstypes": ["gauge"], + "dsnames": ["value"], + "host": "host-name.with-dashes", + "plugin": "plugin_with_underscores", + "plugin_instance": "instance.0", + "type": "type-name", + "type_instance": "instance_name" + }]`) + + metrics, err := ParseInputByte(input) + require.NoError(t, err) + require.NotNil(t, metrics) + require.Equal(t, 1, len(*metrics)) + + metric := (*metrics)[0] + assert.Equal(t, "host-name.with-dashes", metric.Host) + assert.Equal(t, "plugin_with_underscores", metric.Plugin) + assert.Equal(t, "instance.0", metric.PluginInstance) + assert.Equal(t, "type-name", metric.Type) + assert.Equal(t, "instance_name", metric.TypeInstance) + }) +} diff --git a/plugins/handler/sensubility-metrics/pkg/sensu/sensu_test.go b/plugins/handler/sensubility-metrics/pkg/sensu/sensu_test.go new file mode 100644 index 00000000..72021269 --- /dev/null +++ b/plugins/handler/sensubility-metrics/pkg/sensu/sensu_test.go @@ -0,0 +1,480 @@ +package sensu + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIsMsgValid(t *testing.T) { + t.Run("valid message with all required fields", func(t *testing.T) { + msg := Message{ + StartsAt: "2023-01-01T00:00:00Z", + Labels: Labels{ + Client: "test-client", + Check: "test-check", + Severity: "warning", + }, + Annotations: Annotations{ + Command: "test-command", + Output: "test output", + }, + } + + assert.True(t, IsMsgValid(msg)) + }) + + t.Run("valid message with minimal required fields", func(t *testing.T) { + msg := Message{ + StartsAt: "2023-01-01T00:00:00Z", + Labels: Labels{ + Client: "test-client", + }, + } + + assert.True(t, IsMsgValid(msg)) + }) + + t.Run("invalid message with missing StartsAt", func(t *testing.T) { + msg := Message{ + Labels: Labels{ + Client: "test-client", + }, + } + + assert.False(t, IsMsgValid(msg)) + }) + + t.Run("invalid message with empty StartsAt", func(t *testing.T) { + msg := Message{ + StartsAt: "", + Labels: Labels{ + Client: "test-client", + }, + } + + assert.False(t, IsMsgValid(msg)) + }) + + t.Run("invalid message with missing Client", func(t *testing.T) { + msg := Message{ + StartsAt: "2023-01-01T00:00:00Z", + Labels: Labels{}, + } + + assert.False(t, IsMsgValid(msg)) + }) + + t.Run("invalid message with empty Client", func(t *testing.T) { + msg := Message{ + StartsAt: "2023-01-01T00:00:00Z", + Labels: Labels{ + Client: "", + }, + } + + assert.False(t, IsMsgValid(msg)) + }) + + t.Run("invalid message with both fields missing", func(t *testing.T) { + msg := Message{} + + assert.False(t, IsMsgValid(msg)) + }) + + t.Run("valid message with optional fields", func(t *testing.T) { + msg := Message{ + StartsAt: "2023-01-01T00:00:00Z", + Labels: Labels{ + Client: "test-client", + Check: "disk-usage", + Severity: "critical", + }, + Annotations: Annotations{ + Command: "/usr/bin/check_disk", + Issued: 1234567890, + Executed: 1234567891, + Duration: 1.5, + Output: "CRITICAL - disk usage at 95%", + Status: 2, + StartsAt: "2023-01-01T00:00:00Z", + }, + } + + assert.True(t, IsMsgValid(msg)) + }) +} + +func TestIsOutputValid(t *testing.T) { + t.Run("valid output with single service", func(t *testing.T) { + outputs := HealthCheckOutput{ + { + Service: "test-service", + Container: "test-container", + Status: "running", + Healthy: 1.0, + }, + } + + assert.True(t, IsOutputValid(outputs)) + }) + + t.Run("valid output with multiple services", func(t *testing.T) { + outputs := HealthCheckOutput{ + { + Service: "service1", + Container: "container1", + Status: "running", + Healthy: 1.0, + }, + { + Service: "service2", + Container: "container2", + Status: "stopped", + Healthy: 0.0, + }, + { + Service: "service3", + Container: "container3", + Status: "running", + Healthy: 1.0, + }, + } + + assert.True(t, IsOutputValid(outputs)) + }) + + t.Run("valid output with minimal fields", func(t *testing.T) { + outputs := HealthCheckOutput{ + { + Service: "minimal-service", + }, + } + + assert.True(t, IsOutputValid(outputs)) + }) + + t.Run("invalid output with missing service", func(t *testing.T) { + outputs := HealthCheckOutput{ + { + Container: "test-container", + Status: "running", + Healthy: 1.0, + }, + } + + assert.False(t, IsOutputValid(outputs)) + }) + + t.Run("invalid output with empty service", func(t *testing.T) { + outputs := HealthCheckOutput{ + { + Service: "", + Container: "test-container", + }, + } + + assert.False(t, IsOutputValid(outputs)) + }) + + t.Run("invalid output with one valid and one invalid", func(t *testing.T) { + outputs := HealthCheckOutput{ + { + Service: "valid-service", + Container: "container1", + }, + { + Service: "", + Container: "container2", + }, + } + + assert.False(t, IsOutputValid(outputs)) + }) + + t.Run("valid empty output array", func(t *testing.T) { + outputs := HealthCheckOutput{} + + assert.True(t, IsOutputValid(outputs)) + }) + + t.Run("invalid output with missing service in middle", func(t *testing.T) { + outputs := HealthCheckOutput{ + { + Service: "service1", + }, + { + Service: "", + }, + { + Service: "service3", + }, + } + + assert.False(t, IsOutputValid(outputs)) + }) +} + +func TestBuildMsgErr(t *testing.T) { + t.Run("error with missing StartsAt", func(t *testing.T) { + msg := Message{ + Labels: Labels{ + Client: "test-client", + }, + } + + err := BuildMsgErr(msg) + require.NotNil(t, err) + + eMF, ok := err.(*ErrMissingFields) + require.True(t, ok) + assert.Contains(t, eMF.Fields, "startsAt") + assert.NotContains(t, eMF.Fields, "labels.client") + assert.Contains(t, err.Error(), "startsAt") + }) + + t.Run("error with missing Client", func(t *testing.T) { + msg := Message{ + StartsAt: "2023-01-01T00:00:00Z", + } + + err := BuildMsgErr(msg) + require.NotNil(t, err) + + eMF, ok := err.(*ErrMissingFields) + require.True(t, ok) + assert.Contains(t, eMF.Fields, "labels.client") + assert.NotContains(t, eMF.Fields, "startsAt") + assert.Contains(t, err.Error(), "labels.client") + }) + + t.Run("error with both fields missing", func(t *testing.T) { + msg := Message{} + + err := BuildMsgErr(msg) + require.NotNil(t, err) + + eMF, ok := err.(*ErrMissingFields) + require.True(t, ok) + assert.Contains(t, eMF.Fields, "startsAt") + assert.Contains(t, eMF.Fields, "labels.client") + assert.Contains(t, err.Error(), "startsAt") + assert.Contains(t, err.Error(), "labels.client") + assert.Contains(t, err.Error(), "missing fields in received data") + }) + + t.Run("error with valid message returns empty error", func(t *testing.T) { + msg := Message{ + StartsAt: "2023-01-01T00:00:00Z", + Labels: Labels{ + Client: "test-client", + }, + } + + err := BuildMsgErr(msg) + require.NotNil(t, err) + + eMF, ok := err.(*ErrMissingFields) + require.True(t, ok) + assert.Empty(t, eMF.Fields) + }) + + t.Run("error message format", func(t *testing.T) { + msg := Message{} + + err := BuildMsgErr(msg) + require.NotNil(t, err) + + errorMsg := err.Error() + assert.Contains(t, errorMsg, "missing fields in received data") + assert.Contains(t, errorMsg, "(") + assert.Contains(t, errorMsg, ")") + }) +} + +func TestBuildOutputsErr(t *testing.T) { + t.Run("error with single missing service", func(t *testing.T) { + outputs := HealthCheckOutput{ + { + Container: "test-container", + }, + } + + err := BuildOutputsErr(outputs) + require.NotNil(t, err) + + eMF, ok := err.(*ErrMissingFields) + require.True(t, ok) + assert.Contains(t, eMF.Fields, "annotations.output[0].service") + assert.Contains(t, err.Error(), "annotations.output[0].service") + }) + + t.Run("error with multiple missing services", func(t *testing.T) { + outputs := HealthCheckOutput{ + { + Container: "container1", + }, + { + Service: "valid-service", + }, + { + Container: "container3", + }, + } + + err := BuildOutputsErr(outputs) + require.NotNil(t, err) + + eMF, ok := err.(*ErrMissingFields) + require.True(t, ok) + assert.Contains(t, eMF.Fields, "annotations.output[0].service") + assert.Contains(t, eMF.Fields, "annotations.output[2].service") + assert.NotContains(t, eMF.Fields, "annotations.output[1].service") + }) + + t.Run("error with all services missing", func(t *testing.T) { + outputs := HealthCheckOutput{ + { + Container: "container1", + }, + { + Container: "container2", + }, + { + Container: "container3", + }, + } + + err := BuildOutputsErr(outputs) + require.NotNil(t, err) + + eMF, ok := err.(*ErrMissingFields) + require.True(t, ok) + assert.Len(t, eMF.Fields, 3) + assert.Contains(t, eMF.Fields, "annotations.output[0].service") + assert.Contains(t, eMF.Fields, "annotations.output[1].service") + assert.Contains(t, eMF.Fields, "annotations.output[2].service") + }) + + t.Run("error with valid outputs returns empty error", func(t *testing.T) { + outputs := HealthCheckOutput{ + { + Service: "service1", + Container: "container1", + }, + { + Service: "service2", + Container: "container2", + }, + } + + err := BuildOutputsErr(outputs) + require.NotNil(t, err) + + eMF, ok := err.(*ErrMissingFields) + require.True(t, ok) + assert.Empty(t, eMF.Fields) + }) + + t.Run("error with empty outputs array", func(t *testing.T) { + outputs := HealthCheckOutput{} + + err := BuildOutputsErr(outputs) + require.NotNil(t, err) + + eMF, ok := err.(*ErrMissingFields) + require.True(t, ok) + assert.Empty(t, eMF.Fields) + }) + + t.Run("error index format in message", func(t *testing.T) { + outputs := HealthCheckOutput{ + {}, + {}, + {}, + } + + err := BuildOutputsErr(outputs) + require.NotNil(t, err) + + errorMsg := err.Error() + assert.Contains(t, errorMsg, "[0]") + assert.Contains(t, errorMsg, "[1]") + assert.Contains(t, errorMsg, "[2]") + }) +} + +func TestErrMissingFields(t *testing.T) { + t.Run("error message with single field", func(t *testing.T) { + err := &ErrMissingFields{ + Fields: []string{"field1"}, + } + + assert.Equal(t, "missing fields in received data (field1)", err.Error()) + }) + + t.Run("error message with multiple fields", func(t *testing.T) { + err := &ErrMissingFields{ + Fields: []string{"field1", "field2", "field3"}, + } + + errorMsg := err.Error() + assert.Contains(t, errorMsg, "missing fields in received data") + assert.Contains(t, errorMsg, "field1") + assert.Contains(t, errorMsg, "field2") + assert.Contains(t, errorMsg, "field3") + assert.Contains(t, errorMsg, ", ") + }) + + t.Run("error message with empty fields", func(t *testing.T) { + err := &ErrMissingFields{ + Fields: []string{}, + } + + assert.Equal(t, "missing fields in received data ()", err.Error()) + }) + + t.Run("add missing field", func(t *testing.T) { + err := &ErrMissingFields{ + Fields: []string{}, + } + + err.addMissingField("field1") + assert.Contains(t, err.Fields, "field1") + assert.Len(t, err.Fields, 1) + + err.addMissingField("field2") + assert.Contains(t, err.Fields, "field1") + assert.Contains(t, err.Fields, "field2") + assert.Len(t, err.Fields, 2) + }) + + t.Run("add multiple missing fields", func(t *testing.T) { + err := &ErrMissingFields{} + + err.addMissingField("field1") + err.addMissingField("field2") + err.addMissingField("field3") + + assert.Len(t, err.Fields, 3) + assert.Equal(t, "field1", err.Fields[0]) + assert.Equal(t, "field2", err.Fields[1]) + assert.Equal(t, "field3", err.Fields[2]) + }) + + t.Run("error message format with long field names", func(t *testing.T) { + err := &ErrMissingFields{ + Fields: []string{ + "annotations.output[0].service", + "annotations.output[1].service", + "labels.client", + }, + } + + errorMsg := err.Error() + assert.Contains(t, errorMsg, "annotations.output[0].service") + assert.Contains(t, errorMsg, "annotations.output[1].service") + assert.Contains(t, errorMsg, "labels.client") + }) +} diff --git a/plugins/transport/socket/main.go b/plugins/transport/socket/main.go index bfac8fa7..58efb538 100644 --- a/plugins/transport/socket/main.go +++ b/plugins/transport/socket/main.go @@ -20,11 +20,13 @@ import ( ) const ( - maxBufferSize = 65535 - udp = "udp" - unix = "unix" - tcp = "tcp" - msgLengthSize = 8 + maxBufferSize = 65535 // 64KB - initial buffer size for all socket types and max for UDP (OS datagram limit) + maxBufferSizeUnix = 10485760 // 10MB - max buffer size for Unix domain sockets + maxBufferSizeTCP = 104857600 // 100MB - max buffer size for TCP (stream-based, can handle very large messages) + udp = "udp" + unix = "unix" + tcp = "tcp" + msgLengthSize = 8 ) var ( @@ -138,6 +140,17 @@ func (s *Socket) initTCPSocket() *net.TCPListener { return pc } +func (s *Socket) getMaxBufferSize() int64 { + switch s.conf.Type { + case udp: + return maxBufferSize + case tcp: + return maxBufferSizeTCP + default: + return maxBufferSizeUnix + } +} + func (s *Socket) WriteTCPMsg(w transport.WriteFn, msgBuffer []byte, n int) (int64, error) { var pos int64 var length int64 @@ -165,10 +178,13 @@ func (s *Socket) WriteTCPMsg(w transport.WriteFn, msgBuffer []byte, n int) (int6 return pos, nil } -func (s *Socket) ReceiveData(maxBuffSize int64, done chan bool, pc net.Conn, w transport.WriteFn) { +func (s *Socket) ReceiveData(initialBuffSize int64, done chan bool, pc net.Conn, w transport.WriteFn) { defer pc.Close() - msgBuffer := make([]byte, maxBuffSize) + currentBuffSize := initialBuffSize + maxBuffSize := s.getMaxBufferSize() + msgBuffer := make([]byte, currentBuffSize) var remainingMsg []byte + for { n, err := pc.Read(msgBuffer) if err != nil || n < 1 { @@ -180,17 +196,40 @@ func (s *Socket) ReceiveData(maxBuffSize int64, done chan bool, pc net.Conn, w t } return } - msgBuffer = append(remainingMsg, msgBuffer...) - // whole buffer was used, so we are potentially handling larger message - if n == len(msgBuffer) { - s.logger.Warnf("full read buffer used") + // Combine remaining data from previous iteration with newly read data + var data []byte + if len(remainingMsg) > 0 { + data = make([]byte, len(remainingMsg)+n) + copy(data, remainingMsg) + copy(data[len(remainingMsg):], msgBuffer[:n]) + } else { + data = msgBuffer[:n] + } + totalSize := len(data) + + // Check if buffer was completely filled - message may have been truncated + if n == int(currentBuffSize) { + if s.conf.Type == tcp { + s.logger.Debugf("full read buffer used (%d bytes), TCP will handle continuation if needed", n) + } else { + // For UDP/Unix sockets, buffer being full means message was likely truncated + if currentBuffSize < maxBuffSize { + newSize := currentBuffSize * 2 + if newSize > maxBuffSize { + newSize = maxBuffSize + } + s.logger.Warnf("message may have been truncated (buffer filled with %d bytes), growing buffer from %d to %d bytes for next message", currentBuffSize, currentBuffSize, newSize) + currentBuffSize = newSize + msgBuffer = make([]byte, currentBuffSize) + } else { + s.logger.Errorf(nil, "message truncated: buffer size (%d bytes) exceeded for %s socket and already at maximum buffer size (%d bytes)", currentBuffSize, s.conf.Type, maxBuffSize) + } + } } - - n += len(remainingMsg) if s.conf.DumpMessages.Enabled { - _, err := s.dumpBuf.Write(msgBuffer[:n]) + _, err := s.dumpBuf.Write(data) if err != nil { s.logger.Errorf(err, "writing to dump buffer") } @@ -202,16 +241,17 @@ func (s *Socket) ReceiveData(maxBuffSize int64, done chan bool, pc net.Conn, w t } if s.conf.Type == tcp { - parsed, err := s.WriteTCPMsg(w, msgBuffer, n) + parsed, err := s.WriteTCPMsg(w, data, totalSize) if err != nil { s.logger.Errorf(err, "error, while parsing messages") return } - remainingMsg = make([]byte, int64(n)-parsed) - copy(remainingMsg, msgBuffer[parsed:n]) + remainingMsg = make([]byte, int64(totalSize)-parsed) + copy(remainingMsg, data[parsed:totalSize]) } else { - w(msgBuffer[:n]) + w(data) msgCount++ + remainingMsg = nil } } } @@ -223,7 +263,7 @@ func (s *Socket) Run(ctx context.Context, w transport.WriteFn, done chan bool) { case udp: pc = s.initUDPSocket() if pc == (*net.UDPConn)(nil) { - s.logger.Errorf(nil, "Failed to initialize socket transport plugin with type: "+s.conf.Type) + s.logger.Errorf(nil, "Failed to initialize socket transport plugin with type: %s", s.conf.Type) return } go s.ReceiveData(maxBufferSize, done, pc, w) @@ -231,7 +271,7 @@ func (s *Socket) Run(ctx context.Context, w transport.WriteFn, done chan bool) { case tcp: TCPSocket := s.initTCPSocket() if TCPSocket == nil { - s.logger.Errorf(nil, "Failed to initialize socket transport plugin with type: "+s.conf.Type) + s.logger.Errorf(nil, "Failed to initialize socket transport plugin with type: %s", s.conf.Type) return } go func() { @@ -254,7 +294,7 @@ func (s *Socket) Run(ctx context.Context, w transport.WriteFn, done chan bool) { default: pc = s.initUnixSocket() if pc == (*net.UnixConn)(nil) { - s.logger.Errorf(nil, "Failed to initialize socket transport plugin with type: "+s.conf.Type) + s.logger.Errorf(nil, "Failed to initialize socket transport plugin with type: %s", s.conf.Type) return } go s.ReceiveData(maxBufferSize, done, pc, w) diff --git a/plugins/transport/socket/main_test.go b/plugins/transport/socket/main_test.go index d734248c..c2843ea7 100644 --- a/plugins/transport/socket/main_test.go +++ b/plugins/transport/socket/main_test.go @@ -1,6 +1,7 @@ package main import ( + "bufio" "bytes" "context" "encoding/binary" @@ -12,11 +13,12 @@ import ( "time" "github.com/infrawatch/apputils/logging" + "github.com/infrawatch/sg-core/pkg/data" "github.com/stretchr/testify/require" "gopkg.in/go-playground/assert.v1" ) -const regularBuffSize = 16384 +const regularBuffSize = 65535 // default buffer size const addition = "wubba lubba dub dub" func TestUnixSocketTransport(t *testing.T) { @@ -28,40 +30,106 @@ func TestUnixSocketTransport(t *testing.T) { logger, err := logging.NewLogger(logging.DEBUG, logpath) require.NoError(t, err) - sktpath := path.Join(tmpdir, "socket") - skt, err := os.OpenFile(sktpath, os.O_RDWR|os.O_CREATE, os.ModeSocket|os.ModePerm) - require.NoError(t, err) - defer skt.Close() + t.Run("test normal message", func(t *testing.T) { + // Create a normal-sized message (5KB) + msg := make([]byte, 5000) + for i := 0; i < len(msg); i++ { + msg[i] = byte('A') + } + marker := []byte("--END--") + copy(msg[len(msg)-len(marker):], marker) - trans := Socket{ - conf: configT{ - Path: sktpath, - }, - logger: &logWrapper{ - l: logger, - }, - } + sktpath := path.Join(tmpdir, "socket1") + skt, err := os.OpenFile(sktpath, os.O_RDWR|os.O_CREATE, os.ModeSocket|os.ModePerm) + require.NoError(t, err) + defer skt.Close() + + trans := Socket{ + conf: configT{ + Path: sktpath, + }, + logger: &logWrapper{ + l: logger, + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(1) + var receivedMsg []byte + go trans.Run(ctx, func(mess []byte) { + receivedMsg = mess + wg.Done() + }, make(chan bool)) + + // Wait for socket file to be created + for { + stat, err := os.Stat(sktpath) + require.NoError(t, err) + if stat.Mode()&os.ModeType == os.ModeSocket { + break + } + time.Sleep(250 * time.Millisecond) + } + + wskt, err := net.DialUnix("unixgram", nil, &net.UnixAddr{Name: sktpath, Net: "unixgram"}) + require.NoError(t, err) + _, err = wskt.Write(msg) + require.NoError(t, err) + + wg.Wait() + cancel() + time.Sleep(100 * time.Millisecond) + wskt.Close() + + // Verify we received the complete message + assert.Equal(t, len(msg), len(receivedMsg)) + // Verify the end marker is present + endMarkerPos := len(receivedMsg) - len(marker) + assert.Equal(t, string(marker), string(receivedMsg[endMarkerPos:])) + }) t.Run("test large message transport", func(t *testing.T) { - msg := make([]byte, regularBuffSize) - for i := 0; i < regularBuffSize; i++ { + // Create a message larger than initial buffer to test dynamic buffer growth + largeBuffSize := regularBuffSize * 2 // 131070 bytes + msg := make([]byte, largeBuffSize) + for i := 0; i < largeBuffSize; i++ { msg[i] = byte('X') } - msg[regularBuffSize-1] = byte('$') - msg = append(msg, []byte(addition)...) + msg[largeBuffSize-1] = byte('$') + msg = append(msg, []byte(addition)...) // Total: 131089 bytes + + // Setup socket using same pattern as sendUnixSocketMessage + sktpath := path.Join(tmpdir, "socket2") + skt, err := os.OpenFile(sktpath, os.O_RDWR|os.O_CREATE, os.ModeSocket|os.ModePerm) + require.NoError(t, err) + defer skt.Close() + + trans := Socket{ + conf: configT{ + Path: sktpath, + }, + logger: &logWrapper{ + l: logger, + }, + } - // verify transport ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var receivedMsgs [][]byte + var mutex sync.Mutex wg := sync.WaitGroup{} + wg.Add(3) // Expecting 3 messages + go trans.Run(ctx, func(mess []byte) { - wg.Add(1) - strmsg := string(mess) - assert.Equal(t, regularBuffSize+len(addition), len(strmsg)) // we received whole message - assert.Equal(t, addition, strmsg[len(strmsg)-len(addition):]) // and the out-of-band part is correct + mutex.Lock() + receivedMsgs = append(receivedMsgs, mess) + mutex.Unlock() wg.Done() }, make(chan bool)) - // wait for socket file to be created + // Wait for socket file to be created for { stat, err := os.Stat(sktpath) require.NoError(t, err) @@ -71,18 +139,83 @@ func TestUnixSocketTransport(t *testing.T) { time.Sleep(250 * time.Millisecond) } - // write to socket wskt, err := net.DialUnix("unixgram", nil, &net.UnixAddr{Name: sktpath, Net: "unixgram"}) require.NoError(t, err) + defer wskt.Close() + + // Send the same message 3 times + _, err = wskt.Write(msg) + require.NoError(t, err) + time.Sleep(100 * time.Millisecond) + + _, err = wskt.Write(msg) + require.NoError(t, err) + time.Sleep(100 * time.Millisecond) + _, err = wskt.Write(msg) require.NoError(t, err) - cancel() wg.Wait() - wskt.Close() + + // Verify we received 3 messages + require.Equal(t, 3, len(receivedMsgs)) + + // First message: the message is truncated to the maximum 64KB (65535 bytes) + require.Equal(t, len(receivedMsgs[0]), regularBuffSize) + + // Second message: check for 128KB (131070 bytes) with '$' at position 131069 + require.Equal(t, len(receivedMsgs[1]), largeBuffSize) + assert.Equal(t, byte('$'), receivedMsgs[1][131069]) + + // Third message: check for > 128KB (131070 bytes) with "wubba lubba dub dub" at the end + require.GreaterOrEqual(t, len(receivedMsgs[2]), largeBuffSize+len(addition)) + endStr := string(receivedMsgs[2][len(receivedMsgs[2])-len(addition):]) + assert.Equal(t, addition, endStr) }) } +// Helper function to send and receive UDP socket message +func sendUDPSocketMessage(t *testing.T, logger *logging.Logger, addr string, msg []byte) ([]byte, error) { + trans := Socket{ + conf: configT{ + Socketaddr: addr, + Type: "udp", + }, + logger: &logWrapper{ + l: logger, + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(1) + var receivedMsg []byte + messageReceived := false + go trans.Run(ctx, func(mess []byte) { + receivedMsg = mess + messageReceived = true + wg.Done() + }, make(chan bool)) + + // Wait for socket to be ready + time.Sleep(100 * time.Millisecond) + + udpAddr, err := net.ResolveUDPAddr("udp", addr) + require.NoError(t, err) + wskt, err := net.DialUDP("udp", nil, udpAddr) + require.NoError(t, err) + _, writeErr := wskt.Write(msg) + + if writeErr == nil && messageReceived { + wg.Wait() + } + cancel() + time.Sleep(100 * time.Millisecond) + wskt.Close() + + return receivedMsg, writeErr +} + func TestUdpSocketTransport(t *testing.T) { tmpdir, err := os.MkdirTemp(".", "socket_test_tmp") require.NoError(t, err) @@ -92,50 +225,257 @@ func TestUdpSocketTransport(t *testing.T) { logger, err := logging.NewLogger(logging.DEBUG, logpath) require.NoError(t, err) + t.Run("test normal message", func(t *testing.T) { + // Create a normal message (5KB) + msg := make([]byte, 5000) + for i := 0; i < len(msg); i++ { + msg[i] = byte('U') + } + marker := []byte("--UDP-END--") + copy(msg[len(msg)-len(marker):], marker) + + receivedMsg, err := sendUDPSocketMessage(t, logger, "127.0.0.1:8650", msg) + require.NoError(t, err) + + // Verify we received the complete message + assert.Equal(t, len(msg), len(receivedMsg)) + // Verify the end marker is present + endMarkerPos := len(receivedMsg) - len(marker) + assert.Equal(t, string(marker), string(receivedMsg[endMarkerPos:])) + }) + + t.Run("test large message transport", func(t *testing.T) { + // Create message that exceeds UDP datagram limits + // UDP max payload is ~65507 bytes, we're trying to send 65535 + 19 = 65554 bytes + largeBuffSize := regularBuffSize - len(addition) + msg := make([]byte, largeBuffSize) + for i := 0; i < largeBuffSize; i++ { + msg[i] = byte('X') + } + msg[largeBuffSize-1] = byte('$') + msg = append(msg, []byte(addition)...) + + _, err := sendUDPSocketMessage(t, logger, "127.0.0.1:8652", msg) + + // Verify that sending a message that's too large for UDP fails + require.Error(t, err) + }) +} + +// Helper function to connect to TCP with retries +func connectTCPWithRetry(t *testing.T, addr string) net.Conn { + wskt, err := net.Dial("tcp", addr) + if err != nil { + for retries := 0; err != nil && retries < 3; retries++ { + time.Sleep(500 * time.Millisecond) + wskt, err = net.Dial("tcp", addr) + } + } + require.NoError(t, err) + return wskt +} + +// Helper function to create a TCP message with length header +func createTCPMessage(t *testing.T, content []byte) []byte { + msgLength := new(bytes.Buffer) + err := binary.Write(msgLength, binary.LittleEndian, uint64(len(content))) + require.NoError(t, err) + return append(msgLength.Bytes(), content...) +} + +// Helper function to send and verify TCP socket message with marker +func sendTCPSocketMessage(t *testing.T, logger *logging.Logger, addr string, msgSize int, fillByte byte, marker []byte) { trans := Socket{ conf: configT{ - Socketaddr: "127.0.0.1:8642", - Type: "udp", + Socketaddr: addr, + Type: "tcp", }, logger: &logWrapper{ l: logger, }, } - t.Run("test large message transport", func(t *testing.T) { - msg := make([]byte, regularBuffSize) + msgContent := make([]byte, msgSize) + for i := 0; i < msgSize; i++ { + msgContent[i] = fillByte + } + copy(msgContent[len(msgContent)-len(marker):], marker) + + fullMsg := createTCPMessage(t, msgContent) + + ctx, cancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + wg.Add(1) + go trans.Run(ctx, func(mess []byte) { + assert.Equal(t, msgSize, len(mess)) + endMarkerPos := len(mess) - len(marker) + assert.Equal(t, string(marker), string(mess[endMarkerPos:])) + wg.Done() + }, make(chan bool)) + + time.Sleep(100 * time.Millisecond) + + wskt := connectTCPWithRetry(t, addr) + _, err := wskt.Write(fullMsg) + require.NoError(t, err) + + wg.Wait() + cancel() + time.Sleep(100 * time.Millisecond) + wskt.Close() +} + +func TestTcpSocketTransport(t *testing.T) { + tmpdir, err := os.MkdirTemp(".", "socket_test_tmp") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + logpath := path.Join(tmpdir, "test.log") + logger, err := logging.NewLogger(logging.DEBUG, logpath) + require.NoError(t, err) + + t.Run("test normal message", func(t *testing.T) { + // Create a normal message (5KB) + sendTCPSocketMessage(t, logger, "127.0.0.1:8660", 5000, 'T', []byte("--TCP-END--")) + }) + + t.Run("test message exceeding initial buffer", func(t *testing.T) { + // Create a message larger than initial buffer (100KB) + sendTCPSocketMessage(t, logger, "127.0.0.1:8661", 100000, 'B', []byte("--LARGE-TCP--")) + }) + + t.Run("test multiple large messages", func(t *testing.T) { + trans := Socket{ + conf: configT{ + Socketaddr: "127.0.0.1:8663", + Type: "tcp", + }, + logger: &logWrapper{ + l: logger, + }, + } + + numMessages := 3 + messageSizes := []int{80000, 120000, 90000} + var combinedMsg bytes.Buffer + + // Create multiple large messages + for i := 0; i < numMessages; i++ { + msgContent := make([]byte, messageSizes[i]) + fillByte := byte('0' + i) + for j := 0; j < messageSizes[i]; j++ { + msgContent[j] = fillByte + } + combinedMsg.Write(createTCPMessage(t, msgContent)) + } + + // Setup message verification + ctx, cancel := context.WithCancel(context.Background()) + receivedCount := 0 + var mutex sync.Mutex + wg := sync.WaitGroup{} + wg.Add(numMessages) + + go trans.Run(ctx, func(mess []byte) { + mutex.Lock() + defer mutex.Unlock() + + // Verify message size matches one of our expected sizes + found := false + for i, expectedSize := range messageSizes { + if len(mess) == expectedSize { + expectedByte := byte('0' + i) + allMatch := true + for _, b := range mess { + if b != expectedByte { + allMatch = false + break + } + } + if allMatch { + found = true + receivedCount++ + wg.Done() + break + } + } + } + assert.Equal(t, true, found) + }, make(chan bool)) + + // Wait for socket to be ready + time.Sleep(100 * time.Millisecond) + + // Connect and send all messages + wskt := connectTCPWithRetry(t, "127.0.0.1:8663") + _, err = wskt.Write(combinedMsg.Bytes()) + require.NoError(t, err) + + wg.Wait() + + mutex.Lock() + assert.Equal(t, numMessages, receivedCount) + mutex.Unlock() + + cancel() + time.Sleep(100 * time.Millisecond) + wskt.Close() + }) + + t.Run("test large message transport multiple connections", func(t *testing.T) { + trans := Socket{ + conf: configT{ + Socketaddr: "127.0.0.1:8665", + Type: "tcp", + }, + logger: &logWrapper{ + l: logger, + }, + } + + msgContent := make([]byte, regularBuffSize) for i := 0; i < regularBuffSize; i++ { - msg[i] = byte('X') + msgContent[i] = byte('X') } - msg[regularBuffSize-1] = byte('$') - msg = append(msg, []byte(addition)...) + msgContent[regularBuffSize-1] = byte('$') + msgContent = append(msgContent, []byte(addition)...) + msg := createTCPMessage(t, msgContent) // verify transport ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} + wg.Add(2) go trans.Run(ctx, func(mess []byte) { - wg.Add(1) strmsg := string(mess) assert.Equal(t, regularBuffSize+len(addition), len(strmsg)) // we received whole message assert.Equal(t, addition, strmsg[len(strmsg)-len(addition):]) // and the out-of-band part is correct wg.Done() }, make(chan bool)) + // Wait for socket to be ready + time.Sleep(100 * time.Millisecond) + // write to socket - addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:8642") + wskt1 := connectTCPWithRetry(t, "127.0.0.1:8665") + + // We shouldn't need to retry the second connection, if this fails, then something is wrong + wskt2, err := net.Dial("tcp", "127.0.0.1:8665") require.NoError(t, err) - wskt, err := net.DialUDP("udp", nil, addr) + + _, err = wskt1.Write(msg) require.NoError(t, err) - _, err = wskt.Write(msg) + _, err = wskt2.Write(msg) require.NoError(t, err) - cancel() wg.Wait() - wskt.Close() + cancel() + time.Sleep(100 * time.Millisecond) + wskt1.Close() + wskt2.Close() }) } -func TestTcpSocketTransport(t *testing.T) { +func TestNew(t *testing.T) { tmpdir, err := os.MkdirTemp(".", "socket_test_tmp") require.NoError(t, err) defer os.RemoveAll(tmpdir) @@ -144,103 +484,545 @@ func TestTcpSocketTransport(t *testing.T) { logger, err := logging.NewLogger(logging.DEBUG, logpath) require.NoError(t, err) - trans := Socket{ - conf: configT{ - Socketaddr: "127.0.0.1:8642", - Type: "tcp", - }, - logger: &logWrapper{ - l: logger, - }, + trans := New(logger) + require.NotNil(t, trans) + + socket, ok := trans.(*Socket) + require.True(t, ok) + require.NotNil(t, socket.logger) + require.NotNil(t, socket.logger.l) +} + +func TestListen(t *testing.T) { + tmpdir, err := os.MkdirTemp(".", "socket_test_tmp") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + logpath := path.Join(tmpdir, "test.log") + logger, err := logging.NewLogger(logging.DEBUG, logpath) + require.NoError(t, err) + + trans := New(logger) + socket := trans.(*Socket) + + // Listen should not panic and should print the event + testEvent := data.Event{ + Index: "test-index", + Time: 123.456, + Type: data.EVENT, + Publisher: "test-publisher", + Severity: data.INFO, } + socket.Listen(testEvent) +} - t.Run("test large message transport single connection", func(t *testing.T) { - msg := make([]byte, regularBuffSize) - for i := 0; i < regularBuffSize; i++ { - msg[i] = byte('X') +func TestConfig(t *testing.T) { + tmpdir, err := os.MkdirTemp(".", "socket_test_tmp") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + logpath := path.Join(tmpdir, "test.log") + logger, err := logging.NewLogger(logging.DEBUG, logpath) + require.NoError(t, err) + + t.Run("valid unix socket config", func(t *testing.T) { + trans := New(logger) + socket := trans.(*Socket) + + config := ` +type: unix +path: /tmp/test.sock +` + err := socket.Config([]byte(config)) + require.NoError(t, err) + assert.Equal(t, "unix", socket.conf.Type) + assert.Equal(t, "/tmp/test.sock", socket.conf.Path) + }) + + t.Run("valid udp socket config", func(t *testing.T) { + trans := New(logger) + socket := trans.(*Socket) + + config := ` +type: udp +socketaddr: 127.0.0.1:9999 +` + err := socket.Config([]byte(config)) + require.NoError(t, err) + assert.Equal(t, "udp", socket.conf.Type) + assert.Equal(t, "127.0.0.1:9999", socket.conf.Socketaddr) + }) + + t.Run("valid tcp socket config", func(t *testing.T) { + trans := New(logger) + socket := trans.(*Socket) + + config := ` +type: tcp +socketaddr: 127.0.0.1:8888 +` + err := socket.Config([]byte(config)) + require.NoError(t, err) + assert.Equal(t, "tcp", socket.conf.Type) + assert.Equal(t, "127.0.0.1:8888", socket.conf.Socketaddr) + }) + + t.Run("config with dump messages enabled", func(t *testing.T) { + trans := New(logger) + socket := trans.(*Socket) + + dumpPath := path.Join(tmpdir, "dump.txt") + config := ` +type: unix +path: /tmp/test.sock +dumpMessages: + enabled: true + path: ` + dumpPath + err := socket.Config([]byte(config)) + require.NoError(t, err) + assert.Equal(t, true, socket.conf.DumpMessages.Enabled) + assert.Equal(t, dumpPath, socket.conf.DumpMessages.Path) + require.NotNil(t, socket.dumpFile) + require.NotNil(t, socket.dumpBuf) + socket.dumpFile.Close() + }) + + t.Run("invalid socket type", func(t *testing.T) { + trans := New(logger) + socket := trans.(*Socket) + + config := ` +type: invalid +path: /tmp/test.sock +` + err := socket.Config([]byte(config)) + require.Error(t, err) + require.Contains(t, err.Error(), "unable to determine socket type") + }) + + t.Run("unix socket without path", func(t *testing.T) { + trans := New(logger) + socket := trans.(*Socket) + + config := ` +type: unix +` + err := socket.Config([]byte(config)) + require.Error(t, err) + require.Contains(t, err.Error(), "path") + }) + + t.Run("udp socket without socketaddr", func(t *testing.T) { + trans := New(logger) + socket := trans.(*Socket) + + config := ` +type: udp +` + err := socket.Config([]byte(config)) + require.Error(t, err) + require.Contains(t, err.Error(), "socketaddr") + }) + + t.Run("tcp socket without socketaddr", func(t *testing.T) { + trans := New(logger) + socket := trans.(*Socket) + + config := ` +type: tcp +` + err := socket.Config([]byte(config)) + require.Error(t, err) + require.Contains(t, err.Error(), "socketaddr") + }) + + t.Run("invalid yaml config", func(t *testing.T) { + trans := New(logger) + socket := trans.(*Socket) + + config := ` +this is not: valid: yaml +` + err := socket.Config([]byte(config)) + require.Error(t, err) + }) + + t.Run("case insensitive socket type", func(t *testing.T) { + trans := New(logger) + socket := trans.(*Socket) + + config := ` +type: TCP +socketaddr: 127.0.0.1:8888 +` + err := socket.Config([]byte(config)) + require.NoError(t, err) + assert.Equal(t, "tcp", socket.conf.Type) + }) + + t.Run("default values", func(t *testing.T) { + trans := New(logger) + socket := trans.(*Socket) + + config := ` +path: /tmp/test.sock +` + err := socket.Config([]byte(config)) + require.NoError(t, err) + assert.Equal(t, "unix", socket.conf.Type) + assert.Equal(t, false, socket.conf.DumpMessages.Enabled) + assert.Equal(t, "/dev/stdout", socket.conf.DumpMessages.Path) + }) +} + +func TestInitializationErrors(t *testing.T) { + tmpdir, err := os.MkdirTemp(".", "socket_test_tmp") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + logpath := path.Join(tmpdir, "test.log") + logger, err := logging.NewLogger(logging.DEBUG, logpath) + require.NoError(t, err) + + t.Run("unix socket initialization with path in non-existent directory", func(t *testing.T) { + // Create a file where we want to create a directory, causing mkdir to fail + blockingFile := path.Join(tmpdir, "blocking_file") + err := os.WriteFile(blockingFile, []byte("test"), 0600) + require.NoError(t, err) + + // Try to create a socket in a "subdirectory" of this file (which is impossible) + invalidPath := path.Join(blockingFile, "subdir", "socket.sock") + + trans := Socket{ + conf: configT{ + Path: invalidPath, + Type: unix, + }, + logger: &logWrapper{ + l: logger, + }, } - msg[regularBuffSize-1] = byte('$') - msg = append(msg, []byte(addition)...) - msgLength := new(bytes.Buffer) - err := binary.Write(msgLength, binary.LittleEndian, uint64(len(msg))) + + result := trans.initUnixSocket() + require.Nil(t, result) + }) + + t.Run("udp socket initialization with invalid address", func(t *testing.T) { + trans := Socket{ + conf: configT{ + Socketaddr: "not-a-valid-address:::::99999", + Type: udp, + }, + logger: &logWrapper{ + l: logger, + }, + } + + result := trans.initUDPSocket() + require.Nil(t, result) + }) + + t.Run("udp socket initialization with address already in use", func(t *testing.T) { + // First, bind to a port + addr, err := net.ResolveUDPAddr(udp, "127.0.0.1:18680") + require.NoError(t, err) + firstConn, err := net.ListenUDP(udp, addr) require.NoError(t, err) - msg = append(msgLength.Bytes(), msg...) + defer firstConn.Close() + + // Now try to bind to the same port + trans := Socket{ + conf: configT{ + Socketaddr: "127.0.0.1:18680", + Type: udp, + }, + logger: &logWrapper{ + l: logger, + }, + } + + result := trans.initUDPSocket() + require.Nil(t, result) + }) + + t.Run("tcp socket initialization with invalid address", func(t *testing.T) { + trans := Socket{ + conf: configT{ + Socketaddr: "not-a-valid-address:::::99999", + Type: tcp, + }, + logger: &logWrapper{ + l: logger, + }, + } + + result := trans.initTCPSocket() + require.Nil(t, result) + }) + + t.Run("tcp socket initialization with address already in use", func(t *testing.T) { + // First, bind to a port + addr, err := net.ResolveTCPAddr(tcp, "127.0.0.1:18681") + require.NoError(t, err) + firstListener, err := net.ListenTCP(tcp, addr) + require.NoError(t, err) + defer firstListener.Close() + + // Now try to bind to the same port + trans := Socket{ + conf: configT{ + Socketaddr: "127.0.0.1:18681", + Type: tcp, + }, + logger: &logWrapper{ + l: logger, + }, + } + + result := trans.initTCPSocket() + require.Nil(t, result) + }) +} + +func TestDumpMessagesFeature(t *testing.T) { + tmpdir, err := os.MkdirTemp(".", "socket_test_tmp") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + logpath := path.Join(tmpdir, "test.log") + logger, err := logging.NewLogger(logging.DEBUG, logpath) + require.NoError(t, err) + + t.Run("unix socket with dump messages enabled", func(t *testing.T) { + dumpPath := path.Join(tmpdir, "dump_unix.txt") + sktpath := path.Join(tmpdir, "socket_dump") + skt, err := os.OpenFile(sktpath, os.O_RDWR|os.O_CREATE, os.ModeSocket|os.ModePerm) + require.NoError(t, err) + defer skt.Close() + + trans := Socket{ + conf: configT{ + Path: sktpath, + Type: unix, + DumpMessages: struct { + Enabled bool + Path string + }{ + Enabled: true, + Path: dumpPath, + }, + }, + logger: &logWrapper{ + l: logger, + }, + } + + // Initialize dump file and buffer + trans.dumpFile, err = os.OpenFile(dumpPath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666) + require.NoError(t, err) + defer trans.dumpFile.Close() + trans.dumpBuf = bufio.NewWriter(trans.dumpFile) - // verify transport ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} + wg.Add(1) + var receivedMsg []byte go trans.Run(ctx, func(mess []byte) { - wg.Add(1) - strmsg := string(mess) - assert.Equal(t, regularBuffSize+len(addition), len(strmsg)) // we received whole message - assert.Equal(t, addition, strmsg[len(strmsg)-len(addition):]) // and the out-of-band part is correct + receivedMsg = mess wg.Done() }, make(chan bool)) - // write to socket - wskt, err := net.Dial("tcp", "127.0.0.1:8642") - if err != nil { - // The socket might not be listening yet, wait a little bit and try to connect again - for retries := 0; err != nil && retries < 3; retries++ { - time.Sleep(2 * time.Second) - wskt, err = net.Dial("tcp", "127.0.0.1:8642") + // Wait for socket file to be created + for { + stat, err := os.Stat(sktpath) + require.NoError(t, err) + if stat.Mode()&os.ModeType == os.ModeSocket { + break } + time.Sleep(250 * time.Millisecond) } + + // Send a message + msg := []byte("test message with dump") + wskt, err := net.DialUnix("unixgram", nil, &net.UnixAddr{Name: sktpath, Net: "unixgram"}) require.NoError(t, err) _, err = wskt.Write(msg) require.NoError(t, err) - cancel() wg.Wait() + cancel() + time.Sleep(100 * time.Millisecond) wskt.Close() + + // Verify message was received + assert.Equal(t, string(msg), string(receivedMsg)) + + // Verify message was dumped to file + dumpContent, err := os.ReadFile(dumpPath) + require.NoError(t, err) + require.Contains(t, string(dumpContent), "test message with dump") }) - t.Run("test large message transport multiple connections", func(t *testing.T) { - msg := make([]byte, regularBuffSize) - for i := 0; i < regularBuffSize; i++ { - msg[i] = byte('X') + t.Run("tcp socket with dump messages enabled", func(t *testing.T) { + dumpPath := path.Join(tmpdir, "dump_tcp.txt") + trans := Socket{ + conf: configT{ + Socketaddr: "127.0.0.1:18690", + Type: tcp, + DumpMessages: struct { + Enabled bool + Path string + }{ + Enabled: true, + Path: dumpPath, + }, + }, + logger: &logWrapper{ + l: logger, + }, } - msg[regularBuffSize-1] = byte('$') - msg = append(msg, []byte(addition)...) - msgLength := new(bytes.Buffer) - err := binary.Write(msgLength, binary.LittleEndian, uint64(len(msg))) + + // Initialize dump file and buffer + trans.dumpFile, err = os.OpenFile(dumpPath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666) require.NoError(t, err) - msg = append(msgLength.Bytes(), msg...) + defer trans.dumpFile.Close() + trans.dumpBuf = bufio.NewWriter(trans.dumpFile) + + msgContent := []byte("tcp dump test message") + fullMsg := createTCPMessage(t, msgContent) - // verify transport ctx, cancel := context.WithCancel(context.Background()) wg := sync.WaitGroup{} + wg.Add(1) go trans.Run(ctx, func(mess []byte) { - wg.Add(1) - strmsg := string(mess) - assert.Equal(t, regularBuffSize+len(addition), len(strmsg)) // we received whole message - assert.Equal(t, addition, strmsg[len(strmsg)-len(addition):]) // and the out-of-band part is correct + assert.Equal(t, string(msgContent), string(mess)) wg.Done() }, make(chan bool)) - // write to socket - wskt1, err := net.Dial("tcp", "127.0.0.1:8642") - if err != nil { - // The socket might not be listening yet, wait a little bit and try to connect again - for retries := 0; err != nil && retries < 3; retries++ { - time.Sleep(2 * time.Second) - wskt1, err = net.Dial("tcp", "127.0.0.1:8642") - } + time.Sleep(100 * time.Millisecond) + + wskt := connectTCPWithRetry(t, "127.0.0.1:18690") + _, err = wskt.Write(fullMsg) + require.NoError(t, err) + + wg.Wait() + cancel() + time.Sleep(100 * time.Millisecond) + wskt.Close() + + // Verify message was dumped to file + dumpContent, err := os.ReadFile(dumpPath) + require.NoError(t, err) + require.Contains(t, string(dumpContent), "tcp dump test message") + }) +} + +func TestWriteTCPMsgErrors(t *testing.T) { + tmpdir, err := os.MkdirTemp(".", "socket_test_tmp") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + logpath := path.Join(tmpdir, "test.log") + logger, err := logging.NewLogger(logging.DEBUG, logpath) + require.NoError(t, err) + + t.Run("overflow protection - negative length", func(t *testing.T) { + trans := Socket{ + conf: configT{ + Socketaddr: "127.0.0.1:8670", + Type: "tcp", + }, + logger: &logWrapper{ + l: logger, + }, } + + // Create a buffer with a message that would cause overflow + msgBuffer := make([]byte, 100) + // Write a very large length value that will overflow when added to position + binary.LittleEndian.PutUint64(msgBuffer[0:8], uint64(0x7FFFFFFFFFFFFFFF)) + + messageCount := 0 + pos, err := trans.WriteTCPMsg(func(data []byte) { + messageCount++ + }, msgBuffer, len(msgBuffer)) + require.NoError(t, err) + // Should stop without processing any messages due to overflow protection + assert.Equal(t, 0, messageCount) + assert.Equal(t, int64(0), pos) + }) + + t.Run("incomplete message - not enough data", func(t *testing.T) { + trans := Socket{ + conf: configT{ + Socketaddr: "127.0.0.1:8671", + Type: "tcp", + }, + logger: &logWrapper{ + l: logger, + }, + } + + // Create a buffer with message length header indicating more data than available + msgBuffer := make([]byte, 20) + // Indicate 100 bytes of data, but we only have 12 bytes after the length header + binary.LittleEndian.PutUint64(msgBuffer[0:8], uint64(100)) + copy(msgBuffer[8:], []byte("test")) + + messageCount := 0 + pos, err := trans.WriteTCPMsg(func(data []byte) { + messageCount++ + }, msgBuffer, len(msgBuffer)) - // We shouldn't need to retry the second connection, if this fails, then something is wrong - wskt2, err := net.Dial("tcp", "127.0.0.1:8642") require.NoError(t, err) + // Should not process the incomplete message + assert.Equal(t, 0, messageCount) + assert.Equal(t, int64(0), pos) + }) - _, err = wskt1.Write(msg) + t.Run("multiple messages with partial last message", func(t *testing.T) { + trans := Socket{ + conf: configT{ + Socketaddr: "127.0.0.1:8672", + Type: "tcp", + }, + logger: &logWrapper{ + l: logger, + }, + } + + var msgBuffer bytes.Buffer + + // First complete message + msg1 := []byte("Complete message 1") + err := binary.Write(&msgBuffer, binary.LittleEndian, uint64(len(msg1))) require.NoError(t, err) - _, err = wskt2.Write(msg) + msgBuffer.Write(msg1) + + // Second complete message + msg2 := []byte("Complete message 2") + err = binary.Write(&msgBuffer, binary.LittleEndian, uint64(len(msg2))) require.NoError(t, err) + msgBuffer.Write(msg2) - cancel() - wg.Wait() - wskt1.Close() - wskt2.Close() + // Third incomplete message (header indicates more data than available) + err = binary.Write(&msgBuffer, binary.LittleEndian, uint64(1000)) + require.NoError(t, err) + msgBuffer.Write([]byte("Incomplete")) + + receivedMessages := []string{} + pos, err := trans.WriteTCPMsg(func(data []byte) { + receivedMessages = append(receivedMessages, string(data)) + }, msgBuffer.Bytes(), msgBuffer.Len()) + + require.NoError(t, err) + // Should process only the two complete messages + assert.Equal(t, 2, len(receivedMessages)) + assert.Equal(t, "Complete message 1", receivedMessages[0]) + assert.Equal(t, "Complete message 2", receivedMessages[1]) + // Position should be at the start of the incomplete message + expectedPos := int64(8 + len(msg1) + 8 + len(msg2)) + assert.Equal(t, expectedPos, pos) }) }