Skip to content

Commit 76b582f

Browse files
authored
Merge pull request #18 from dbos-inc/manoj/wfmgmt
Workflow management - cancel/resume
2 parents 55b2a70 + 0aaf499 commit 76b582f

20 files changed

Lines changed: 613 additions & 83 deletions

src/main/java/dev/dbos/transact/Constants.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,6 @@ public class Constants {
1212
public static final String DEFAULT_EXECUTORID = "local";
1313

1414
public static final String DBOS_NULL_TOPIC = "__null__topic__" ;
15+
16+
public static final String DBOS_INTERNAL_QUEUE = "_dbos_internal_queue" ;
1517
}

src/main/java/dev/dbos/transact/DBOS.java

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ public void shutdown() {
268268
}
269269
}
270270

271-
public static WorkflowHandle retrieveWorkflow(String workflowId) {
271+
public static WorkflowHandle<?> retrieveWorkflow(String workflowId) {
272272
return DBOS.getInstance().dbosExecutor.retrieveWorkflow(workflowId);
273273
}
274274

@@ -342,5 +342,28 @@ public void sleep(float seconds) {
342342

343343
this.dbosExecutor.sleep(seconds) ;
344344
}
345+
346+
/**
347+
*
348+
* Resume a workflow starting from the step after the last complete step
349+
*
350+
* @param workflowId id of the workflow
351+
* @return A handle to the workflow
352+
*/
353+
public WorkflowHandle<?> resumeWorkflow(String workflowId) {
354+
return this.dbosExecutor.resumeWorkflow(workflowId) ;
355+
}
356+
357+
/***
358+
*
359+
* Cancel the workflow. After this function is called, the next step (not the current one)
360+
* will not execute
361+
*
362+
* @param workflowId
363+
*/
364+
365+
public void cancelWorkflow(String workflowId) {
366+
this.dbosExecutor.cancelWorkflow(workflowId);
367+
}
345368
}
346369

src/main/java/dev/dbos/transact/database/StepsDAO.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ public void recordStepResultTxn(StepResult result, Connection connection) throws
6969
}
7070
}
7171

72-
7372
}
7473

7574

src/main/java/dev/dbos/transact/database/SystemDatabase.java

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
import com.zaxxer.hikari.HikariDataSource;
55
import dev.dbos.transact.Constants;
66
import dev.dbos.transact.config.DBOSConfig;
7+
import dev.dbos.transact.context.DBOSContext;
8+
import dev.dbos.transact.context.DBOSContextHolder;
79
import dev.dbos.transact.exceptions.*;
10+
import dev.dbos.transact.json.JSONUtil;
811
import dev.dbos.transact.notifications.GetWorkflowEventContext;
912
import dev.dbos.transact.notifications.NotificationService;
1013
import dev.dbos.transact.queue.Queue;
@@ -21,6 +24,7 @@
2124
import javax.sql.DataSource;
2225
import java.sql.*;
2326
import java.util.*;
27+
import java.util.function.Supplier;
2428

2529
import static dev.dbos.transact.exceptions.ErrorCode.UNEXPECTED;
2630

@@ -302,6 +306,82 @@ public void cancelWorkflow(String workflowId) {
302306

303307
}
304308

