@@ -646,6 +646,12 @@ async def rewind_async(
646646 session_id = session_id ,
647647 get_session_config = run_config .get_session_config ,
648648 )
649+ if not rewind_before_invocation_id :
650+ # Guard against matching the synthetic initial-state event that is
651+ # appended by `create_session`; that event has an empty invocation_id by
652+ # design and is not a valid rewind target.
653+ raise ValueError ('rewind_before_invocation_id must be non-empty.' )
654+
649655 rewind_event_index = - 1
650656 for i , event in enumerate (session .events ):
651657 if event .invocation_id == rewind_before_invocation_id :
@@ -686,16 +692,34 @@ async def _compute_state_delta_for_rewind(
686692 self , session : Session , rewind_event_index : int
687693 ) -> dict [str , Any ]:
688694 """Computes the state delta to reverse changes."""
695+ # State at the rewind point is reconstructed entirely from the event
696+ # stream. Session-scoped initial state from `create_session` is captured
697+ # as a synthetic event by `BaseSessionService._record_initial_state_event`,
698+ # so walking events naturally restores initial values even when a later
699+ # event overwrote them.
689700 state_at_rewind_point : dict [str , Any ] = {}
690- for i in range (rewind_event_index ):
691- if session .events [i ].actions .state_delta :
692- for k , v in session .events [i ].actions .state_delta .items ():
693- if k .startswith ('app:' ) or k .startswith ('user:' ):
694- continue
695- if v is None :
696- state_at_rewind_point .pop (k , None )
697- else :
698- state_at_rewind_point [k ] = v
701+ all_event_keys : set [str ] = set ()
702+
703+ for event in session .events [:rewind_event_index ]:
704+ if not event .actions .state_delta :
705+ continue
706+ for k , v in event .actions .state_delta .items ():
707+ if k .startswith ('app:' ) or k .startswith ('user:' ):
708+ continue
709+ all_event_keys .add (k )
710+ if v is None :
711+ state_at_rewind_point .pop (k , None )
712+ else :
713+ state_at_rewind_point [k ] = v
714+
715+ # Collect any other keys touched by events after the rewind point so we
716+ # know which keys were ever event-sourced.
717+ for event in session .events [rewind_event_index :]:
718+ if not event .actions .state_delta :
719+ continue
720+ for k in event .actions .state_delta :
721+ if not k .startswith ('app:' ) and not k .startswith ('user:' ):
722+ all_event_keys .add (k )
699723
700724 current_state = session .state
701725 rewind_state_delta = {}
@@ -706,12 +730,13 @@ async def _compute_state_delta_for_rewind(
706730 rewind_state_delta [key ] = value_at_rewind
707731
708732 # 2. Set keys to None in rewind_state_delta if they are in current_state
709- # but not in state_at_rewind_point. These keys were added after the
710- # rewind point and need to be removed.
733+ # but not in state_at_rewind_point. Only nullify keys that were
734+ # introduced or modified through events; keys set outside the event
735+ # stream are preserved.
711736 for key in current_state :
712737 if key .startswith ('app:' ) or key .startswith ('user:' ):
713738 continue
714- if key not in state_at_rewind_point :
739+ if key not in state_at_rewind_point and key in all_event_keys :
715740 rewind_state_delta [key ] = None
716741
717742 return rewind_state_delta
0 commit comments