Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions reflex/.templates/web/utils/react-theme.js
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,18 @@ const ThemeContext = createContext({

export function ThemeProvider({ children, defaultTheme = "system" }) {
const [theme, setTheme] = useState(defaultTheme);
const [systemTheme, setSystemTheme] = useState(
defaultTheme !== "system" ? defaultTheme : "light",
);

// Detect system preference synchronously during initialization
const getInitialSystemTheme = () => {
if (defaultTheme !== "system") return defaultTheme;
if (typeof window === "undefined") return "light";
return window.matchMedia("(prefers-color-scheme: dark)").matches
? "dark"
: "light";
};

const [systemTheme, setSystemTheme] = useState(getInitialSystemTheme);
const [isInitialized, setIsInitialized] = useState(false);

const firstRender = useRef(true);

Expand All @@ -43,6 +52,7 @@ export function ThemeProvider({ children, defaultTheme = "system" }) {
// Load saved theme from localStorage
const savedTheme = localStorage.getItem("theme") || defaultTheme;
setTheme(savedTheme);
setIsInitialized(true);
});

const resolvedTheme = useMemo(
Expand All @@ -68,10 +78,12 @@ export function ThemeProvider({ children, defaultTheme = "system" }) {
};
});

// Save theme to localStorage whenever it changes
// Save theme to localStorage whenever it changes (but not on initial mount)
useEffect(() => {
localStorage.setItem("theme", theme);
}, [theme]);
if (isInitialized) {
localStorage.setItem("theme", theme);
}
}, [theme, isInitialized]);