309+
public void resumeWorkflow(String workflowId) {
310+
try {
311+
workflowDAO.resumeWorkflow(workflowId);
312+
} catch (SQLException s) {
313+
throw new DBOSException(ErrorCode.RESUME_WORKFLOW_ERROR.getCode(), s.getMessage()) ;
314+
}
315+
316+
}
317+
318+
319+
public <T> T callFunctionAsStep(Supplier<T> fn, String functionName) {
320+
DBOSContext ctx = DBOSContextHolder.get();
321+
322+
int nextFuncId = 0 ;
323+
324+
if (ctx != null && ctx.isInWorkflow()) {
325+
nextFuncId = ctx.getAndIncrementFunctionId() ;
326+
327+
StepResult result = null ;
328+
329+
try (Connection connection = dataSource.getConnection()) {
330+
result = stepsDAO.checkStepExecutionTxn(
331+
ctx.getWorkflowId(), nextFuncId, functionName, connection
332+
);
333+
} catch(SQLException e) {
334+
throw new DBOSException(UNEXPECTED.getCode(), "Function execution failed: " + functionName, e);
335+
}
336+
337+
if (result != null) {
338+
return handleExistingResult(result, functionName);
339+
}
340+
}
341+
342+
T functionResult;
343+
try {
344+
345+
try {
346+
functionResult = fn.get();
347+
} catch (Exception e) {
348+
if (ctx != null && ctx.isInWorkflow()) {
349+
String jsonError = JSONUtil.serializeError(e);
350+
StepResult r = new StepResult(ctx.getWorkflowId(), nextFuncId, functionName, null, jsonError);
351+
stepsDAO.recordStepResultTxn(r);
352+
}
353+
throw new DBOSException(UNEXPECTED.getCode(), "Function execution failed: " + functionName, e);
354+
}
355+
356+
// If we're in a workflow, record the successful result
357+
if (ctx != null && ctx.isInWorkflow()) {
358+
String jsonOutput = JSONUtil.serialize(functionResult);
359+
StepResult o = new StepResult(ctx.getWorkflowId(), nextFuncId, functionName, jsonOutput, null);
360+
stepsDAO.recordStepResultTxn(o);
361+
}
362+
} catch(SQLException sq) {
363+
throw new DBOSException(UNEXPECTED.getCode(), "Function execution failed: " + functionName, sq);
364+
}
365+
366+
return functionResult;
367+
}
368+
369+
@SuppressWarnings("unchecked")
370+
private <T> T handleExistingResult(StepResult result, String functionName) {
371+
if (result.getOutput() != null) {
372+
Object[] resArray = JSONUtil.deserializeToArray(result.getOutput());
373+
return resArray == null ? null : (T) resArray[0];
374+
} else if (result.getError() != null) {
375+
Object[] eArray = JSONUtil.deserializeToArray(result.getError());
376+
SerializableException se = (SerializableException) eArray[0];
377+
throw new DBOSAppException(String.format("Exception of type %s", se.className), se) ;
378+
} else {
379+
throw new IllegalStateException(
380+
String.format("Recorded output and error are both null for %s", functionName)
381+
);
382+
}
383+
}
384+
305385
private void createDataSource(String dbName) {
306386
HikariConfig hikariConfig = new HikariConfig();
307387

src/main/java/dev/dbos/transact/database/WorkflowDAO.java

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,6 @@ public void recordWorkflowError(String workflowId, String error) {
358358
options.setError(error);
359359
options.setResetDeduplicationID(true);
360360

361-
362361
updateWorkflowStatus(connection, workflowId, WorkflowState.ERROR.toString(), options);
363362

364363
}
@@ -700,4 +699,74 @@ public void cancelWorkflow(String workflowId) throws SQLException {
700699
}
701700
}
702701

