From 83c2dd600ad2f11cd65381a5847b3c95ba9f00c2 Mon Sep 17 00:00:00 2001 From: Cory Dolphin Date: Tue, 6 Jan 2026 14:05:04 -0800 Subject: [PATCH] Use tuples for all AST collection fields (instead of lists) Prepares AST for immutability by using tuples instead of lists for collection fields. This aligns with the JavaScript GraphQL library which uses readonly arrays, and enables future frozen datastructures. --- docs/usage/parser.rst | 18 ++--- src/graphql/language/parser.py | 73 ++++++++++---------- src/graphql/utilities/ast_to_dict.py | 8 ++- src/graphql/utilities/concat_ast.py | 5 +- src/graphql/utilities/separate_operations.py | 4 +- src/graphql/utilities/sort_value_node.py | 21 +++--- tests/language/test_schema_parser.py | 50 +++++++------- tests/utilities/test_ast_from_value.py | 20 +++--- tests/utilities/test_build_ast_schema.py | 2 +- tests/utilities/test_type_info.py | 2 +- 10 files changed, 105 insertions(+), 98 deletions(-) diff --git a/docs/usage/parser.rst b/docs/usage/parser.rst index 049fd7b3..7902adf2 100644 --- a/docs/usage/parser.rst +++ b/docs/usage/parser.rst @@ -35,30 +35,30 @@ This will give the same result as manually creating the AST document:: from graphql.language.ast import * - document = DocumentNode(definitions=[ + document = DocumentNode(definitions=( ObjectTypeDefinitionNode( name=NameNode(value='Query'), - fields=[ + fields=( FieldDefinitionNode( name=NameNode(value='me'), type=NamedTypeNode(name=NameNode(value='User')), - arguments=[], directives=[]) - ], directives=[], interfaces=[]), + arguments=(), directives=()), + ), interfaces=(), directives=()), ObjectTypeDefinitionNode( name=NameNode(value='User'), - fields=[ + fields=( FieldDefinitionNode( name=NameNode(value='id'), type=NamedTypeNode( name=NameNode(value='ID')), - arguments=[], directives=[]), + arguments=(), directives=()), FieldDefinitionNode( name=NameNode(value='name'), type=NamedTypeNode( name=NameNode(value='String')), - arguments=[], directives=[]), - ], directives=[], interfaces=[]), - ]) + arguments=(), directives=()), + ), interfaces=(), directives=()), + )) When parsing with ``no_location=False`` (the default), the AST nodes will also have a diff --git a/src/graphql/language/parser.py b/src/graphql/language/parser.py index 4373cde3..78eb5ccc 100644 --- a/src/graphql/language/parser.py +++ b/src/graphql/language/parser.py @@ -3,7 +3,7 @@ from __future__ import annotations from functools import partial -from typing import Callable, List, Mapping, TypeVar, Union, cast +from typing import Callable, Mapping, TypeVar, Union, cast from ..error import GraphQLError, GraphQLSyntaxError from .ast import ( @@ -349,8 +349,8 @@ def parse_operation_definition(self) -> OperationDefinitionNode: return OperationDefinitionNode( operation=OperationType.QUERY, name=None, - variable_definitions=[], - directives=[], + variable_definitions=(), + directives=(), selection_set=self.parse_selection_set(), loc=self.loc(start), ) @@ -373,7 +373,7 @@ def parse_operation_type(self) -> OperationType: except ValueError as error: raise self.unexpected(operation_token) from error - def parse_variable_definitions(self) -> list[VariableDefinitionNode]: + def parse_variable_definitions(self) -> tuple[VariableDefinitionNode, ...]: """VariableDefinitions: (VariableDefinition+)""" return self.optional_many( TokenKind.PAREN_L, self.parse_variable_definition, TokenKind.PAREN_R @@ -468,7 +468,7 @@ def parse_nullability_assertion(self) -> NullabilityAssertionNode | None: return nullability_assertion - def parse_arguments(self, is_const: bool) -> list[ArgumentNode]: + def parse_arguments(self, is_const: bool) -> tuple[ArgumentNode, ...]: """Arguments[Const]: (Argument[?Const]+)""" item = self.parse_const_argument if is_const else self.parse_argument return self.optional_many( @@ -533,6 +533,7 @@ def parse_fragment_definition(self) -> FragmentDefinitionNode: ) return FragmentDefinitionNode( name=self.parse_fragment_name(), + variable_definitions=(), type_condition=self.parse_type_condition(), directives=self.parse_directives(False), selection_set=self.parse_selection_set(), @@ -646,16 +647,16 @@ def parse_const_value_literal(self) -> ConstValueNode: # Implement the parsing rules in the Directives section. - def parse_directives(self, is_const: bool) -> list[DirectiveNode]: + def parse_directives(self, is_const: bool) -> tuple[DirectiveNode, ...]: """Directives[Const]: Directive[?Const]+""" directives: list[DirectiveNode] = [] append = directives.append while self.peek(TokenKind.AT): append(self.parse_directive(is_const)) - return directives + return tuple(directives) - def parse_const_directives(self) -> list[ConstDirectiveNode]: - return cast("List[ConstDirectiveNode]", self.parse_directives(True)) + def parse_const_directives(self) -> tuple[ConstDirectiveNode, ...]: + return cast("tuple[ConstDirectiveNode, ...]", self.parse_directives(True)) def parse_directive(self, is_const: bool) -> DirectiveNode: """Directive[Const]: @ Name Arguments[?Const]?""" @@ -778,15 +779,15 @@ def parse_object_type_definition(self) -> ObjectTypeDefinitionNode: loc=self.loc(start), ) - def parse_implements_interfaces(self) -> list[NamedTypeNode]: + def parse_implements_interfaces(self) -> tuple[NamedTypeNode, ...]: """ImplementsInterfaces""" return ( self.delimited_many(TokenKind.AMP, self.parse_named_type) if self.expect_optional_keyword("implements") - else [] + else () ) - def parse_fields_definition(self) -> list[FieldDefinitionNode]: + def parse_fields_definition(self) -> tuple[FieldDefinitionNode, ...]: """FieldsDefinition: {FieldDefinition+}""" return self.optional_many( TokenKind.BRACE_L, self.parse_field_definition, TokenKind.BRACE_R @@ -810,7 +811,7 @@ def parse_field_definition(self) -> FieldDefinitionNode: loc=self.loc(start), ) - def parse_argument_defs(self) -> list[InputValueDefinitionNode]: + def parse_argument_defs(self) -> tuple[InputValueDefinitionNode, ...]: """ArgumentsDefinition: (InputValueDefinition+)""" return self.optional_many( TokenKind.PAREN_L, self.parse_input_value_def, TokenKind.PAREN_R @@ -872,12 +873,12 @@ def parse_union_type_definition(self) -> UnionTypeDefinitionNode: loc=self.loc(start), ) - def parse_union_member_types(self) -> list[NamedTypeNode]: + def parse_union_member_types(self) -> tuple[NamedTypeNode, ...]: """UnionMemberTypes""" return ( self.delimited_many(TokenKind.PIPE, self.parse_named_type) if self.expect_optional_token(TokenKind.EQUALS) - else [] + else () ) def parse_enum_type_definition(self) -> EnumTypeDefinitionNode: @@ -896,7 +897,7 @@ def parse_enum_type_definition(self) -> EnumTypeDefinitionNode: loc=self.loc(start), ) - def parse_enum_values_definition(self) -> list[EnumValueDefinitionNode]: + def parse_enum_values_definition(self) -> tuple[EnumValueDefinitionNode, ...]: """EnumValuesDefinition: {EnumValueDefinition+}""" return self.optional_many( TokenKind.BRACE_L, self.parse_enum_value_definition, TokenKind.BRACE_R @@ -942,7 +943,7 @@ def parse_input_object_type_definition(self) -> InputObjectTypeDefinitionNode: loc=self.loc(start), ) - def parse_input_fields_definition(self) -> list[InputValueDefinitionNode]: + def parse_input_fields_definition(self) -> tuple[InputValueDefinitionNode, ...]: """InputFieldsDefinition: {InputValueDefinition+}""" return self.optional_many( TokenKind.BRACE_L, self.parse_input_value_def, TokenKind.BRACE_R @@ -1076,7 +1077,7 @@ def parse_directive_definition(self) -> DirectiveDefinitionNode: loc=self.loc(start), ) - def parse_directive_locations(self) -> list[NameNode]: + def parse_directive_locations(self) -> tuple[NameNode, ...]: """DirectiveLocations""" return self.delimited_many(TokenKind.PIPE, self.parse_directive_location) @@ -1173,11 +1174,11 @@ def unexpected(self, at_token: Token | None = None) -> GraphQLError: def any( self, open_kind: TokenKind, parse_fn: Callable[[], T], close_kind: TokenKind - ) -> list[T]: + ) -> tuple[T, ...]: """Fetch any matching nodes, possibly none. - Returns a possibly empty list of parse nodes, determined by the ``parse_fn``. - This list begins with a lex token of ``open_kind`` and ends with a lex token of + Returns a possibly empty tuple of parse nodes, determined by the ``parse_fn``. + This tuple begins with a lex token of ``open_kind`` and ends with a lex token of ``close_kind``. Advances the parser to the next lex token after the closing token. """ @@ -1187,16 +1188,16 @@ def any( expect_optional_token = partial(self.expect_optional_token, close_kind) while not expect_optional_token(): append(parse_fn()) - return nodes + return tuple(nodes) def optional_many( self, open_kind: TokenKind, parse_fn: Callable[[], T], close_kind: TokenKind - ) -> list[T]: + ) -> tuple[T, ...]: """Fetch matching nodes, maybe none. - Returns a list of parse nodes, determined by the ``parse_fn``. It can be empty + Returns a tuple of parse nodes, determined by the ``parse_fn``. It can be empty only if the open token is missing, otherwise it will always return a non-empty - list that begins with a lex token of ``open_kind`` and ends with a lex token of + tuple that begins with a lex token of ``open_kind`` and ends with a lex token of ``close_kind``. Advances the parser to the next lex token after the closing token. """ @@ -1206,16 +1207,16 @@ def optional_many( expect_optional_token = partial(self.expect_optional_token, close_kind) while not expect_optional_token(): append(parse_fn()) - return nodes - return [] + return tuple(nodes) + return () def many( self, open_kind: TokenKind, parse_fn: Callable[[], T], close_kind: TokenKind - ) -> list[T]: + ) -> tuple[T, ...]: """Fetch matching nodes, at least one. - Returns a non-empty list of parse nodes, determined by the ``parse_fn``. This - list begins with a lex token of ``open_kind`` and ends with a lex token of + Returns a non-empty tuple of parse nodes, determined by the ``parse_fn``. This + tuple begins with a lex token of ``open_kind`` and ends with a lex token of ``close_kind``. Advances the parser to the next lex token after the closing token. """ @@ -1225,17 +1226,17 @@ def many( expect_optional_token = partial(self.expect_optional_token, close_kind) while not expect_optional_token(): append(parse_fn()) - return nodes + return tuple(nodes) def delimited_many( self, delimiter_kind: TokenKind, parse_fn: Callable[[], T] - ) -> list[T]: + ) -> tuple[T, ...]: """Fetch many delimited nodes. - Returns a non-empty list of parse nodes, determined by the ``parse_fn``. This - list may begin with a lex token of ``delimiter_kind`` followed by items + Returns a non-empty tuple of parse nodes, determined by the ``parse_fn``. This + tuple may begin with a lex token of ``delimiter_kind`` followed by items separated by lex tokens of ``delimiter_kind``. Advances the parser to the next - lex token after the last item in the list. + lex token after the last item in the tuple. """ expect_optional_token = partial(self.expect_optional_token, delimiter_kind) expect_optional_token() @@ -1245,7 +1246,7 @@ def delimited_many( append(parse_fn()) if not expect_optional_token(): break - return nodes + return tuple(nodes) def advance_lexer(self) -> None: """Advance the lexer.""" diff --git a/src/graphql/utilities/ast_to_dict.py b/src/graphql/utilities/ast_to_dict.py index 10f13c15..c276868d 100644 --- a/src/graphql/utilities/ast_to_dict.py +++ b/src/graphql/utilities/ast_to_dict.py @@ -45,14 +45,18 @@ def ast_to_dict( elif node in cache: return cache[node] cache[node] = res = {} + # Note: We don't use msgspec.structs.asdict() because loc needs special + # handling (converted to {start, end} dict rather than full Location object) + # Filter out 'loc' - it's handled separately for the locations option + fields = [f for f in node.keys if f != "loc"] res.update( { key: ast_to_dict(getattr(node, key), locations, cache) - for key in ("kind", *node.keys[1:]) + for key in ("kind", *fields) } ) if locations: - loc = node.loc + loc = getattr(node, "loc", None) if loc: res["loc"] = {"start": loc.start, "end": loc.end} return res diff --git a/src/graphql/utilities/concat_ast.py b/src/graphql/utilities/concat_ast.py index 806292f9..6a2398c3 100644 --- a/src/graphql/utilities/concat_ast.py +++ b/src/graphql/utilities/concat_ast.py @@ -17,6 +17,5 @@ def concat_ast(asts: Collection[DocumentNode]) -> DocumentNode: the ASTs together into batched AST, useful for validating many GraphQL source files which together represent one conceptual application. """ - return DocumentNode( - definitions=list(chain.from_iterable(document.definitions for document in asts)) - ) + all_definitions = chain.from_iterable(doc.definitions for doc in asts) + return DocumentNode(definitions=tuple(all_definitions)) diff --git a/src/graphql/utilities/separate_operations.py b/src/graphql/utilities/separate_operations.py index 53867662..45589404 100644 --- a/src/graphql/utilities/separate_operations.py +++ b/src/graphql/utilities/separate_operations.py @@ -60,7 +60,7 @@ def separate_operations(document_ast: DocumentNode) -> dict[str, DocumentNode]: # The list of definition nodes to be included for this operation, sorted # to retain the same order as the original document. separated_document_asts[operation_name] = DocumentNode( - definitions=[ + definitions=tuple( node for node in document_ast.definitions if node is operation @@ -68,7 +68,7 @@ def separate_operations(document_ast: DocumentNode) -> dict[str, DocumentNode]: isinstance(node, FragmentDefinitionNode) and node.name.value in dependencies ) - ] + ) ) return separated_document_asts diff --git a/src/graphql/utilities/sort_value_node.py b/src/graphql/utilities/sort_value_node.py index bf20cf37..970978ee 100644 --- a/src/graphql/utilities/sort_value_node.py +++ b/src/graphql/utilities/sort_value_node.py @@ -2,8 +2,6 @@ from __future__ import annotations -from copy import copy - from ..language import ListValueNode, ObjectFieldNode, ObjectValueNode, ValueNode from ..pyutils import natural_comparison_key @@ -18,18 +16,23 @@ def sort_value_node(value_node: ValueNode) -> ValueNode: For internal use only. """ if isinstance(value_node, ObjectValueNode): - value_node = copy(value_node) - value_node.fields = sort_fields(value_node.fields) + # Create new node with updated fields (immutable-friendly copy-on-write) + values = {k: getattr(value_node, k) for k in value_node.keys} + values["fields"] = sort_fields(value_node.fields) + value_node = value_node.__class__(**values) elif isinstance(value_node, ListValueNode): - value_node = copy(value_node) - value_node.values = tuple(sort_value_node(value) for value in value_node.values) + # Create new node with updated values (immutable-friendly copy-on-write) + values = {k: getattr(value_node, k) for k in value_node.keys} + values["values"] = tuple(sort_value_node(value) for value in value_node.values) + value_node = value_node.__class__(**values) return value_node def sort_field(field: ObjectFieldNode) -> ObjectFieldNode: - field = copy(field) - field.value = sort_value_node(field.value) - return field + # Create new node with updated value (immutable-friendly copy-on-write) + values = {k: getattr(field, k) for k in field.keys} + values["value"] = sort_value_node(field.value) + return field.__class__(**values) def sort_fields(fields: tuple[ObjectFieldNode, ...]) -> tuple[ObjectFieldNode, ...]: diff --git a/tests/language/test_schema_parser.py b/tests/language/test_schema_parser.py index df64381a..3a0e6301 100644 --- a/tests/language/test_schema_parser.py +++ b/tests/language/test_schema_parser.py @@ -78,12 +78,12 @@ def name_node(name: str, loc: Location): def field_node(name: NameNode, type_: TypeNode, loc: Location): - return field_node_with_args(name, type_, [], loc) + return field_node_with_args(name, type_, (), loc) -def field_node_with_args(name: NameNode, type_: TypeNode, args: list, loc: Location): +def field_node_with_args(name: NameNode, type_: TypeNode, args: tuple, loc: Location): return FieldDefinitionNode( - name=name, arguments=args, type=type_, directives=[], loc=loc, description=None + name=name, arguments=args, type=type_, directives=(), loc=loc, description=None ) @@ -93,7 +93,7 @@ def non_null_type(type_: TypeNode, loc: Location): def enum_value_node(name: str, loc: Location): return EnumValueDefinitionNode( - name=name_node(name, loc), directives=[], loc=loc, description=None + name=name_node(name, loc), directives=(), loc=loc, description=None ) @@ -104,7 +104,7 @@ def input_value_node( name=name, type=type_, default_value=default_value, - directives=[], + directives=(), loc=loc, description=None, ) @@ -123,8 +123,8 @@ def list_type_node(type_: TypeNode, loc: Location): def schema_extension_node( - directives: list[DirectiveNode], - operation_types: list[OperationTypeDefinitionNode], + directives: tuple[DirectiveNode, ...], + operation_types: tuple[OperationTypeDefinitionNode, ...], loc: Location, ): return SchemaExtensionNode( @@ -136,7 +136,7 @@ def operation_type_definition(operation: OperationType, type_: TypeNode, loc: Lo return OperationTypeDefinitionNode(operation=operation, type=type_, loc=loc) -def directive_node(name: NameNode, arguments: list[ArgumentNode], loc: Location): +def directive_node(name: NameNode, arguments: tuple[ArgumentNode, ...], loc: Location): return DirectiveNode(name=name, arguments=arguments, loc=loc) @@ -351,14 +351,14 @@ def schema_extension(): assert doc.loc == (0, 75) assert doc.definitions == ( schema_extension_node( - [], - [ + (), + ( operation_type_definition( OperationType.MUTATION, type_node("Mutation", (53, 61)), (43, 61), - ) - ], + ), + ), (13, 75), ), ) @@ -370,8 +370,8 @@ def schema_extension_with_only_directives(): assert doc.loc == (0, 24) assert doc.definitions == ( schema_extension_node( - [directive_node(name_node("directive", (15, 24)), [], (14, 24))], - [], + (directive_node(name_node("directive", (15, 24)), (), (14, 24)),), + (), (0, 24), ), ) @@ -571,14 +571,14 @@ def simple_field_with_arg(): field_node_with_args( name_node("world", (16, 21)), type_node("String", (38, 44)), - [ + ( input_value_node( name_node("flag", (22, 26)), type_node("Boolean", (28, 35)), None, (22, 35), - ) - ], + ), + ), (16, 44), ), ) @@ -602,14 +602,14 @@ def simple_field_with_arg_with_default_value(): field_node_with_args( name_node("world", (16, 21)), type_node("String", (45, 51)), - [ + ( input_value_node( name_node("flag", (22, 26)), type_node("Boolean", (28, 35)), boolean_value_node(True, (38, 42)), (22, 42), - ) - ], + ), + ), (16, 51), ), ) @@ -633,14 +633,14 @@ def simple_field_with_list_arg(): field_node_with_args( name_node("world", (16, 21)), type_node("String", (41, 47)), - [ + ( input_value_node( name_node("things", (22, 28)), list_type_node(type_node("String", (31, 37)), (30, 38)), None, (22, 38), - ) - ], + ), + ), (16, 47), ), ) @@ -664,7 +664,7 @@ def simple_field_with_two_args(): field_node_with_args( name_node("world", (16, 21)), type_node("String", (53, 59)), - [ + ( input_value_node( name_node("argOne", (22, 28)), type_node("Boolean", (30, 37)), @@ -677,7 +677,7 @@ def simple_field_with_two_args(): None, (39, 50), ), - ], + ), (16, 59), ), ) diff --git a/tests/utilities/test_ast_from_value.py b/tests/utilities/test_ast_from_value.py index 947f2b18..5af52924 100644 --- a/tests/utilities/test_ast_from_value.py +++ b/tests/utilities/test_ast_from_value.py @@ -204,13 +204,13 @@ def converts_list_values_to_list_asts(): assert ast_from_value( ["FOO", "BAR"], GraphQLList(GraphQLString) ) == ConstListValueNode( - values=[StringValueNode(value="FOO"), StringValueNode(value="BAR")] + values=(StringValueNode(value="FOO"), StringValueNode(value="BAR")) ) assert ast_from_value( ["HELLO", "GOODBYE"], GraphQLList(my_enum) ) == ConstListValueNode( - values=[EnumValueNode(value="HELLO"), EnumValueNode(value="GOODBYE")] + values=(EnumValueNode(value="HELLO"), EnumValueNode(value="GOODBYE")) ) def list_generator(): @@ -220,11 +220,11 @@ def list_generator(): assert ast_from_value(list_generator(), GraphQLList(GraphQLInt)) == ( ConstListValueNode( - values=[ + values=( IntValueNode(value="1"), IntValueNode(value="2"), IntValueNode(value="3"), - ] + ) ) ) @@ -239,7 +239,7 @@ def skips_invalid_list_items(): ) assert ast == ConstListValueNode( - values=[StringValueNode(value="FOO"), StringValueNode(value="BAR")] + values=(StringValueNode(value="FOO"), StringValueNode(value="BAR")) ) input_obj = GraphQLInputObjectType( @@ -251,21 +251,21 @@ def converts_input_objects(): assert ast_from_value( {"foo": 3, "bar": "HELLO"}, input_obj ) == ConstObjectValueNode( - fields=[ + fields=( ConstObjectFieldNode( name=NameNode(value="foo"), value=FloatValueNode(value="3") ), ConstObjectFieldNode( name=NameNode(value="bar"), value=EnumValueNode(value="HELLO") ), - ] + ) ) def converts_input_objects_with_explicit_nulls(): assert ast_from_value({"foo": None}, input_obj) == ConstObjectValueNode( - fields=[ - ConstObjectFieldNode(name=NameNode(value="foo"), value=NullValueNode()) - ] + fields=( + ConstObjectFieldNode(name=NameNode(value="foo"), value=NullValueNode()), + ) ) def does_not_convert_non_object_values_as_input_objects(): diff --git a/tests/utilities/test_build_ast_schema.py b/tests/utilities/test_build_ast_schema.py index 12e16f8f..63e1614f 100644 --- a/tests/utilities/test_build_ast_schema.py +++ b/tests/utilities/test_build_ast_schema.py @@ -133,7 +133,7 @@ def ignores_non_type_system_definitions(): def match_order_of_default_types_and_directives(): schema = GraphQLSchema() - sdl_schema = build_ast_schema(DocumentNode(definitions=[])) + sdl_schema = build_ast_schema(DocumentNode(definitions=())) assert sdl_schema.directives == schema.directives assert sdl_schema.type_map == schema.type_map diff --git a/tests/utilities/test_type_info.py b/tests/utilities/test_type_info.py index 01f7e464..031a2b0f 100644 --- a/tests/utilities/test_type_info.py +++ b/tests/utilities/test_type_info.py @@ -346,7 +346,7 @@ def enter(*args): arguments=node.arguments, directives=node.directives, selection_set=SelectionSetNode( - selections=[FieldNode(name=NameNode(value="__typename"))] + selections=(FieldNode(name=NameNode(value="__typename")),) ), )