Skip to content

Commit fd7642c

Browse files
committed
Refactor PPL query cancellation to cooperative model and other PR suggestions.
1 parent 024a8ca commit fd7642c

9 files changed

Lines changed: 130 additions & 169 deletions

File tree

opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchQueryManager.java

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
import org.apache.logging.log4j.LogManager;
1111
import org.apache.logging.log4j.Logger;
1212
import org.apache.logging.log4j.ThreadContext;
13-
import org.opensearch.OpenSearchException;
1413
import org.opensearch.OpenSearchTimeoutException;
1514
import org.opensearch.common.unit.TimeValue;
1615
import org.opensearch.sql.common.setting.Settings;
1716
import org.opensearch.sql.executor.QueryId;
1817
import org.opensearch.sql.executor.QueryManager;
1918
import org.opensearch.sql.executor.execution.AbstractPlan;
19+
import org.opensearch.tasks.CancellableTask;
2020
import org.opensearch.threadpool.Scheduler;
2121
import org.opensearch.threadpool.ThreadPool;
2222
import org.opensearch.transport.client.node.NodeClient;
@@ -34,44 +34,39 @@ public class OpenSearchQueryManager implements QueryManager {
3434
public static final String SQL_WORKER_THREAD_POOL_NAME = "sql-worker";
3535
public static final String SQL_BACKGROUND_THREAD_POOL_NAME = "sql_background_io";
3636

37-
public interface CancellationCallBack {
38-
void onExecutionThreadAvailable(Thread thread);
39-
void onExecutionComplete();
40-
boolean isCancelled();
41-
}
37+
private static final ThreadLocal<CancellableTask> cancellableTask = new ThreadLocal<>();
4238

43-
public static ThreadLocal<CancellationCallBack> cancellationCallBackThreadLocal = new ThreadLocal<>();
39+
public static void setCancellableTask(CancellableTask task) {
40+
cancellableTask.set(task);
41+
}
4442

45-
public static void setCancellationCallback(CancellationCallBack value) {
46-
cancellationCallBackThreadLocal.set(value);
43+
public static CancellableTask getCancellableTask() {
44+
return cancellableTask.get();
4745
}
4846

49-
public static void clearCancellationCallback() {
50-
cancellationCallBackThreadLocal.remove();
47+
public static void clearCancellableTask() {
48+
cancellableTask.remove();
5149
}
5250

5351
@Override
5452
public QueryId submit(AbstractPlan queryPlan) {
5553
TimeValue timeout = settings.getSettingValue(Settings.Key.PPL_QUERY_TIMEOUT);
56-
CancellationCallBack callBack = cancellationCallBackThreadLocal.get();
57-
cancellationCallBackThreadLocal.remove();
58-
schedule(nodeClient, queryPlan::execute, timeout, callBack);
54+
CancellableTask cancelTask = cancellableTask.get();
55+
cancellableTask.remove();
56+
schedule(nodeClient, queryPlan::execute, timeout, cancelTask);
5957

6058
return queryPlan.getQueryId();
6159
}
6260

63-
private void schedule(NodeClient client, Runnable task, TimeValue timeout, CancellationCallBack callBack) {
61+
private void schedule(
62+
NodeClient client, Runnable task, TimeValue timeout, CancellableTask cancelTask) {
6463
ThreadPool threadPool = client.threadPool();
6564

6665
Runnable wrappedTask =
6766
withCurrentContext(
6867
() -> {
6968
final Thread executionThread = Thread.currentThread();
7069

71-
if (callBack != null) {
72-
callBack.onExecutionThreadAvailable(executionThread);
73-
}
74-
7570
Scheduler.ScheduledCancellable timeoutTask =
7671
threadPool.schedule(
7772
() -> {
@@ -83,6 +78,8 @@ private void schedule(NodeClient client, Runnable task, TimeValue timeout, Cance
8378
timeout,
8479
ThreadPool.Names.GENERIC);
8580

81+
setCancellableTask(cancelTask);
82+
8683
try {
8784
task.run();
8885
timeoutTask.cancel();
@@ -93,21 +90,14 @@ private void schedule(NodeClient client, Runnable task, TimeValue timeout, Cance
9390

9491
// Special-case handling of timeout-related interruptions
9592
if (Thread.interrupted() || e.getCause() instanceof InterruptedException) {
96-
if (callBack != null && callBack.isCancelled()) {
97-
LOG.info("Query was cancelled");
98-
throw new OpenSearchException("Query was cancelled.");
99-
}
10093
LOG.error("Query was interrupted due to timeout after {}", timeout);
10194
throw new OpenSearchTimeoutException(
10295
"Query execution timed out after " + timeout);
10396
}
10497

10598
throw e;
106-
}
107-
finally {
108-
if (callBack != null) {
109-
callBack.onExecutionComplete();
110-
}
99+
} finally {
100+
clearCancellableTask();
111101
}
112102
});
113103

opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexEnumerator.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,16 @@
1111
import lombok.EqualsAndHashCode;
1212
import lombok.ToString;
1313
import org.apache.calcite.linq4j.Enumerator;
14+
import org.opensearch.core.tasks.TaskCancelledException;
1415
import org.opensearch.sql.data.model.ExprValue;
1516
import org.opensearch.sql.data.model.ExprValueUtils;
1617
import org.opensearch.sql.exception.NonFallbackCalciteException;
1718
import org.opensearch.sql.expression.HighlightExpression;
1819
import org.opensearch.sql.monitor.ResourceMonitor;
1920
import org.opensearch.sql.opensearch.client.OpenSearchClient;
21+
import org.opensearch.sql.opensearch.executor.OpenSearchQueryManager;
2022
import org.opensearch.sql.opensearch.request.OpenSearchRequest;
23+
import org.opensearch.tasks.CancellableTask;
2124

2225
/**
2326
* Supports a simple iteration over a collection for OpenSearch index
@@ -55,6 +58,8 @@ public class OpenSearchIndexEnumerator implements Enumerator<Object> {
5558

5659
private ExprValue current = null;
5760

61+
private CancellableTask cancellableTask;
62+
5863
public OpenSearchIndexEnumerator(
5964
OpenSearchClient client,
6065
List<String> fields,
@@ -80,6 +85,7 @@ public OpenSearchIndexEnumerator(
8085
this.client = client;
8186
this.bgScanner = new BackgroundSearchScanner(client, maxResultWindow, queryBucketSize);
8287
this.bgScanner.startScanning(request);
88+
this.cancellableTask = OpenSearchQueryManager.getCancellableTask();
8389
}
8490

8591
private Iterator<ExprValue> fetchNextBatch() {
@@ -112,6 +118,10 @@ public boolean moveNext() {
112118
return false;
113119
}
114120

121+
if (cancellableTask != null && cancellableTask.isCancelled()) {
122+
throw new TaskCancelledException("The task is cancelled.");
123+
}
124+
115125
boolean shouldCheck = (queryCount % NUMBER_OF_NEXT_CALL_TO_CHECK == 0);
116126
if (shouldCheck) {
117127
org.opensearch.sql.monitor.ResourceStatus status = this.monitor.getStatus();

plugin/src/main/java/org/opensearch/sql/plugin/request/PPLQueryRequestFactory.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ private static PPLQueryRequest parsePPLRequestFromPayload(RestRequest restReques
116116
// set queryId
117117
String queryId = jsonContent.optString("queryId", null);
118118
if (queryId != null) {
119-
pplRequest.queryId(queryId);
119+
pplRequest.queryId(queryId);
120120
}
121121
return pplRequest;
122122
} catch (JSONException e) {
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.plugin.transport;
7+
8+
import java.util.Map;
9+
import org.opensearch.core.tasks.TaskId;
10+
import org.opensearch.tasks.CancellableTask;
11+
12+
public class PPLQueryTask extends CancellableTask {
13+
14+
public PPLQueryTask(
15+
long id,
16+
String type,
17+
String action,
18+
String description,
19+
TaskId parentTaskId,
20+
Map<String, String> headers) {
21+
super(id, type, action, description, parentTaskId, headers);
22+
}
23+
24+
@Override
25+
public boolean shouldCancelChildrenOnCancellation() {
26+
return true;
27+
}
28+
}

plugin/src/main/java/org/opensearch/sql/plugin/transport/SQLQueryTask.java

Lines changed: 0 additions & 44 deletions
This file was deleted.

plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -110,24 +110,8 @@ protected void doExecute(
110110
return;
111111
}
112112

113-
if (task instanceof SQLQueryTask sqlQueryTask) {
114-
115-
OpenSearchQueryManager.setCancellationCallback(new OpenSearchQueryManager.CancellationCallBack() {
116-
@Override
117-
public void onExecutionThreadAvailable(Thread thread) {
118-
sqlQueryTask.setExecutionThread(thread);
119-
}
120-
121-
@Override
122-
public void onExecutionComplete() {
123-
sqlQueryTask.clearExecutionThread();
124-
}
125-
126-
@Override
127-
public boolean isCancelled() {
128-
return sqlQueryTask.isCancelled();
129-
}
130-
});
113+
if (task instanceof PPLQueryTask pplQueryTask) {
114+
OpenSearchQueryManager.setCancellableTask(pplQueryTask);
131115
}
132116
Metrics.getInstance().getNumericalMetric(MetricName.PPL_REQ_TOTAL).increment();
133117
Metrics.getInstance().getNumericalMetric(MetricName.PPL_REQ_COUNT_TOTAL).increment();

plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryRequest.java

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import org.opensearch.sql.ppl.domain.PPLQueryRequest;
2727
import org.opensearch.sql.protocol.response.format.Format;
2828
import org.opensearch.sql.protocol.response.format.JsonResponseFormatter;
29-
import org.opensearch.tasks.Task;
3029

3130
@RequiredArgsConstructor
3231
public class TransportPPLQueryRequest extends ActionRequest {
@@ -159,20 +158,15 @@ public ActionRequestValidationException validate() {
159158
}
160159

161160
@Override
162-
public SQLQueryTask createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
163-
return new SQLQueryTask(id, type, action, getDescription() , parentTaskId, headers);
161+
public PPLQueryTask createTask(
162+
long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
163+
return new PPLQueryTask(id, type, action, getDescription(), parentTaskId, headers);
164164
}
165165

166166
@Override
167-
public String getDescription()
168-
{
169-
String prefix = (queryId != null) ? "PPL [queryId=" + queryId + "]: " : "PPL: ";
170-
171-
if (pplQuery != null && pplQuery.length() > 512) {
172-
return prefix + pplQuery.substring(0,512) + "...";
173-
}
174-
175-
return prefix + pplQuery;
167+
public String getDescription() {
168+
String prefix = (queryId != null) ? "PPL [queryId=" + queryId + "]: " : "PPL: ";
169+
return prefix + pplQuery;
176170
}
177171

178172
/** Convert to PPLQueryRequest. */
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.plugin.transport;
7+
8+
import static org.junit.Assert.*;
9+
10+
import java.util.Map;
11+
import org.junit.Test;
12+
import org.opensearch.core.tasks.TaskId;
13+
14+
public class PPLQueryTaskTest {
15+
16+
@Test
17+
public void testShouldCancelChildrenReturnsTrue() {
18+
PPLQueryTask pplQueryTask =
19+
new PPLQueryTask(
20+
1,
21+
"transport",
22+
"cluster:admin/opensearch/ppl",
23+
"test query",
24+
TaskId.EMPTY_TASK_ID,
25+
Map.of());
26+
assertTrue(pplQueryTask.shouldCancelChildrenOnCancellation());
27+
}
28+
29+
@Test
30+
public void testCreateTaskReturnsPPLQueryTask() {
31+
TransportPPLQueryRequest transportPPLQueryRequest =
32+
new TransportPPLQueryRequest("source=t a=1", null, "/_plugins/_ppl");
33+
PPLQueryTask task =
34+
transportPPLQueryRequest.createTask(
35+
1, "transport", "cluster:admin/opensearch/ppl", TaskId.EMPTY_TASK_ID, Map.of());
36+
assertNotNull(task);
37+
}
38+
39+
@Test
40+
public void testWithQueryId() {
41+
TransportPPLQueryRequest transportPPLQueryRequest =
42+
new TransportPPLQueryRequest("source=t a=1", null, "/_plugins/_ppl");
43+
transportPPLQueryRequest.queryId("test-123");
44+
assertEquals("PPL [queryId=test-123]: source=t a=1", transportPPLQueryRequest.getDescription());
45+
}
46+
47+
@Test
48+
public void testWithoutQueryId() {
49+
TransportPPLQueryRequest transportPPLQueryRequest =
50+
new TransportPPLQueryRequest("source=t a=1", null, "/_plugins/_ppl");
51+
assertEquals("PPL: source=t a=1", transportPPLQueryRequest.getDescription());
52+
}
53+
54+
@Test
55+
public void testCooperativeModel() {
56+
TransportPPLQueryRequest transportPPLQueryRequest =
57+
new TransportPPLQueryRequest("source=t a=1", null, "/_plugins/_ppl");
58+
PPLQueryTask task =
59+
transportPPLQueryRequest.createTask(
60+
1, "transport", "cluster:admin/opensearch/ppl", TaskId.EMPTY_TASK_ID, Map.of());
61+
assertFalse(task.isCancelled());
62+
task.cancel("Test");
63+
assertTrue(task.isCancelled());
64+
}
65+
}

0 commit comments

Comments
 (0)