Skip to content

Commit 5e192b5

Browse files
committed
fix (query/v2): clean up some poorly supported neo4j constructs
1 parent 4ac05a8 commit 5e192b5

4 files changed

Lines changed: 145 additions & 11 deletions

File tree

query/neo4j/neo4j.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ func (s *QueryBuilder) rewriteParameters() error {
5353
return nil
5454
}
5555

56+
func hasPreparedMatchPattern(readingClause *cypher.ReadingClause) bool {
57+
if readingClause == nil || readingClause.Match == nil {
58+
return false
59+
}
60+
61+
return len(readingClause.Match.Pattern) > 0
62+
}
63+
5664
func (s *QueryBuilder) Apply(criteria graph.Criteria) {
5765
switch typedCriteria := criteria.(type) {
5866
case *cypher.Where:
@@ -201,6 +209,10 @@ func (s *QueryBuilder) prepareMatch() error {
201209
return ErrAmbiguousQueryVariables
202210
}
203211

212+
if firstReadingClause := query.GetFirstReadingClause(s.query); hasPreparedMatchPattern(firstReadingClause) {
213+
return nil
214+
}
215+
204216
if singleNodeBound && !creatingSingleNode {
205217
patternPart.AddPatternElements(&cypher.NodePattern{
206218
Variable: cypher.NewVariableWithSymbol(query.NodeSymbol),

query/v2/backend_test.go

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package v2_test
22

33
import (
44
"context"
5+
"strings"
56
"testing"
67

78
"github.com/specterops/dawgs/cypher/models/pgsql/translate"
@@ -49,9 +50,31 @@ func TestBackendParityNeo4jPrepare(t *testing.T) {
4950
v2.Relationship().ID(),
5051
v2.End().ID(),
5152
),
52-
expectedCypher: "match (s)-[r]->(e) where id(s) = $p0 return id(s), id(r), id(e)",
53+
expectedCypher: "match (s)-[r:MemberOf]->(e) where id(s) = $p0 return id(s), id(r), id(e)",
5354
expectedParams: map[string]any{"p0": 1},
5455
},
56+
"shortest path": {
57+
builder: v2.New().WithShortestPaths().Where(
58+
v2.Relationship().Kind().Is(graph.StringKind("MemberOf")),
59+
v2.Start().ID().Equals(1),
60+
v2.End().ID().Equals(2),
61+
).Return(
62+
v2.Path(),
63+
),
64+
expectedCypher: "match p = shortestPath((s)-[r:MemberOf*]->(e)) where id(s) = $p0 and id(e) = $p1 return p",
65+
expectedParams: map[string]any{"p0": 1, "p1": 2},
66+
},
67+
"all shortest paths": {
68+
builder: v2.New().WithAllShortestPaths().Where(
69+
v2.Relationship().Kind().Is(graph.StringKind("MemberOf")),
70+
v2.Start().ID().Equals(1),
71+
v2.End().ID().Equals(2),
72+
).Return(
73+
v2.Path(),
74+
),
75+
expectedCypher: "match p = allShortestPaths((s)-[r:MemberOf*]->(e)) where id(s) = $p0 and id(e) = $p1 return p",
76+
expectedParams: map[string]any{"p0": 1, "p1": 2},
77+
},
5578
"create node": {
5679
builder: v2.New().Create(
5780
v2.NodePattern(graph.Kinds{graph.StringKind("User")}, v2.NamedParameter("props", map[string]any{"name": "u"})),
@@ -79,6 +102,15 @@ func TestBackendParityNeo4jPrepare(t *testing.T) {
79102
expectedCypher: "match ()-[r]->() where id(r) = $p0 delete r",
80103
expectedParams: map[string]any{"p0": 1},
81104
},
105+
"delete node": {
106+
builder: v2.New().Where(
107+
v2.Node().ID().Equals(1),
108+
).Delete(
109+
v2.Node(),
110+
),
111+
expectedCypher: "match (n) where id(n) = $p0 detach delete n",
112+
expectedParams: map[string]any{"p0": 1},
113+
},
82114
}
83115

84116
for name, testCase := range cases {
@@ -148,6 +180,15 @@ func TestBackendParityPGTranslate(t *testing.T) {
148180
expectedSQL: "with s0 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from edge e0 join node n0 on n0.id = e0.start_id join node n1 on n1.id = e0.end_id where (e0.id = @pi0::int8)), s1 as (delete from edge e1 using s0 where (s0.e0).id = e1.id) select 1;",
149181
expectedParams: map[string]any{"p0": 1, "pi0": 1},
150182
},
183+
"delete node": {
184+
builder: v2.New().Where(
185+
v2.Node().ID().Equals(1),
186+
).Delete(
187+
v2.Node(),
188+
),
189+
expectedSQL: "with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where (n0.id = @pi0::int8)), s1 as (delete from node n1 using s0 where (s0.n0).id = n1.id) select 1;",
190+
expectedParams: map[string]any{"p0": 1, "pi0": 1},
191+
},
151192
}
152193

153194
for name, testCase := range cases {
@@ -166,6 +207,63 @@ func TestBackendParityPGTranslate(t *testing.T) {
166207
}
167208
}
168209

210+
func TestBackendParityPGTranslateShortestPaths(t *testing.T) {
211+
edgeKind := graph.StringKind("MemberOf")
212+
mapper := testKindMapper(edgeKind)
213+
214+
cases := map[string]struct {
215+
builder v2.QueryBuilder
216+
expectedHarness string
217+
}{
218+
"shortest path": {
219+
builder: v2.New().WithShortestPaths().Where(
220+
v2.Relationship().Kind().Is(edgeKind),
221+
v2.Start().ID().Equals(1),
222+
v2.End().ID().Equals(2),
223+
).Return(
224+
v2.Path(),
225+
),
226+
expectedHarness: "unidirectional_sp_harness",
227+
},
228+
"all shortest paths": {
229+
builder: v2.New().WithAllShortestPaths().Where(
230+
v2.Relationship().Kind().Is(edgeKind),
231+
v2.Start().ID().Equals(1),
232+
v2.End().ID().Equals(2),
233+
).Return(
234+
v2.Path(),
235+
),
236+
expectedHarness: "bidirectional_asp_harness",
237+
},
238+
}
239+
240+
for name, testCase := range cases {
241+
t.Run(name, func(t *testing.T) {
242+
preparedQuery, err := testCase.builder.Build()
243+
require.NoError(t, err)
244+
245+
translation, err := translate.Translate(context.Background(), preparedQuery.Query, mapper, preparedQuery.Parameters)
246+
require.NoError(t, err)
247+
248+
sql, err := translate.Translated(translation)
249+
require.NoError(t, err)
250+
require.Contains(t, sql, testCase.expectedHarness)
251+
require.Contains(t, sql, "edges_to_path")
252+
require.Equal(t, 1, translation.Parameters["p0"])
253+
require.Equal(t, 2, translation.Parameters["p1"])
254+
255+
serializedHarnessQueryHasKindConstraint := false
256+
for _, parameterValue := range translation.Parameters {
257+
if serializedQuery, typeOK := parameterValue.(string); typeOK && strings.Contains(serializedQuery, "array [1]::int2[]") {
258+
serializedHarnessQueryHasKindConstraint = true
259+
break
260+
}
261+
}
262+
require.True(t, serializedHarnessQueryHasKindConstraint, "expected serialized shortest-path harness query to contain edge kind constraint: %#v", translation.Parameters)
263+
})
264+
}
265+
}
266+
169267
func TestBackendParityPGCreateUnsupported(t *testing.T) {
170268
edgeKind := graph.StringKind("MemberOf")
171269
mapper := testKindMapper(edgeKind)

query/v2/query_test.go

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,26 +187,28 @@ func TestInvalidRelationshipDirectionReturnsError(t *testing.T) {
187187

188188
func TestShortestPathControls(t *testing.T) {
189189
preparedQuery, err := v2.New().WithShortestPaths().Where(
190+
v2.Relationship().Kind().Is(graph.StringKind("MemberOf")),
190191
v2.Start().ID().Equals(1),
191192
v2.End().ID().Equals(2),
192193
).Return(
193194
v2.Path(),
194195
).Build()
195196
require.NoError(t, err)
196-
require.Equal(t, "match p = shortestPath((s)-[*]->(e)) where id(s) = $p0 and id(e) = $p1 return p", renderPrepared(t, preparedQuery))
197+
require.Equal(t, "match p = shortestPath((s)-[r:MemberOf*]->(e)) where id(s) = $p0 and id(e) = $p1 return p", renderPrepared(t, preparedQuery))
197198
require.Equal(t, map[string]any{
198199
"p0": 1,
199200
"p1": 2,
200201
}, preparedQuery.Parameters)
201202

202203
preparedQuery, err = v2.New().WithAllShortestPaths().Where(
204+
v2.Relationship().Kind().Is(graph.StringKind("MemberOf")),
203205
v2.Start().ID().Equals(1),
204206
v2.End().ID().Equals(2),
205207
).Return(
206208
v2.Path(),
207209
).Build()
208210
require.NoError(t, err)
209-
require.Equal(t, "match p = allShortestPaths((s)-[*]->(e)) where id(s) = $p0 and id(e) = $p1 return p", renderPrepared(t, preparedQuery))
211+
require.Equal(t, "match p = allShortestPaths((s)-[r:MemberOf*]->(e)) where id(s) = $p0 and id(e) = $p1 return p", renderPrepared(t, preparedQuery))
210212

211213
_, err = v2.New().WithShortestPaths().WithAllShortestPaths().Where(
212214
v2.Start().ID().Equals(1),
@@ -217,6 +219,16 @@ func TestShortestPathControls(t *testing.T) {
217219
require.ErrorContains(t, err, "query is requesting both all shortest paths and shortest paths")
218220
}
219221

222+
func TestMixedNodeAndRelationshipIdentifiersReturnError(t *testing.T) {
223+
_, err := v2.New().Where(
224+
v2.Node().ID().Equals(1),
225+
v2.Relationship().ID().Equals(2),
226+
).Return(
227+
v2.Node(),
228+
).Build()
229+
require.ErrorContains(t, err, "query mixes node and relationship query identifiers")
230+
}
231+
220232
func TestInvalidExplicitRelationshipPatternDirectionReturnsError(t *testing.T) {
221233
_, err := v2.New().Create(
222234
v2.RelationshipPattern(graph.StringKind("Edge"), nil, graph.DirectionBoth),
@@ -317,6 +329,26 @@ func TestRawReturnInputPreservesProjectionMetadata(t *testing.T) {
317329
require.Equal(t, "match (n) return distinct id(n) order by n.name desc skip 5 limit 10", renderPrepared(t, preparedQuery))
318330
}
319331

332+
func TestRawReturnInputMergesWithBuilderProjectionControls(t *testing.T) {
333+
returnClause := cypher.NewReturn()
334+
projection := returnClause.NewProjection(true)
335+
projection.Items = append(projection.Items, v2.Node().ID())
336+
projection.Order = &cypher.Order{
337+
Items: []*cypher.SortItem{
338+
v2.Desc(v2.Node().Property("name")),
339+
},
340+
}
341+
projection.Skip = cypher.NewSkip(5)
342+
projection.Limit = cypher.NewLimit(10)
343+
344+
preparedQuery, err := v2.New().Return(returnClause).OrderBy(
345+
v2.Asc(v2.Node().Property("created_at")),
346+
).Skip(15).Limit(20).Build()
347+
require.NoError(t, err)
348+
349+
require.Equal(t, "match (n) return distinct id(n) order by n.name desc, n.created_at asc skip 15 limit 20", renderPrepared(t, preparedQuery))
350+
}
351+
320352
func TestRawUpdatingInputsAreValidated(t *testing.T) {
321353
var setClause *cypher.Set
322354
_, err := v2.New().Update(setClause).Build()

query/v2/util.go

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -619,14 +619,6 @@ func (s *identifierSet) CollectFromValue(value any) error {
619619
case QualifiedExpression:
620620
return s.CollectFromExpression(typedValue.qualifier())
621621

622-
case kindContinuation:
623-
s.Add(typedValue.identifier.Symbol)
624-
return nil
625-
626-
case kindsContinuation:
627-
s.Add(typedValue.identifier.Symbol)
628-
return nil
629-
630622
case *cypher.Return:
631623
if projectionItems, err := projectionItemsFromReturn(typedValue); err != nil {
632624
return err

0 commit comments

Comments
 (0)