diff --git a/reflex/.templates/web/utils/react-theme.js b/reflex/.templates/web/utils/react-theme.js index e0f48dc98b3..efb06cd0efb 100644 --- a/reflex/.templates/web/utils/react-theme.js +++ b/reflex/.templates/web/utils/react-theme.js @@ -18,9 +18,18 @@ const ThemeContext = createContext({ export function ThemeProvider({ children, defaultTheme = "system" }) { const [theme, setTheme] = useState(defaultTheme); - const [systemTheme, setSystemTheme] = useState( - defaultTheme !== "system" ? defaultTheme : "light", - ); + + // Detect system preference synchronously during initialization + const getInitialSystemTheme = () => { + if (defaultTheme !== "system") return defaultTheme; + if (typeof window === "undefined") return "light"; + return window.matchMedia("(prefers-color-scheme: dark)").matches + ? "dark" + : "light"; + }; + + const [systemTheme, setSystemTheme] = useState(getInitialSystemTheme); + const [isInitialized, setIsInitialized] = useState(false); const firstRender = useRef(true); @@ -43,6 +52,7 @@ export function ThemeProvider({ children, defaultTheme = "system" }) { // Load saved theme from localStorage const savedTheme = localStorage.getItem("theme") || defaultTheme; setTheme(savedTheme); + setIsInitialized(true); }); const resolvedTheme = useMemo( @@ -68,10 +78,12 @@ export function ThemeProvider({ children, defaultTheme = "system" }) { }; }); - // Save theme to localStorage whenever it changes + // Save theme to localStorage whenever it changes (but not on initial mount) useEffect(() => { - localStorage.setItem("theme", theme); - }, [theme]); + if (isInitialized) { + localStorage.setItem("theme", theme); + } + }, [theme, isInitialized]); useEffect(() => { const root = window.document.documentElement; diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index 6bbb0a2c79d..1a5693482d3 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -359,6 +359,8 @@ def create_document_root( Returns: The document root. """ + from reflex.utils.misc import preload_color_theme + existing_meta_types = set() for component in head_components or []: @@ -385,7 +387,11 @@ def create_document_root( Meta.create(name="viewport", content="width=device-width, initial-scale=1") ) + # Add theme preload script as the very first component to prevent FOUC + theme_preload_components = [preload_color_theme()] + head_components = [ + *theme_preload_components, *(head_components or []), *maybe_head_components, *always_head_components, diff --git a/reflex/utils/misc.py b/reflex/utils/misc.py index 421dd9c19a0..396b9037df2 100644 --- a/reflex/utils/misc.py +++ b/reflex/utils/misc.py @@ -90,3 +90,46 @@ def with_cwd_in_syspath(): yield finally: sys.path[:] = orig_sys_path + + +def preload_color_theme(): + """Create a script component that preloads the color theme to prevent FOUC. + + This script runs immediately in the document head before React hydration, + reading the saved theme from localStorage and applying the correct CSS classes + to prevent flash of unstyled content. + + Returns: + Script: A script component to add to App.head_components + """ + from reflex.components.el.elements.scripts import Script + + # Create direct inline script content (like next-themes dangerouslySetInnerHTML) + script_content = """ +// Only run in browser environment, not during SSR +if (typeof document !== 'undefined') { + try { + const theme = localStorage.getItem("theme") || "system"; + const systemPreference = window.matchMedia("(prefers-color-scheme: dark)").matches ? "dark" : "light"; + const resolvedTheme = theme === "system" ? systemPreference : theme; + + console.log("[PRELOAD] Theme applied:", resolvedTheme, "from theme:", theme, "system:", systemPreference); + + // Apply theme immediately - blocks until complete + // Use classList to avoid overwriting other classes + document.documentElement.classList.remove("light", "dark"); + document.documentElement.classList.add(resolvedTheme); + document.documentElement.style.colorScheme = resolvedTheme; + + } catch (e) { + // Fallback to system preference on any error (resolve "system" to actual theme) + const fallbackTheme = window.matchMedia("(prefers-color-scheme: dark)").matches ? "dark" : "light"; + console.log("[PRELOAD] Error, falling back to:", fallbackTheme); + document.documentElement.classList.remove("light", "dark"); + document.documentElement.classList.add(fallbackTheme); + document.documentElement.style.colorScheme = fallbackTheme; + } +} +""" + + return Script.create(script_content) diff --git a/tests/units/compiler/test_compiler.py b/tests/units/compiler/test_compiler.py index 4e4a618c4b8..18324e610a2 100644 --- a/tests/units/compiler/test_compiler.py +++ b/tests/units/compiler/test_compiler.py @@ -364,17 +364,17 @@ def test_create_document_root(): assert isinstance(lang, LiteralStringVar) assert lang.equals(Var.create("en")) # No children in head. - assert len(root.children[0].children) == 4 - assert isinstance(root.children[0].children[0], utils.Meta) - char_set = root.children[0].children[0].char_set # pyright: ignore [reportAttributeAccessIssue] + assert len(root.children[0].children) == 5 + assert isinstance(root.children[0].children[1], utils.Meta) + char_set = root.children[0].children[1].char_set # pyright: ignore [reportAttributeAccessIssue] assert isinstance(char_set, LiteralStringVar) assert char_set.equals(Var.create("utf-8")) - assert isinstance(root.children[0].children[1], utils.Meta) - name = root.children[0].children[1].name # pyright: ignore [reportAttributeAccessIssue] + assert isinstance(root.children[0].children[2], utils.Meta) + name = root.children[0].children[2].name # pyright: ignore [reportAttributeAccessIssue] assert isinstance(name, LiteralStringVar) assert name.equals(Var.create("viewport")) - assert isinstance(root.children[0].children[2], document.Meta) - assert isinstance(root.children[0].children[3], document.Links) + assert isinstance(root.children[0].children[3], document.Meta) + assert isinstance(root.children[0].children[4], document.Links) def test_create_document_root_with_scripts(): @@ -389,9 +389,9 @@ def test_create_document_root_with_scripts(): html_custom_attrs={"project": "reflex"}, ) assert isinstance(root, utils.Html) - assert len(root.children[0].children) == 6 + assert len(root.children[0].children) == 7 names = [c.tag for c in root.children[0].children] - assert names == ["Scripts", "Scripts", "meta", "meta", "Meta", "Links"] + assert names == ["script", "Scripts", "Scripts", "meta", "meta", "Meta", "Links"] lang = root.lang # pyright: ignore [reportAttributeAccessIssue] assert isinstance(lang, LiteralStringVar) assert lang.equals(Var.create("rx")) @@ -408,10 +408,10 @@ def test_create_document_root_with_meta_char_set(): head_components=comps, ) assert isinstance(root, utils.Html) - assert len(root.children[0].children) == 4 + assert len(root.children[0].children) == 5 names = [c.tag for c in root.children[0].children] - assert names == ["meta", "meta", "Meta", "Links"] - assert str(root.children[0].children[0].char_set) == '"cp1252"' # pyright: ignore [reportAttributeAccessIssue] + assert names == ["script", "meta", "meta", "Meta", "Links"] + assert str(root.children[0].children[1].char_set) == '"cp1252"' # pyright: ignore [reportAttributeAccessIssue] def test_create_document_root_with_meta_viewport(): @@ -424,10 +424,10 @@ def test_create_document_root_with_meta_viewport(): head_components=comps, ) assert isinstance(root, utils.Html) - assert len(root.children[0].children) == 5 + assert len(root.children[0].children) == 6 names = [c.tag for c in root.children[0].children] - assert names == ["meta", "meta", "meta", "Meta", "Links"] - assert str(root.children[0].children[0].http_equiv) == '"refresh"' # pyright: ignore [reportAttributeAccessIssue] - assert str(root.children[0].children[1].name) == '"viewport"' # pyright: ignore [reportAttributeAccessIssue] - assert str(root.children[0].children[1].content) == '"foo"' # pyright: ignore [reportAttributeAccessIssue] - assert str(root.children[0].children[2].char_set) == '"utf-8"' # pyright: ignore [reportAttributeAccessIssue] + assert names == ["script", "meta", "meta", "meta", "Meta", "Links"] + assert str(root.children[0].children[1].http_equiv) == '"refresh"' # pyright: ignore [reportAttributeAccessIssue] + assert str(root.children[0].children[2].name) == '"viewport"' # pyright: ignore [reportAttributeAccessIssue] + assert str(root.children[0].children[2].content) == '"foo"' # pyright: ignore [reportAttributeAccessIssue] + assert str(root.children[0].children[3].char_set) == '"utf-8"' # pyright: ignore [reportAttributeAccessIssue]