702+
public void resumeWorkflow(String workflowId) throws SQLException {
703+
704+
try (Connection connection = dataSource.getConnection()) {
705+
connection.setAutoCommit(false);
706+
connection.setTransactionIsolation(Connection.TRANSACTION_REPEATABLE_READ);
707+
708+
try {
709+
String currentStatus = getWorkflowStatus(connection, workflowId);
710+
711+
712+
if (currentStatus == null) {
713+
connection.rollback();
714+
return;
715+
}
716+
717+
// If workflow is already complete, do nothing
718+
if (WorkflowState.SUCCESS.name().equals(currentStatus) ||
719+
WorkflowState.ERROR.name().equals(currentStatus)) {
720+
connection.rollback();
721+
return;
722+
}
723+
724+
// Set the workflow's status to ENQUEUED and clear recovery fields
725+
updateWorkflowToEnqueued(connection, workflowId);
726+
727+
connection.commit();
728+
729+
} catch (SQLException e) {
730+
connection.rollback();
731+
throw e;
732+
}
733+
}
734+
}
735+
736+
private String getWorkflowStatus(Connection connection, String workflowId) throws SQLException {
737+
String sql = "SELECT status FROM dbos.workflow_status WHERE workflow_uuid = ?";
738+
739+
try (PreparedStatement stmt = connection.prepareStatement(sql)) {
740+
stmt.setString(1, workflowId);
741+
742+
try (ResultSet rs = stmt.executeQuery()) {
743+
if (rs.next()) {
744+
return rs.getString("status");
745+
}
746+
return null;
747+
}
748+
}
749+
}
750+
751+
private void updateWorkflowToEnqueued(Connection connection, String workflowId) throws SQLException {
752+
String sql = "UPDATE dbos.workflow_status " +
753+
" SET status = ?, " +
754+
" queue_name = ?, " +
755+
" recovery_attempts = ?, " +
756+
" workflow_deadline_epoch_ms = 0, " +
757+
" deduplication_id = NULL, " +
758+
" started_at_epoch_ms = NULL " +
759+
" WHERE workflow_uuid = ? " ;
760+
761+
762+
try (PreparedStatement stmt = connection.prepareStatement(sql)) {
763+
stmt.setString(1, WorkflowState.ENQUEUED.name());
764+
stmt.setString(2, Constants.DBOS_INTERNAL_QUEUE);
765+
stmt.setInt(3, 0); // recovery_attempts = 0
766+
stmt.setString(4, workflowId);
767+
768+
stmt.executeUpdate();
769+
}
770+
}
771+
703772
}

