diff --git a/announce_test.go b/announce_test.go new file mode 100644 index 00000000..b822b168 --- /dev/null +++ b/announce_test.go @@ -0,0 +1,505 @@ +package pubsub + +import ( + "bytes" + "context" + "testing" + "time" + + pb "github.com/libp2p/go-libp2p-pubsub/pb" + "github.com/libp2p/go-libp2p/core/peer" +) + +func TestAnnounceStorage(t *testing.T) { + ctx := t.Context() + + const topic = "test-announce-storage" + hosts := getDefaultHosts(t, 2) + psubs := getGossipsubs(ctx, hosts) + connectAll(t, hosts) + + topics := getTopics(psubs, topic) + + // Host 1 subscribes + _, err := topics[1].Subscribe() + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Millisecond * 500) + + // Host 0 announces + payload := []byte("test storage") + expiry := time.Now().Add(time.Second * 10) + err = topics[0].Announce(ctx, payload, expiry) + if err != nil { + t.Fatal(err) + } + + // Verify the message is stored in host 0's message cache announcements + gs0, ok := psubs[0].rt.(*GossipSubRouter) + if !ok { + t.Fatal("expected GossipSubRouter") + } + + resultChan := make(chan int, 1) + psubs[0].eval <- func() { + // Count total announcements across all buckets in the wheel + count := 0 + for _, bucket := range gs0.mcache.annWheel { + count += len(bucket) + } + resultChan <- count + } + + count := <-resultChan + if count != 1 { + t.Fatalf("expected 1 announcement stored, got %d", count) + } +} + +func TestAnnounceBasic(t *testing.T) { + ctx := t.Context() + + const topic = "test-announce" + hosts := getDefaultHosts(t, 3) + psubs := getGossipsubs(ctx, hosts) + connectAll(t, hosts) + + // Get topics for all hosts + topics := getTopics(psubs, topic) + + // Subscribe on host 1 and 2 + sub1, err := topics[1].Subscribe() + if err != nil { + t.Fatal(err) + } + + sub2, err := topics[2].Subscribe() + if err != nil { + t.Fatal(err) + } + + // Wait for mesh to form and subscriptions to propagate + time.Sleep(time.Second * 2) + + // Host 0 announces a message (not subscribed) + payload := []byte("announced message") + expiry := time.Now().Add(time.Second * 5) + err = topics[0].Announce(ctx, payload, expiry) + if err != nil { + t.Fatal(err) + } + + // Subscribers should receive the message via IWANT + timeoutCtx, cancel := context.WithTimeout(ctx, time.Second*5) + defer cancel() + + msg1, err := sub1.Next(timeoutCtx) + if err != nil { + t.Fatalf("host 1 failed to receive message: %v", err) + } + if !bytes.Equal(msg1.Data, payload) { + t.Fatalf("received incorrect message: got %s, want %s", msg1.Data, payload) + } + + msg2, err := sub2.Next(timeoutCtx) + if err != nil { + t.Fatalf("host 2 failed to receive message: %v", err) + } + if !bytes.Equal(msg2.Data, payload) { + t.Fatalf("received incorrect message: got %s, want %s", msg2.Data, payload) + } +} + +func TestAnnounceWhenSubscribed(t *testing.T) { + ctx := t.Context() + + const topic = "test-announce-subscribed" + hosts := getDefaultHosts(t, 2) + psubs := getGossipsubs(ctx, hosts) + connectAll(t, hosts) + + topics := getTopics(psubs, topic) + + // Both hosts subscribe + sub0, err := topics[0].Subscribe() + if err != nil { + t.Fatal(err) + } + + sub1, err := topics[1].Subscribe() + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Millisecond * 500) + + // Host 0 announces while subscribed + payload := []byte("announced while subscribed") + expiry := time.Now().Add(time.Second * 5) + err = topics[0].Announce(ctx, payload, expiry) + if err != nil { + t.Fatal(err) + } + + // Host 0 should NOT receive its own announcement (marked as seen) + timeoutCtx, cancel := context.WithTimeout(ctx, time.Millisecond*200) + defer cancel() + msg, err := sub0.Next(timeoutCtx) + if err != context.DeadlineExceeded { + if msg != nil { + t.Fatal("announcer should not receive own announcement when subscribed") + } + t.Fatalf("expected timeout, got error: %v", err) + } + + // Host 1 should receive it + msg1, err := sub1.Next(ctx) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(msg1.Data, payload) { + t.Fatalf("received incorrect message: got %s, want %s", msg1.Data, payload) + } +} + +func TestAnnounceDuplicate(t *testing.T) { + ctx := t.Context() + + const topic = "test-announce-duplicate" + hosts := getDefaultHosts(t, 2) + psubs := getGossipsubs(ctx, hosts, WithMessageIdFn(func(msg *pb.Message) string { + // use a content addressed ID function + return string(msg.Data) + })) + connectAll(t, hosts) + + topics := getTopics(psubs, topic) + + // Host 0 subscribes + _, err := topics[0].Subscribe() + if err != nil { + t.Fatal(err) + } + + // Host 1 subscribes + sub1, err := topics[1].Subscribe() + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Millisecond * 500) + + payload := []byte("duplicate test") + expiry := time.Now().Add(time.Second * 5) + + // First announcement should succeed + err = topics[0].Announce(ctx, payload, expiry) + if err != nil { + t.Fatalf("first announce failed: %v", err) + } + + // Host 1 receives the message + msg1, err := sub1.Next(ctx) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(msg1.Data, payload) { + t.Fatal("received incorrect message") + } + + // Try announcing the exact same payload again - this is a duplicate + expiry = time.Now().Add(time.Second * 5) + err = topics[0].Announce(ctx, payload, expiry) + if err != nil { + t.Fatalf("second announce failed: %v", err) + } + + // Host 1 should NOT receive the duplicate message (it should be filtered) + timeoutCtx, cancel := context.WithTimeout(ctx, time.Millisecond*500) + defer cancel() + msg2, err := sub1.Next(timeoutCtx) + if err != context.DeadlineExceeded { + if msg2 != nil { + t.Fatal("host 1 should not receive duplicate announcement") + } + t.Fatalf("expected timeout for duplicate message, got error: %v", err) + } +} + +func TestAnnounceExpiry(t *testing.T) { + ctx := t.Context() + + const topic = "test-announce-expiry" + hosts := getDefaultHosts(t, 2) + psubs := getGossipsubs(ctx, hosts) + connectAll(t, hosts) + + topics := getTopics(psubs, topic) + + // Only host 1 subscribes + _, err := topics[1].Subscribe() + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Millisecond * 500) + + // Announce with very short expiry + payload := []byte("expires soon") + expiry := time.Now().Add(time.Millisecond * 100) + err = topics[0].Announce(ctx, payload, expiry) + if err != nil { + t.Fatal(err) + } + + // Wait for expiry plus heartbeat + time.Sleep(time.Millisecond*100 + time.Second*2) + + // Try to access the gossipsub router to verify cleanup + gs0, ok := psubs[0].rt.(*GossipSubRouter) + if !ok { + t.Fatal("expected GossipSubRouter") + } + + // Check that the announcement was cleaned up + resultChan := make(chan int, 1) + psubs[0].eval <- func() { + // Count total announcements across all buckets in the wheel + count := 0 + for _, bucket := range gs0.mcache.annWheel { + count += len(bucket) + } + resultChan <- count + } + + announcementCount := <-resultChan + if announcementCount != 0 { + t.Fatalf("expected 0 announcements after expiry, got %d", announcementCount) + } +} + +func TestAnnounceNoSubscribers(t *testing.T) { + ctx := t.Context() + + const topic = "test-announce-no-subs" + hosts := getDefaultHosts(t, 2) + psubs := getGossipsubs(ctx, hosts) + connectAll(t, hosts) + + topics := getTopics(psubs, topic) + + // No one subscribes + time.Sleep(time.Millisecond * 500) + + // Announce should succeed even without subscribers (it's a no-op) + payload := []byte("no subscribers") + expiry := time.Now().Add(time.Second * 5) + err := topics[0].Announce(ctx, payload, expiry) + if err != nil { + t.Fatal(err) + } + + // Since no one is subscribed, the message is not stored and no IHAVE is sent + _, err = topics[1].Subscribe() + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Millisecond * 500) + + // Now announce another message - this one should be received + payload2 := []byte("with subscriber") + expiry2 := time.Now().Add(time.Second * 5) + err = topics[0].Announce(ctx, payload2, expiry2) + if err != nil { + t.Fatal(err) + } + + // Verify the announcement was stored + gs0, ok := psubs[0].rt.(*GossipSubRouter) + if !ok { + t.Fatal("expected GossipSubRouter") + } + + resultChan := make(chan int, 1) + psubs[0].eval <- func() { + // Count total announcements across all buckets in the wheel + count := 0 + for _, bucket := range gs0.mcache.annWheel { + count += len(bucket) + } + resultChan <- count + } + + count := <-resultChan + if count != 1 { + t.Fatalf("expected 1 announcements stored, got %d", count) + } +} + +func TestAnnounceMultipleMessages(t *testing.T) { + ctx := t.Context() + + const topic = "test-announce-multiple" + hosts := getDefaultHosts(t, 3) + psubs := getGossipsubs(ctx, hosts) + connectAll(t, hosts) + + topics := getTopics(psubs, topic) + + // All hosts subscribe + subs := make([]*Subscription, 3) + for i := range 3 { + sub, err := topics[i].Subscribe() + if err != nil { + t.Fatal(err) + } + subs[i] = sub + } + + time.Sleep(time.Millisecond * 500) + + // Host 0 announces multiple messages + numMessages := 5 + payloads := make([][]byte, numMessages) + expiry := time.Now().Add(time.Second * 10) + + for i := range numMessages { + payloads[i] = []byte("message " + string(rune('0'+i))) + err := topics[0].Announce(ctx, payloads[i], expiry) + if err != nil { + t.Fatal(err) + } + time.Sleep(time.Millisecond * 50) + } + + // Host 1 and 2 should receive all messages + for hostIdx := 1; hostIdx < 3; hostIdx++ { + receivedCount := 0 + for receivedCount < numMessages { + timeoutCtx, cancel := context.WithTimeout(ctx, time.Second*2) + _, err := subs[hostIdx].Next(timeoutCtx) + cancel() + if err != nil { + t.Fatalf("host %d: failed to receive message %d: %v", hostIdx, receivedCount, err) + } + receivedCount++ + } + } +} + +func TestAnnounceWithClosedTopic(t *testing.T) { + ctx := t.Context() + + const topic = "test-announce-closed" + hosts := getDefaultHosts(t, 1) + psubs := getGossipsubs(ctx, hosts) + + topics := getTopics(psubs, topic) + + // Close the topic + err := topics[0].Close() + if err != nil { + t.Fatal(err) + } + + // Announce should fail with ErrTopicClosed + payload := []byte("should fail") + expiry := time.Now().Add(time.Second * 5) + err = topics[0].Announce(ctx, payload, expiry) + if err != ErrTopicClosed { + t.Fatalf("expected ErrTopicClosed, got %v", err) + } +} + +func TestAnnounceWithFloodsub(t *testing.T) { + ctx := t.Context() + + const topic = "test-announce-floodsub" + hosts := getDefaultHosts(t, 1) + + // Create a floodsub instance instead of gossipsub + psubs := getPubsubs(ctx, hosts) // This creates floodsub + + topics := getTopics(psubs, topic) + + // Announce should fail with non-GossipSub router + payload := []byte("floodsub test") + expiry := time.Now().Add(time.Second * 5) + err := topics[0].Announce(ctx, payload, expiry) + if err == nil { + t.Fatal("expected error with floodsub router, got nil") + } +} + +func TestAnnounceGossipThreshold(t *testing.T) { + ctx := t.Context() + + const topic = "test-announce-threshold" + hosts := getDefaultHosts(t, 3) + + // Setup peer scoring with gossip threshold + psubs := getGossipsubs(ctx, hosts, + WithPeerScore( + &PeerScoreParams{ + AppSpecificScore: func(p peer.ID) float64 { + // Give host 2 a very low score + if p == hosts[2].ID() { + return -1000 + } + return 0 + }, + AppSpecificWeight: 1.0, + DecayInterval: time.Second, + DecayToZero: 0.01, + }, + &PeerScoreThresholds{ + GossipThreshold: -500, + PublishThreshold: -1000, + GraylistThreshold: -2000, + }, + ), + ) + + connectAll(t, hosts) + topics := getTopics(psubs, topic) + + // All hosts subscribe + _, err := topics[0].Subscribe() + if err != nil { + t.Fatal(err) + } + + _, err = topics[1].Subscribe() + if err != nil { + t.Fatal(err) + } + + sub2, err := topics[2].Subscribe() + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Second * 1) + + // Host 0 announces + payload := []byte("threshold test") + expiry := time.Now().Add(time.Second * 5) + err = topics[0].Announce(ctx, payload, expiry) + if err != nil { + t.Fatal(err) + } + + // Host 2 with low score should not receive IHAVE + timeoutCtx, cancel := context.WithTimeout(ctx, time.Millisecond*500) + defer cancel() + msg, err := sub2.Next(timeoutCtx) + if err != context.DeadlineExceeded { + if msg != nil { + t.Fatal("host with low score should not receive announcement") + } + t.Fatalf("expected timeout for low-score peer, got error: %v", err) + } +} diff --git a/gossipsub.go b/gossipsub.go index c492ded9..498ba442 100644 --- a/gossipsub.go +++ b/gossipsub.go @@ -62,6 +62,7 @@ var ( GossipSubHeartbeatInitialDelay = 100 * time.Millisecond GossipSubHeartbeatInterval = 1 * time.Second GossipSubFanoutTTL = 60 * time.Second + GossipSubAnnouncementMaxTTL = 60 * time.Second GossipSubPrunePeers = 16 GossipSubPruneBackoff = time.Minute GossipSubUnsubscribeBackoff = 10 * time.Second @@ -168,6 +169,11 @@ type GossipSubParams struct { // we'll delete the fanout map for that topic. FanoutTTL time.Duration + // AnnouncementMaxTTL is the maximum possible time-to-live for a message announced + // via Announce. This is used to size internal data structures. Deadlines passed to + // Announce exceeding this value will be clamped, and a warning will be logged. + AnnouncementMaxTTL time.Duration + // PrunePeers controls the number of peers to include in prune Peer eXchange. // When we prune a peer that's eligible for PX (has a good score, etc), we will try to // send them signed peer records for up to PrunePeers other peers that we @@ -292,6 +298,7 @@ func NewGossipSubWithRouter(ctx context.Context, h host.Host, rt PubSubRouter, o // DefaultGossipSubRouter returns a new GossipSubRouter with default parameters. func DefaultGossipSubRouter(h host.Host) *GossipSubRouter { params := DefaultGossipSubParams() + mcache := NewMessageCache(params.HistoryGossip, params.HistoryLength, params.HeartbeatInterval, params.AnnouncementMaxTTL) rt := &GossipSubRouter{ peers: make(map[peer.ID]protocol.ID), mesh: make(map[string]map[peer.ID]struct{}), @@ -307,7 +314,7 @@ func DefaultGossipSubRouter(h host.Host) *GossipSubRouter { outbound: make(map[peer.ID]bool), connect: make(chan connectInfo, params.MaxPendingConnections), cab: pstoremem.NewAddrBook(), - mcache: NewMessageCache(params.HistoryGossip, params.HistoryLength), + mcache: mcache, protos: GossipSubDefaultProtocols, feature: GossipSubDefaultFeatures, tagTracer: newTagTracer(h.ConnManager()), @@ -341,6 +348,7 @@ func DefaultGossipSubParams() GossipSubParams { HeartbeatInitialDelay: GossipSubHeartbeatInitialDelay, HeartbeatInterval: GossipSubHeartbeatInterval, FanoutTTL: GossipSubFanoutTTL, + AnnouncementMaxTTL: GossipSubAnnouncementMaxTTL, PrunePeers: GossipSubPrunePeers, PruneBackoff: GossipSubPruneBackoff, UnsubscribeBackoff: GossipSubUnsubscribeBackoff, @@ -569,7 +577,7 @@ func WithGossipSubParams(cfg GossipSubParams) Option { // Overwrite current config and associated variables in the router. gs.params = cfg gs.connect = make(chan connectInfo, cfg.MaxPendingConnections) - gs.mcache = NewMessageCache(cfg.HistoryGossip, cfg.HistoryLength) + gs.mcache = NewMessageCache(cfg.HistoryGossip, cfg.HistoryLength, cfg.HeartbeatInterval, cfg.AnnouncementMaxTTL) return nil } @@ -1303,7 +1311,7 @@ func (gs *GossipSubRouter) Publish(msg *Message) { func (gs *GossipSubRouter) rpcs(msg *Message) iter.Seq2[peer.ID, *RPC] { return func(yield func(peer.ID, *RPC) bool) { - gs.mcache.Put(msg) + gs.mcache.AppendWindow(msg) from := msg.ReceivedFrom topic := msg.GetTopic() @@ -1595,6 +1603,9 @@ func (gs *GossipSubRouter) heartbeat() { // clean up IDONTWANT counters gs.clearIDontWantCounters() + // clean up expired announcements + gs.purgeAnnouncements() + // apply IWANT request penalties gs.applyIwantPenalties() @@ -1832,7 +1843,7 @@ func (gs *GossipSubRouter) heartbeat() { gs.flush() // advance the message history window - gs.mcache.Shift() + gs.mcache.ShiftWindow() } func (gs *GossipSubRouter) clearIHaveCounters() { @@ -1864,6 +1875,10 @@ func (gs *GossipSubRouter) clearIDontWantCounters() { } } +func (gs *GossipSubRouter) purgeAnnouncements() { + gs.mcache.PruneAnns() +} + func (gs *GossipSubRouter) applyIwantPenalties() { for p, count := range gs.gossipTracer.GetBrokenPromises() { gs.logger.Info("peer didn't follow up in IWANT requests; adding penalty", "peer", p, "requestCount", count) @@ -1956,7 +1971,7 @@ func (gs *GossipSubRouter) sendGraftPrune(tograft, toprune map[peer.ID][]string, // emitGossip emits IHAVE gossip advertising items in the message cache window // of this topic. func (gs *GossipSubRouter) emitGossip(topic string, exclude map[peer.ID]struct{}) { - mids := gs.mcache.GetGossipIDs(topic) + mids := gs.mcache.GossipForTopic(topic) if len(mids) == 0 { return } @@ -2033,6 +2048,44 @@ func (gs *GossipSubRouter) enqueueGossip(p peer.ID, ihave *pb.ControlIHave) { gs.gossip[p] = gossip } +func (gs *GossipSubRouter) announceMessage(topic string, msg *Message, expiry time.Time) { + // Get all peers in topic + tmap, ok := gs.p.topics[topic] + if !ok { + return + } + + // Store message for IWANT retrieval in the message cache + msgID := gs.p.idGen.ID(msg) + + // Send IHAVE to all topic peers (excluding direct peers, applying score threshold) + // Match the filtering logic from emitGossip + var gossipQueued bool + for p := range tmap { + if !gs.feature(GossipSubFeatureMesh, gs.peers[p]) { + continue + } + if gs.score.Score(p) < gs.gossipThreshold { + continue + } + gs.enqueueGossip(p, &pb.ControlIHave{ + TopicID: &topic, + MessageIDs: []string{msgID}, + }) + gossipQueued = true + } + + if !gossipQueued { + return + } + + // Track announcement in message cache for IWANT retrieval + gs.mcache.TrackAnn(msg, expiry) + + // Flush gossip immediately + gs.flush() +} + func (gs *GossipSubRouter) piggybackGossip(p peer.ID, out *RPC, ihave []*pb.ControlIHave) { ctl := out.GetControl() if ctl == nil { diff --git a/mcache.go b/mcache.go index e4e82d90..12d9f3ba 100644 --- a/mcache.go +++ b/mcache.go @@ -2,69 +2,108 @@ package pubsub import ( "fmt" + "time" "github.com/libp2p/go-libp2p/core/peer" ) +type historyEntry struct { + mid string + topic string +} + +type messageRef struct { + *Message + refs int +} + +type MessageCache struct { + msgID func(*Message) string + + // All messages unified storage, indexed by message ID + // Messages can be in window, announcement wheel, or both + msgs map[string]*messageRef + + // Sliding window for all messages + history [][]historyEntry + gossipLen int + + // Time wheel for announcements with expiry-based cleanup + // Behaves like a circular buffer of time buckets containing message IDs + // Actual messages are stored in the unified storage `msgs` + annWheel [][]string + annWheelPos int + annWheelTick time.Duration + + // Per-peer transmission counters + peertx map[string]map[peer.ID]int +} + // NewMessageCache creates a sliding window cache that remembers messages for as -// long as `history` slots. +// long as `historyLen` slots. // -// When queried for messages to advertise, the cache only returns messages in -// the last `gossip` slots. +// When queried for messages to advertise via gossip, the cache only returns messages +// in the last `gossipLen` slots. // -// The `gossip` parameter must be smaller or equal to `history`, or this +// The `gossipLen` parameter must be smaller or equal to `historyLen`, or this // function will panic. // -// The slack between `gossip` and `history` accounts for the reaction time +// The slack between `gossipLen` and `historyLen` accounts for the reaction time // between when a message is advertised via IHAVE gossip, and the peer pulls it // via an IWANT command. -func NewMessageCache(gossip, history int) *MessageCache { - if gossip > history { +func NewMessageCache(gossipLen, historyLen int, heartbeatInterval, maxTTL time.Duration) *MessageCache { + if gossipLen > historyLen { err := fmt.Errorf("invalid parameters for message cache; gossip slots (%d) cannot be larger than history slots (%d)", - gossip, history) + gossipLen, historyLen) panic(err) } + + wheelLen := ceilDivDuration(maxTTL, heartbeatInterval) + wheel := make([][]string, wheelLen) + return &MessageCache{ - msgs: make(map[string]*Message), - peertx: make(map[string]map[peer.ID]int), - history: make([][]CacheEntry, history), - gossip: gossip, + msgs: make(map[string]*messageRef), + peertx: make(map[string]map[peer.ID]int), + history: make([][]historyEntry, historyLen), + gossipLen: gossipLen, + annWheel: wheel, + annWheelPos: 0, + annWheelTick: heartbeatInterval, msgID: func(msg *Message) string { return DefaultMsgIdFn(msg.Message) }, } } -type MessageCache struct { - msgs map[string]*Message - peertx map[string]map[peer.ID]int - history [][]CacheEntry - gossip int - msgID func(*Message) string -} - func (mc *MessageCache) SetMsgIdFn(msgID func(*Message) string) { mc.msgID = msgID } -type CacheEntry struct { - mid string - topic string -} - -func (mc *MessageCache) Put(msg *Message) { - mid := mc.msgID(msg) - mc.msgs[mid] = msg - mc.history[0] = append(mc.history[0], CacheEntry{mid: mid, topic: msg.GetTopic()}) +// AppendWindow adds a message to the sliding window cache. +// The message will be retained for the duration of the window. +// If the message already exists in the cache, its reference count is incremented. +func (mc *MessageCache) AppendWindow(msg *Message) { + mid := mc.upsertMessage(msg) + mc.history[0] = append(mc.history[0], historyEntry{mid: mid, topic: msg.GetTopic()}) } +// Get retrieves the message for the given message ID without modifying +// any transmission counts. +// It returns the message and a boolean indicating whether the message was found in the cache. func (mc *MessageCache) Get(mid string) (*Message, bool) { - m, ok := mc.msgs[mid] - return m, ok + ref, ok := mc.msgs[mid] + if !ok { + return nil, false + } + return ref.Message, true } +// GetForPeer retrieves the message for the given message ID and increments +// the transmission count for the specified peer. +// It returns the message, the updated transmission count, and a boolean indicating +// whether the message was found in the cache. func (mc *MessageCache) GetForPeer(mid string, p peer.ID) (*Message, int, bool) { - m, ok := mc.msgs[mid] + ref, ok := mc.msgs[mid] if !ok { return nil, 0, false } @@ -76,12 +115,13 @@ func (mc *MessageCache) GetForPeer(mid string, p peer.ID) (*Message, int, bool) } tx[p]++ - return m, tx[p], true + return ref.Message, tx[p], true } -func (mc *MessageCache) GetGossipIDs(topic string) []string { +// GossipForTopic returns the message IDs in the gossip window for the given topic. +func (mc *MessageCache) GossipForTopic(topic string) []string { var mids []string - for _, entries := range mc.history[:mc.gossip] { + for _, entries := range mc.history[:mc.gossipLen] { for _, entry := range entries { if entry.topic == topic { mids = append(mids, entry.mid) @@ -91,10 +131,13 @@ func (mc *MessageCache) GetGossipIDs(topic string) []string { return mids } -func (mc *MessageCache) Shift() { +// ShiftWindow advances the sliding window by one slot. +// Messages that fall out of the window have their reference counts decremented +// and are removed from the cache if they are no longer referenced. +func (mc *MessageCache) ShiftWindow() { last := mc.history[len(mc.history)-1] for _, entry := range last { - delete(mc.msgs, entry.mid) + mc.tryDropMessage(entry.mid) delete(mc.peertx, entry.mid) } for i := len(mc.history) - 2; i >= 0; i-- { @@ -102,3 +145,72 @@ func (mc *MessageCache) Shift() { } mc.history[0] = nil } + +// TrackAnn adds a message to the announcement cache with time-based expiry. +// Unlike AppendWindow, these messages are not part of the sliding window and expire at a specific time. +func (mc *MessageCache) TrackAnn(msg *Message, expiry time.Time) { + ttl := time.Until(expiry) + if ttl <= 0 { + return + } + + mid := mc.upsertMessage(msg) + + // Insert the message into the storage and the wheel + offset := ceilDivDuration(ttl, mc.annWheelTick) + bucket := (mc.annWheelPos + offset) % len(mc.annWheel) + mc.annWheel[bucket] = append(mc.annWheel[bucket], mid) +} + +// PruneAnns removes expired announcements from the cache. +// This should be called periodically (e.g., during heartbeat). +// Advances the time wheel by one tick and cleans up the current bucket. +func (mc *MessageCache) PruneAnns() { + bucket := mc.annWheel[mc.annWheelPos] + + // Drop all messages in the current bucket + for _, mid := range bucket { + mc.tryDropMessage(mid) + delete(mc.peertx, mid) + } + + // Clear the current bucket and advance the wheel position + mc.annWheel[mc.annWheelPos] = mc.annWheel[mc.annWheelPos][:0] + mc.annWheelPos = (mc.annWheelPos + 1) % len(mc.annWheel) +} + +// tryDropMessage decrements the reference count of the message with the given ID. +// If the reference count reaches zero, the message is removed from the cache. +// Returns true if the message was dropped, false otherwise. +func (mc *MessageCache) tryDropMessage(mid string) { + ref, ok := mc.msgs[mid] + if !ok { + return + } + if ref.refs--; ref.refs == 0 { + delete(mc.msgs, mid) + } +} + +func (mc *MessageCache) upsertMessage(msg *Message) string { + mid := mc.msgID(msg) + ref, exists := mc.msgs[mid] + if !exists { + ref = &messageRef{Message: msg} + mc.msgs[mid] = ref + } + ref.refs++ + return mid +} + +// ceilDivDuration performs ceiling division of two time.Duration values. +func ceilDivDuration(a, b time.Duration) int { + switch { + case b <= 0: + panic("b must be > 0") + case a <= 0: + return 0 + default: + return (int(a) + int(b) - 1) / int(b) + } +} diff --git a/mcache_test.go b/mcache_test.go index 93bcfdc6..41ef6e64 100644 --- a/mcache_test.go +++ b/mcache_test.go @@ -4,12 +4,13 @@ import ( "encoding/binary" "fmt" "testing" + "time" pb "github.com/libp2p/go-libp2p-pubsub/pb" ) func TestMessageCache(t *testing.T) { - mcache := NewMessageCache(3, 5) + mcache := NewMessageCache(3, 5, time.Second, 60*time.Second) // 3 gossip, 5 history, 1s heartbeat, 60s max TTL msgID := DefaultMsgIdFn msgs := make([]*pb.Message, 60) @@ -17,11 +18,11 @@ func TestMessageCache(t *testing.T) { msgs[i] = makeTestMessage(i) } - for i := 0; i < 10; i++ { - mcache.Put(&Message{Message: msgs[i]}) + for i := range 10 { + mcache.AppendWindow(&Message{Message: msgs[i]}) } - for i := 0; i < 10; i++ { + for i := range 10 { mid := msgID(msgs[i]) m, ok := mcache.Get(mid) if !ok { @@ -33,21 +34,21 @@ func TestMessageCache(t *testing.T) { } } - gids := mcache.GetGossipIDs("test") + gids := mcache.GossipForTopic("test") if len(gids) != 10 { t.Fatalf("Expected 10 gossip IDs; got %d", len(gids)) } - for i := 0; i < 10; i++ { + for i := range 10 { mid := msgID(msgs[i]) if mid != gids[i] { t.Fatalf("GossipID mismatch for message %d", i) } } - mcache.Shift() + mcache.ShiftWindow() for i := 10; i < 20; i++ { - mcache.Put(&Message{Message: msgs[i]}) + mcache.AppendWindow(&Message{Message: msgs[i]}) } for i := 0; i < 20; i++ { @@ -62,12 +63,12 @@ func TestMessageCache(t *testing.T) { } } - gids = mcache.GetGossipIDs("test") + gids = mcache.GossipForTopic("test") if len(gids) != 20 { t.Fatalf("Expected 20 gossip IDs; got %d", len(gids)) } - for i := 0; i < 10; i++ { + for i := range 10 { mid := msgID(msgs[i]) if mid != gids[10+i] { t.Fatalf("GossipID mismatch for message %d", i) @@ -81,31 +82,31 @@ func TestMessageCache(t *testing.T) { } } - mcache.Shift() + mcache.ShiftWindow() for i := 20; i < 30; i++ { - mcache.Put(&Message{Message: msgs[i]}) + mcache.AppendWindow(&Message{Message: msgs[i]}) } - mcache.Shift() + mcache.ShiftWindow() for i := 30; i < 40; i++ { - mcache.Put(&Message{Message: msgs[i]}) + mcache.AppendWindow(&Message{Message: msgs[i]}) } - mcache.Shift() + mcache.ShiftWindow() for i := 40; i < 50; i++ { - mcache.Put(&Message{Message: msgs[i]}) + mcache.AppendWindow(&Message{Message: msgs[i]}) } - mcache.Shift() + mcache.ShiftWindow() for i := 50; i < 60; i++ { - mcache.Put(&Message{Message: msgs[i]}) + mcache.AppendWindow(&Message{Message: msgs[i]}) } if len(mcache.msgs) != 50 { t.Fatalf("Expected 50 messages in the cache; got %d", len(mcache.msgs)) } - for i := 0; i < 10; i++ { + for i := range 10 { mid := msgID(msgs[i]) _, ok := mcache.Get(mid) if ok { @@ -125,12 +126,12 @@ func TestMessageCache(t *testing.T) { } } - gids = mcache.GetGossipIDs("test") + gids = mcache.GossipForTopic("test") if len(gids) != 30 { t.Fatalf("Expected 30 gossip IDs; got %d", len(gids)) } - for i := 0; i < 10; i++ { + for i := range 10 { mid := msgID(msgs[50+i]) if mid != gids[i] { t.Fatalf("GossipID mismatch for message %d", i) @@ -165,3 +166,97 @@ func makeTestMessage(n int) *pb.Message { Seqno: seqno, } } + +func TestAnnouncementTimeWheel(t *testing.T) { + // Create cache with 60 buckets for announcements (simulating 60 heartbeat intervals) + mcache := NewMessageCache(3, 5, time.Second, 60*time.Second) + msgID := DefaultMsgIdFn + + // Test basic insertion + msg1 := makeTestMessage(1) + expiry1 := time.Now().Add(5 * time.Second) + mcache.TrackAnn(&Message{Message: msg1}, expiry1) + + mid1 := msgID(msg1) + + // Verify message is in cache (announcements are stored in msgs) + if _, ok := mcache.Get(mid1); !ok { + t.Fatal("Message not in announcement cache") + } + + // Verify message can be retrieved + m, _, ok := mcache.GetForPeer(mid1, "peer1") + if !ok { + t.Fatal("Failed to retrieve announced message") + } + if m.Message != msg1 { + t.Fatal("Retrieved message doesn't match") + } + + // Test multiple messages with different expiries + msg2 := makeTestMessage(2) + msg3 := makeTestMessage(3) + expiry2 := time.Now().Add(10 * time.Second) + expiry3 := time.Now().Add(15 * time.Second) + + mcache.TrackAnn(&Message{Message: msg2}, expiry2) + mcache.TrackAnn(&Message{Message: msg3}, expiry3) + + mid2 := msgID(msg2) + mid3 := msgID(msg3) + + // Verify all messages are in cache + if _, ok := mcache.Get(mid1); !ok { + t.Fatal("Message 1 should be in cache") + } + if _, ok := mcache.Get(mid2); !ok { + t.Fatal("Message 2 should be in cache") + } + if _, ok := mcache.Get(mid3); !ok { + t.Fatal("Message 3 should be in cache") + } + + // Test wheel advancement (cleanup) + // Advance 6 ticks (6 seconds) - msg1 should be cleaned up + for i := 0; i < 6; i++ { + mcache.PruneAnns() + } + + // msg1 should be gone + if _, ok := mcache.Get(mid1); ok { + t.Fatal("Message 1 should have been cleaned up") + } + + // msg2 and msg3 should still exist + if _, ok := mcache.Get(mid2); !ok { + t.Fatal("Message 2 should still exist") + } + if _, ok := mcache.Get(mid3); !ok { + t.Fatal("Message 3 should still exist") + } + + // Test expired message insertion (shouldn't be added) + msg4 := makeTestMessage(4) + expiry4 := time.Now().Add(-1 * time.Second) // Already expired + mcache.TrackAnn(&Message{Message: msg4}, expiry4) + + mid4 := msgID(msg4) + if _, ok := mcache.Get(mid4); ok { + t.Fatal("Expired message should not have been added") + } + + // Test wraparound (TTL > wheel size) + msg5 := makeTestMessage(5) + expiry5 := time.Now().Add(70 * time.Second) // Exceeds 60s max + mcache.TrackAnn(&Message{Message: msg5}, expiry5) + + mid5 := msgID(msg5) + if _, ok := mcache.Get(mid5); !ok { + t.Fatal("Long TTL message should still be added (clamped to last bucket)") + } + + // Verify we still have at least 3 messages in cache (msg2, msg3, msg5) + if len(mcache.msgs) < 3 { + t.Fatalf("Expected at least 3 messages in cache, got %d", len(mcache.msgs)) + } +} diff --git a/topic.go b/topic.go index dd094eae..ee4ca117 100644 --- a/topic.go +++ b/topic.go @@ -261,6 +261,60 @@ func (t *Topic) AddToBatch(ctx context.Context, batch *MessageBatch, data []byte return nil } +// Announce sends IHAVE gossip for a message to all peers subscribed to the topic +// without publishing it through the mesh. The message is stored for IWANT retrieval +// until the expiry time. Works even if we're not subscribed to the topic - in that +// case, IHAVE is sent to all connected peers who are subscribed. If we are subscribed, +// the message is marked as seen to prevent duplicate processing. +func (t *Topic) Announce(ctx context.Context, data []byte, expiry time.Time, opts ...PubOpt) error { + t.mux.RLock() + defer t.mux.RUnlock() + + if t.closed { + return ErrTopicClosed + } + + // Validate and construct message (reuse existing validation logic) + msg, err := t.validate(ctx, data, opts...) + if err != nil { + if errors.Is(err, dupeErr{}) { + // If it was a duplicate, we return nil to indicate success. + // Semantically the message was published by us or someone else. + return nil + } + return err + } + + // Get GossipSubRouter + gs, ok := t.p.rt.(*GossipSubRouter) + if !ok { + return fmt.Errorf("announce only works with GossipSub router") + } + + // Execute in pubsub event loop + done := make(chan struct{}) + select { + case t.p.eval <- func() { + gs.announceMessage(t.topic, msg, expiry) + close(done) + }: + case <-t.p.ctx.Done(): + return t.p.ctx.Err() + case <-ctx.Done(): + return ctx.Err() + } + + // Wait for completion + select { + case <-done: + return nil + case <-t.p.ctx.Done(): + return t.p.ctx.Err() + case <-ctx.Done(): + return ctx.Err() + } +} + func (t *Topic) validate(ctx context.Context, data []byte, opts ...PubOpt) (*Message, error) { t.mux.RLock() defer t.mux.RUnlock()