1+ import { useEffect } from 'react' ;
12import isVisible from './isVisible' ;
23
34type DisabledElement =
@@ -99,12 +100,50 @@ export function triggerFocus(
99100// ======================================================
100101// == Lock Focus ==
101102// ======================================================
103+ let lastFocusElement : HTMLElement | null = null ;
102104let focusElements : HTMLElement [ ] = [ ] ;
103105
104- function onWindowFocus ( e : FocusEvent ) {
105- const lastElement = focusElements [ focusElements . length - 1 ] ;
106+ function getLastElement ( ) {
107+ return focusElements [ focusElements . length - 1 ] ;
108+ }
109+
110+ function hasFocus ( element : HTMLElement ) {
111+ const { activeElement } = document ;
112+ return element === activeElement || element . contains ( activeElement ) ;
113+ }
114+
115+ function syncFocus ( ) {
116+ const lastElement = getLastElement ( ) ;
117+ const { activeElement } = document ;
118+
119+ if ( lastElement && ! hasFocus ( lastElement ) ) {
120+ const focusableList = getFocusNodeList ( lastElement ) ;
121+
122+ const matchElement = focusableList . includes ( lastFocusElement as HTMLElement )
123+ ? lastFocusElement
124+ : focusableList [ 0 ] ;
125+
126+ matchElement ?. focus ( ) ;
127+ } else {
128+ lastFocusElement = activeElement as HTMLElement ;
129+ }
130+ }
106131
107- console . log ( 'lock focus' , e . target , lastElement ) ;
132+ function onWindowKeyDown ( e : KeyboardEvent ) {
133+ if ( e . key === 'Tab' ) {
134+ const { activeElement } = document ;
135+ const lastElement = getLastElement ( ) ;
136+ const focusableList = getFocusNodeList ( lastElement ) ;
137+ const last = focusableList [ focusableList . length - 1 ] ;
138+
139+ if ( e . shiftKey && activeElement === focusableList [ 0 ] ) {
140+ // Tab backward on first focusable element
141+ lastFocusElement = last ;
142+ } else if ( ! e . shiftKey && activeElement === last ) {
143+ // Tab forward on last focusable element
144+ lastFocusElement = focusableList [ 0 ] ;
145+ }
146+ }
108147}
109148
110149/**
@@ -117,12 +156,24 @@ export function lockFocus(element: HTMLElement): VoidFunction {
117156 focusElements . push ( element ) ;
118157
119158 // Just add event since it will de-duplicate
120- window . addEventListener ( 'focusin' , onWindowFocus , true ) ;
159+ window . addEventListener ( 'focusin' , syncFocus ) ;
160+ window . addEventListener ( 'keydown' , onWindowKeyDown , true ) ;
161+ syncFocus ( ) ;
121162
122163 return ( ) => {
164+ lastFocusElement = null ;
123165 focusElements = focusElements . filter ( ele => ele !== element ) ;
124166 if ( focusElements . length === 0 ) {
125- window . removeEventListener ( 'focusin' , onWindowFocus , true ) ;
167+ window . removeEventListener ( 'focusin' , syncFocus ) ;
168+ window . removeEventListener ( 'keydown' , onWindowKeyDown , true ) ;
126169 }
127170 } ;
128171}
172+
173+ export function useFocusLock ( element : HTMLElement | null ) {
174+ useEffect ( ( ) => {
175+ if ( element ) {
176+ return lockFocus ( element ) ;
177+ }
178+ } , [ element ] ) ;
179+ }
0 commit comments