Skip to content

Commit 63ec984

Browse files
authored
ArcadeData#3996 fix(cypher): CALL...YIELD preserves variables carried in from WITH (ArcadeData#4009)
1 parent 983da3f commit 63ec984

2 files changed

Lines changed: 158 additions & 12 deletions

File tree

engine/src/main/java/com/arcadedb/query/opencypher/executor/steps/CallStep.java

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,12 @@ private ResultSet executeChainedCall(final CommandContext context, final int nRe
118118
final boolean hasYield = callClause.hasYield() && !callClause.isYieldAll();
119119
final boolean yieldHasWhere = hasYield && callClause.getYieldWhere() != null;
120120

121-
// Collect all results from all input rows
122-
final List<Iterator<?>> allIters = new ArrayList<>();
121+
// Each entry pairs the originating inputRow with the procedure's result iterator.
122+
// Pairing is required so variables carried in from preceding WITH/MATCH clauses
123+
// can be merged into every yielded result (issue #3996).
124+
// Capacity is bounded by nRecords (the upstream batch limit); capped at 1M to guard
125+
// against Integer.MAX_VALUE being passed as a "fetch all" sentinel.
126+
final List<Map.Entry<Result, Iterator<?>>> allPairs = new ArrayList<>(nRecords > 0 && nRecords < 1000000 ? nRecords : 10);
123127
while (prevResults.hasNext()) {
124128
final Result inputRow = prevResults.next();
125129
final long begin = context.isProfiling() ? System.nanoTime() : 0;
@@ -131,17 +135,24 @@ private ResultSet executeChainedCall(final CommandContext context, final int nRe
131135

132136
if (callResult == null) {
133137
if (callClause.isOptional())
134-
allIters.add(java.util.Collections.singletonList((Object) mergeWithInputRow(inputRow, null)).iterator());
138+
// Use an empty result so YIELD only sees the procedure's outputs (null for every
139+
// field). Pre-merging inputRow here would let YIELD incorrectly read outer-scope
140+
// variables if they share a name with a YIELD field. The lazy iterator merges
141+
// inputRow later, after YIELD filtering.
142+
allPairs.add(Map.entry(inputRow,
143+
java.util.Collections.singletonList((Object) new ResultInternal()).iterator()));
135144
continue;
136145
}
137146

147+
final Iterator<?> iter;
138148
if (callResult instanceof Iterator) {
139-
allIters.add((Iterator<?>) callResult);
149+
iter = (Iterator<?>) callResult;
140150
} else if (callResult instanceof Collection) {
141-
allIters.add(((Collection<?>) callResult).iterator());
151+
iter = ((Collection<?>) callResult).iterator();
142152
} else {
143-
allIters.add(java.util.Collections.singletonList(callResult).iterator());
153+
iter = java.util.Collections.singletonList(callResult).iterator();
144154
}
155+
allPairs.add(Map.entry(inputRow, iter));
145156
} finally {
146157
if (context.isProfiling())
147158
cost += (System.nanoTime() - begin);
@@ -176,9 +187,12 @@ public void close() {
176187
};
177188
}
178189

179-
// Standard path: lazily iterate through all result iterators
180-
final Iterator<Iterator<?>> iterOfIters = allIters.iterator();
190+
// Standard path: lazily iterate through all (inputRow, resultIterator) pairs.
191+
// Each yielded result is merged with its originating inputRow so that variables
192+
// from a preceding WITH/MATCH clause remain visible after CALL ... YIELD.
193+
final Iterator<Map.Entry<Result, Iterator<?>>> pairIter = allPairs.iterator();
181194
final Iterator<Result> lazyIter = new Iterator<>() {
195+
private Result currentInputRow = null;
182196
private Iterator<?> currentIter = null;
183197
private Result next = null;
184198

@@ -190,15 +204,17 @@ public boolean hasNext() {
190204
if (hasYield) {
191205
final ResultInternal filtered = applyYieldToSingleResult(converted);
192206
if (filtered != null) {
193-
next = filtered;
207+
next = mergeWithInputRow(currentInputRow, filtered);
194208
return true;
195209
}
196210
} else {
197-
next = converted;
211+
next = mergeWithInputRow(currentInputRow, converted);
198212
return true;
199213
}
200-
} else if (iterOfIters.hasNext()) {
201-
currentIter = iterOfIters.next();
214+
} else if (pairIter.hasNext()) {
215+
final Map.Entry<Result, Iterator<?>> pair = pairIter.next();
216+
currentInputRow = pair.getKey();
217+
currentIter = pair.getValue();
202218
} else {
203219
return false;
204220
}
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
/*
2+
* Copyright © 2021-present Arcade Data Ltd (info@arcadedata.com)
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*
16+
* SPDX-FileCopyrightText: 2021-present Arcade Data Ltd (info@arcadedata.com)
17+
* SPDX-License-Identifier: Apache-2.0
18+
*/
19+
package com.arcadedb.query.opencypher;
20+
21+
import com.arcadedb.database.Database;
22+
import com.arcadedb.database.DatabaseFactory;
23+
import com.arcadedb.query.sql.executor.Result;
24+
import com.arcadedb.query.sql.executor.ResultSet;
25+
import org.junit.jupiter.api.AfterEach;
26+
import org.junit.jupiter.api.BeforeEach;
27+
import org.junit.jupiter.api.Test;
28+
29+
import java.util.ArrayList;
30+
import java.util.List;
31+
32+
import static org.assertj.core.api.Assertions.assertThat;
33+
34+
/**
35+
* Regression tests for GitHub issue #3996:
36+
* CALL ... YIELD may null out variables carried in from WITH.
37+
*/
38+
class CypherCallYieldWithVariablesTest {
39+
private Database database;
40+
41+
@BeforeEach
42+
void setUp() {
43+
database = new DatabaseFactory("./target/databases/testopencypher-call-yield-with").create();
44+
final var personType = database.getSchema().createVertexType("Person");
45+
personType.createProperty("name", String.class);
46+
database.getSchema().createEdgeType("KNOWS");
47+
database.transaction(() -> database.command("opencypher", "CREATE (:Person {name: 'Alice'})"));
48+
}
49+
50+
@AfterEach
51+
void tearDown() {
52+
if (database != null) {
53+
database.drop();
54+
database = null;
55+
}
56+
}
57+
58+
@Test
59+
void withLiteralPreservedAcrossCallDbLabels() {
60+
final ResultSet rs = database.query("opencypher",
61+
"WITH 1 AS x CALL db.labels() YIELD label RETURN x, label");
62+
63+
final List<Result> rows = new ArrayList<>();
64+
while (rs.hasNext())
65+
rows.add(rs.next());
66+
67+
assertThat(rows).isNotEmpty();
68+
for (final Result row : rows) {
69+
final Number x = row.getProperty("x");
70+
assertThat(x).as("x must not be null across CALL db.labels()").isNotNull();
71+
assertThat(x.longValue()).isEqualTo(1L);
72+
assertThat((Object) row.getProperty("label")).as("label must not be null").isNotNull();
73+
}
74+
}
75+
76+
@Test
77+
void withLiteralPreservedAcrossCallDbRelationshipTypes() {
78+
final ResultSet rs = database.query("opencypher",
79+
"WITH 1 AS x CALL db.relationshipTypes() YIELD relationshipType RETURN x, relationshipType");
80+
81+
final List<Result> rows = new ArrayList<>();
82+
while (rs.hasNext())
83+
rows.add(rs.next());
84+
85+
assertThat(rows).isNotEmpty();
86+
for (final Result row : rows) {
87+
final Number x = row.getProperty("x");
88+
assertThat(x).as("x must not be null across CALL db.relationshipTypes()").isNotNull();
89+
assertThat(x.longValue()).isEqualTo(1L);
90+
assertThat((Object) row.getProperty("relationshipType")).as("relationshipType must not be null").isNotNull();
91+
}
92+
}
93+
94+
@Test
95+
void withLiteralPreservedAcrossCallDbPropertyKeys() {
96+
final ResultSet rs = database.query("opencypher",
97+
"WITH 1 AS x CALL db.propertyKeys() YIELD propertyKey RETURN x, propertyKey");
98+
99+
final List<Result> rows = new ArrayList<>();
100+
while (rs.hasNext())
101+
rows.add(rs.next());
102+
103+
assertThat(rows).isNotEmpty();
104+
for (final Result row : rows) {
105+
final Number x = row.getProperty("x");
106+
assertThat(x).as("x must not be null across CALL db.propertyKeys()").isNotNull();
107+
assertThat(x.longValue()).isEqualTo(1L);
108+
assertThat((Object) row.getProperty("propertyKey")).as("propertyKey must not be null").isNotNull();
109+
}
110+
}
111+
112+
@Test
113+
void aggregatedValuePreservedAcrossCallDbLabels() {
114+
// Aggregated values carried through WITH must also survive CALL ... YIELD
115+
final ResultSet rs = database.query("opencypher",
116+
"MATCH (:Person) WITH count(*) AS c CALL db.labels() YIELD label RETURN c, label");
117+
118+
final List<Result> rows = new ArrayList<>();
119+
while (rs.hasNext())
120+
rows.add(rs.next());
121+
122+
assertThat(rows).isNotEmpty();
123+
for (final Result row : rows) {
124+
final Number c = row.getProperty("c");
125+
assertThat(c).as("count(*) must not be null across CALL db.labels()").isNotNull();
126+
assertThat(c.longValue()).as("count must be 1").isEqualTo(1L);
127+
assertThat((Object) row.getProperty("label")).as("label must not be null").isNotNull();
128+
}
129+
}
130+
}

0 commit comments

Comments
 (0)