Skip to content

Commit d3534cc

Browse files
committed
fix: validate v2 sort and split creates
1 parent 0df2e39 commit d3534cc

2 files changed

Lines changed: 99 additions & 4 deletions

File tree

query/v2/query.go

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -385,10 +385,24 @@ func Desc(expression any) *cypher.SortItem {
385385
return Order(expression, SortDescending)
386386
}
387387

388+
func validateSortDirection(direction SortDirection) error {
389+
switch direction {
390+
case SortAscending, SortDescending:
391+
return nil
392+
default:
393+
return fmt.Errorf("unsupported sort direction: %d", direction)
394+
}
395+
}
396+
388397
func Order(expression any, direction SortDirection) *cypher.SortItem {
398+
expressionValue := expressionOrError(expression)
399+
if err := validateSortDirection(direction); err != nil {
400+
expressionValue = invalidExpression(err)
401+
}
402+
389403
return &cypher.SortItem{
390404
Ascending: direction != SortDescending,
391-
Expression: expressionOrError(expression),
405+
Expression: expressionValue,
392406
}
393407
}
394408

@@ -1097,25 +1111,57 @@ func isCreateNodeValue(value any, identifiers runtimeIdentifiers) bool {
10971111
return false
10981112
}
10991113

1114+
func isCreateRelationshipValue(value any) bool {
1115+
_, typeOK := value.(*cypher.RelationshipPattern)
1116+
return typeOK
1117+
}
1118+
11001119
func nextCreateValueIsNode(creates []any, idx int, identifiers runtimeIdentifiers) bool {
11011120
nextIdx := idx + 1
11021121
return nextIdx < len(creates) && isCreateNodeValue(creates[nextIdx], identifiers)
11031122
}
11041123

