Skip to content

Commit 4a85ec6

Browse files
jentfooMzack9999
andauthored
fix: Accurate session count to avoid constant upward drift (#1348)
* fix: Accurate session count to avoid constant upward drift The register handler incremented the session counter before `SetIDPublicKey` could fail, leaking +1 on duplicate IDs or bad keys. Now only incremented after successful registration. The deregister handler decremented the counter before validating the request, leaking -1 on malformed or unauthorized requests. Removed the explicit decrement entirely (handled in logic below). Cache eviction and TTL expiry silently removed sessions without decrementing the counter, causing monotonic growth. Unified all session decrements into a single cache removal callback that fires on deregistration, eviction, and cache close, filtering to only count entries with a SecretKey (true client sessions). session_total was added as a metric so that even short lived sessions can be viewed in the metrics. * address review nits in session-tracking tests --------- Co-authored-by: Mzack9999 <mzack9999@protonmail.com>
1 parent e3e2878 commit 4a85ec6

7 files changed

Lines changed: 187 additions & 24 deletions

File tree

cmd/interactsh-server/main.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"path/filepath"
1515
"strconv"
1616
"strings"
17+
"sync/atomic"
1718
"time"
1819

1920
_ "net/http/pprof"
@@ -261,6 +262,11 @@ func main() {
261262
}
262263
}
263264

