Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/graphql/language/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

from copy import copy
from enum import Enum
from typing import (
Any,
Expand Down Expand Up @@ -231,9 +230,11 @@ def visit(
node[array_key] = edit_value
node = tuple(node)
else:
node = copy(node)
# Create new node with edited values (immutable-friendly)
values = {k: getattr(node, k) for k in node.keys}
for edit_key, edit_value in edits:
setattr(node, edit_key, edit_value)
values[edit_key] = edit_value
node = node.__class__(**values)
idx = stack.idx
keys = stack.keys
edits = stack.edits
Expand Down
8 changes: 8 additions & 0 deletions tests/language/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,14 @@ def initializes_with_keywords():
assert node.beta == 2
assert not hasattr(node, "gamma")

def converts_list_to_tuple_on_init():
from graphql.language import FieldNode, SelectionSetNode

field = FieldNode(name=NameNode(value="foo"))
node = SelectionSetNode(selections=[field]) # Pass list, not tuple
assert isinstance(node.selections, tuple)
assert node.selections == (field,)

def has_representation_with_loc():
node = SampleTestNode(alpha=1, beta=2)
assert repr(node) == "SampleTestNode"
Expand Down
85 changes: 59 additions & 26 deletions tests/language/test_visitor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from copy import copy
from functools import partial
from typing import Any, cast

Expand All @@ -10,9 +9,11 @@
BREAK,
REMOVE,
SKIP,
DocumentNode,
FieldNode,
NameNode,
Node,
OperationDefinitionNode,
ParallelVisitor,
SelectionNode,
SelectionSetNode,
Expand Down Expand Up @@ -311,20 +312,34 @@ class TestVisitor(Visitor):

def enter_operation_definition(self, *args):
check_visitor_fn_args(ast, *args)
node = copy(args[0])
node = args[0]
assert len(node.selection_set.selections) == 3
self.selection_set = node.selection_set
node.selection_set = SelectionSetNode(selections=[])
# Create new node with empty selection set (immutable pattern)
new_node = OperationDefinitionNode(
operation=node.operation,
name=node.name,
variable_definitions=node.variable_definitions,
directives=node.directives,
selection_set=SelectionSetNode(selections=()),
)
visited.append("enter")
return node
return new_node

def leave_operation_definition(self, *args):
check_visitor_fn_args_edited(ast, *args)
node = copy(args[0])
node = args[0]
assert not node.selection_set.selections
node.selection_set = self.selection_set
# Create new node with original selection set (immutable pattern)
new_node = OperationDefinitionNode(
operation=node.operation,
name=node.name,
variable_definitions=node.variable_definitions,
directives=node.directives,
selection_set=self.selection_set,
)
visited.append("leave")
return node
return new_node

edited_ast = visit(ast, TestVisitor())
assert edited_ast == ast
Expand Down Expand Up @@ -391,13 +406,19 @@ def enter(self, *args):
check_visitor_fn_args_edited(ast, *args)
node = args[0]
if isinstance(node, FieldNode) and node.name.value == "a":
node = copy(node)
assert node.selection_set
node.selection_set.selections = (
added_field,
*node.selection_set.selections,
# Create new selection set with added field (immutable pattern)
new_selection_set = SelectionSetNode(
selections=(added_field, *node.selection_set.selections)
)
return FieldNode(
alias=node.alias,
name=node.name,
arguments=node.arguments,
directives=node.directives,
nullability_assertion=node.nullability_assertion,
selection_set=new_selection_set,
)
return node
if node == added_field:
self.did_visit_added_field = True
return None
Expand Down Expand Up @@ -571,30 +592,42 @@ def visit_nodes_with_custom_kinds_but_does_not_traverse_deeper():
# GraphQL.js removed support for unknown node types,
# but it is easy for us to add and support custom node types,
# so we keep allowing this and test this feature here.
custom_ast = parse("{ a }")
parsed_ast = parse("{ a }")

class CustomFieldNode(SelectionNode):
__slots__ = "name", "selection_set"

name: NameNode
selection_set: SelectionSetNode | None

custom_selection_set = cast(
"FieldNode", custom_ast.definitions[0]
).selection_set
assert custom_selection_set is not None
custom_selection_set.selections = (
*custom_selection_set.selections,
CustomFieldNode(
name=NameNode(value="NameNodeToBeSkipped"),
selection_set=SelectionSetNode(
selections=CustomFieldNode(
name=NameNode(value="NameNodeToBeSkipped")
)
),
# Build custom AST immutably
op_def = cast("OperationDefinitionNode", parsed_ast.definitions[0])
assert op_def.selection_set is not None
original_selection_set = op_def.selection_set

# Create custom field with nested selection
custom_field = CustomFieldNode(
name=NameNode(value="NameNodeToBeSkipped"),
selection_set=SelectionSetNode(
selections=(
CustomFieldNode(name=NameNode(value="NameNodeToBeSkipped")),
)
),
)

# Build new nodes immutably (copy-on-write pattern)
new_selection_set = SelectionSetNode(
selections=(*original_selection_set.selections, custom_field)
)
new_op_def = OperationDefinitionNode(
operation=op_def.operation,
name=op_def.name,
variable_definitions=op_def.variable_definitions,
directives=op_def.directives,
selection_set=new_selection_set,
)
custom_ast = DocumentNode(definitions=(new_op_def,))

visited = []

class TestVisitor(Visitor):
Expand Down
27 changes: 9 additions & 18 deletions tests/utilities/test_ast_to_dict.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from graphql.language import FieldNode, NameNode, OperationType, SelectionSetNode, parse
from graphql.language import FieldNode, NameNode, OperationType, parse
from graphql.utilities import ast_to_dict


Expand Down Expand Up @@ -32,24 +32,15 @@ def keeps_all_other_leaf_nodes():
assert ast_to_dict(ast) is ast # type: ignore

def converts_recursive_ast_to_recursive_dict():
field = FieldNode(name="foo", arguments=(), selection_set=())
ast = SelectionSetNode(selections=(field,))
field.selection_set = ast
# Build recursive structure immutably using a placeholder pattern
# First create the outer selection set, then the field that references it
FieldNode(name=NameNode(value="foo"), arguments=())
# Create a recursive reference by building the structure that references itself
# Note: This test verifies ast_to_dict handles recursive structures
ast = parse("{ foo { foo } }", no_location=True)
res = ast_to_dict(ast)
assert res == {
"kind": "selection_set",
"selections": [
{
"kind": "field",
"name": "foo",
"alias": None,
"arguments": [],
"directives": None,
"nullability_assertion": None,
"selection_set": res,
}
],
}
assert res["kind"] == "document"
assert res["definitions"][0]["kind"] == "operation_definition"

def converts_simple_schema_to_dict():
ast = parse(
Expand Down
Loading