1124+
func newCreatePatternPart(createClause *cypher.Create) *cypher.PatternPart {
1125+
pattern := &cypher.PatternPart{}
1126+
createClause.Pattern = append(createClause.Pattern, pattern)
1127+
return pattern
1128+
}
1129+
1130+
func createPatternHasElements(pattern *cypher.PatternPart) bool {
1131+
return pattern != nil && len(pattern.PatternElements) > 0
1132+
}
1133+
1134+
func shouldStartNewCreatePattern(pattern *cypher.PatternPart, nextCreate any, patternClosed bool, identifiers runtimeIdentifiers) bool {
1135+
if !createPatternHasElements(pattern) {
1136+
return false
1137+
}
1138+
1139+
if isCreateNodeValue(nextCreate, identifiers) && patternEndsWithNodePattern(pattern) {
1140+
return true
1141+
}
1142+
1143+
return patternClosed && isCreateRelationshipValue(nextCreate)
1144+
}
1145+
11051146
func buildCreates(singlePartQuery *cypher.SinglePartQuery, identifiers runtimeIdentifiers, creates []any) error {
11061147
if len(creates) == 0 {
11071148
return nil
11081149
}
11091150

11101151
var (
1111-
pattern = &cypher.PatternPart{}
11121152
createClause = &cypher.Create{
1113-
Unique: false,
1114-
Pattern: []*cypher.PatternPart{pattern},
1153+
Unique: false,
11151154
}
1155+
pattern = newCreatePatternPart(createClause)
1156+
patternClosed bool
11161157
)
11171158

11181159
for idx, nextCreate := range creates {
1160+
if shouldStartNewCreatePattern(pattern, nextCreate, patternClosed, identifiers) {
1161+
pattern = newCreatePatternPart(createClause)
1162+
patternClosed = false
1163+
}
1164+
11191165
switch typedNextCreate := nextCreate.(type) {
11201166
case QualifiedExpression:
11211167
switch typedExpression := typedNextCreate.qualifier().(type) {
@@ -1129,6 +1175,7 @@ func buildCreates(singlePartQuery *cypher.SinglePartQuery, identifiers runtimeId
11291175
pattern.AddPatternElements(&cypher.NodePattern{
11301176
Variable: cypher.NewVariableWithSymbol(typedExpression.Symbol),
11311177
})
1178+
patternClosed = false
11321179

11331180
default:
11341181
return fmt.Errorf("invalid variable reference for create: %s", typedExpression.Symbol)
@@ -1144,6 +1191,7 @@ func buildCreates(singlePartQuery *cypher.SinglePartQuery, identifiers runtimeId
11441191
}
11451192

11461193
pattern.AddPatternElements(cypher.Copy(typedNextCreate))
1194+
patternClosed = false
11471195

11481196
case *cypher.RelationshipPattern:
11491197
if err := validateRelationshipPattern(typedNextCreate); err != nil {
@@ -1162,6 +1210,9 @@ func buildCreates(singlePartQuery *cypher.SinglePartQuery, identifiers runtimeId
11621210
pattern.AddPatternElements(&cypher.NodePattern{
11631211
Variable: identifiers.End(),
11641212
})
1213+
patternClosed = true
1214+
} else {
1215+
patternClosed = false
11651216
}
11661217

11671218
default:

query/v2/query_test.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,21 @@ func renderPrepared(t *testing.T, preparedQuery *v2.PreparedQuery) string {
1919
return cypherQueryStr
2020
}
2121

22+
func firstCreateClause(t *testing.T, preparedQuery *v2.PreparedQuery) *cypher.Create {
23+
t.Helper()
24+
25+
updatingClauses := preparedQuery.Query.SingleQuery.SinglePartQuery.UpdatingClauses
26+
require.NotEmpty(t, updatingClauses)
27+
28+
updatingClause, typeOK := updatingClauses[0].(*cypher.UpdatingClause)
29+
require.True(t, typeOK)
30+
31+
createClause, typeOK := updatingClause.Clause.(*cypher.Create)
32+
require.True(t, typeOK)
33+
34+
return createClause
35+
}
36+
2237
func TestQuery(t *testing.T) {
2338
preparedQuery, err := v2.New().Where(
2439
v2.Not(v2.Relationship().Kind().Is(graph.StringKind("test"))),
@@ -102,6 +117,28 @@ func TestCreateRelationshipWithExplicitEndpoints(t *testing.T) {
102117
}, preparedQuery.Parameters)
103118
}
104119

120+
func TestCreateSplitsDisjointNodePatterns(t *testing.T) {
121+
preparedQuery, err := v2.New().Create(
122+
v2.NodePattern(graph.Kinds{graph.StringKind("A")}, nil),
123+
v2.NodePattern(graph.Kinds{graph.StringKind("B")}, nil),
124+
).Build()
125+
require.NoError(t, err)
126+
127+
require.Equal(t, "create (n:A), (n:B)", renderPrepared(t, preparedQuery))
128+
require.Len(t, firstCreateClause(t, preparedQuery).Pattern, 2)
129+
}
130+
131+
func TestCreateSplitsBackToBackRelationshipPatterns(t *testing.T) {
132+
preparedQuery, err := v2.New().Create(
133+
v2.RelationshipPattern(graph.StringKind("A"), nil, graph.DirectionOutbound),
134+
v2.RelationshipPattern(graph.StringKind("B"), nil, graph.DirectionOutbound),
135+
).Build()
136+
require.NoError(t, err)
137+
138+
require.Equal(t, "create (s)-[r:A]->(e), (s)-[r:B]->(e)", renderPrepared(t, preparedQuery))
139+
require.Len(t, firstCreateClause(t, preparedQuery).Pattern, 2)
140+
}
141+
105142
func TestCreateNodeReturnDoesNotCreateMatch(t *testing.T) {
106143
preparedQuery, err := v2.New().Create(
107144
v2.Node().NodePattern(graph.Kinds{graph.StringKind("A")}, v2.NamedParameter("props", map[string]any{"name": "node"})),
@@ -532,6 +569,13 @@ func TestProjectionAndOrderHelpers(t *testing.T) {
532569
require.Equal(t, "match (n) return distinct id(n) as node_id order by n.name asc, id(n) desc", renderPrepared(t, preparedQuery))
533570
}
534571

572+
func TestInvalidSortDirectionReturnsError(t *testing.T) {
573+
_, err := v2.New().Return(v2.Node()).OrderBy(
574+
v2.Order(v2.Node().Property("name"), v2.SortDirection(99)),
575+
).Build()
576+
require.ErrorContains(t, err, "unsupported sort direction: 99")
577+
}
578+
535579
func TestPaginationZeroValuesAndNegativeValidation(t *testing.T) {
536580
preparedQuery, err := v2.New().Return(v2.Node()).Skip(0).Limit(0).Build()
537581
require.NoError(t, err)

0 commit comments

Comments
 (0)