265+
serverOptions.Stats = &server.Metrics{}
266+
storeOptions.OnRemoval = func() {
267+
atomic.AddInt64(&serverOptions.Stats.Sessions, -1)
268+
}
269+
264270
var err error
265271
store, err = storage.New(&storeOptions)
266272
if err != nil {
@@ -273,8 +279,6 @@ func main() {
273279
_ = serverOptions.Storage.SetID(serverOptions.Token)
274280
}
275281

276-
serverOptions.Stats = &server.Metrics{}
277-
278282
// If root-tld is enabled create a singleton unencrypted record in the store
279283
if serverOptions.RootTLD {
280284
for _, domain := range serverOptions.Domains {

pkg/server/http_server.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -388,13 +388,13 @@ func (h *HTTPServer) registerHandler(w http.ResponseWriter, req *http.Request) {
388388
return
389389
}
390390

391-
atomic.AddInt64(&h.options.Stats.Sessions, 1)
392-
393391
if err := h.options.Storage.SetIDPublicKey(r.CorrelationID, r.SecretKey, r.PublicKey); err != nil {
394392
gologger.Warning().Msgf("Could not set id and public key for %s: %s\n", r.CorrelationID, err)
395393
jsonError(w, fmt.Sprintf("could not set id and public key: %s", err), http.StatusBadRequest)
396394
return
397395
}
396+
atomic.AddInt64(&h.options.Stats.Sessions, 1)
397+
atomic.AddInt64(&h.options.Stats.SessionsTotal, 1)
398398
jsonMsg(w, "registration successful", http.StatusOK)
399399
gologger.Debug().Msgf("Registered correlationID %s for key\n", r.CorrelationID)
400400
}
@@ -409,8 +409,6 @@ type DeregisterRequest struct {
409409

410410
// deregisterHandler is a handler for client deregister requests
411411
func (h *HTTPServer) deregisterHandler(w http.ResponseWriter, req *http.Request) {
412-
atomic.AddInt64(&h.options.Stats.Sessions, -1)
413-
414412
r := &DeregisterRequest{}
415413
if err := jsoniter.NewDecoder(req.Body).Decode(r); err != nil {
416414
gologger.Warning().Msgf("Could not decode json body: %s\n", err)

pkg/server/http_server_test.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,25 @@
11
package server
22

33
import (
4+
"bytes"
5+
"crypto/rand"
6+
"crypto/rsa"
47
"crypto/tls"
8+
"crypto/x509"
9+
"encoding/base64"
10+
"encoding/pem"
511
"io"
612
"net/http"
713
"net/http/httptest"
14+
"sync"
15+
"sync/atomic"
816
"testing"
917
"time"
1018

19+
"github.com/google/uuid"
20+
jsoniter "github.com/json-iterator/go"
21+
"github.com/projectdiscovery/interactsh/pkg/storage"
22+
"github.com/rs/xid"
1123
"github.com/stretchr/testify/require"
1224
)
1325

@@ -70,3 +82,72 @@ func TestWriteResponseFromDynamicRequest(t *testing.T) {
7082
require.Equal(t, resp.Header.Get("Test"), "Another", "could not get correct result")
7183
})
7284
}
85+
86+
func TestSessionTotalMetric(t *testing.T) {
87+
stats := &Metrics{}
88+
removed := make(chan struct{})
89+
closeOnce := sync.Once{}
90+
91+
store, err := storage.New(&storage.Options{
92+
EvictionTTL: 5 * time.Minute,
93+
OnRemoval: func() {
94+
atomic.AddInt64(&stats.Sessions, -1)
95+
closeOnce.Do(func() { close(removed) })
96+
},
97+
})
98+
require.NoError(t, err)
99+
defer func() { _ = store.Close() }()
100+
101+
h := &HTTPServer{
102+
options: &Options{
103+
Storage: store,
104+
Stats: stats,
105+
},
106+
}
107+
108+
// Generate a client key pair and registration payload.
109+
priv, err := rsa.GenerateKey(rand.Reader, 2048)
110+
require.NoError(t, err)
111+
pubBytes, err := x509.MarshalPKIXPublicKey(&priv.PublicKey)
112+
require.NoError(t, err)
113+
pubPem := pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubBytes})
114+
pubB64 := base64.StdEncoding.EncodeToString(pubPem)
115+
116+
correlationID := xid.New().String()
117+
secretKey := uuid.New().String()
118+
119+
// --- Register ---
120+
regBody, err := jsoniter.Marshal(&RegisterRequest{
121+
PublicKey: pubB64,
122+
SecretKey: secretKey,
123+
CorrelationID: correlationID,
124+
})
125+
require.NoError(t, err)
126+
req := httptest.NewRequest("POST", "/register", bytes.NewReader(regBody))
127+
w := httptest.NewRecorder()
128+
h.registerHandler(w, req)
129+
require.Equal(t, http.StatusOK, w.Code)
130+
131+
require.Equal(t, int64(1), atomic.LoadInt64(&stats.Sessions), "sessions should be 1 after register")
132+
require.Equal(t, int64(1), atomic.LoadInt64(&stats.SessionsTotal), "sessions_total should be 1 after register")
133+
134+
// --- Deregister ---
135+
deregBody, err := jsoniter.Marshal(&DeregisterRequest{
136+
SecretKey: secretKey,
137+
CorrelationID: correlationID,
138+
})
139+
require.NoError(t, err)
140+
req = httptest.NewRequest("POST", "/deregister", bytes.NewReader(deregBody))
141+
w = httptest.NewRecorder()
142+
h.deregisterHandler(w, req)
143+
require.Equal(t, http.StatusOK, w.Code)
144+
145+
select {
146+
case <-removed:
147+
case <-time.After(2 * time.Second):
148+
t.Fatal("timed out waiting for OnRemoval callback")
149+
}
150+
151+
require.Equal(t, int64(0), atomic.LoadInt64(&stats.Sessions), "sessions should be 0 after deregister")
152+
require.Equal(t, int64(1), atomic.LoadInt64(&stats.SessionsTotal), "sessions_total should remain 1 after deregister")
153+
}

pkg/server/metrics.go

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,18 @@ import (
99
)
1010

1111
type Metrics struct {
12-
Dns uint64 `json:"dns"`
13-
Ftp uint64 `json:"ftp"`
14-
Http uint64 `json:"http"`
15-
Ldap uint64 `json:"ldap"`
16-
Smb uint64 `json:"smb"`
17-
Smtp uint64 `json:"smtp"`
18-
Sessions int64 `json:"sessions"`
19-
Cache *storage.CacheMetrics `json:"cache"`
20-
Memory *MemoryMetrics `json:"memory"`
21-
Cpu *CpuStats `json:"cpu"`
22-
Network *NetworkStats `json:"network"`
12+
Dns uint64 `json:"dns"`
13+
Ftp uint64 `json:"ftp"`
14+
Http uint64 `json:"http"`
15+
Ldap uint64 `json:"ldap"`
16+
Smb uint64 `json:"smb"`
17+
Smtp uint64 `json:"smtp"`
18+
Sessions int64 `json:"sessions"`
19+
SessionsTotal int64 `json:"sessions_total"`
20+
Cache *storage.CacheMetrics `json:"cache"`
21+
Memory *MemoryMetrics `json:"memory"`
22+
Cpu *CpuStats `json:"cpu"`
23+
Network *NetworkStats `json:"network"`
2324
}
2425

2526
func GetCacheMetrics(options *Options) *storage.CacheMetrics {

pkg/storage/option.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ type Options struct {
1515
MaxSize int
1616
MaxSharedInteractions int
1717
EvictionStrategy EvictionStrategy
18+
// OnRemoval is called for each client session removed from cache
19+
// (deregistration, TTL expiry, size eviction, or cache close).
20+
OnRemoval func()
1821
}
1922

2023
func (options *Options) UseDisk() bool {

pkg/storage/roundtrip_test.go

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ func TestStaleDataCleanupOnReRegistration(t *testing.T) {
403403
_ = priv1
404404
}
405405

406-
// TestCacheEvictionCleansLevelDB verifies the OnCacheRemovalCallback properly
406+
// TestCacheEvictionCleansLevelDB verifies the onCacheRemoval callback properly
407407
// deletes LevelDB entries when cache entries are evicted.
408408
func TestCacheEvictionCleansLevelDB(t *testing.T) {
409409
tmpDir, err := os.MkdirTemp("", "interactsh-eviction-*")
@@ -451,7 +451,74 @@ func TestCacheEvictionCleansLevelDB(t *testing.T) {
451451
// Small delay for async eviction callback
452452
time.Sleep(50 * time.Millisecond)
453453

454-
// LevelDB entry should be cleaned up by OnCacheRemovalCallback
454+
// LevelDB entry should be cleaned up by onCacheRemoval
455455
_, err = db.db.Get([]byte(correlationID), nil)
456456
require.Error(t, err, "LevelDB entry should be deleted after cache eviction")
457457
}
458+
459+
// TestOnRemovalSessionTracking verifies that the OnRemoval callback fires
460+
// exactly once per client session on deregister and TTL eviction, and does
461+
// not fire for non-session entries created via SetID.
462+
func TestOnRemovalSessionTracking(t *testing.T) {
463+
removed := make(chan struct{}, 10)
464+
onRemoval := func() { removed <- struct{}{} }
465+
466+
db, err := New(&Options{
467+
EvictionTTL: 50 * time.Millisecond,
468+
EvictionStrategy: EvictionStrategyFixed,
469+
OnRemoval: onRemoval,
470+
})
471+
require.NoError(t, err)
472+
defer func() { _ = db.Close() }()
473+
474+
waitRemoval := func(msg string) {
475+
t.Helper()
476+
select {
477+
case <-removed:
478+
case <-time.After(2 * time.Second):
479+
t.Fatalf("timed out waiting for OnRemoval: %s", msg)
480+
}
481+
}
482+
483+
// --- Non-session entries (SetID) must not trigger OnRemoval ---
484+
// Invalidate a SetID entry, then register+deregister a real session as a
485+
// FIFO barrier: the cache event channel is ordered, so receiving the
486+
// session's callback proves the SetID invalidation was already processed.
487+
_ = db.SetID("token-entry")
488+
db.cache.Invalidate("token-entry")
489+
490+
secret := uuid.New().String()
491+
cid := xid.New().String()
492+
_, pubKey := generateRSAKeyPair(t)
493+
require.NoError(t, db.SetIDPublicKey(cid, secret, pubKey))
494+
require.NoError(t, db.RemoveID(cid, secret))
495+
waitRemoval("deregister barrier")
496+
select {
497+
case <-removed:
498+
t.Fatal("SetID entry should not trigger OnRemoval")
499+
default:
500+
}
501+
502+
// --- TTL eviction must trigger OnRemoval ---
503+
secret2 := uuid.New().String()
504+
cid2 := xid.New().String()
505+
_, pubKey2 := generateRSAKeyPair(t)
506+
require.NoError(t, db.SetIDPublicKey(cid2, secret2, pubKey2))
507+
508+
// Periodically access the cache to trigger lazy eviction.
509+
stop := make(chan struct{})
510+
defer close(stop)
511+
go func() {
512+
ticker := time.NewTicker(10 * time.Millisecond)
513+
defer ticker.Stop()
514+
for {
515+
select {
516+
case <-stop:
517+
return
518+
case <-ticker.C:
519+
db.cache.GetIfPresent(cid2)
520+
}
521+
}
522+
}()
523+
waitRemoval("TTL eviction")
524+
}

pkg/storage/storagedb.go

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,7 @@ func New(options *Options) (*StorageDB, error) {
5050
cacheOptions = append(cacheOptions, cache.WithExpireAfterAccess(options.EvictionTTL))
5151
}
5252
}
53-
if options.UseDisk() {
54-
cacheOptions = append(cacheOptions, cache.WithRemovalListener(storageDB.OnCacheRemovalCallback))
55-
}
53+
cacheOptions = append(cacheOptions, cache.WithRemovalListener(storageDB.onCacheRemoval))
5654
cacheDb := cache.New(cacheOptions...)
5755
storageDB.cache = cacheDb
5856

@@ -77,10 +75,21 @@ func New(options *Options) (*StorageDB, error) {
7775
return storageDB, nil
7876
}
7977

80-
func (s *StorageDB) OnCacheRemovalCallback(key cache.Key, value cache.Value) {
81-
if k, ok := key.(string); ok {
78+
func (s *StorageDB) onCacheRemoval(key cache.Key, value cache.Value) {
79+
k, ok := key.(string)
80+
if !ok {
81+
return
82+
}
83+
if s.Options.UseDisk() && s.db != nil {
8284
_ = s.db.Delete([]byte(k), &opt.WriteOptions{})
8385
}
86+
// Only fire for client sessions (entries with a SecretKey),
87+
// not for token/domain entries created via SetID.
88+
if s.Options.OnRemoval != nil {
89+
if cd, ok := value.(*CorrelationData); ok && cd.SecretKey != "" {
90+
s.Options.OnRemoval()
91+
}
92+
}
8493
}
8594

8695
func (s *StorageDB) GetCacheMetrics() (*CacheMetrics, error) {

0 commit comments

Comments
 (0)