1010import dev .dbos .transact .json .DBOSSerializer ;
1111import dev .dbos .transact .json .SerializationUtil ;
1212import dev .dbos .transact .workflow .NotificationInfo ;
13+ import dev .dbos .transact .workflow .SendMessage ;
1314import dev .dbos .transact .workflow .internal .StepResult ;
1415
1516import java .sql .Connection ;
1819import java .sql .SQLException ;
1920import java .time .Duration ;
2021import java .time .Instant ;
22+ import java .util .ArrayDeque ;
2123import java .util .ArrayList ;
24+ import java .util .Deque ;
25+ import java .util .HashMap ;
26+ import java .util .LinkedHashMap ;
27+ import java .util .LinkedHashSet ;
2228import java .util .List ;
29+ import java .util .Map ;
2330import java .util .Objects ;
2431import java .util .Optional ;
32+ import java .util .Set ;
2533import java .util .UUID ;
2634
35+ import static java .util .stream .Collectors .joining ;
36+
2737import org .jspecify .annotations .NonNull ;
2838import org .jspecify .annotations .Nullable ;
2939import org .slf4j .Logger ;
@@ -35,77 +45,165 @@ private NotificationsDAO() {}
3545
3646 private static final Logger logger = LoggerFactory .getLogger (NotificationsDAO .class );
3747
38- public static void send (
48+ private static Map <String , Set <String >> findForkDescendantsTxn (
49+ Connection conn , String schema , List <String > workflowIds ) throws SQLException {
50+ Map <String , List <String >> children = new HashMap <>();
51+ Set <String > seen = new LinkedHashSet <>(workflowIds );
52+ List <String > frontier = new ArrayList <>(new LinkedHashSet <>(workflowIds ));
53+
54+ while (!frontier .isEmpty ()) {
55+ String placeholders = frontier .stream ().map (x -> "?" ).collect (joining ("," ));
56+ String sql =
57+ """
58+ SELECT workflow_uuid, forked_from FROM "%s".workflow_status
59+ WHERE forked_from IN (%s)
60+ """ .formatted (schema , placeholders );
61+ try (var stmt = conn .prepareStatement (sql )) {
62+ for (int i = 0 ; i < frontier .size (); i ++) stmt .setString (i + 1 , frontier .get (i ));
63+ List <String > next = new ArrayList <>();
64+ try (var rs = stmt .executeQuery ()) {
65+ while (rs .next ()) {
66+ String forkedId = rs .getString ("workflow_uuid" );
67+ String forkedFrom = rs .getString ("forked_from" );
68+ children .computeIfAbsent (forkedFrom , k -> new ArrayList <>()).add (forkedId );
69+ if (seen .add (forkedId )) next .add (forkedId );
70+ }
71+ }
72+ frontier = next ;
73+ }
74+ }
75+
76+ Map <String , Set <String >> result = new LinkedHashMap <>();
77+ for (String root : workflowIds ) {
78+ if (result .containsKey (root )) continue ;
79+ Set <String > descendants = new LinkedHashSet <>();
80+ Deque <String > stack = new ArrayDeque <>(children .getOrDefault (root , List .of ()));
81+ while (!stack .isEmpty ()) {
82+ String node = stack .pop ();
83+ if (!node .equals (root ) && descendants .add (node )) {
84+ stack .addAll (children .getOrDefault (node , List .of ()));
85+ }
86+ }
87+ result .put (root , descendants );
88+ }
89+ return result ;
90+ }
91+
92+ public static void sendBulk (
3993 DbContext ctx ,
94+ List <SendMessage > messages ,
4095 String workflowId ,
4196 int stepId ,
42- String destinationId ,
43- Object message ,
44- String topic ,
45- String messageId ,
97+ String functionName ,
98+ boolean sendToForks ,
4699 String serialization )
47100 throws SQLException {
48101
102+ if (messages .isEmpty ()) {
103+ return ;
104+ }
105+
106+ // Reject duplicate idempotency keys within the batch
107+ var keys =
108+ messages .stream ()
109+ .map (SendMessage ::idempotencyKey )
110+ .filter (Objects ::nonNull )
111+ .toList ();
112+ if (keys .size () != keys .stream ().distinct ().count ()) {
113+ throw new IllegalArgumentException ("Duplicate idempotency keys within sendBulk batch" );
114+ }
115+
49116 DBOSSerializer serializer = ctx .serializer ();
50117 var startTime = System .currentTimeMillis ();
51- String functionName = "DBOS.send" ;
52- String finalTopic = (topic != null ) ? topic : Constants .DBOS_NULL_TOPIC ;
118+
119+ // Serialize each message once
120+ record SerializedPair (SendMessage msg , SerializationUtil .SerializedResult serialized ) {}
121+ List <SerializedPair > pairs = new ArrayList <>(messages .size ());
122+ for (var msg : messages ) {
123+ pairs .add (new SerializedPair (msg , SerializationUtil .serializeValue (msg .message (), serialization , serializer )));
124+ }
53125
54126 try (Connection conn = ctx .getConnection ()) {
55127 conn .setAutoCommit (false );
56-
57128 try {
58- StepResult recordedOutput =
59- StepsDAO .checkStepResult (conn , ctx .schema (), workflowId , stepId , functionName );
60-
61- if (recordedOutput != null ) {
62- logger .debug (
63- "Replaying send, id: {}, destination_uuid: {}, topic: {}" ,
64- stepId ,
65- destinationId ,
66- finalTopic );
67- conn .commit ();
68- return ;
69- } else {
70- logger .debug (
71- "Running send, id: {}, destination_uuid: {}, topic: {}" ,
72- stepId ,
73- destinationId ,
74- finalTopic );
129+ // Check for replay if inside a workflow
130+ if (workflowId != null ) {
131+ StepResult recorded =
132+ StepsDAO .checkStepResult (conn , ctx .schema (), workflowId , stepId , functionName );
133+ if (recorded != null ) {
134+ logger .debug ("Replaying sendBulk, workflowId: {}, stepId: {}" , workflowId , stepId );
135+ conn .commit ();
136+ return ;
137+ }
138+ }
139+
140+ // Collect all destination IDs for fork resolution
141+ Map <String , Set <String >> forkDescendants = Map .of ();
142+ if (sendToForks ) {
143+ List <String > destIds = pairs .stream ().map (p -> p .msg ().destinationId ()).distinct ().toList ();
144+ forkDescendants = findForkDescendantsTxn (conn , ctx .schema (), destIds );
75145 }
76146
77- var finalMessageId = (messageId != null ) ? messageId : UUID .randomUUID ().toString ();
78- var serializedMsg = SerializationUtil .serializeValue (message , serialization , serializer );
147+ // Build insert rows: base dest + sorted descendants
148+ record InsertRow (String destId , SerializationUtil .SerializedResult serialized , String topic , String messageUuid ) {}
149+ List <InsertRow > rows = new ArrayList <>();
150+ for (var pair : pairs ) {
151+ var msg = pair .msg ();
152+ String baseDest = msg .destinationId ();
153+ String finalTopic = (msg .topic () != null ) ? msg .topic () : Constants .DBOS_NULL_TOPIC ;
154+
155+ List <String > destinations = new ArrayList <>();
156+ destinations .add (baseDest );
157+ if (sendToForks ) {
158+ var desc = forkDescendants .getOrDefault (baseDest , Set .of ());
159+ desc .stream ().sorted ().forEach (destinations ::add );
160+ }
161+
162+ for (String dest : destinations ) {
163+ String uuid ;
164+ if (msg .idempotencyKey () != null ) {
165+ uuid = sendToForks ? msg .idempotencyKey () + "::" + dest : msg .idempotencyKey ();
166+ } else {
167+ uuid = UUID .randomUUID ().toString ();
168+ }
169+ rows .add (new InsertRow (dest , pair .serialized (), finalTopic , uuid ));
170+ }
171+ }
79172
173+ // Batch-insert all rows
80174 final String sql =
81175 """
82176 INSERT INTO "%s".notifications
83177 (destination_uuid, topic, message, serialization, message_uuid)
84178 VALUES (?, ?, ?, ?, ?)
85179 ON CONFLICT (message_uuid) DO NOTHING
86- """
87- .formatted (ctx .schema ());
180+ """ .formatted (ctx .schema ());
88181
89182 try (PreparedStatement stmt = conn .prepareStatement (sql )) {
90- stmt .setString (1 , destinationId );
91- stmt .setString (2 , finalTopic );
92- stmt .setString (3 , serializedMsg .serializedValue ());
93- stmt .setString (4 , serializedMsg .serialization ());
94- stmt .setString (5 , finalMessageId );
95- stmt .executeUpdate ();
183+ for (var row : rows ) {
184+ stmt .setString (1 , row .destId ());
185+ stmt .setString (2 , row .topic ());
186+ stmt .setString (3 , row .serialized ().serializedValue ());
187+ stmt .setString (4 , row .serialized ().serialization ());
188+ stmt .setString (5 , row .messageUuid ());
189+ stmt .addBatch ();
190+ }
191+ stmt .executeBatch ();
96192 } catch (SQLException e ) {
97193 if ("23503" .equals (e .getSQLState ())) {
98- throw new DBOSNonExistentWorkflowException (destinationId );
194+ // Find which destination was missing
195+ String missingDest = rows .stream ().map (InsertRow ::destId ).findFirst ().orElse ("unknown" );
196+ throw new DBOSNonExistentWorkflowException (missingDest );
99197 }
100198 throw e ;
101199 }
102200
103- var output = new StepResult (workflowId , stepId , functionName , null , null , null , null );
104- StepsDAO .recordStepResult (
105- conn , ctx .schema (), output , startTime , System .currentTimeMillis ());
201+ if (workflowId != null ) {
202+ var output = new StepResult (workflowId , stepId , functionName , null , null , null , null );
203+ StepsDAO .recordStepResult (conn , ctx .schema (), output , startTime , System .currentTimeMillis ());
204+ }
106205
107206 conn .commit ();
108-
109207 } catch (Exception e ) {
110208 try {
111209 conn .rollback ();
@@ -117,44 +215,6 @@ ON CONFLICT (message_uuid) DO NOTHING
117215 }
118216 }
119217
120- public static void sendDirect (
121- DbContext ctx ,
122- String destinationId ,
123- Object message ,
124- String topic ,
125- String messageId ,
126- String serialization )
127- throws SQLException {
128- DBOSSerializer serializer = ctx .serializer ();
129- String finalTopic = (topic != null ) ? topic : Constants .DBOS_NULL_TOPIC ;
130- String finalMessageId = (messageId != null ) ? messageId : UUID .randomUUID ().toString ();
131- var serializedMsg = SerializationUtil .serializeValue (message , serialization , serializer );
132-
133- final String sql =
134- """
135- INSERT INTO "%s".notifications
136- (destination_uuid, topic, message, message_uuid, serialization)
137- VALUES (?, ?, ?, ?, ?)
138- ON CONFLICT (message_uuid) DO NOTHING
139- """
140- .formatted (ctx .schema ());
141-
142- try (var conn = ctx .getConnection ();
143- var stmt = conn .prepareStatement (sql )) {
144- stmt .setString (1 , destinationId );
145- stmt .setString (2 , finalTopic );
146- stmt .setString (3 , serializedMsg .serializedValue ());
147- stmt .setString (4 , finalMessageId );
148- stmt .setString (5 , serializedMsg .serialization ());
149- stmt .executeUpdate ();
150- } catch (SQLException e ) {
151- if ("23503" .equals (e .getSQLState ())) {
152- throw new DBOSNonExistentWorkflowException (destinationId );
153- }
154- throw e ;
155- }
156- }
157-
158218 public static Object recv (
159219 DbContext ctx ,
160220 String workflowId ,
0 commit comments