useEffect(() => {
const root = window.document.documentElement;
Expand Down
6 changes: 6 additions & 0 deletions reflex/compiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ def create_document_root(
Returns:
The document root.
"""
from reflex.utils.misc import preload_color_theme

existing_meta_types = set()

for component in head_components or []:
Expand All @@ -385,7 +387,11 @@ def create_document_root(
Meta.create(name="viewport", content="width=device-width, initial-scale=1")
)

# Add theme preload script as the very first component to prevent FOUC
theme_preload_components = [preload_color_theme()]

head_components = [
*theme_preload_components,
*(head_components or []),
*maybe_head_components,
*always_head_components,
Expand Down
43 changes: 43 additions & 0 deletions reflex/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,46 @@ def with_cwd_in_syspath():
yield
finally:
sys.path[:] = orig_sys_path


def preload_color_theme():
"""Create a script component that preloads the color theme to prevent FOUC.

This script runs immediately in the document head before React hydration,
reading the saved theme from localStorage and applying the correct CSS classes
to prevent flash of unstyled content.

Returns:
Script: A script component to add to App.head_components
"""
from reflex.components.el.elements.scripts import Script

# Create direct inline script content (like next-themes dangerouslySetInnerHTML)
script_content = """
// Only run in browser environment, not during SSR
if (typeof document !== 'undefined') {
try {
const theme = localStorage.getItem("theme") || "system";
const systemPreference = window.matchMedia("(prefers-color-scheme: dark)").matches ? "dark" : "light";
const resolvedTheme = theme === "system" ? systemPreference : theme;

console.log("[PRELOAD] Theme applied:", resolvedTheme, "from theme:", theme, "system:", systemPreference);

// Apply theme immediately - blocks until complete
// Use classList to avoid overwriting other classes
document.documentElement.classList.remove("light", "dark");
document.documentElement.classList.add(resolvedTheme);
document.documentElement.style.colorScheme = resolvedTheme;

} catch (e) {
// Fallback to system preference on any error (resolve "system" to actual theme)
const fallbackTheme = window.matchMedia("(prefers-color-scheme: dark)").matches ? "dark" : "light";
console.log("[PRELOAD] Error, falling back to:", fallbackTheme);
document.documentElement.classList.remove("light", "dark");
document.documentElement.classList.add(fallbackTheme);
document.documentElement.style.colorScheme = fallbackTheme;
}
}
"""

return Script.create(script_content)
36 changes: 18 additions & 18 deletions tests/units/compiler/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,17 +364,17 @@ def test_create_document_root():
assert isinstance(lang, LiteralStringVar)
assert lang.equals(Var.create("en"))
# No children in head.
assert len(root.children[0].children) == 4
assert isinstance(root.children[0].children[0], utils.Meta)
char_set = root.children[0].children[0].char_set # pyright: ignore [reportAttributeAccessIssue]
assert len(root.children[0].children) == 5
assert isinstance(root.children[0].children[1], utils.Meta)
char_set = root.children[0].children[1].char_set # pyright: ignore [reportAttributeAccessIssue]
assert isinstance(char_set, LiteralStringVar)
assert char_set.equals(Var.create("utf-8"))
assert isinstance(root.children[0].children[1], utils.Meta)
name = root.children[0].children[1].name # pyright: ignore [reportAttributeAccessIssue]
assert isinstance(root.children[0].children[2], utils.Meta)
name = root.children[0].children[2].name # pyright: ignore [reportAttributeAccessIssue]
assert isinstance(name, LiteralStringVar)
assert name.equals(Var.create("viewport"))
assert isinstance(root.children[0].children[2], document.Meta)
assert isinstance(root.children[0].children[3], document.Links)
assert isinstance(root.children[0].children[3], document.Meta)
assert isinstance(root.children[0].children[4], document.Links)


def test_create_document_root_with_scripts():
Expand All @@ -389,9 +389,9 @@ def test_create_document_root_with_scripts():
html_custom_attrs={"project": "reflex"},
)
assert isinstance(root, utils.Html)
assert len(root.children[0].children) == 6
assert len(root.children[0].children) == 7
names = [c.tag for c in root.children[0].children]
assert names == ["Scripts", "Scripts", "meta", "meta", "Meta", "Links"]
assert names == ["script", "Scripts", "Scripts", "meta", "meta", "Meta", "Links"]
lang = root.lang # pyright: ignore [reportAttributeAccessIssue]
assert isinstance(lang, LiteralStringVar)
assert lang.equals(Var.create("rx"))
Expand All @@ -408,10 +408,10 @@ def test_create_document_root_with_meta_char_set():
head_components=comps,
)
assert isinstance(root, utils.Html)
assert len(root.children[0].children) == 4
assert len(root.children[0].children) == 5
names = [c.tag for c in root.children[0].children]
assert names == ["meta", "meta", "Meta", "Links"]
assert str(root.children[0].children[0].char_set) == '"cp1252"' # pyright: ignore [reportAttributeAccessIssue]
assert names == ["script", "meta", "meta", "Meta", "Links"]
assert str(root.children[0].children[1].char_set) == '"cp1252"' # pyright: ignore [reportAttributeAccessIssue]


def test_create_document_root_with_meta_viewport():
Expand All @@ -424,10 +424,10 @@ def test_create_document_root_with_meta_viewport():
head_components=comps,
)
assert isinstance(root, utils.Html)
assert len(root.children[0].children) == 5
assert len(root.children[0].children) == 6
names = [c.tag for c in root.children[0].children]
assert names == ["meta", "meta", "meta", "Meta", "Links"]
assert str(root.children[0].children[0].http_equiv) == '"refresh"' # pyright: ignore [reportAttributeAccessIssue]
assert str(root.children[0].children[1].name) == '"viewport"' # pyright: ignore [reportAttributeAccessIssue]
assert str(root.children[0].children[1].content) == '"foo"' # pyright: ignore [reportAttributeAccessIssue]
assert str(root.children[0].children[2].char_set) == '"utf-8"' # pyright: ignore [reportAttributeAccessIssue]
assert names == ["script", "meta", "meta", "meta", "Meta", "Links"]
assert str(root.children[0].children[1].http_equiv) == '"refresh"' # pyright: ignore [reportAttributeAccessIssue]
assert str(root.children[0].children[2].name) == '"viewport"' # pyright: ignore [reportAttributeAccessIssue]
assert str(root.children[0].children[2].content) == '"foo"' # pyright: ignore [reportAttributeAccessIssue]
assert str(root.children[0].children[3].char_set) == '"utf-8"' # pyright: ignore [reportAttributeAccessIssue]