src/main/java/dev/dbos/transact/exceptions/ErrorCode.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ public enum ErrorCode {
1111
WORKFLOW_CANCELLED(7),
1212
UNEXPECTED_STEP(8),
1313
WORKFLOW_FUNCTION_NOT_FOUND(9),
14-
SLEEP_NOT_IN_WORKFLOW(10);
14+
SLEEP_NOT_IN_WORKFLOW(10),
15+
RESUME_WORKFLOW_ERROR(11);
1516

1617
private int code ;
1718

src/main/java/dev/dbos/transact/execution/DBOSExecutor.java

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@
2424

2525
import java.lang.reflect.InvocationTargetException;
2626
import java.lang.reflect.Method;
27+
import java.sql.SQLException;
2728
import java.util.UUID;
2829
import java.util.concurrent.*;
30+
import java.util.function.Supplier;
2931

3032
import static dev.dbos.transact.exceptions.ErrorCode.UNEXPECTED;
3133

@@ -363,8 +365,9 @@ public <T> T runStep(String stepName,
363365

364366
String output = recordedResult.getOutput() ;
365367
if (output != null) {
368+
logger.info("Result has an output") ;
366369
Object[] stepO = JSONUtil.deserializeToArray(output);
367-
return(T) stepO[0];
370+
return stepO == null ? null : (T) stepO[0];
368371
}
369372

370373
String error = recordedResult.getError();
@@ -414,7 +417,7 @@ public <T> T runStep(String stepName,
414417
* Retrieve the workflowHandle for the workflowId
415418
*
416419
*/
417-
public WorkflowHandle retrieveWorkflow(String workflowId) {
420+
public WorkflowHandle<?> retrieveWorkflow(String workflowId) {
418421
return new WorkflowHandleDBPoll(workflowId, systemDatabase) ;
419422
}
420423

@@ -467,4 +470,28 @@ public void sleep(float seconds) {
467470

468471
}
469472

473+
public WorkflowHandle<?> resumeWorkflow(String workflowId) {
474+
475+
Supplier<Void> resumeFunction = () -> {
476+
logger.info("Resuming workflow: ", workflowId);
477+
systemDatabase.resumeWorkflow(workflowId);
478+
return null ; // void
479+
};
480+
// Execute the resume operation as a workflow step
481+
systemDatabase.callFunctionAsStep(resumeFunction, "DBOS.resumeWorkflow");
482+
return retrieveWorkflow(workflowId);
483+
}
484+
485+
public void cancelWorkflow(String workflowId) {
486+
487+
Supplier<Void> cancelFunction = () -> {
488+
logger.info("Cancelling workflow: ", workflowId);
489+
systemDatabase.cancelWorkflow(workflowId);
490+
return null ; // void
491+
};
492+
// Execute the cancel operation as a workflow step
493+
systemDatabase.callFunctionAsStep(cancelFunction, "DBOS.resumeWorkflow");
494+
495+
}
496+
470497
}

src/main/java/dev/dbos/transact/json/JSONUtil.java

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,14 @@
1010
import com.fasterxml.jackson.databind.jsontype.BasicPolymorphicTypeValidator;
1111
import com.fasterxml.jackson.databind.jsontype.PolymorphicTypeValidator;
1212
import com.fasterxml.jackson.databind.jsontype.impl.LaissezFaireSubTypeValidator;
13+
import dev.dbos.transact.exceptions.SerializableException;
1314

1415
import java.lang.reflect.Type;
1516

1617
public class JSONUtil {
1718

1819
private static final ObjectMapper mapper = new ObjectMapper();
1920

20-
/* static {
21-
// mapper.activateDefaultTyping(mapper.getPolymorphicTypeValidator(), ObjectMapper.DefaultTyping.NON_FINAL, JsonTypeInfo.As.PROPERTY);
22-
PolymorphicTypeValidator ptv = BasicPolymorphicTypeValidator.builder()
23-
.allowIfSubType(Object.class)
24-
.build();
25-
26-
// This ensures type info is added even for final classes like String, Integer, etc.
27-
mapper.activateDefaultTyping(ptv, ObjectMapper.DefaultTyping.EVERYTHING, JsonTypeInfo.As.PROPERTY);
28-
29-
} */
30-
3121
static {
3222
mapper.registerModule(new JavaTimeModule());
3323
mapper.disable(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS); // Optional
@@ -54,22 +44,19 @@ public static Object[] deserializeToArray(String json) {
5444
}
5545
}
5646

47+
public static String serializeError(Throwable error) {
5748

58-
public static <T> T deserialize(String json, Class<T> clazz) {
59-
try {
60-
return mapper.readValue(json, clazz);
61-
} catch (JsonProcessingException e) {
62-
throw new RuntimeException("Deserialization failed", e);
63-
}
49+
SerializableException se = new SerializableException(error);
50+
return JSONUtil.serialize(se) ;
6451
}
6552

66-
public static <T> T deserialize(String json, TypeReference<T> typeRef) {
67-
try {
68-
return mapper.readValue(json, typeRef);
69-
} catch (JsonProcessingException e) {
70-
throw new RuntimeException("Deserialization failed", e);
71-
}
53+
public static SerializableException deserializeError(String json) {
54+
Object[] eArray = JSONUtil.deserializeToArray(json);
55+
return (SerializableException) eArray[0];
7256
}
7357

58+
59+
60+
7461
}
7562

src/main/java/dev/dbos/transact/queue/QueueService.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package dev.dbos.transact.queue;
22

33
import dev.dbos.transact.Constants;
4+
import dev.dbos.transact.DBOS;
45
import dev.dbos.transact.database.SystemDatabase;
56
import dev.dbos.transact.execution.DBOSExecutor;
67
import org.slf4j.Logger;
@@ -24,6 +25,8 @@ public class QueueService {
2425
private QueueRegistry queueRegistry ;
2526
private CountDownLatch shutdownLatch;
2627

28+
private Queue internalQueue ;
29+
2730
public QueueService(SystemDatabase systemDatabase) {
2831
this.systemDatabase = systemDatabase ;
2932
queueRegistry = new QueueRegistry();
@@ -41,6 +44,8 @@ public void register(Queue queue) {
4144
private void pollForWorkflows() {
4245
logger.info("PollQueuesThread started ...." + Thread.currentThread().getId()) ;
4346

47+
internalQueue = new DBOS.QueueBuilder(Constants.DBOS_INTERNAL_QUEUE).build() ;
48+
4449
double pollingInterval = 1.0 ;
4550
double minPollingInterval = 1.0 ;
4651
double maxPollingInterval = 120.0 ;

0 commit comments

Comments
 (0)