Skip to content

Commit f9a5518

Browse files
committed
add sendBulk sysdb
1 parent 3a23ab8 commit f9a5518

4 files changed

Lines changed: 431 additions & 86 deletions

File tree

transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import dev.dbos.transact.workflow.ListWorkflowsInput;
2323
import dev.dbos.transact.workflow.NotificationInfo;
2424
import dev.dbos.transact.workflow.Queue;
25+
import dev.dbos.transact.workflow.SendMessage;
2526
import dev.dbos.transact.workflow.QueueOptions;
2627
import dev.dbos.transact.workflow.ScheduleStatus;
2728
import dev.dbos.transact.workflow.StepInfo;
@@ -454,6 +455,19 @@ public Optional<String> checkChildWorkflow(String workflowUuid, int functionId)
454455
return dbRetry(() -> WorkflowDAO.checkChildWorkflow(ctx, workflowUuid, functionId));
455456
}
456457

458+
public void sendBulk(
459+
List<SendMessage> messages,
460+
String workflowId,
461+
int stepId,
462+
String functionName,
463+
boolean sendToForks,
464+
String serialization) {
465+
dbRetry(
466+
() ->
467+
NotificationsDAO.sendBulk(
468+
ctx, messages, workflowId, stepId, functionName, sendToForks, serialization));
469+
}
470+
457471
public void send(
458472
String workflowId,
459473
int stepId,
@@ -462,18 +476,24 @@ public void send(
462476
String topic,
463477
String messageId,
464478
String serialization) {
465-
dbRetry(
466-
() ->
467-
NotificationsDAO.send(
468-
ctx, workflowId, stepId, destinationId, message, topic, messageId, serialization));
479+
sendBulk(
480+
List.of(new SendMessage(destinationId, message, topic, messageId)),
481+
workflowId,
482+
stepId,
483+
"DBOS.send",
484+
false,
485+
serialization);
469486
}
470487

471488
public void sendDirect(
472489
String destinationId, Object message, String topic, String messageId, String serialization) {
473-
dbRetry(
474-
() ->
475-
NotificationsDAO.sendDirect(
476-
ctx, destinationId, message, topic, messageId, serialization));
490+
sendBulk(
491+
List.of(new SendMessage(destinationId, message, topic, messageId)),
492+
null,
493+
-1,
494+
"DBOS.send",
495+
false,
496+
serialization);
477497
}
478498

479499
public Object recv(

transact/src/main/java/dev/dbos/transact/database/dao/NotificationsDAO.java

Lines changed: 138 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import dev.dbos.transact.json.DBOSSerializer;
1111
import dev.dbos.transact.json.SerializationUtil;
1212
import dev.dbos.transact.workflow.NotificationInfo;
13+
import dev.dbos.transact.workflow.SendMessage;
1314
import dev.dbos.transact.workflow.internal.StepResult;
1415

1516
import java.sql.Connection;
@@ -18,12 +19,21 @@
1819
import java.sql.SQLException;
1920
import java.time.Duration;
2021
import java.time.Instant;
22+
import java.util.ArrayDeque;
2123
import java.util.ArrayList;
24+
import java.util.Deque;
25+
import java.util.HashMap;
26+
import java.util.LinkedHashMap;
27+
import java.util.LinkedHashSet;
2228
import java.util.List;
29+
import java.util.Map;
2330
import java.util.Objects;
2431
import java.util.Optional;
32+
import java.util.Set;
2533
import java.util.UUID;
2634

35+
import static java.util.stream.Collectors.joining;
36+
2737
import org.jspecify.annotations.NonNull;
2838
import org.jspecify.annotations.Nullable;
2939
import org.slf4j.Logger;
@@ -35,77 +45,165 @@ private NotificationsDAO() {}
3545

3646
private static final Logger logger = LoggerFactory.getLogger(NotificationsDAO.class);
3747

38-
public static void send(
48+
private static Map<String, Set<String>> findForkDescendantsTxn(
49+
Connection conn, String schema, List<String> workflowIds) throws SQLException {
50+
Map<String, List<String>> children = new HashMap<>();
51+
Set<String> seen = new LinkedHashSet<>(workflowIds);
52+
List<String> frontier = new ArrayList<>(new LinkedHashSet<>(workflowIds));
53+
54+
while (!frontier.isEmpty()) {
55+
String placeholders = frontier.stream().map(x -> "?").collect(joining(","));
56+
String sql =
57+
"""
58+
SELECT workflow_uuid, forked_from FROM "%s".workflow_status
59+
WHERE forked_from IN (%s)
60+
""".formatted(schema, placeholders);
61+
try (var stmt = conn.prepareStatement(sql)) {
62+
for (int i = 0; i < frontier.size(); i++) stmt.setString(i + 1, frontier.get(i));
63+
List<String> next = new ArrayList<>();
64+
try (var rs = stmt.executeQuery()) {
65+
while (rs.next()) {
66+
String forkedId = rs.getString("workflow_uuid");
67+
String forkedFrom = rs.getString("forked_from");
68+
children.computeIfAbsent(forkedFrom, k -> new ArrayList<>()).add(forkedId);
69+
if (seen.add(forkedId)) next.add(forkedId);
70+
}
71+
}
72+
frontier = next;
73+
}
74+
}
75+
76+
Map<String, Set<String>> result = new LinkedHashMap<>();
77+
for (String root : workflowIds) {
78+
if (result.containsKey(root)) continue;
79+
Set<String> descendants = new LinkedHashSet<>();
80+
Deque<String> stack = new ArrayDeque<>(children.getOrDefault(root, List.of()));
81+
while (!stack.isEmpty()) {
82+
String node = stack.pop();
83+
if (!node.equals(root) && descendants.add(node)) {
84+
stack.addAll(children.getOrDefault(node, List.of()));
85+
}
86+
}
87+
result.put(root, descendants);
88+
}
89+
return result;
90+
}
91+
92+
public static void sendBulk(
3993
DbContext ctx,
94+
List<SendMessage> messages,
4095
String workflowId,
4196
int stepId,
42-
String destinationId,
43-
Object message,
44-
String topic,
45-
String messageId,
97+
String functionName,
98+
boolean sendToForks,
4699
String serialization)
47100
throws SQLException {
48101

102+
if (messages.isEmpty()) {
103+
return;
104+
}
105+
106+
// Reject duplicate idempotency keys within the batch
107+
var keys =
108+
messages.stream()
109+
.map(SendMessage::idempotencyKey)
110+
.filter(Objects::nonNull)
111+
.toList();
112+
if (keys.size() != keys.stream().distinct().count()) {
113+
throw new IllegalArgumentException("Duplicate idempotency keys within sendBulk batch");
114+
}
115+
49116
DBOSSerializer serializer = ctx.serializer();
50117
var startTime = System.currentTimeMillis();
51-
String functionName = "DBOS.send";
52-
String finalTopic = (topic != null) ? topic : Constants.DBOS_NULL_TOPIC;
118+
119+
// Serialize each message once
120+
record SerializedPair(SendMessage msg, SerializationUtil.SerializedResult serialized) {}
121+
List<SerializedPair> pairs = new ArrayList<>(messages.size());
122+
for (var msg : messages) {
123+
pairs.add(new SerializedPair(msg, SerializationUtil.serializeValue(msg.message(), serialization, serializer)));
124+
}
53125

54126
try (Connection conn = ctx.getConnection()) {
55127
conn.setAutoCommit(false);
56-
57128
try {
58-
StepResult recordedOutput =
59-
StepsDAO.checkStepResult(conn, ctx.schema(), workflowId, stepId, functionName);
60-
61-
if (recordedOutput != null) {
62-
logger.debug(
63-
"Replaying send, id: {}, destination_uuid: {}, topic: {}",
64-
stepId,
65-
destinationId,
66-
finalTopic);
67-
conn.commit();
68-
return;
69-
} else {
70-
logger.debug(
71-
"Running send, id: {}, destination_uuid: {}, topic: {}",
72-
stepId,
73-
destinationId,
74-
finalTopic);
129+
// Check for replay if inside a workflow
130+
if (workflowId != null) {
131+
StepResult recorded =
132+
StepsDAO.checkStepResult(conn, ctx.schema(), workflowId, stepId, functionName);
133+
if (recorded != null) {
134+
logger.debug("Replaying sendBulk, workflowId: {}, stepId: {}", workflowId, stepId);
135+
conn.commit();
136+
return;
137+
}
138+
}
139+
140+
// Collect all destination IDs for fork resolution
141+
Map<String, Set<String>> forkDescendants = Map.of();
142+
if (sendToForks) {
143+
List<String> destIds = pairs.stream().map(p -> p.msg().destinationId()).distinct().toList();
144+
forkDescendants = findForkDescendantsTxn(conn, ctx.schema(), destIds);
75145
}
76146

77-
var finalMessageId = (messageId != null) ? messageId : UUID.randomUUID().toString();
78-
var serializedMsg = SerializationUtil.serializeValue(message, serialization, serializer);
147+
// Build insert rows: base dest + sorted descendants
148+
record InsertRow(String destId, SerializationUtil.SerializedResult serialized, String topic, String messageUuid) {}
149+
List<InsertRow> rows = new ArrayList<>();
150+
for (var pair : pairs) {
151+
var msg = pair.msg();
152+
String baseDest = msg.destinationId();
153+
String finalTopic = (msg.topic() != null) ? msg.topic() : Constants.DBOS_NULL_TOPIC;
154+
155+
List<String> destinations = new ArrayList<>();
156+
destinations.add(baseDest);
157+
if (sendToForks) {
158+
var desc = forkDescendants.getOrDefault(baseDest, Set.of());
159+
desc.stream().sorted().forEach(destinations::add);
160+
}
161+
162+
for (String dest : destinations) {
163+
String uuid;
164+
if (msg.idempotencyKey() != null) {
165+
uuid = sendToForks ? msg.idempotencyKey() + "::" + dest : msg.idempotencyKey();
166+
} else {
167+
uuid = UUID.randomUUID().toString();
168+
}
169+
rows.add(new InsertRow(dest, pair.serialized(), finalTopic, uuid));
170+
}
171+
}
79172

173+
// Batch-insert all rows
80174
final String sql =
81175
"""
82176
INSERT INTO "%s".notifications
83177
(destination_uuid, topic, message, serialization, message_uuid)
84178
VALUES (?, ?, ?, ?, ?)
85179
ON CONFLICT (message_uuid) DO NOTHING
86-
"""
87-
.formatted(ctx.schema());
180+
""".formatted(ctx.schema());
88181

89182
try (PreparedStatement stmt = conn.prepareStatement(sql)) {
90-
stmt.setString(1, destinationId);
91-
stmt.setString(2, finalTopic);
92-
stmt.setString(3, serializedMsg.serializedValue());
93-
stmt.setString(4, serializedMsg.serialization());
94-
stmt.setString(5, finalMessageId);
95-
stmt.executeUpdate();
183+
for (var row : rows) {
184+
stmt.setString(1, row.destId());
185+
stmt.setString(2, row.topic());
186+
stmt.setString(3, row.serialized().serializedValue());
187+
stmt.setString(4, row.serialized().serialization());
188+
stmt.setString(5, row.messageUuid());
189+
stmt.addBatch();
190+
}
191+
stmt.executeBatch();
96192
} catch (SQLException e) {
97193
if ("23503".equals(e.getSQLState())) {
98-
throw new DBOSNonExistentWorkflowException(destinationId);
194+
// Find which destination was missing
195+
String missingDest = rows.stream().map(InsertRow::destId).findFirst().orElse("unknown");
196+
throw new DBOSNonExistentWorkflowException(missingDest);
99197
}
100198
throw e;
101199
}
102200

103-
var output = new StepResult(workflowId, stepId, functionName, null, null, null, null);
104-
StepsDAO.recordStepResult(
105-
conn, ctx.schema(), output, startTime, System.currentTimeMillis());
201+
if (workflowId != null) {
202+
var output = new StepResult(workflowId, stepId, functionName, null, null, null, null);
203+
StepsDAO.recordStepResult(conn, ctx.schema(), output, startTime, System.currentTimeMillis());
204+
}
106205

107206
conn.commit();
108-
109207
} catch (Exception e) {
110208
try {
111209
conn.rollback();
@@ -117,44 +215,6 @@ ON CONFLICT (message_uuid) DO NOTHING
117215
}
118216
}
119217

120-
public static void sendDirect(
121-
DbContext ctx,
122-
String destinationId,
123-
Object message,
124-
String topic,
125-
String messageId,
126-
String serialization)
127-
throws SQLException {
128-
DBOSSerializer serializer = ctx.serializer();
129-
String finalTopic = (topic != null) ? topic : Constants.DBOS_NULL_TOPIC;
130-
String finalMessageId = (messageId != null) ? messageId : UUID.randomUUID().toString();
131-
var serializedMsg = SerializationUtil.serializeValue(message, serialization, serializer);
132-
133-
final String sql =
134-
"""
135-
INSERT INTO "%s".notifications
136-
(destination_uuid, topic, message, message_uuid, serialization)
137-
VALUES (?, ?, ?, ?, ?)
138-
ON CONFLICT (message_uuid) DO NOTHING
139-
"""
140-
.formatted(ctx.schema());
141-
142-
try (var conn = ctx.getConnection();
143-
var stmt = conn.prepareStatement(sql)) {
144-
stmt.setString(1, destinationId);
145-
stmt.setString(2, finalTopic);
146-
stmt.setString(3, serializedMsg.serializedValue());
147-
stmt.setString(4, finalMessageId);
148-
stmt.setString(5, serializedMsg.serialization());
149-
stmt.executeUpdate();
150-
} catch (SQLException e) {
151-
if ("23503".equals(e.getSQLState())) {
152-
throw new DBOSNonExistentWorkflowException(destinationId);
153-
}
154-
throw e;
155-
}
156-
}
157-
158218
public static Object recv(
159219
DbContext ctx,
160220
String workflowId,
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package dev.dbos.transact.workflow;
2+
3+
import org.jspecify.annotations.NonNull;
4+
import org.jspecify.annotations.Nullable;
5+
6+
public record SendMessage(
7+
@NonNull String destinationId,
8+
@NonNull Object message,
9+
@Nullable String topic,
10+
@Nullable String idempotencyKey) {
11+
12+
public SendMessage(@NonNull String destinationId, @NonNull Object message) {
13+
this(destinationId, message, null, null);
14+
}
15+
16+
public SendMessage(@NonNull String destinationId, @NonNull Object message, @Nullable String topic) {
17+
this(destinationId, message, topic, null);
18+
}
19+
}

0 commit comments

Comments
 (0)