diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index 22a176e93c3..6bbb0a2c79d 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -359,21 +359,37 @@ def create_document_root( Returns: The document root. """ - head_components = [ - *( - head_components - or [ - # Default meta tags if user does not provide. - Meta.create(char_set="utf-8"), - Meta.create( - name="viewport", content="width=device-width, initial-scale=1" - ), - ] - ), - # Always include the framework meta and link tags. + existing_meta_types = set() + + for component in head_components or []: + if isinstance(component, Meta): + if component.char_set is not None: # pyright: ignore[reportAttributeAccessIssue] + existing_meta_types.add("char_set") + if ( + (name := component.name) is not None # pyright: ignore[reportAttributeAccessIssue] + and name.equals(Var.create("viewport")) + ): + existing_meta_types.add("viewport") + + # Always include the framework meta and link tags. + always_head_components = [ ReactMeta.create(), Links.create(), ] + maybe_head_components = [] + # Only include these if the user has not specified them. + if "char_set" not in existing_meta_types: + maybe_head_components.append(Meta.create(char_set="utf-8")) + if "viewport" not in existing_meta_types: + maybe_head_components.append( + Meta.create(name="viewport", content="width=device-width, initial-scale=1") + ) + + head_components = [ + *(head_components or []), + *maybe_head_components, + *always_head_components, + ] return Html.create( Head.create(*head_components), Body.create( diff --git a/tests/units/compiler/test_compiler.py b/tests/units/compiler/test_compiler.py index 74f62c4297e..4e4a618c4b8 100644 --- a/tests/units/compiler/test_compiler.py +++ b/tests/units/compiler/test_compiler.py @@ -376,6 +376,8 @@ def test_create_document_root(): assert isinstance(root.children[0].children[2], document.Meta) assert isinstance(root.children[0].children[3], document.Links) + +def test_create_document_root_with_scripts(): # Test with components. comps = [ utils.Scripts.create(src="foo.js"), @@ -387,11 +389,45 @@ def test_create_document_root(): html_custom_attrs={"project": "reflex"}, ) assert isinstance(root, utils.Html) - assert len(root.children[0].children) == 4 + assert len(root.children[0].children) == 6 names = [c.tag for c in root.children[0].children] - assert names == ["Scripts", "Scripts", "Meta", "Links"] + assert names == ["Scripts", "Scripts", "meta", "meta", "Meta", "Links"] lang = root.lang # pyright: ignore [reportAttributeAccessIssue] assert isinstance(lang, LiteralStringVar) assert lang.equals(Var.create("rx")) assert isinstance(root.custom_attrs, dict) assert root.custom_attrs == {"project": "reflex"} + + +def test_create_document_root_with_meta_char_set(): + # Test with components. + comps = [ + utils.Meta.create(char_set="cp1252"), + ] + root = utils.create_document_root( + head_components=comps, + ) + assert isinstance(root, utils.Html) + assert len(root.children[0].children) == 4 + names = [c.tag for c in root.children[0].children] + assert names == ["meta", "meta", "Meta", "Links"] + assert str(root.children[0].children[0].char_set) == '"cp1252"' # pyright: ignore [reportAttributeAccessIssue] + + +def test_create_document_root_with_meta_viewport(): + # Test with components. + comps = [ + utils.Meta.create(http_equiv="refresh", content="5"), + utils.Meta.create(name="viewport", content="foo"), + ] + root = utils.create_document_root( + head_components=comps, + ) + assert isinstance(root, utils.Html) + assert len(root.children[0].children) == 5 + names = [c.tag for c in root.children[0].children] + assert names == ["meta", "meta", "meta", "Meta", "Links"] + assert str(root.children[0].children[0].http_equiv) == '"refresh"' # pyright: ignore [reportAttributeAccessIssue] + assert str(root.children[0].children[1].name) == '"viewport"' # pyright: ignore [reportAttributeAccessIssue] + assert str(root.children[0].children[1].content) == '"foo"' # pyright: ignore [reportAttributeAccessIssue] + assert str(root.children[0].children[2].char_set) == '"utf-8"' # pyright: ignore [reportAttributeAccessIssue]