1- import { ClusterAdapter , MessageType } from "socket.io-adapter" ;
1+ import { BroadcastFlags , ClusterAdapter , MessageType } from "socket.io-adapter" ;
22import type {
33 ClusterAdapterOptions ,
44 ClusterMessage ,
@@ -15,7 +15,6 @@ import {
1515 SET ,
1616 SUBSCRIBE ,
1717 XADD ,
18- XRANGE ,
1918 XREAD ,
2019 hashCode ,
2120 duplicateClient ,
@@ -27,8 +26,6 @@ import {
2726
2827const debug = debugModule ( "socket.io-redis-streams-adapter" ) ;
2928
30- const RESTORE_SESSION_MAX_XRANGE_CALLS = 100 ;
31-
3229export interface RedisStreamsAdapterOptions {
3330 /**
3431 * The name of the Redis stream (or the prefix used when using multiple streams).
@@ -271,11 +268,19 @@ function isEphemeral(message: ClusterMessage) {
271268 ) ;
272269}
273270
271+ const RESTORABLE_MESSAGE_TYPES = new Set < MessageType > ( [
272+ MessageType . BROADCAST ,
273+ MessageType . SOCKETS_JOIN ,
274+ MessageType . SOCKETS_LEAVE ,
275+ MessageType . DISCONNECT_SOCKETS ,
276+ ] ) ;
277+
274278class RedisStreamsAdapter extends ClusterAdapter {
275279 readonly #redisClient: any ;
276280 readonly #opts: Required < RedisStreamsAdapterOptions > ;
277281 readonly #streamName: string ;
278282 readonly #publicChannel: string ;
283+ readonly #messageBuffer?: MessageBuffer ;
279284
280285 constructor (
281286 nsp : any ,
@@ -306,6 +311,12 @@ class RedisStreamsAdapter extends ClusterAdapter {
306311 }
307312 ) ;
308313 } ) ;
314+
315+ if ( nsp . server . opts . connectionStateRecovery !== undefined ) {
316+ const maxMessages =
317+ nsp . server . opts . connectionStateRecovery . maxMessages || 10_000 ;
318+ this . #messageBuffer = new MessageBuffer ( maxMessages ) ;
319+ }
309320 }
310321
311322 override doPublish ( message : ClusterMessage ) {
@@ -387,6 +398,13 @@ class RedisStreamsAdapter extends ClusterAdapter {
387398 return debug ( "invalid format: %s" , e . message ) ;
388399 }
389400
401+ if (
402+ this . nsp . server . opts . connectionStateRecovery !== undefined &&
403+ RESTORABLE_MESSAGE_TYPES . has ( message . type )
404+ ) {
405+ this . #messageBuffer. add ( message , offset ) ;
406+ }
407+
390408 this . onMessage ( message , offset ) ;
391409 }
392410
@@ -446,13 +464,8 @@ class RedisStreamsAdapter extends ClusterAdapter {
446464
447465 const sessionKey = this . #opts. sessionKeyPrefix + pid ;
448466
449- const results = await Promise . all ( [
450- GETDEL ( this . #redisClient, sessionKey ) ,
451- XRANGE ( this . #redisClient, this . #streamName, offset , offset ) ,
452- ] ) ;
453-
454- const rawSession = results [ 0 ] [ 0 ] ;
455- const offsetExists = results [ 1 ] [ 0 ] ;
467+ const [ rawSession ] = await GETDEL ( this . #redisClient, sessionKey ) ;
468+ const offsetExists = this . #messageBuffer. hasOffset ( offset ) ;
456469
457470 if ( ! rawSession || ! offsetExists ) {
458471 return Promise . reject ( "session or offset not found" ) ;
@@ -464,33 +477,30 @@ class RedisStreamsAdapter extends ClusterAdapter {
464477
465478 session . missedPackets = [ ] ;
466479
467- // FIXME we need to add an arbitrary limit here, because if entries are added faster than what we can consume, then
468- // we will loop endlessly. But if we stop before reaching the end of the stream, we might lose messages.
469- for ( let i = 0 ; i < RESTORE_SESSION_MAX_XRANGE_CALLS ; i ++ ) {
470- const entries = await XRANGE (
471- this . #redisClient,
472- this . #streamName,
473- RedisStreamsAdapter . nextOffset ( offset ) ,
474- "+"
475- ) ;
476-
477- if ( entries . length === 0 ) {
478- break ;
479- }
480-
481- for ( const entry of entries ) {
482- if ( entry . message . nsp === this . nsp . name && entry . message . type === "3" ) {
483- const message = RedisStreamsAdapter . decode ( entry . message ) as {
484- data : any ;
485- } ;
486- const { packet, opts } = message . data ;
487-
488- if ( shouldIncludePacket ( session . rooms , opts ) ) {
489- packet . data . push ( entry . id ) ;
490- session . missedPackets . push ( packet . data ) ;
491- }
480+ for ( const entry of this . #messageBuffer. getFromOffset ( offset ) ) {
481+ const { message } = entry ;
482+ const messageOffset = entry . offset ;
483+ if ( isSocketImpacted ( session . rooms , message . data . opts ) ) {
484+ switch ( message . type ) {
485+ case MessageType . BROADCAST :
486+ const packetData = message . data . packet . data ;
487+ packetData . push ( messageOffset ) ;
488+ session . missedPackets . push ( packetData ) ;
489+ break ;
490+ case MessageType . SOCKETS_JOIN :
491+ session . rooms . push ( ...message . data . rooms ) ;
492+ break ;
493+ case MessageType . SOCKETS_LEAVE :
494+ for ( const room of message . data . rooms ) {
495+ const i = session . rooms . indexOf ( room ) ;
496+ if ( i !== - 1 ) {
497+ session . rooms . splice ( i , 1 ) ;
498+ }
499+ }
500+ break ;
501+ case MessageType . DISCONNECT_SOCKETS :
502+ return Promise . reject ( "session was manually disconnected" ) ;
492503 }
493- offset = entry . id ;
494504 }
495505 }
496506
@@ -510,7 +520,13 @@ class RedisStreamsAdapter extends ClusterAdapter {
510520 }
511521}
512522
513- function shouldIncludePacket ( sessionRooms , opts ) {
523+ function isSocketImpacted (
524+ sessionRooms : string [ ] ,
525+ opts : {
526+ rooms : string [ ] ;
527+ except : string [ ] ;
528+ }
529+ ) {
514530 const included =
515531 opts . rooms . length === 0 ||
516532 sessionRooms . some ( ( room ) => opts . rooms . indexOf ( room ) !== - 1 ) ;
@@ -519,3 +535,89 @@ function shouldIncludePacket(sessionRooms, opts) {
519535 ) ;
520536 return included && notExcluded ;
521537}
538+
539+ type RestorableMessage =
540+ | {
541+ type : MessageType . BROADCAST ;
542+ data : {
543+ opts : {
544+ rooms : string [ ] ;
545+ except : string [ ] ;
546+ flags : BroadcastFlags ;
547+ } ;
548+ packet : {
549+ data : any [ ] ;
550+ } ;
551+ requestId ?: string ;
552+ } ;
553+ }
554+ | {
555+ type : MessageType . SOCKETS_JOIN | MessageType . SOCKETS_LEAVE ;
556+ data : {
557+ opts : {
558+ rooms : string [ ] ;
559+ except : string [ ] ;
560+ flags : BroadcastFlags ;
561+ } ;
562+ rooms : string [ ] ;
563+ } ;
564+ }
565+ | {
566+ type : MessageType . DISCONNECT_SOCKETS ;
567+ data : {
568+ opts : {
569+ rooms : string [ ] ;
570+ except : string [ ] ;
571+ flags : BroadcastFlags ;
572+ } ;
573+ close ?: boolean ;
574+ } ;
575+ } ;
576+
577+ class MessageBuffer {
578+ readonly #messages: Array < { offset : string ; message : RestorableMessage } > ;
579+ readonly #capacity: number ;
580+ readonly #offsetMap = new Map < string , number > ( ) ;
581+ #writeIndex: number = 0 ;
582+
583+ constructor ( capacity : number ) {
584+ this . #capacity = capacity ;
585+ this . #messages = new Array ( capacity ) ;
586+ }
587+
588+ add ( message : RestorableMessage , offset : string ) {
589+ const oldEntry = this . #messages[ this . #writeIndex] ;
590+
591+ if ( oldEntry ) {
592+ this . #offsetMap. delete ( oldEntry . offset ) ;
593+ }
594+
595+ this . #messages[ this . #writeIndex] = { offset, message } ;
596+ this . #offsetMap. set ( offset , this . #writeIndex) ;
597+ this . #writeIndex = this . #nextIndex( this . #writeIndex) ;
598+ }
599+
600+ #nextIndex( index : number ) {
601+ return ( index + 1 ) % this . #capacity;
602+ }
603+
604+ hasOffset ( offset : string ) {
605+ return this . #offsetMap. has ( offset ) ;
606+ }
607+
608+ * getFromOffset ( offset : string ) {
609+ const offsetIndex = this . #offsetMap. get ( offset ) ;
610+
611+ if ( offsetIndex === undefined ) {
612+ return ;
613+ }
614+
615+ for (
616+ let index = this . #nextIndex( offsetIndex ) ;
617+ index !== this . #writeIndex;
618+ index = this . #nextIndex( index )
619+ ) {
620+ yield this . #messages[ index ] ;
621+ }
622+ }
623+ }
0 commit comments