diff --git a/.github/workflows/on_pr.yml b/.github/workflows/on_pr.yml index 714de1f64..d3ffff5fb 100644 --- a/.github/workflows/on_pr.yml +++ b/.github/workflows/on_pr.yml @@ -14,5 +14,7 @@ on: jobs: test: uses: ./.github/workflows/test.yml + test-crdb: + uses: ./.github/workflows/test_crdb.yml test-demo-apps: uses: ./.github/workflows/test_demo_apps.yml \ No newline at end of file diff --git a/.github/workflows/on_push.yml b/.github/workflows/on_push.yml index c6d155cab..668fa3c9b 100644 --- a/.github/workflows/on_push.yml +++ b/.github/workflows/on_push.yml @@ -9,6 +9,8 @@ on: jobs: test: uses: ./.github/workflows/test.yml + test-crdb: + uses: ./.github/workflows/test_crdb.yml publish: needs: test uses: ./.github/workflows/publish.yml diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 99f1e7647..601464c88 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -39,7 +39,6 @@ jobs: run: ./gradlew clean build env: PGPASSWORD: dbos - JDKVERSION: ${{ matrix.jdk-version }} SCALE_TEST: "true" - name: Test Summary diff --git a/.github/workflows/test_crdb.yml b/.github/workflows/test_crdb.yml new file mode 100644 index 000000000..42aa46f0b --- /dev/null +++ b/.github/workflows/test_crdb.yml @@ -0,0 +1,46 @@ +name: Test (CockroachDB) + +on: + workflow_call: + workflow_dispatch: + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v6 + with: + fetch-depth: 0 # fetch-depth 0 needed for version calculation + + - name: Set up JDK temurin 25 + uses: actions/setup-java@v5 + with: + java-version: '25' + distribution: 'temurin' + + - name: Setup Gradle + uses: gradle/actions/setup-gradle@v5 + + - name: Run tests + run: ./gradlew clean build + env: + PGPASSWORD: dbos + DBOS_TEST_USE_COCKROACH_DB: 'true' + + - name: Test Summary + uses: test-summary/action@v2 + with: + paths: "transact/build/test-results/test/TEST-*.xml" + show: "fail, skip" + if: always() + + - name: Upload test results + uses: actions/upload-artifact@v7 + if: always() + with: + name: test-results-crdb-temurin-25 + path: | + transact/build/reports/tests/ + transact/build/test-results/ diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index f9f4b1678..95d4cf67a 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -51,6 +51,7 @@ jspecify = { module = "org.jspecify:jspecify", version.ref = "jspecify" } junit-bom = { module = "org.junit:junit-bom", version.ref = "junit" } junit-jupiter = { module = "org.junit.jupiter:junit-jupiter" } junit-pioneer = { module = "org.junit-pioneer:junit-pioneer", version.ref = "junit-pioneer" } +junit-platform-engine = { module = "org.junit.platform:junit-platform-engine" } junit-platform-launcher = { module = "org.junit.platform:junit-platform-launcher" } kryo = { module = "com.esotericsoftware:kryo", version.ref = "kryo" } logback-classic = { module = "ch.qos.logback:logback-classic", version.ref = "logback" } @@ -70,6 +71,7 @@ spring-boot4-dependencies = { module = "org.springframework.boot:spring-boot-dep spring-boot4-test = { module = "org.springframework.boot:spring-boot-test", version.ref = "spring-boot4" } sqlite-jdbc = { module = "org.xerial:sqlite-jdbc", version.ref = "sqlite-jdbc" } system-stubs-jupiter = { module = "uk.org.webcompere:system-stubs-jupiter", version.ref = "system-stubs" } +testcontainers-cockroachdb = { module = "org.testcontainers:testcontainers-cockroachdb", version.ref = "testcontainers" } testcontainers-postgresql = { module = "org.testcontainers:testcontainers-postgresql", version.ref = "testcontainers" } [bundles] diff --git a/transact-cli/src/main/java/dev/dbos/transact/cli/MigrateCommand.java b/transact-cli/src/main/java/dev/dbos/transact/cli/MigrateCommand.java index 995b8a998..6a00da13d 100644 --- a/transact-cli/src/main/java/dev/dbos/transact/cli/MigrateCommand.java +++ b/transact-cli/src/main/java/dev/dbos/transact/cli/MigrateCommand.java @@ -38,8 +38,9 @@ public Integer call() throws Exception { out.format(" System Database: %s\n", dbOptions.url()); out.format(" System Database User: %s\n", dbOptions.user()); + // TODO: add option for useListenNotify MigrationManager.runMigrations( - dbOptions.url(), dbOptions.user(), dbOptions.password(), dbOptions.schema()); + dbOptions.url(), dbOptions.user(), dbOptions.password(), dbOptions.schema(), true); grantDBOSSchemaPermissions(out, dbOptions.schema()); return 0; } diff --git a/transact-spring-boot-starter/src/main/java/dev/dbos/transact/spring/DBOSAutoConfiguration.java b/transact-spring-boot-starter/src/main/java/dev/dbos/transact/spring/DBOSAutoConfiguration.java index 068de6383..935b516cb 100644 --- a/transact-spring-boot-starter/src/main/java/dev/dbos/transact/spring/DBOSAutoConfiguration.java +++ b/transact-spring-boot-starter/src/main/java/dev/dbos/transact/spring/DBOSAutoConfiguration.java @@ -142,6 +142,7 @@ private DBOSConfig buildConfig(DBOSProperties props, String springAppName) { config = config.withAdminServer(props.getAdminServer().isEnabled()); config = config.withAdminServerPort(props.getAdminServer().getPort()); config = config.withMigrate(props.getDatasource().isMigrate()); + config = config.withUseListenNotify(props.getDatasource().isUseListenNotify()); config = config.withEnablePatching(props.isEnablePatching()); List listenQueues = props.getListenQueues(); diff --git a/transact-spring-boot-starter/src/main/java/dev/dbos/transact/spring/DBOSProperties.java b/transact-spring-boot-starter/src/main/java/dev/dbos/transact/spring/DBOSProperties.java index 1d2baa2bf..e24e32b28 100644 --- a/transact-spring-boot-starter/src/main/java/dev/dbos/transact/spring/DBOSProperties.java +++ b/transact-spring-boot-starter/src/main/java/dev/dbos/transact/spring/DBOSProperties.java @@ -133,6 +133,12 @@ public static class Datasource { /** Whether to run database migrations on startup. Defaults to {@code true}. */ private boolean migrate = true; + /** + * Whether to use PostgreSQL LISTEN/NOTIFY for event delivery. Defaults to {@code true}. Set to + * {@code false} to use polling instead + */ + private boolean useListenNotify = true; + public String getUrl() { return url; } @@ -172,6 +178,14 @@ public boolean isMigrate() { public void setMigrate(boolean migrate) { this.migrate = migrate; } + + public boolean isUseListenNotify() { + return useListenNotify; + } + + public void setUseListenNotify(boolean useListenNotify) { + this.useListenNotify = useListenNotify; + } } public Application getApplication() { diff --git a/transact/build.gradle.kts b/transact/build.gradle.kts index 27bc382ca..c50162935 100644 --- a/transact/build.gradle.kts +++ b/transact/build.gradle.kts @@ -39,6 +39,7 @@ dependencies { testImplementation(libs.junit.jupiter) testImplementation(libs.junit.pioneer) testImplementation(libs.system.stubs.jupiter) + testImplementation(libs.junit.platform.engine) testRuntimeOnly(libs.junit.platform.launcher) testImplementation(libs.java.websocket) @@ -48,6 +49,7 @@ dependencies { testImplementation(libs.rest.assured) testImplementation(libs.kryo) testImplementation(libs.maven.artifact) + testImplementation(libs.testcontainers.cockroachdb) testImplementation(libs.testcontainers.postgresql) } diff --git a/transact/src/main/java/dev/dbos/transact/DBOSClient.java b/transact/src/main/java/dev/dbos/transact/DBOSClient.java index 76092069e..01fdfed2e 100644 --- a/transact/src/main/java/dev/dbos/transact/DBOSClient.java +++ b/transact/src/main/java/dev/dbos/transact/DBOSClient.java @@ -79,7 +79,7 @@ public T getResult() throws E { * @param password System database credential / password */ public DBOSClient(@NonNull String url, @NonNull String user, @NonNull String password) { - this(url, user, password, null, null); + this(url, user, password, null, null, true); } /** @@ -95,7 +95,7 @@ public DBOSClient( @NonNull String user, @NonNull String password, @Nullable String schema) { - this(url, user, password, schema, null); + this(url, user, password, schema, null, true); } /** @@ -113,8 +113,18 @@ public DBOSClient( @NonNull String password, @Nullable String schema, @Nullable DBOSSerializer serializer) { + this(url, user, password, schema, serializer, true); + } + + public DBOSClient( + @NonNull String url, + @NonNull String user, + @NonNull String password, + @Nullable String schema, + @Nullable DBOSSerializer serializer, + boolean useListenNotify) { this.serializer = serializer; - systemDatabase = new SystemDatabase(url, user, password, schema, serializer); + systemDatabase = new SystemDatabase(url, user, password, schema, serializer, useListenNotify); } /** diff --git a/transact/src/main/java/dev/dbos/transact/config/DBOSConfig.java b/transact/src/main/java/dev/dbos/transact/config/DBOSConfig.java index 79b06be59..97f5200ff 100644 --- a/transact/src/main/java/dev/dbos/transact/config/DBOSConfig.java +++ b/transact/src/main/java/dev/dbos/transact/config/DBOSConfig.java @@ -35,7 +35,8 @@ public record DBOSConfig( boolean enablePatching, @NonNull Set listenQueues, @Nullable DBOSSerializer serializer, - @Nullable Duration schedulerPollingInterval) { + @Nullable Duration schedulerPollingInterval, + boolean useListenNotify) { public DBOSConfig { if (appName == null || appName.isEmpty()) { @@ -80,7 +81,8 @@ public DBOSConfig(DBOSConfig other) { other.enablePatching, (other.listenQueues == null ? null : Set.copyOf(other.listenQueues)), other.serializer, - other.schedulerPollingInterval); + other.schedulerPollingInterval, + other.useListenNotify); } public static @NonNull DBOSConfig defaults(@NonNull String appName) { @@ -88,7 +90,7 @@ public DBOSConfig(DBOSConfig other) { appName, null, null, null, null, false, // adminServer 3001, // adminServerPort true, // migrate - null, null, null, null, null, null, false, null, null, null); + null, null, null, null, null, null, false, null, null, null, true); // useListenNotify } public static @NonNull DBOSConfig defaultsFromEnv(@NonNull String appName) { @@ -121,7 +123,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withDatabaseUrl(@Nullable String v) { @@ -143,7 +146,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withDbUser(@Nullable String v) { @@ -165,7 +169,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withDbPassword(@Nullable String v) { @@ -187,7 +192,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withDataSource(@Nullable DataSource v) { @@ -209,7 +215,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withAdminServer(boolean v) { @@ -231,7 +238,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withAdminServerPort(int v) { @@ -253,7 +261,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withMigrate(boolean v) { @@ -275,7 +284,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withConductorKey(@Nullable String v) { @@ -297,7 +307,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withConductorDomain(@Nullable String v) { @@ -319,7 +330,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withConductorExecutorMetadata(@Nullable Map v) { @@ -341,7 +353,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withAppVersion(@Nullable String v) { @@ -363,7 +376,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withExecutorId(@Nullable String v) { @@ -385,7 +399,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withDatabaseSchema(@Nullable String v) { @@ -407,7 +422,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withEnablePatching() { @@ -437,7 +453,8 @@ public DBOSConfig(DBOSConfig other) { v, listenQueues, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig enableAdminServer() { @@ -489,7 +506,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, v, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withSerializer(@Nullable DBOSSerializer v) { @@ -511,7 +529,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, v, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withSchedulerPollingInterval(@Nullable Duration v) { @@ -533,6 +552,30 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, serializer, + v, + useListenNotify); + } + + public @NonNull DBOSConfig withUseListenNotify(boolean v) { + return new DBOSConfig( + appName, + databaseUrl, + dbUser, + dbPassword, + dataSource, + adminServer, + adminServerPort, + migrate, + conductorKey, + conductorDomain, + conductorExecutorMetadata, + appVersion, + executorId, + databaseSchema, + enablePatching, + listenQueues, + serializer, + schedulerPollingInterval, v); } @@ -544,7 +587,7 @@ public String toString() { dataSource=%s, databaseSchema=%s, adminServer=%s, adminServerPort=%d, \ migrate=%s, conductorKey=%s, conductorDomain=%s, \ conductorExecutorMetadata=%s, appVersion=%s, executorId=%s, enablePatching=%s, \ - listenQueues=%s, serializer=%s, schedulerPollingInterval=%s] + listenQueues=%s, serializer=%s, schedulerPollingInterval=%s, useListenNotify=%s] """ .formatted( appName, @@ -563,6 +606,7 @@ public String toString() { enablePatching, listenQueues, serializer != null ? serializer.name() : null, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } } diff --git a/transact/src/main/java/dev/dbos/transact/database/DbContext.java b/transact/src/main/java/dev/dbos/transact/database/DbContext.java new file mode 100644 index 000000000..d92dbfce0 --- /dev/null +++ b/transact/src/main/java/dev/dbos/transact/database/DbContext.java @@ -0,0 +1,23 @@ +package dev.dbos.transact.database; + +import dev.dbos.transact.json.DBOSSerializer; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.function.BooleanSupplier; + +import javax.sql.DataSource; + +public record DbContext( + DataSource dataSource, String schema, DBOSSerializer serializer, BooleanSupplier closed) { + + public Connection getConnection() throws SQLException { + return dataSource.getConnection(); + } + + public void checkClosed() { + if (closed.getAsBoolean()) { + throw new IllegalStateException("Database is closed"); + } + } +} diff --git a/transact/src/main/java/dev/dbos/transact/database/GetEventCaller.java b/transact/src/main/java/dev/dbos/transact/database/GetEventCaller.java new file mode 100644 index 000000000..de7f514c3 --- /dev/null +++ b/transact/src/main/java/dev/dbos/transact/database/GetEventCaller.java @@ -0,0 +1,3 @@ +package dev.dbos.transact.database; + +public record GetEventCaller(String workflowId, int stepId, int timeoutStepId) {} diff --git a/transact/src/main/java/dev/dbos/transact/database/GetWorkflowEventContext.java b/transact/src/main/java/dev/dbos/transact/database/GetWorkflowEventContext.java deleted file mode 100644 index e183b4627..000000000 --- a/transact/src/main/java/dev/dbos/transact/database/GetWorkflowEventContext.java +++ /dev/null @@ -1,3 +0,0 @@ -package dev.dbos.transact.database; - -public record GetWorkflowEventContext(String workflowId, int functionId, int timeoutFunctionId) {} diff --git a/transact/src/main/java/dev/dbos/transact/database/NotificationService.java b/transact/src/main/java/dev/dbos/transact/database/NotificationListenerSource.java similarity index 65% rename from transact/src/main/java/dev/dbos/transact/database/NotificationService.java rename to transact/src/main/java/dev/dbos/transact/database/NotificationListenerSource.java index 2681b6963..87e2f8c90 100644 --- a/transact/src/main/java/dev/dbos/transact/database/NotificationService.java +++ b/transact/src/main/java/dev/dbos/transact/database/NotificationListenerSource.java @@ -1,13 +1,14 @@ package dev.dbos.transact.database; +import dev.dbos.transact.database.SystemDatabase.NotificationSource; +import dev.dbos.transact.database.signal.SignalKey; +import dev.dbos.transact.database.signal.SignalMap; +import dev.dbos.transact.database.signal.Subscription; + import java.sql.Connection; import java.sql.SQLException; import java.sql.Statement; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; -import java.util.concurrent.locks.Condition; -import java.util.concurrent.locks.ReentrantLock; import javax.sql.DataSource; @@ -16,35 +17,31 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class NotificationService { - - public static class LockConditionPair { - public final ReentrantLock lock = new ReentrantLock(); - public final Condition condition = lock.newCondition(); - } +class NotificationListenerSource implements NotificationSource { - private static final Logger logger = LoggerFactory.getLogger(NotificationService.class); + private static final Logger logger = LoggerFactory.getLogger(NotificationListenerSource.class); - private final Map notificationsMap = new ConcurrentHashMap<>(); - private final AtomicReference notificationListenerThread = new AtomicReference<>(null); private final DataSource dataSource; + private final AtomicReference notificationListenerThread = new AtomicReference<>(null); + private final SignalMap signalMap = new SignalMap<>(); - public NotificationService(DataSource dataSource) { + public NotificationListenerSource(DataSource dataSource) { this.dataSource = dataSource; } - public boolean registerNotificationCondition(String key, LockConditionPair pair) { - return notificationsMap.putIfAbsent(key, pair) == null; - } - - public LockConditionPair getOrCreateNotificationCondition(String key) { - return notificationsMap.computeIfAbsent(key, k -> new LockConditionPair()); + @Override + public Subscription subscribe(SignalKey.Message key) { + var strKey = "m::%s::%s".formatted(key.workflowId(), key.topic()); + return signalMap.subscribe(strKey, key.wakeReason()); } - public void unregisterNotificationCondition(String key) { - notificationsMap.remove(key); + @Override + public Subscription subscribe(SignalKey.Event key) { + var strKey = "e::%s::%s".formatted(key.workflowId(), key.key()); + return signalMap.subscribe(strKey, key.wakeReason()); } + @Override public void start() { Thread t = new Thread(this::notificationListener, "NotificationListener"); t.setDaemon(true); @@ -54,7 +51,8 @@ public void start() { } } - public void stop() { + @Override + public void close() { Thread t = notificationListenerThread.getAndSet(null); if (t != null) { t.interrupt(); @@ -65,7 +63,6 @@ public void stop() { } } - notificationsMap.clear(); logger.debug("Notification listener stopped"); } @@ -103,9 +100,8 @@ private void notificationListener() { logger.error("Received notification with null channel. Payload: {}", payload); } else switch (channel) { - case "dbos_notifications_channel" -> handleNotification(payload, "notifications"); - case "dbos_workflow_events_channel" -> - handleNotification(payload, "workflow_events"); + case "dbos_notifications_channel" -> signalMap.signal("m::" + payload); + case "dbos_workflow_events_channel" -> signalMap.signal("e::" + payload); default -> logger.error("Unknown NOTIFY channel: {}", channel); } } @@ -136,25 +132,4 @@ private void notificationListener() { } logger.debug("Notification listener thread exiting"); } - - private void handleNotification(String payload, String mapType) { - - logger.debug("Received notification for {}", payload); - - if (payload != null && !payload.isEmpty()) { - LockConditionPair pair = notificationsMap.get(payload); - if (pair != null) { - pair.lock.lock(); - try { - pair.condition.signalAll(); - } finally { - pair.lock.unlock(); - } - logger.debug("Signaled {} condition for {}", mapType, payload); - } else { - logger.debug("ConditionMap has no entry for {}", payload); - } - // If no condition found, we simply ignore the notification - } - } } diff --git a/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java b/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java index e9e8df996..dfa0483b7 100644 --- a/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java +++ b/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java @@ -2,6 +2,18 @@ import dev.dbos.transact.Constants; import dev.dbos.transact.config.DBOSConfig; +import dev.dbos.transact.database.dao.ApplicationVersionDAO; +import dev.dbos.transact.database.dao.ExternalStateDAO; +import dev.dbos.transact.database.dao.NotificationsDAO; +import dev.dbos.transact.database.dao.QueuesDAO; +import dev.dbos.transact.database.dao.SchedulesDAO; +import dev.dbos.transact.database.dao.StepsDAO; +import dev.dbos.transact.database.dao.StreamsDAO; +import dev.dbos.transact.database.dao.WorkflowDAO; +import dev.dbos.transact.database.signal.SignalKey; +import dev.dbos.transact.database.signal.SignalKey.Event; +import dev.dbos.transact.database.signal.SignalKey.Message; +import dev.dbos.transact.database.signal.Subscription; import dev.dbos.transact.exceptions.*; import dev.dbos.transact.json.DBOSSerializer; import dev.dbos.transact.workflow.ExportedWorkflow; @@ -24,6 +36,7 @@ import java.time.Duration; import java.time.Instant; import java.util.*; +import java.util.concurrent.atomic.AtomicBoolean; import javax.sql.DataSource; @@ -34,18 +47,48 @@ public class SystemDatabase implements AutoCloseable { + public interface NotificationRegistry { + Subscription subscribe(SignalKey.Message key); + + Subscription subscribe(SignalKey.Event key); + } + + public interface NotificationSource extends NotificationRegistry { + void start(); + + void close(); + } + + class NullNotificationSource implements NotificationSource { + + @Override + public Subscription subscribe(Message key) { + return new Subscription(() -> {}); + } + + @Override + public Subscription subscribe(Event key) { + return new Subscription(() -> {}); + } + + @Override + public void start() {} + + @Override + public void close() {} + } + private static final Logger logger = LoggerFactory.getLogger(SystemDatabase.class); public static String sanitizeSchema(String schema) { return Objects.requireNonNullElse(schema, Constants.DB_SCHEMA).replace("\0", ""); } - private final DataSource dataSource; - private final String schema; + private final DbContext ctx; private final boolean created; - private final DBOSSerializer serializer; - private final NotificationService notificationService; + private final AtomicBoolean closed = new AtomicBoolean(false); + private final NotificationSource notificationSource; private Duration dbPollingInterval = Duration.ofSeconds(1); private static void validatePostgresDataSource(DataSource dataSource) { @@ -62,36 +105,50 @@ private static void validatePostgresDataSource(DataSource dataSource) { } private SystemDatabase( - DataSource dataSource, String schema, boolean created, DBOSSerializer serializer) { + DataSource dataSource, + String schema, + boolean created, + DBOSSerializer serializer, + boolean useListenNotify) { validatePostgresDataSource(dataSource); schema = sanitizeSchema(schema); if (schema.contains("\"")) { throw new IllegalArgumentException("Schema name must not contain double quotes"); } - this.schema = schema; - this.dataSource = dataSource; + this.ctx = new DbContext(dataSource, schema, serializer, this.closed::get); this.created = created; - this.serializer = serializer; + try { + useListenNotify = isCockroach(dataSource) ? false : useListenNotify; + } catch (SQLException e) { + logger.error("Failed to determine if dataSource is CockroachDB", e); + useListenNotify = false; + } - notificationService = new NotificationService(dataSource); + notificationSource = + useListenNotify ? new NotificationListenerSource(dataSource) : new NullNotificationSource(); } - public SystemDatabase(String url, String user, String password, String schema) { - this(createDataSource(url, user, password), schema, true, null); + public SystemDatabase( + String url, + String user, + String password, + String schema, + DBOSSerializer serializer, + boolean useListenNotify) { + this(createDataSource(url, user, password), schema, true, serializer, useListenNotify); } - public SystemDatabase( - String url, String user, String password, String schema, DBOSSerializer serializer) { - this(createDataSource(url, user, password), schema, true, serializer); + public SystemDatabase(String url, String user, String password, String schema) { + this(createDataSource(url, user, password), schema, true, null, true); } public SystemDatabase(DataSource dataSource, String schema) { - this(dataSource, schema, false, null); + this(dataSource, schema, false, null, true); } public SystemDatabase(DataSource dataSource, String schema, DBOSSerializer serializer) { - this(dataSource, schema, false, serializer); + this(dataSource, schema, false, serializer, true); } public static SystemDatabase create(DBOSConfig config) { @@ -101,14 +158,15 @@ public static SystemDatabase create(DBOSConfig config) { config.dbUser(), config.dbPassword(), config.databaseSchema(), - config.serializer()); + config.serializer(), + config.useListenNotify()); } else { return new SystemDatabase(config.dataSource(), config.databaseSchema(), config.serializer()); } } Optional getConfig() { - if (dataSource instanceof HikariDataSource hds) { + if (ctx.dataSource() instanceof HikariDataSource hds) { return Optional.of(hds); } return Optional.empty(); @@ -140,16 +198,33 @@ public static HikariDataSource createDataSource(String url, String user, String return new HikariDataSource(config); } + public static boolean isCockroach(DataSource dataSource) throws SQLException { + try (var conn = dataSource.getConnection()) { + return isCockroach(conn); + } + } + + public static boolean isCockroach(Connection conn) throws SQLException { + try (var stmt = conn.createStatement(); + var rs = stmt.executeQuery("SELECT version()")) { + if (rs.next()) { + return rs.getString(1).toLowerCase().contains("cockroachdb"); + } + } + return false; + } + @Override public void close() { - notificationService.stop(); - if (created && dataSource instanceof HikariDataSource hikariDataSource) { + closed.set(true); + notificationSource.close(); + if (created && ctx.dataSource() instanceof HikariDataSource hikariDataSource) { hikariDataSource.close(); } } public void start() { - notificationService.start(); + notificationSource.start(); } void speedUpPollingForTest() { @@ -198,6 +273,9 @@ private T dbRetry(SqlSupplier supplier) { final int MAX_RETRIES = 20; int attempt = 0; while (true) { + if (closed.get()) { + throw new IllegalStateException("SystemDatabase is closed"); + } try { return supplier.get(); } catch (SQLException e) { @@ -207,7 +285,7 @@ private T dbRetry(SqlSupplier supplier) { } if (e instanceof SQLRecoverableException || isConnectionFailure(e)) { logger.warn("Recoverable connection error. Resetting client pool.", e); - if (dataSource instanceof HikariDataSource hikariDataSource) { + if (ctx.dataSource() instanceof HikariDataSource hikariDataSource) { hikariDataSource.getHikariPoolMXBean().softEvictConnections(); } waitForRecovery(attempt, 2000); @@ -221,11 +299,11 @@ private T dbRetry(SqlSupplier supplier) { } } - static Instant toInstant(Long epochMs) { + public static Instant toInstant(Long epochMs) { return epochMs != null ? Instant.ofEpochMilli(epochMs) : null; } - static Duration toDuration(Long ms) { + public static Duration toDuration(Long ms) { return ms != null ? Duration.ofMillis(ms) : null; } @@ -256,14 +334,7 @@ public WorkflowInitResult initWorkflowStatus( return dbRetry( () -> WorkflowDAO.initWorkflowStatus( - dataSource, - schema, - serializer, - initStatus, - maxRetries, - isRecoveryRequest, - isDequeuedRequest, - ownerXid)); + ctx, initStatus, maxRetries, isRecoveryRequest, isDequeuedRequest, ownerXid)); } /** @@ -273,7 +344,7 @@ public WorkflowInitResult initWorkflowStatus( * @param result output serialized as json */ public void recordWorkflowOutput(String workflowId, String result) { - dbRetry(() -> WorkflowDAO.recordWorkflowOutput(dataSource, schema, workflowId, result)); + dbRetry(() -> WorkflowDAO.recordWorkflowOutput(ctx, workflowId, result)); } /** @@ -283,77 +354,67 @@ public void recordWorkflowOutput(String workflowId, String result) { * @param error output serialized as json */ public void recordWorkflowError(String workflowId, String error) { - dbRetry(() -> WorkflowDAO.recordWorkflowError(dataSource, schema, workflowId, error)); + dbRetry(() -> WorkflowDAO.recordWorkflowError(ctx, workflowId, error)); } public WorkflowStatus getWorkflowStatus(String workflowId) { - return dbRetry(() -> WorkflowDAO.getWorkflowStatus(dataSource, schema, serializer, workflowId)); + return dbRetry(() -> WorkflowDAO.getWorkflowStatus(ctx, workflowId)); } public String getWorkflowSerialization(String workflowId) { - return dbRetry(() -> WorkflowDAO.getWorkflowSerialization(dataSource, schema, workflowId)); + return dbRetry(() -> WorkflowDAO.getWorkflowSerialization(ctx, workflowId)); } public List listWorkflows(ListWorkflowsInput input) { - return dbRetry(() -> WorkflowDAO.listWorkflows(dataSource, schema, serializer, input)); + return dbRetry(() -> WorkflowDAO.listWorkflows(ctx, input)); } public List getWorkflowAggregates(GetWorkflowAggregatesInput input) { - return dbRetry(() -> WorkflowDAO.getWorkflowAggregates(dataSource, schema, serializer, input)); + return dbRetry(() -> WorkflowDAO.getWorkflowAggregates(ctx, input)); } public List getPendingWorkflows(List executorIds, String appVersion) { - return dbRetry( - () -> - WorkflowDAO.getPendingWorkflows( - dataSource, schema, serializer, executorIds, appVersion)); + return dbRetry(() -> WorkflowDAO.getPendingWorkflows(ctx, executorIds, appVersion)); } public boolean clearQueueAssignment(String workflowId) { - return dbRetry(() -> QueuesDAO.clearQueueAssignment(dataSource, schema, workflowId)); + return dbRetry(() -> QueuesDAO.clearQueueAssignment(ctx, workflowId)); } public List getQueuePartitions(String queueName) { - return dbRetry(() -> QueuesDAO.getQueuePartitions(dataSource, schema, queueName)); + return dbRetry(() -> QueuesDAO.getQueuePartitions(ctx, queueName)); } - public StepResult checkStepExecutionTxn(String workflowId, int functionId, String functionName) { + public StepResult checkStepResult(String workflowId, int functionId, String functionName) { return dbRetry( () -> { - try (Connection connection = dataSource.getConnection()) { - return StepsDAO.checkStepExecutionTxn( - workflowId, functionId, functionName, connection, this.schema); + try (Connection connection = ctx.getConnection()) { + return StepsDAO.checkStepResult( + connection, ctx.schema(), workflowId, functionId, functionName); } }); } - public void recordStepResultTxn(StepResult result, long startTime) { + public void recordStepResult(StepResult result, long startTime) { var et = System.currentTimeMillis(); - dbRetry(() -> StepsDAO.recordStepResultTxn(dataSource, result, startTime, et, this.schema)); + dbRetry(() -> StepsDAO.recordStepResult(ctx, result, startTime, et)); } public List listWorkflowSteps( String workflowId, Boolean loadOutput, Integer limit, Integer offset) { - return dbRetry( - () -> - StepsDAO.listWorkflowSteps( - dataSource, workflowId, loadOutput, limit, offset, this.schema, this.serializer)); + return dbRetry(() -> StepsDAO.listWorkflowSteps(ctx, workflowId, loadOutput, limit, offset)); } public Result awaitWorkflowResult(String workflowId) { - return dbRetry( - () -> - WorkflowDAO.awaitWorkflowResult( - dataSource, schema, serializer, dbPollingInterval, workflowId)); + return dbRetry(() -> WorkflowDAO.awaitWorkflowResult(ctx, dbPollingInterval, workflowId)); } public List getAndStartQueuedWorkflows( Queue queue, String executorId, String appVersion, String partitionKey) { return dbRetry( () -> - QueuesDAO.getAndStartQueuedWorkflows( - dataSource, schema, queue, executorId, appVersion, partitionKey)); + QueuesDAO.getAndStartQueuedWorkflows(ctx, queue, executorId, appVersion, partitionKey)); } public void recordChildWorkflow( @@ -365,12 +426,11 @@ public void recordChildWorkflow( dbRetry( () -> WorkflowDAO.recordChildWorkflow( - dataSource, schema, parentId, childId, functionId, functionName, startTime)); + ctx, parentId, childId, functionId, functionName, startTime)); } public Optional checkChildWorkflow(String workflowUuid, int functionId) { - return dbRetry( - () -> WorkflowDAO.checkChildWorkflow(dataSource, schema, workflowUuid, functionId)); + return dbRetry(() -> WorkflowDAO.checkChildWorkflow(ctx, workflowUuid, functionId)); } public void send( @@ -384,16 +444,7 @@ public void send( dbRetry( () -> NotificationsDAO.send( - dataSource, - schema, - serializer, - workflowId, - stepId, - destinationId, - message, - topic, - messageId, - serialization)); + ctx, workflowId, stepId, destinationId, message, topic, messageId, serialization)); } public void sendDirect( @@ -401,14 +452,7 @@ public void sendDirect( dbRetry( () -> NotificationsDAO.sendDirect( - dataSource, - schema, - serializer, - destinationId, - message, - topic, - messageId, - serialization)); + ctx, destinationId, message, topic, messageId, serialization)); } public Object recv( @@ -416,16 +460,14 @@ public Object recv( return dbRetry( () -> NotificationsDAO.recv( - dataSource, - schema, - serializer, - notificationService, - dbPollingInterval, + ctx, workflowId, stepId, + timeout, timeoutStepId, topic, - timeout)); + dbPollingInterval, + notificationSource)); } public void setEvent( @@ -439,94 +481,74 @@ public void setEvent( dbRetry( () -> NotificationsDAO.setEvent( - dataSource, - schema, - serializer, - workflowId, - functionId, - key, - message, - asStep, - serialization)); + ctx, workflowId, functionId, key, message, asStep, serialization)); } - public Object getEvent( - String targetId, String key, Duration timeout, GetWorkflowEventContext callerCtx) { + public Object getEvent(String targetId, String key, Duration timeout, GetEventCaller caller) { return dbRetry( () -> NotificationsDAO.getEvent( - dataSource, - schema, - serializer, - notificationService, - dbPollingInterval, - targetId, - key, - timeout, - callerCtx)); + ctx, targetId, key, timeout, caller, dbPollingInterval, notificationSource)); } public void sleep(String workflowId, int functionId, Duration duration) { - dbRetry(() -> StepsDAO.sleep(dataSource, workflowId, functionId, duration, schema, serializer)); + dbRetry(() -> StepsDAO.sleep(ctx, workflowId, functionId, duration)); } public void cancelWorkflows(List workflowIds) { - dbRetry(() -> WorkflowDAO.cancelWorkflows(dataSource, schema, workflowIds)); + dbRetry(() -> WorkflowDAO.cancelWorkflows(ctx, workflowIds)); } public void resumeWorkflows(List workflowIds, String queueName) { - dbRetry(() -> WorkflowDAO.resumeWorkflows(dataSource, schema, workflowIds, queueName)); + dbRetry(() -> WorkflowDAO.resumeWorkflows(ctx, workflowIds, queueName)); } public void deleteWorkflows(List workflowIds, boolean deleteChildren) { - dbRetry(() -> WorkflowDAO.deleteWorkflows(dataSource, schema, workflowIds, deleteChildren)); + dbRetry(() -> WorkflowDAO.deleteWorkflows(ctx, workflowIds, deleteChildren)); } public String forkWorkflow(String originalWorkflowId, int startStep, ForkOptions options) { - return dbRetry( - () -> - WorkflowDAO.forkWorkflow( - dataSource, schema, serializer, originalWorkflowId, startStep, options)); + return dbRetry(() -> WorkflowDAO.forkWorkflow(ctx, originalWorkflowId, startStep, options)); } public void createApplicationVersion(String versionName) { - dbRetry(() -> ApplicationVersionDAO.createApplicationVersion(dataSource, schema, versionName)); + dbRetry(() -> ApplicationVersionDAO.createApplicationVersion(ctx, versionName)); } public void updateApplicationVersionTimestamp(String versionName, Instant newTimestamp) { dbRetry( () -> ApplicationVersionDAO.updateApplicationVersionTimestamp( - dataSource, schema, versionName, newTimestamp)); + ctx, versionName, newTimestamp)); } public List listApplicationVersions() { - return dbRetry(() -> ApplicationVersionDAO.listApplicationVersions(dataSource, schema)); + return dbRetry(() -> ApplicationVersionDAO.listApplicationVersions(ctx)); } public VersionInfo getLatestApplicationVersion() { - return dbRetry(() -> ApplicationVersionDAO.getLatestApplicationVersion(dataSource, schema)); + return dbRetry(() -> ApplicationVersionDAO.getLatestApplicationVersion(ctx)); } public void garbageCollect(Instant cutoff, Long rowsThreshold) { - dbRetry(() -> WorkflowDAO.garbageCollect(dataSource, schema, cutoff, rowsThreshold)); + dbRetry(() -> WorkflowDAO.garbageCollect(ctx, cutoff, rowsThreshold)); } public void setWorkflowDelay(String workflowId, WorkflowDelay delay) { - dbRetry(() -> WorkflowDAO.setWorkflowDelay(dataSource, schema, workflowId, delay)); + dbRetry(() -> WorkflowDAO.setWorkflowDelay(ctx, workflowId, delay)); } public void transitionDelayedWorkflows() { - dbRetry(() -> WorkflowDAO.transitionDelayedWorkflows(dataSource, schema)); + dbRetry(() -> WorkflowDAO.transitionDelayedWorkflows(ctx)); } public void createSchedule(WorkflowSchedule schedule) { - dbRetry(() -> SchedulesDAO.createSchedule(dataSource, schema, serializer, schedule)); + dbRetry(() -> SchedulesDAO.createSchedule(ctx, schedule)); } public Optional getSchedule(String name) { - return dbRetry(() -> SchedulesDAO.getSchedule(dataSource, schema, serializer, name)); + return dbRetry(() -> SchedulesDAO.getSchedule(ctx, name)); } public List listSchedules( @@ -534,74 +556,67 @@ public List listSchedules( List workflowNames, List scheduleNamePrefixes) { return dbRetry( - () -> - SchedulesDAO.listSchedules( - dataSource, schema, serializer, statuses, workflowNames, scheduleNamePrefixes)); + () -> SchedulesDAO.listSchedules(ctx, statuses, workflowNames, scheduleNamePrefixes)); } public void pauseSchedule(String name) { - dbRetry(() -> SchedulesDAO.pauseSchedule(dataSource, schema, name)); + dbRetry(() -> SchedulesDAO.pauseSchedule(ctx, name)); } public void resumeSchedule(String name) { - dbRetry(() -> SchedulesDAO.resumeSchedule(dataSource, schema, name)); + dbRetry(() -> SchedulesDAO.resumeSchedule(ctx, name)); } public void updateScheduleLastFiredAt(String name, Instant lastFiredAt) { - dbRetry(() -> SchedulesDAO.updateScheduleLastFiredAt(dataSource, schema, name, lastFiredAt)); + dbRetry(() -> SchedulesDAO.updateScheduleLastFiredAt(ctx, name, lastFiredAt)); } public void deleteSchedule(String name) { - dbRetry(() -> SchedulesDAO.deleteSchedule(dataSource, schema, name)); + dbRetry(() -> SchedulesDAO.deleteSchedule(ctx, name)); } public void applySchedules(List schedules) { - dbRetry(() -> SchedulesDAO.applySchedules(dataSource, schema, serializer, schedules)); + dbRetry(() -> SchedulesDAO.applySchedules(ctx, schedules)); } public Optional getExternalState(String service, String workflowName, String key) { - return dbRetry( - () -> ExternalStateDAO.getExternalState(dataSource, schema, service, workflowName, key)); + return dbRetry(() -> ExternalStateDAO.getExternalState(ctx, service, workflowName, key)); } public ExternalState upsertExternalState(ExternalState state) { - return dbRetry(() -> ExternalStateDAO.upsertExternalState(dataSource, schema, state)); + return dbRetry(() -> ExternalStateDAO.upsertExternalState(ctx, state)); } public List getMetrics(Instant startTime, Instant endTime) { - return dbRetry(() -> WorkflowDAO.getMetrics(dataSource, schema, startTime, endTime)); + return dbRetry(() -> WorkflowDAO.getMetrics(ctx, startTime, endTime)); } public boolean patch(String workflowId, int functionId, String patchName) { - return dbRetry(() -> StepsDAO.patch(dataSource, workflowId, functionId, patchName, schema)); + return dbRetry(() -> StepsDAO.patch(ctx, workflowId, functionId, patchName)); } public boolean deprecatePatch(String workflowId, int functionId, String patchName) { - return dbRetry( - () -> StepsDAO.deprecatePatch(dataSource, workflowId, functionId, patchName, schema)); + return dbRetry(() -> StepsDAO.deprecatePatch(ctx, workflowId, functionId, patchName)); } public Set getWorkflowChildren(String workflowId) { - return dbRetry(() -> WorkflowDAO.getWorkflowChildren(dataSource, schema, workflowId)); + return dbRetry(() -> WorkflowDAO.getWorkflowChildren(ctx, workflowId)); } public Map getAllEvents(String workflowId) { - return dbRetry(() -> WorkflowDAO.getAllEvents(dataSource, schema, serializer, workflowId)); + return dbRetry(() -> WorkflowDAO.getAllEvents(ctx, workflowId)); } public List getAllNotifications(String workflowId) { - return dbRetry( - () -> NotificationsDAO.getAllNotifications(dataSource, schema, serializer, workflowId)); + return dbRetry(() -> NotificationsDAO.getAllNotifications(ctx, workflowId)); } public List exportWorkflow(String workflowId, boolean exportChildren) { - return dbRetry( - () -> - WorkflowDAO.exportWorkflow(dataSource, schema, serializer, workflowId, exportChildren)); + return dbRetry(() -> WorkflowDAO.exportWorkflow(ctx, workflowId, exportChildren)); } public void importWorkflow(List workflows) { - dbRetry(() -> WorkflowDAO.importWorkflow(dataSource, schema, serializer, workflows)); + dbRetry(() -> WorkflowDAO.importWorkflow(ctx, workflows)); } public void writeStreamFromStep( @@ -609,7 +624,7 @@ public void writeStreamFromStep( dbRetry( () -> StreamsDAO.writeStreamFromStep( - dataSource, schema, workflowId, functionId, key, value, serializationFormat)); + ctx, workflowId, functionId, key, value, serializationFormat)); } public void writeStreamFromWorkflow( @@ -617,18 +632,18 @@ public void writeStreamFromWorkflow( dbRetry( () -> StreamsDAO.writeStreamFromWorkflow( - dataSource, schema, workflowId, functionId, key, value, serializationFormat)); + ctx, workflowId, functionId, key, value, serializationFormat)); } public void closeStream(String workflowId, int functionId, String key) { - dbRetry(() -> StreamsDAO.closeStream(dataSource, schema, workflowId, functionId, key)); + dbRetry(() -> StreamsDAO.closeStream(ctx, workflowId, functionId, key)); } public Object readStream(String workflowId, String key, int offset) { - return dbRetry(() -> StreamsDAO.readStream(dataSource, schema, workflowId, key, offset)); + return dbRetry(() -> StreamsDAO.readStream(ctx, workflowId, key, offset)); } public Map> getAllStreamEntries(String workflowId) { - return dbRetry(() -> StreamsDAO.getAllStreamEntries(dataSource, schema, workflowId)); + return dbRetry(() -> StreamsDAO.getAllStreamEntries(ctx, workflowId)); } } diff --git a/transact/src/main/java/dev/dbos/transact/database/ApplicationVersionDAO.java b/transact/src/main/java/dev/dbos/transact/database/dao/ApplicationVersionDAO.java similarity index 71% rename from transact/src/main/java/dev/dbos/transact/database/ApplicationVersionDAO.java rename to transact/src/main/java/dev/dbos/transact/database/dao/ApplicationVersionDAO.java index 0ccd74023..fd7c6fbfb 100644 --- a/transact/src/main/java/dev/dbos/transact/database/ApplicationVersionDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/dao/ApplicationVersionDAO.java @@ -1,5 +1,6 @@ -package dev.dbos.transact.database; +package dev.dbos.transact.database.dao; +import dev.dbos.transact.database.DbContext; import dev.dbos.transact.workflow.VersionInfo; import java.sql.SQLException; @@ -8,13 +9,11 @@ import java.util.List; import java.util.UUID; -import javax.sql.DataSource; - -class ApplicationVersionDAO { +public class ApplicationVersionDAO { private ApplicationVersionDAO() {} - static void createApplicationVersion(DataSource dataSource, String schema, String versionName) + public static void createApplicationVersion(DbContext ctx, String versionName) throws SQLException { String sql = """ @@ -22,8 +21,8 @@ static void createApplicationVersion(DataSource dataSource, String schema, Strin VALUES (?, ?) ON CONFLICT (version_name) DO NOTHING """ - .formatted(schema); - try (var conn = dataSource.getConnection(); + .formatted(ctx.schema()); + try (var conn = ctx.getConnection(); var stmt = conn.prepareStatement(sql)) { stmt.setString(1, UUID.randomUUID().toString()); stmt.setString(2, versionName); @@ -31,17 +30,16 @@ ON CONFLICT (version_name) DO NOTHING } } - static void updateApplicationVersionTimestamp( - DataSource dataSource, String schema, String versionName, Instant newTimestamp) - throws SQLException { + public static void updateApplicationVersionTimestamp( + DbContext ctx, String versionName, Instant newTimestamp) throws SQLException { String sql = """ UPDATE "%s".application_versions SET version_timestamp = ? WHERE version_name = ? """ - .formatted(schema); - try (var conn = dataSource.getConnection(); + .formatted(ctx.schema()); + try (var conn = ctx.getConnection(); var stmt = conn.prepareStatement(sql)) { stmt.setLong(1, newTimestamp.toEpochMilli()); stmt.setString(2, versionName); @@ -49,17 +47,16 @@ static void updateApplicationVersionTimestamp( } } - static List listApplicationVersions(DataSource dataSource, String schema) - throws SQLException { + public static List listApplicationVersions(DbContext ctx) throws SQLException { String sql = """ SELECT version_id, version_name, version_timestamp, created_at FROM "%s".application_versions ORDER BY version_timestamp DESC """ - .formatted(schema); + .formatted(ctx.schema()); List results = new ArrayList<>(); - try (var conn = dataSource.getConnection(); + try (var conn = ctx.getConnection(); var stmt = conn.prepareStatement(sql); var rs = stmt.executeQuery()) { while (rs.next()) { @@ -74,8 +71,7 @@ static List listApplicationVersions(DataSource dataSource, String s return results; } - static VersionInfo getLatestApplicationVersion(DataSource dataSource, String schema) - throws SQLException { + public static VersionInfo getLatestApplicationVersion(DbContext ctx) throws SQLException { String sql = """ SELECT version_id, version_name, version_timestamp, created_at @@ -83,8 +79,8 @@ static VersionInfo getLatestApplicationVersion(DataSource dataSource, String sch ORDER BY version_timestamp DESC LIMIT 1 """ - .formatted(schema); - try (var conn = dataSource.getConnection(); + .formatted(ctx.schema()); + try (var conn = ctx.getConnection(); var stmt = conn.prepareStatement(sql); var rs = stmt.executeQuery()) { if (rs.next()) { diff --git a/transact/src/main/java/dev/dbos/transact/database/ExternalStateDAO.java b/transact/src/main/java/dev/dbos/transact/database/dao/ExternalStateDAO.java similarity index 85% rename from transact/src/main/java/dev/dbos/transact/database/ExternalStateDAO.java rename to transact/src/main/java/dev/dbos/transact/database/dao/ExternalStateDAO.java index ef395d3df..476e5070c 100644 --- a/transact/src/main/java/dev/dbos/transact/database/ExternalStateDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/dao/ExternalStateDAO.java @@ -1,4 +1,7 @@ -package dev.dbos.transact.database; +package dev.dbos.transact.database.dao; + +import dev.dbos.transact.database.DbContext; +import dev.dbos.transact.database.ExternalState; import java.math.BigDecimal; import java.math.BigInteger; @@ -6,22 +9,19 @@ import java.util.Objects; import java.util.Optional; -import javax.sql.DataSource; - -class ExternalStateDAO { +public class ExternalStateDAO { private ExternalStateDAO() {} - static Optional getExternalState( - DataSource dataSource, String schema, String service, String workflowName, String key) - throws SQLException { + public static Optional getExternalState( + DbContext ctx, String service, String workflowName, String key) throws SQLException { final String sql = """ SELECT value, update_seq, update_time FROM "%s".event_dispatch_kv WHERE service_name = ? AND workflow_fn_name = ? AND key = ? """ - .formatted(schema); + .formatted(ctx.schema()); - try (var conn = dataSource.getConnection(); + try (var conn = ctx.getConnection(); var stmt = conn.prepareStatement(sql)) { stmt.setString(1, Objects.requireNonNull(service, "service must not be null")); stmt.setString(2, Objects.requireNonNull(workflowName, "workflowName must not be null")); @@ -41,8 +41,8 @@ static Optional getExternalState( } } - static ExternalState upsertExternalState( - DataSource dataSource, String schema, ExternalState state) throws SQLException { + public static ExternalState upsertExternalState(DbContext ctx, ExternalState state) + throws SQLException { final var sql = """ INSERT INTO "%s".event_dispatch_kv ( @@ -58,9 +58,9 @@ ON CONFLICT (service_name, workflow_fn_name, key) ) THEN EXCLUDED.value ELSE event_dispatch_kv.value END RETURNING value, update_time, update_seq """ - .formatted(schema); + .formatted(ctx.schema()); - try (var conn = dataSource.getConnection(); + try (var conn = ctx.getConnection(); var stmt = conn.prepareStatement(sql)) { stmt.setString(1, Objects.requireNonNull(state.service(), "service must not be null")); stmt.setString( diff --git a/transact/src/main/java/dev/dbos/transact/database/NotificationsDAO.java b/transact/src/main/java/dev/dbos/transact/database/dao/NotificationsDAO.java similarity index 50% rename from transact/src/main/java/dev/dbos/transact/database/NotificationsDAO.java rename to transact/src/main/java/dev/dbos/transact/database/dao/NotificationsDAO.java index 239511cc8..e8ef29ac5 100644 --- a/transact/src/main/java/dev/dbos/transact/database/NotificationsDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/dao/NotificationsDAO.java @@ -1,8 +1,12 @@ -package dev.dbos.transact.database; +package dev.dbos.transact.database.dao; import dev.dbos.transact.Constants; +import dev.dbos.transact.database.DbContext; +import dev.dbos.transact.database.GetEventCaller; +import dev.dbos.transact.database.SystemDatabase.NotificationRegistry; +import dev.dbos.transact.database.signal.SignalKey; +import dev.dbos.transact.database.signal.SignalMap; import dev.dbos.transact.exceptions.DBOSNonExistentWorkflowException; -import dev.dbos.transact.exceptions.DBOSWorkflowExecutionConflictException; import dev.dbos.transact.json.DBOSSerializer; import dev.dbos.transact.json.SerializationUtil; import dev.dbos.transact.workflow.NotificationInfo; @@ -17,24 +21,22 @@ import java.util.ArrayList; import java.util.List; import java.util.Objects; +import java.util.Optional; import java.util.UUID; -import java.util.concurrent.TimeUnit; - -import javax.sql.DataSource; +import org.jspecify.annotations.NonNull; +import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -class NotificationsDAO { +public class NotificationsDAO { private NotificationsDAO() {} private static final Logger logger = LoggerFactory.getLogger(NotificationsDAO.class); - static void send( - DataSource dataSource, - String schema, - DBOSSerializer serializer, + public static void send( + DbContext ctx, String workflowId, int stepId, String destinationId, @@ -44,16 +46,17 @@ static void send( String serialization) throws SQLException { + DBOSSerializer serializer = ctx.serializer(); var startTime = System.currentTimeMillis(); String functionName = "DBOS.send"; String finalTopic = (topic != null) ? topic : Constants.DBOS_NULL_TOPIC; - try (Connection conn = dataSource.getConnection()) { + try (Connection conn = ctx.getConnection()) { conn.setAutoCommit(false); try { StepResult recordedOutput = - StepsDAO.checkStepExecutionTxn(workflowId, stepId, functionName, conn, schema); + StepsDAO.checkStepResult(conn, ctx.schema(), workflowId, stepId, functionName); if (recordedOutput != null) { logger.debug( @@ -81,7 +84,7 @@ static void send( VALUES (?, ?, ?, ?, ?) ON CONFLICT (message_uuid) DO NOTHING """ - .formatted(schema); + .formatted(ctx.schema()); try (PreparedStatement stmt = conn.prepareStatement(sql)) { stmt.setString(1, destinationId); @@ -98,7 +101,8 @@ ON CONFLICT (message_uuid) DO NOTHING } var output = new StepResult(workflowId, stepId, functionName, null, null, null, null); - StepsDAO.recordStepResultTxn(output, startTime, System.currentTimeMillis(), conn, schema); + StepsDAO.recordStepResult( + conn, ctx.schema(), output, startTime, System.currentTimeMillis()); conn.commit(); @@ -113,16 +117,15 @@ ON CONFLICT (message_uuid) DO NOTHING } } - static void sendDirect( - DataSource dataSource, - String schema, - DBOSSerializer serializer, + public static void sendDirect( + DbContext ctx, String destinationId, Object message, String topic, String messageId, String serialization) throws SQLException { + DBOSSerializer serializer = ctx.serializer(); String finalTopic = (topic != null) ? topic : Constants.DBOS_NULL_TOPIC; String finalMessageId = (messageId != null) ? messageId : UUID.randomUUID().toString(); var serializedMsg = SerializationUtil.serializeValue(message, serialization, serializer); @@ -134,9 +137,9 @@ static void sendDirect( VALUES (?, ?, ?, ?, ?) ON CONFLICT (message_uuid) DO NOTHING """ - .formatted(schema); + .formatted(ctx.schema()); - try (var conn = dataSource.getConnection(); + try (var conn = ctx.getConnection(); var stmt = conn.prepareStatement(sql)) { stmt.setString(1, destinationId); stmt.setString(2, finalTopic); @@ -152,131 +155,111 @@ ON CONFLICT (message_uuid) DO NOTHING } } - static Object recv( - DataSource dataSource, - String schema, - DBOSSerializer serializer, - NotificationService notificationService, - Duration dbPollingInterval, + public static Object recv( + DbContext ctx, String workflowId, int stepId, - int timeoutFunctionId, + Duration timeout, + int timeoutStepId, String topic, - Duration timeout) + Duration dbPollingInterval, + NotificationRegistry notifcationRegistry) throws SQLException { - var startTime = System.currentTimeMillis(); - String functionName = "DBOS.recv"; - String finalTopic = (topic != null) ? topic : Constants.DBOS_NULL_TOPIC; - - StepResult recordedOutput; - try (Connection c = dataSource.getConnection()) { - recordedOutput = StepsDAO.checkStepExecutionTxn(workflowId, stepId, functionName, c, schema); + if (Objects.requireNonNull(workflowId).isEmpty()) { + throw new IllegalArgumentException("workflowId must not be empty"); } - if (recordedOutput != null) { - logger.debug("Replaying recv, id: {}, topic: {}", stepId, finalTopic); - if (recordedOutput.output() != null) { - return SerializationUtil.deserializeValue( - recordedOutput.output(), recordedOutput.serialization(), serializer); - } else { - throw new RuntimeException("No output recorded in the last recv"); + var stepName = "DBOS.recv"; + topic = Objects.requireNonNullElse(topic, Constants.DBOS_NULL_TOPIC); + + var recordedResult = StepsDAO.checkStepResult(ctx, workflowId, stepId, stepName); + if (recordedResult != null) { + logger.debug( + "Replaying recv, workflowId: {}, stepId: {}, topic: {}", workflowId, stepId, topic); + if (recordedResult.output() != null) { + return recordedResult.toResult(ctx.serializer()); } - } else { - logger.debug("Running recv, wfid {}, id: {}, topic: {}", workflowId, stepId, finalTopic); + logger.debug( + "Running recv, workflowId: {}, stepId: {}, topic: {}", workflowId, stepId, topic); } - String payload = workflowId + "::" + finalTopic; - var lockPair = new NotificationService.LockConditionPair(); - - double actualTimeout = timeout.toMillis(); - var targetTime = System.currentTimeMillis() + actualTimeout; - var checkedDBForSleep = false; - - try { - lockPair.lock.lock(); - boolean success = notificationService.registerNotificationCondition(payload, lockPair); - if (!success) { - throw new DBOSWorkflowExecutionConflictException(workflowId); - } + var startTime = System.currentTimeMillis(); + var messageKey = new SignalKey.Message(workflowId, topic); + dbPollingInterval = Objects.requireNonNullElse(dbPollingInterval, Duration.ofSeconds(1)); + try (var messageSignal = notifcationRegistry.subscribe(messageKey)) { while (true) { - boolean hasExistingNotification; - try (Connection conn = dataSource.getConnection()) { - final String sql = - """ + ctx.checkClosed(); + var sql = + """ SELECT topic FROM "%s".notifications WHERE destination_uuid = ? AND topic = ? AND consumed = FALSE - """ - .formatted(schema); - - try (PreparedStatement stmt = conn.prepareStatement(sql)) { - stmt.setString(1, workflowId); - stmt.setString(2, finalTopic); - try (ResultSet rs = stmt.executeQuery()) { - hasExistingNotification = rs.next(); + """ + .formatted(ctx.schema()); + try (var conn = ctx.getConnection(); + var stmt = conn.prepareStatement(sql)) { + stmt.setString(1, workflowId); + stmt.setString(2, topic); + try (var rs = stmt.executeQuery()) { + if (rs.next()) { + // query for results + break; } } } - if (hasExistingNotification) break; - - var nowTime = System.currentTimeMillis(); + // check cancelled - if (!checkedDBForSleep) { - actualTimeout = - StepsDAO.durableSleepDuration( - dataSource, workflowId, timeoutFunctionId, timeout, schema, serializer) - .toMillis(); - checkedDBForSleep = true; - targetTime = nowTime + actualTimeout; + var sleepDuration = StepsDAO.durableSleepDuration(ctx, workflowId, timeoutStepId, timeout); + if (sleepDuration.isNegative() || sleepDuration.isZero()) { + var output = SerializationUtil.serializeValue(null, null, ctx.serializer()); + var stepResult = StepResult.ofOutput(workflowId, stepId, stepName, output); + StepsDAO.recordStepResult(ctx, stepResult, startTime); + return null; } - if (nowTime >= targetTime) break; - long timeoutMs = (long) Math.min(targetTime - nowTime, dbPollingInterval.toMillis()); - try { - lockPair.condition.await(timeoutMs, TimeUnit.MILLISECONDS); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException("Interrupted while waiting for message", e); - } + var loopDuration = + dbPollingInterval.compareTo(sleepDuration) <= 0 ? dbPollingInterval : sleepDuration; + + SignalMap.awaitAny(loopDuration, messageSignal); } - } finally { - lockPair.lock.unlock(); - notificationService.unregisterNotificationCondition(payload); } - try (Connection conn = dataSource.getConnection()) { - conn.setAutoCommit(false); + ctx.checkClosed(); + var sql = + """ + UPDATE "%1$s".notifications + SET consumed = TRUE + WHERE destination_uuid = ? + AND topic = ? + AND consumed = FALSE + AND message_uuid = ( + SELECT message_uuid FROM "%1$s".notifications + WHERE destination_uuid = ? + AND topic = ? + AND consumed = FALSE + ORDER BY created_at_epoch_ms ASC + LIMIT 1 + ) + RETURNING message, serialization + """ + .formatted(ctx.schema()); + try (var conn = ctx.getConnection()) { + conn.setAutoCommit(false); try { - final String sql = - """ - UPDATE "%1$s".notifications - SET consumed = TRUE - WHERE destination_uuid = ? - AND topic = ? - AND consumed = FALSE - AND message_uuid = ( - SELECT message_uuid FROM "%1$s".notifications - WHERE destination_uuid = ? - AND topic = ? - AND consumed = FALSE - ORDER BY created_at_epoch_ms ASC - LIMIT 1 - ) - RETURNING message, serialization - """ - .formatted(schema); - String serializedMessage = null; String serialization = null; try (PreparedStatement stmt = conn.prepareStatement(sql)) { stmt.setString(1, workflowId); - stmt.setString(2, finalTopic); + stmt.setString(2, topic); stmt.setString(3, workflowId); - stmt.setString(4, finalTopic); + stmt.setString(4, topic); + // Note, if there are two executors running the same workflow waiting on the same recv, + // only the first one will return a row here. The second one get a null message but then + // throw a WorkflowExecutionConflictException when it records the step result. try (ResultSet rs = stmt.executeQuery()) { if (rs.next()) { serializedMessage = rs.getString("message"); @@ -285,17 +268,16 @@ static Object recv( } } - var recvdMessage = - SerializationUtil.deserializeValue(serializedMessage, serialization, serializer); + var deserializedMessage = + SerializationUtil.deserializeValue(serializedMessage, serialization, ctx.serializer()); - StepResult output = + var output = new StepResult( - workflowId, stepId, functionName, serializedMessage, null, null, serialization); - StepsDAO.recordStepResultTxn(output, startTime, System.currentTimeMillis(), conn, schema); + workflowId, stepId, stepName, serializedMessage, null, null, serialization); + StepsDAO.recordStepResult(conn, ctx.schema(), output, startTime); conn.commit(); - return recvdMessage; - + return deserializedMessage; } catch (Exception e) { conn.rollback(); throw e; @@ -321,7 +303,7 @@ ON CONFLICT (workflow_uuid, key) """ .formatted(schema); - try (PreparedStatement stmt = conn.prepareStatement(eventSql)) { + try (var stmt = conn.prepareStatement(eventSql)) { stmt.setString(1, workflowId); stmt.setString(2, key); stmt.setString(3, message); @@ -338,7 +320,7 @@ ON CONFLICT (workflow_uuid, key, function_id) """ .formatted(schema); - try (PreparedStatement stmt = conn.prepareStatement(eventHistorySql)) { + try (var stmt = conn.prepareStatement(eventHistorySql)) { stmt.setString(1, workflowId); stmt.setInt(2, functionId); stmt.setString(3, key); @@ -348,10 +330,8 @@ ON CONFLICT (workflow_uuid, key, function_id) } } - static void setEvent( - DataSource dataSource, - String schema, - DBOSSerializer serializer, + public static void setEvent( + DbContext ctx, String workflowId, int functionId, String key, @@ -360,18 +340,19 @@ static void setEvent( String serialization) throws SQLException { + DBOSSerializer serializer = ctx.serializer(); var startTime = System.currentTimeMillis(); String functionName = "DBOS.setEvent"; SerializationUtil.SerializedResult serializedResult = SerializationUtil.serializeValue(message, serialization, serializer); - try (Connection conn = dataSource.getConnection()) { + try (var conn = ctx.getConnection()) { conn.setAutoCommit(false); try { if (asStep) { var recordedOutput = - StepsDAO.checkStepExecutionTxn(workflowId, functionId, functionName, conn, schema); + StepsDAO.checkStepResult(conn, ctx.schema(), workflowId, functionId, functionName); if (recordedOutput != null) { logger.debug( "Replaying setEvent, workflow: {}, step: {}, key: {}", workflowId, functionId, key); @@ -385,7 +366,7 @@ static void setEvent( setEvent( conn, - schema, + ctx.schema(), workflowId, functionId, key, @@ -395,7 +376,7 @@ static void setEvent( if (asStep) { StepResult output = new StepResult(workflowId, functionId, functionName, null, null, null, null); - StepsDAO.recordStepResultTxn(output, startTime, System.currentTimeMillis(), conn, schema); + StepsDAO.recordStepResult(conn, ctx.schema(), output, startTime); } conn.commit(); @@ -408,136 +389,114 @@ static void setEvent( } } - static Object getEvent( - DataSource dataSource, - String schema, - DBOSSerializer serializer, - NotificationService notificationService, - Duration dbPollingInterval, - String targetUuid, + private record GetEventResult(String value, String serialization) {} + + private static Optional getEvent( + DbContext ctx, @NonNull String workflowId, @NonNull String key) throws SQLException { + var sql = + """ + SELECT value, serialization FROM "%s".workflow_events WHERE workflow_uuid = ? AND key = ? + """ + .formatted(ctx.schema()); + try (var conn = ctx.getConnection(); + var stmt = conn.prepareStatement(sql)) { + stmt.setString(1, workflowId); + stmt.setString(2, key); + try (var rs = stmt.executeQuery()) { + if (rs.next()) { + var value = rs.getString("value"); + var serialization = rs.getString("serialization"); + return Optional.of(new GetEventResult(value, serialization)); + } + } + } + + return Optional.empty(); + } + + public static Object getEvent( + DbContext ctx, + String workflowId, String key, Duration timeout, - GetWorkflowEventContext callerCtx) + @Nullable GetEventCaller caller, + Duration dbPollingInterval, + NotificationRegistry notifcationRegistry) throws SQLException { - var startTime = System.currentTimeMillis(); - String functionName = "DBOS.getEvent"; - - if (callerCtx != null) { - StepResult recordedOutput; - try (Connection conn = dataSource.getConnection()) { - recordedOutput = - StepsDAO.checkStepExecutionTxn( - callerCtx.workflowId(), callerCtx.functionId(), functionName, conn, schema); - } + if (Objects.requireNonNull(workflowId).isEmpty()) { + throw new IllegalArgumentException("workflowId must not be empty"); + } - if (recordedOutput != null) { - logger.debug("Replaying getEvent, id: {}, key: {}", callerCtx.functionId(), key); - if (recordedOutput.output() != null) { - return SerializationUtil.deserializeValue( - recordedOutput.output(), recordedOutput.serialization(), serializer); - } else { - throw new RuntimeException("No output recorded in the last getEvent"); - } - } else { - logger.debug("Running getEvent, id: {}, key: {}", callerCtx.functionId(), key); + var stepName = "DBOS.getEvent"; + + if (caller != null) { + var prevResult = + StepsDAO.checkStepResult(ctx, caller.workflowId(), caller.stepId(), stepName); + if (prevResult != null) { + logger.debug("Replaying getEvent, id: {}, key: {}", caller.stepId(), key); + return prevResult.toResult(ctx.serializer()); } + logger.debug("Running getEvent, id: {}, key: {}", caller.stepId(), key); } - String payload = targetUuid + "::" + key; - NotificationService.LockConditionPair lockConditionPair = - notificationService.getOrCreateNotificationCondition(payload); - - lockConditionPair.lock.lock(); - try { - Object value = null; - final String sql = - """ - SELECT value, serialization FROM "%s".workflow_events WHERE workflow_uuid = ? AND key = ? - """ - .formatted(schema); - - double actualTimeout = - Objects.requireNonNull(timeout, "getEvent timeout cannot be null").toMillis(); - var targetTime = System.currentTimeMillis() + actualTimeout; - var checkedDBForSleep = false; - var hasExistingNotification = false; + var startTime = Instant.now(); + var eventKey = new SignalKey.Event(workflowId, key); + dbPollingInterval = Objects.requireNonNullElse(dbPollingInterval, Duration.ofSeconds(1)); + GetEventResult result = null; + try (var eventSignal = notifcationRegistry.subscribe(eventKey)) { while (true) { - try (Connection conn = dataSource.getConnection(); - PreparedStatement stmt = conn.prepareStatement(sql)) { + var optResult = getEvent(ctx, workflowId, key); + if (optResult.isPresent()) { + result = optResult.get(); + break; + } - stmt.setString(1, targetUuid); - stmt.setString(2, key); + // check cancelled (both workflowId and caller.workflowId) - try (ResultSet rs = stmt.executeQuery()) { - if (rs.next()) { - String serializedValue = rs.getString("value"); - String serialization = rs.getString("serialization"); - value = - SerializationUtil.deserializeValue(serializedValue, serialization, serializer); - hasExistingNotification = true; - } - } - } + var sleepDuration = + caller != null + ? StepsDAO.durableSleepDuration( + ctx, caller.workflowId(), caller.timeoutStepId(), timeout) + : timeout.minus(Duration.between(startTime, Instant.now())); - if (hasExistingNotification) break; - var nowTime = System.currentTimeMillis(); - if (nowTime > targetTime) break; - - if (callerCtx != null && !checkedDBForSleep) { - actualTimeout = - StepsDAO.durableSleepDuration( - dataSource, - callerCtx.workflowId(), - callerCtx.timeoutFunctionId(), - timeout, - schema, - serializer) - .toMillis(); - targetTime = System.currentTimeMillis() + actualTimeout; - checkedDBForSleep = true; - if (nowTime > targetTime) break; + if (sleepDuration.isNegative() || sleepDuration.isZero()) { + var serialized = SerializationUtil.serializeValue(null, null, ctx.serializer()); + result = new GetEventResult(serialized.serializedValue(), serialized.serialization()); + break; } - try { - long timeoutms = (long) (targetTime - nowTime); - logger.debug("Waiting for notification {}...", timeout); - lockConditionPair.condition.await( - Math.min(timeoutms, dbPollingInterval.toMillis()), TimeUnit.MILLISECONDS); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException("Interrupted while waiting for event", e); - } - } + var loopDuration = + dbPollingInterval.compareTo(sleepDuration) <= 0 ? dbPollingInterval : sleepDuration; - if (callerCtx != null) { - var toSaveSer = SerializationUtil.serializeValue(value, null, serializer); - StepResult output = - new StepResult( - callerCtx.workflowId(), - callerCtx.functionId(), - functionName, - null, - null, - null, - toSaveSer.serialization()) - .withOutput(toSaveSer.serializedValue()); - StepsDAO.recordStepResultTxn( - dataSource, output, startTime, System.currentTimeMillis(), schema); + SignalMap.awaitAny(loopDuration, eventSignal); } + } - return value; - - } finally { - lockConditionPair.lock.unlock(); - notificationService.unregisterNotificationCondition(payload); + Objects.requireNonNull(result); + ctx.checkClosed(); + + if (caller != null) { + var stepResult = + new StepResult( + caller.workflowId(), + caller.stepId(), + stepName, + result.value(), + null, + null, + result.serialization()); + StepsDAO.recordStepResult(ctx, stepResult, startTime.toEpochMilli()); } + + return SerializationUtil.deserializeValue( + result.value(), result.serialization(), ctx.serializer()); } - static List getAllNotifications( - DataSource dataSource, String schema, DBOSSerializer serializer, String workflowId) + public static List getAllNotifications(DbContext ctx, String workflowId) throws SQLException { + DBOSSerializer serializer = ctx.serializer(); var sql = """ SELECT topic, message, serialization, created_at_epoch_ms, consumed @@ -545,10 +504,10 @@ static List getAllNotifications( WHERE destination_uuid = ? ORDER BY created_at_epoch_ms """ - .formatted(schema); + .formatted(ctx.schema()); var notifications = new ArrayList(); - try (var conn = dataSource.getConnection(); + try (var conn = ctx.getConnection(); var stmt = conn.prepareStatement(sql)) { stmt.setString(1, workflowId); try (var rs = stmt.executeQuery()) { diff --git a/transact/src/main/java/dev/dbos/transact/database/QueuesDAO.java b/transact/src/main/java/dev/dbos/transact/database/dao/QueuesDAO.java similarity index 89% rename from transact/src/main/java/dev/dbos/transact/database/QueuesDAO.java rename to transact/src/main/java/dev/dbos/transact/database/dao/QueuesDAO.java index c0adfc482..f1544adff 100644 --- a/transact/src/main/java/dev/dbos/transact/database/QueuesDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/dao/QueuesDAO.java @@ -1,5 +1,6 @@ -package dev.dbos.transact.database; +package dev.dbos.transact.database.dao; +import dev.dbos.transact.database.DbContext; import dev.dbos.transact.workflow.Queue; import dev.dbos.transact.workflow.WorkflowState; @@ -14,31 +15,24 @@ import java.util.List; import java.util.Map; -import javax.sql.DataSource; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; -class QueuesDAO { +public class QueuesDAO { private QueuesDAO() {} private static final Logger logger = LoggerFactory.getLogger(QueuesDAO.class); - static List getAndStartQueuedWorkflows( - DataSource dataSource, - String schema, - Queue queue, - String executorId, - String appVersion, - String partitionKey) + public static List getAndStartQueuedWorkflows( + DbContext ctx, Queue queue, String executorId, String appVersion, String partitionKey) throws SQLException { if (partitionKey != null && partitionKey.length() == 0) { partitionKey = null; } - try (Connection connection = dataSource.getConnection()) { + try (Connection connection = ctx.getConnection()) { connection.setAutoCommit(false); try (Statement stmt = connection.createStatement()) { @@ -59,7 +53,7 @@ SELECT COUNT(*) AND status NOT IN (?, ?) AND started_at_epoch_ms > ? """ - .formatted(schema); + .formatted(ctx.schema()); if (partitionKey != null) { limiterQuery += " AND queue_partition_key = ?"; } @@ -94,7 +88,7 @@ SELECT executor_id, COUNT(*) as task_count FROM "%s".workflow_status WHERE queue_name = ? AND status = ? """ - .formatted(schema); + .formatted(ctx.schema()); if (partitionKey != null) { pendingQuery += " AND queue_partition_key = ?"; } @@ -163,7 +157,7 @@ SELECT executor_id, COUNT(*) as task_count AND status = ? AND (application_version = ? OR application_version IS NULL) """ - .formatted(schema); + .formatted(ctx.schema()); if (partitionKey != null) { query += " AND queue_partition_key = ?"; } @@ -216,12 +210,12 @@ SELECT executor_id, COUNT(*) as task_count started_at_epoch_ms = ?, workflow_deadline_epoch_ms = CASE WHEN workflow_timeout_ms IS NOT NULL AND workflow_deadline_epoch_ms IS NULL - THEN EXTRACT(epoch FROM NOW()) * 1000 + workflow_timeout_ms + THEN (EXTRACT(epoch FROM now()) * 1000)::bigint + workflow_timeout_ms ELSE workflow_deadline_epoch_ms END WHERE workflow_uuid = ? """ - .formatted(schema); + .formatted(ctx.schema()); try (var ps = connection.prepareStatement(updateQuery)) { for (var id : dequeuedWorkflowIds) { @@ -251,8 +245,7 @@ THEN EXTRACT(epoch FROM NOW()) * 1000 + workflow_timeout_ms } } - static boolean clearQueueAssignment(DataSource dataSource, String schema, String workflowId) - throws SQLException { + public static boolean clearQueueAssignment(DbContext ctx, String workflowId) throws SQLException { final String sql = """ @@ -260,8 +253,8 @@ static boolean clearQueueAssignment(DataSource dataSource, String schema, String SET started_at_epoch_ms = NULL, status = ? WHERE workflow_uuid = ? AND queue_name IS NOT NULL AND status = ? """ - .formatted(schema); - try (Connection connection = dataSource.getConnection(); + .formatted(ctx.schema()); + try (Connection connection = ctx.getConnection(); PreparedStatement stmt = connection.prepareStatement(sql)) { stmt.setString(1, WorkflowState.ENQUEUED.name()); stmt.setString(2, workflowId); @@ -272,7 +265,7 @@ static boolean clearQueueAssignment(DataSource dataSource, String schema, String } } - static List getQueuePartitions(DataSource dataSource, String schema, String queueName) + public static List getQueuePartitions(DbContext ctx, String queueName) throws SQLException { final String sql = @@ -283,9 +276,9 @@ static List getQueuePartitions(DataSource dataSource, String schema, Str AND status = ? AND queue_partition_key IS NOT NULL """ - .formatted(schema); + .formatted(ctx.schema()); - try (Connection connection = dataSource.getConnection(); + try (Connection connection = ctx.getConnection(); PreparedStatement stmt = connection.prepareStatement(sql)) { stmt.setString(1, queueName); stmt.setString(2, WorkflowState.ENQUEUED.name()); diff --git a/transact/src/main/java/dev/dbos/transact/database/SchedulesDAO.java b/transact/src/main/java/dev/dbos/transact/database/dao/SchedulesDAO.java similarity index 77% rename from transact/src/main/java/dev/dbos/transact/database/SchedulesDAO.java rename to transact/src/main/java/dev/dbos/transact/database/dao/SchedulesDAO.java index 8e2f07159..e846940dd 100644 --- a/transact/src/main/java/dev/dbos/transact/database/SchedulesDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/dao/SchedulesDAO.java @@ -1,5 +1,6 @@ -package dev.dbos.transact.database; +package dev.dbos.transact.database.dao; +import dev.dbos.transact.database.DbContext; import dev.dbos.transact.execution.SchedulerService; import dev.dbos.transact.json.DBOSSerializer; import dev.dbos.transact.json.SerializationUtil; @@ -20,17 +21,13 @@ import java.util.StringJoiner; import java.util.UUID; -import javax.sql.DataSource; - -class SchedulesDAO { +public class SchedulesDAO { private SchedulesDAO() {} - static void createSchedule( - DataSource dataSource, String schema, DBOSSerializer serializer, WorkflowSchedule schedule) - throws SQLException { - try (Connection conn = dataSource.getConnection()) { - createSchedule(conn, schema, serializer, schedule); + public static void createSchedule(DbContext ctx, WorkflowSchedule schedule) throws SQLException { + try (Connection conn = ctx.getConnection()) { + createSchedule(conn, ctx.schema(), ctx.serializer(), schedule); } } @@ -84,15 +81,14 @@ static void createSchedule( } } - static List listSchedules( - DataSource dataSource, - String schema, - DBOSSerializer serializer, + public static List listSchedules( + DbContext ctx, List statuses, List workflowNames, List scheduleNamePrefixes) throws SQLException { + DBOSSerializer serializer = ctx.serializer(); StringBuilder sql = new StringBuilder( """ @@ -102,7 +98,7 @@ static List listSchedules( FROM "%s".workflow_schedules WHERE TRUE """ - .formatted(schema)); + .formatted(ctx.schema())); List params = new ArrayList<>(); @@ -124,7 +120,7 @@ static List listSchedules( sql.append(orClauses).append(")"); } - try (Connection conn = dataSource.getConnection(); + try (Connection conn = ctx.getConnection(); PreparedStatement ps = conn.prepareStatement(sql.toString())) { List arrays = new ArrayList<>(); try { @@ -153,9 +149,9 @@ static List listSchedules( } } - static Optional getSchedule( - DataSource dataSource, String schema, DBOSSerializer serializer, String name) + public static Optional getSchedule(DbContext ctx, String name) throws SQLException { + DBOSSerializer serializer = ctx.serializer(); String sql = """ SELECT schedule_id, schedule_name, workflow_name, workflow_class_name, @@ -164,9 +160,9 @@ static Optional getSchedule( FROM "%s".workflow_schedules WHERE schedule_name = ? """ - .formatted(schema); + .formatted(ctx.schema()); - try (Connection conn = dataSource.getConnection(); + try (Connection conn = ctx.getConnection(); PreparedStatement ps = conn.prepareStatement(sql)) { ps.setString(1, name); try (ResultSet rs = ps.executeQuery()) { @@ -178,25 +174,23 @@ static Optional getSchedule( } } - static void pauseSchedule(DataSource dataSource, String schema, String name) throws SQLException { - setScheduleStatus(dataSource, schema, name, ScheduleStatus.PAUSED); + public static void pauseSchedule(DbContext ctx, String name) throws SQLException { + setScheduleStatus(ctx, name, ScheduleStatus.PAUSED); } - static void resumeSchedule(DataSource dataSource, String schema, String name) - throws SQLException { - setScheduleStatus(dataSource, schema, name, ScheduleStatus.ACTIVE); + public static void resumeSchedule(DbContext ctx, String name) throws SQLException { + setScheduleStatus(ctx, name, ScheduleStatus.ACTIVE); } - private static void setScheduleStatus( - DataSource dataSource, String schema, String name, ScheduleStatus status) + private static void setScheduleStatus(DbContext ctx, String name, ScheduleStatus status) throws SQLException { String sql = """ UPDATE "%s".workflow_schedules SET status = ? WHERE schedule_name = ? """ - .formatted(schema); + .formatted(ctx.schema()); - try (Connection conn = dataSource.getConnection(); + try (Connection conn = ctx.getConnection(); PreparedStatement ps = conn.prepareStatement(sql)) { ps.setString(1, status.name()); ps.setString(2, name); @@ -204,15 +198,15 @@ private static void setScheduleStatus( } } - static void updateScheduleLastFiredAt( - DataSource dataSource, String schema, String name, Instant lastFiredAt) throws SQLException { + public static void updateScheduleLastFiredAt(DbContext ctx, String name, Instant lastFiredAt) + throws SQLException { String sql = """ UPDATE "%s".workflow_schedules SET last_fired_at = ? WHERE schedule_name = ? """ - .formatted(schema); + .formatted(ctx.schema()); - try (Connection conn = dataSource.getConnection(); + try (Connection conn = ctx.getConnection(); PreparedStatement ps = conn.prepareStatement(sql)) { ps.setString(1, lastFiredAt != null ? lastFiredAt.toString() : null); ps.setString(2, name); @@ -220,10 +214,9 @@ static void updateScheduleLastFiredAt( } } - static void deleteSchedule(DataSource dataSource, String schema, String name) - throws SQLException { - try (Connection conn = dataSource.getConnection()) { - deleteSchedule(conn, schema, name); + public static void deleteSchedule(DbContext ctx, String name) throws SQLException { + try (var conn = ctx.getConnection()) { + deleteSchedule(conn, ctx.schema(), name); } } @@ -234,27 +227,23 @@ static void deleteSchedule(Connection conn, String schema, String name) throws S """ .formatted(schema); - try (PreparedStatement ps = conn.prepareStatement(sql)) { - ps.setString(1, name); - ps.executeUpdate(); + try (var stmt = conn.prepareStatement(sql)) { + stmt.setString(1, name); + stmt.executeUpdate(); } } - static void applySchedules( - DataSource dataSource, - String schema, - DBOSSerializer serializer, - List schedules) + public static void applySchedules(DbContext ctx, List schedules) throws SQLException { - try (Connection conn = dataSource.getConnection()) { + try (var conn = ctx.getConnection()) { conn.setAutoCommit(false); try { for (WorkflowSchedule schedule : schedules) { - deleteSchedule(conn, schema, schedule.scheduleName()); + deleteSchedule(conn, ctx.schema(), schedule.scheduleName()); createSchedule( conn, - schema, - serializer, + ctx.schema(), + ctx.serializer(), schedule .withScheduleId(UUID.randomUUID().toString()) .withStatus(ScheduleStatus.ACTIVE) @@ -264,8 +253,6 @@ static void applySchedules( } catch (SQLException e) { conn.rollback(); throw e; - } finally { - conn.setAutoCommit(true); } } } diff --git a/transact/src/main/java/dev/dbos/transact/database/StepsDAO.java b/transact/src/main/java/dev/dbos/transact/database/dao/StepsDAO.java similarity index 69% rename from transact/src/main/java/dev/dbos/transact/database/StepsDAO.java rename to transact/src/main/java/dev/dbos/transact/database/dao/StepsDAO.java index f6cc20523..45015c3c9 100644 --- a/transact/src/main/java/dev/dbos/transact/database/StepsDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/dao/StepsDAO.java @@ -1,5 +1,6 @@ -package dev.dbos.transact.database; +package dev.dbos.transact.database.dao; +import dev.dbos.transact.database.DbContext; import dev.dbos.transact.exceptions.*; import dev.dbos.transact.internal.DebugTriggers; import dev.dbos.transact.json.DBOSSerializer; @@ -16,36 +17,37 @@ import java.util.List; import java.util.Objects; -import javax.sql.DataSource; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; -class StepsDAO { +public class StepsDAO { private StepsDAO() {} private static final Logger logger = LoggerFactory.getLogger(StepsDAO.class); - static void recordStepResultTxn( - DataSource dataSource, - StepResult result, - long startTimeEpochMs, - long endTimeEpochMs, - String schema) + static void recordStepResult(DbContext ctx, StepResult result, long startTimeEpochMs) + throws SQLException { + recordStepResult(ctx, result, startTimeEpochMs, System.currentTimeMillis()); + } + + public static void recordStepResult( + DbContext ctx, StepResult result, long startTimeEpochMs, long endTimeEpochMs) throws SQLException { - try (Connection connection = dataSource.getConnection()) { - recordStepResultTxn(result, startTimeEpochMs, endTimeEpochMs, connection, schema); + try (var conn = ctx.getConnection()) { + recordStepResult(conn, ctx.schema(), result, startTimeEpochMs, endTimeEpochMs); } DebugTriggers.debugTriggerPoint(DebugTriggers.DEBUG_TRIGGER_STEP_COMMIT); } - static void recordStepResultTxn( - StepResult result, - Long startTimeEpochMs, - Long endTimeEpochMs, - Connection connection, - String schema) + static void recordStepResult( + Connection conn, String schema, StepResult result, long startTimeEpochMs) + throws SQLException { + recordStepResult(conn, schema, result, startTimeEpochMs, System.currentTimeMillis()); + } + + static void recordStepResult( + Connection conn, String schema, StepResult result, Long startTimeEpochMs, Long endTimeEpochMs) throws SQLException { Objects.requireNonNull(schema); @@ -58,33 +60,33 @@ static void recordStepResultTxn( """ .formatted(schema); - try (PreparedStatement pstmt = connection.prepareStatement(sql)) { - pstmt.setString(1, result.workflowId()); - pstmt.setInt(2, result.stepId()); - pstmt.setString(3, result.stepName()); + try (var stmt = conn.prepareStatement(sql)) { + stmt.setString(1, result.workflowId()); + stmt.setInt(2, result.stepId()); + stmt.setString(3, result.stepName()); if (result.output() != null) { - pstmt.setString(4, result.output()); + stmt.setString(4, result.output()); } else { - pstmt.setNull(4, Types.LONGVARCHAR); + stmt.setNull(4, Types.LONGVARCHAR); } if (result.error() != null) { - pstmt.setString(5, result.error()); + stmt.setString(5, result.error()); } else { - pstmt.setNull(5, Types.LONGVARCHAR); + stmt.setNull(5, Types.LONGVARCHAR); } if (result.childWorkflowId() != null) { - pstmt.setString(6, result.childWorkflowId()); + stmt.setString(6, result.childWorkflowId()); } else { - pstmt.setNull(6, Types.VARCHAR); + stmt.setNull(6, Types.VARCHAR); } - pstmt.setObject(7, startTimeEpochMs); - pstmt.setObject(8, endTimeEpochMs); + stmt.setObject(7, startTimeEpochMs); + stmt.setObject(8, endTimeEpochMs); - try (ResultSet rs = pstmt.executeQuery()) { + try (ResultSet rs = stmt.executeQuery()) { if (rs.next() && endTimeEpochMs != null) { long completedAt = rs.getLong("completed_at_epoch_ms"); if (completedAt != endTimeEpochMs) { @@ -106,9 +108,16 @@ static void recordStepResultTxn( } } - static StepResult checkStepExecutionTxn( - String workflowId, int functionId, String functionName, Connection connection, String schema) - throws SQLException, DBOSWorkflowCancelledException, DBOSUnexpectedStepException { + static StepResult checkStepResult( + DbContext ctx, String workflowId, int functionId, String functionName) throws SQLException { + try (var conn = ctx.getConnection()) { + return checkStepResult(conn, ctx.schema(), workflowId, functionId, functionName); + } + } + + public static StepResult checkStepResult( + Connection conn, String schema, String workflowId, int functionId, String functionName) + throws SQLException { Objects.requireNonNull(schema); final String sql = @@ -118,7 +127,7 @@ static StepResult checkStepExecutionTxn( .formatted(schema); String workflowStatus = null; - try (PreparedStatement pstmt = connection.prepareStatement(sql)) { + try (var pstmt = conn.prepareStatement(sql)) { pstmt.setString(1, workflowId); try (ResultSet rs = pstmt.executeQuery()) { if (rs.next()) { @@ -147,7 +156,7 @@ static StepResult checkStepExecutionTxn( StepResult recordedResult = null; String recordedFunctionName = null; - try (PreparedStatement pstmt = connection.prepareStatement(operationOutputSql)) { + try (var pstmt = conn.prepareStatement(operationOutputSql)) { pstmt.setString(1, workflowId); pstmt.setInt(2, functionId); try (ResultSet rs = pstmt.executeQuery()) { @@ -175,39 +184,33 @@ static StepResult checkStepExecutionTxn( return recordedResult; } - static List listWorkflowSteps( - DataSource dataSource, - String workflowId, - Boolean loadOutput, - Integer limit, - Integer offset, - String schema, - DBOSSerializer serializer) + public static List listWorkflowSteps( + DbContext ctx, String workflowId, Boolean loadOutput, Integer limit, Integer offset) throws SQLException { - try (Connection connection = dataSource.getConnection()) { + try (var conn = ctx.getConnection()) { return listWorkflowSteps( - connection, workflowId, loadOutput, limit, offset, schema, serializer); + conn, ctx.schema(), ctx.serializer(), workflowId, loadOutput, limit, offset); } } static List listWorkflowSteps( - Connection connection, + Connection conn, + String schema, + DBOSSerializer serializer, String workflowId, Boolean loadOutput, Integer limit, - Integer offset, - String schema, - DBOSSerializer serializer) + Integer offset) throws SQLException { StringBuilder sqlBuilder = new StringBuilder( """ - SELECT function_id, function_name, output, error, child_workflow_id, started_at_epoch_ms, completed_at_epoch_ms, serialization - FROM "%s".operation_outputs - WHERE workflow_uuid = ? - ORDER BY function_id - """ + SELECT function_id, function_name, output, error, child_workflow_id, started_at_epoch_ms, completed_at_epoch_ms, serialization + FROM "%s".operation_outputs + WHERE workflow_uuid = ? + ORDER BY function_id + """ .formatted(schema)); if (limit != null) { @@ -221,7 +224,7 @@ static List listWorkflowSteps( List steps = new ArrayList<>(); - try (PreparedStatement stmt = connection.prepareStatement(sql)) { + try (var stmt = conn.prepareStatement(sql)) { int paramIndex = 1; stmt.setString(paramIndex++, workflowId); @@ -277,16 +280,9 @@ static List listWorkflowSteps( return steps; } - static void sleep( - DataSource dataSource, - String workflowUuid, - int functionId, - Duration duration, - String schema, - DBOSSerializer serializer) + public static void sleep(DbContext ctx, String workflowUuid, int functionId, Duration duration) throws SQLException { - var sleepDuration = - durableSleepDuration(dataSource, workflowUuid, functionId, duration, schema, serializer); + var sleepDuration = durableSleepDuration(ctx, workflowUuid, functionId, duration); logger.debug("Sleeping for duration {}", sleepDuration); try { Thread.sleep(sleepDuration.toMillis()); @@ -296,7 +292,7 @@ static void sleep( } } - static String getCheckpointName(Connection conn, String workflowId, int functionId, String schema) + static String getCheckpointName(Connection conn, String schema, String workflowId, int functionId) throws SQLException { var sql = """ @@ -306,10 +302,10 @@ static String getCheckpointName(Connection conn, String workflowId, int function """ .formatted(schema); - try (var ps = conn.prepareStatement(sql)) { - ps.setString(1, workflowId); - ps.setInt(2, functionId); - try (var rs = ps.executeQuery()) { + try (var stmt = conn.prepareStatement(sql)) { + stmt.setString(1, workflowId); + stmt.setInt(2, functionId); + try (var rs = stmt.executeQuery()) { if (rs.next()) { return rs.getString("function_name"); } else { @@ -319,15 +315,14 @@ static String getCheckpointName(Connection conn, String workflowId, int function } } - static boolean patch( - DataSource dataSource, String workflowId, int functionId, String patchName, String schema) + public static boolean patch(DbContext ctx, String workflowId, int functionId, String patchName) throws SQLException { Objects.requireNonNull(patchName, "patchName cannot be null"); - try (Connection conn = dataSource.getConnection()) { - var checkpointName = getCheckpointName(conn, workflowId, functionId, schema); + try (var conn = ctx.getConnection()) { + var checkpointName = getCheckpointName(conn, ctx.schema(), workflowId, functionId); if (checkpointName == null) { var output = new StepResult(workflowId, functionId, patchName, null, null, null, null); - recordStepResultTxn(output, System.currentTimeMillis(), null, conn, schema); + recordStepResult(conn, ctx.schema(), output, System.currentTimeMillis(), null); return true; } else { return patchName.equals(checkpointName); @@ -335,34 +330,27 @@ static boolean patch( } } - static boolean deprecatePatch( - DataSource dataSource, String workflowId, int functionId, String patchName, String schema) - throws SQLException { + public static boolean deprecatePatch( + DbContext ctx, String workflowId, int functionId, String patchName) throws SQLException { Objects.requireNonNull(patchName, "patchName cannot be null"); - try (Connection conn = dataSource.getConnection()) { - var checkpointName = getCheckpointName(conn, workflowId, functionId, schema); + try (var conn = ctx.getConnection()) { + var checkpointName = getCheckpointName(conn, ctx.schema(), workflowId, functionId); return patchName.equals(checkpointName); } } static Duration durableSleepDuration( - DataSource dataSource, - String workflowUuid, - int functionId, - Duration duration, - String schema, - DBOSSerializer serializer) - throws SQLException { + DbContext ctx, String workflowUuid, int functionId, Duration duration) throws SQLException { - Objects.requireNonNull(schema); + DBOSSerializer serializer = ctx.serializer(); + Objects.requireNonNull(ctx.schema()); var startTime = System.currentTimeMillis(); String functionName = "DBOS.sleep"; StepResult recordedOutput; - try (Connection connection = dataSource.getConnection()) { - recordedOutput = - checkStepExecutionTxn(workflowUuid, functionId, functionName, connection, schema); + try (var conn = ctx.getConnection()) { + recordedOutput = checkStepResult(conn, ctx.schema(), workflowUuid, functionId, functionName); } long endTime; @@ -398,7 +386,7 @@ static Duration durableSleepDuration( null, null, serializedValue.serialization()); - recordStepResultTxn(dataSource, output, startTime, (long) endTime, schema); + recordStepResult(ctx, output, startTime, (long) endTime); } catch (DBOSWorkflowExecutionConflictException e) { logger.error("Error recording sleep", e); } diff --git a/transact/src/main/java/dev/dbos/transact/database/StreamsDAO.java b/transact/src/main/java/dev/dbos/transact/database/dao/StreamsDAO.java similarity index 78% rename from transact/src/main/java/dev/dbos/transact/database/StreamsDAO.java rename to transact/src/main/java/dev/dbos/transact/database/dao/StreamsDAO.java index 59f474035..02a7a05f1 100644 --- a/transact/src/main/java/dev/dbos/transact/database/StreamsDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/dao/StreamsDAO.java @@ -1,5 +1,6 @@ -package dev.dbos.transact.database; +package dev.dbos.transact.database.dao; +import dev.dbos.transact.database.DbContext; import dev.dbos.transact.json.SerializationUtil; import dev.dbos.transact.workflow.internal.StepResult; @@ -10,29 +11,25 @@ import java.util.List; import java.util.Map; -import javax.sql.DataSource; - -class StreamsDAO { +public class StreamsDAO { private StreamsDAO() {} - static void writeStreamFromStep( - DataSource dataSource, - String schema, + public static void writeStreamFromStep( + DbContext ctx, String workflowId, int functionId, String key, Object value, String serializationFormat) throws SQLException { - try (Connection conn = dataSource.getConnection()) { - insertStream(conn, schema, workflowId, functionId, key, value, serializationFormat); + try (var conn = ctx.getConnection()) { + insertStream(conn, ctx.schema(), workflowId, functionId, key, value, serializationFormat); } } - static void writeStreamFromWorkflow( - DataSource dataSource, - String schema, + public static void writeStreamFromWorkflow( + DbContext ctx, String workflowId, int functionId, String key, @@ -43,12 +40,12 @@ static void writeStreamFromWorkflow( STREAM_CLOSED_SENTINEL.equals(value) ? "DBOS.closeStream" : "DBOS.writeStream"; long startTime = System.currentTimeMillis(); - try (Connection conn = dataSource.getConnection()) { + try (var conn = ctx.getConnection()) { conn.setAutoCommit(false); try { StepResult recordedOutput = - StepsDAO.checkStepExecutionTxn(workflowId, functionId, functionName, conn, schema); + StepsDAO.checkStepResult(conn, ctx.schema(), workflowId, functionId, functionName); if (recordedOutput != null) { logger.debug("Replaying writeStream, id: {}, key: {}", functionId, key); @@ -58,10 +55,11 @@ static void writeStreamFromWorkflow( logger.debug("Running writeStream, id: {}, key: {}", functionId, key); } - insertStream(conn, schema, workflowId, functionId, key, value, serializationFormat); + insertStream(conn, ctx.schema(), workflowId, functionId, key, value, serializationFormat); var output = new StepResult(workflowId, functionId, functionName, null, null, null, null); - StepsDAO.recordStepResultTxn(output, startTime, System.currentTimeMillis(), conn, schema); + StepsDAO.recordStepResult( + conn, ctx.schema(), output, startTime, System.currentTimeMillis()); conn.commit(); @@ -128,15 +126,13 @@ SELECT COALESCE(MAX("offset"), -1) + 1 } } - static void closeStream( - DataSource dataSource, String schema, String workflowId, int functionId, String key) + public static void closeStream(DbContext ctx, String workflowId, int functionId, String key) throws SQLException { writeStreamFromWorkflow( - dataSource, schema, workflowId, functionId, key, STREAM_CLOSED_SENTINEL, "portable_json"); + ctx, workflowId, functionId, key, STREAM_CLOSED_SENTINEL, "portable_json"); } - static Object readStream( - DataSource dataSource, String schema, String workflowId, String key, int offset) + public static Object readStream(DbContext ctx, String workflowId, String key, int offset) throws SQLException { String sql = """ @@ -144,9 +140,9 @@ static Object readStream( FROM "%s".streams WHERE workflow_uuid = ? AND key = ? AND "offset" = ? """ - .formatted(schema); + .formatted(ctx.schema()); - try (Connection conn = dataSource.getConnection(); + try (Connection conn = ctx.getConnection(); var stmt = conn.prepareStatement(sql)) { stmt.setString(1, workflowId); stmt.setString(2, key); @@ -167,8 +163,8 @@ static Object readStream( } } - static Map> getAllStreamEntries( - DataSource dataSource, String schema, String workflowId) throws SQLException { + public static Map> getAllStreamEntries(DbContext ctx, String workflowId) + throws SQLException { String sql = """ SELECT key, value, serialization @@ -176,10 +172,10 @@ static Map> getAllStreamEntries( WHERE workflow_uuid = ? ORDER BY key, "offset" """ - .formatted(schema); + .formatted(ctx.schema()); var streams = new LinkedHashMap>(); - try (Connection conn = dataSource.getConnection(); + try (Connection conn = ctx.getConnection(); var stmt = conn.prepareStatement(sql)) { stmt.setString(1, workflowId); try (var rs = stmt.executeQuery()) { diff --git a/transact/src/main/java/dev/dbos/transact/database/WorkflowDAO.java b/transact/src/main/java/dev/dbos/transact/database/dao/WorkflowDAO.java similarity index 86% rename from transact/src/main/java/dev/dbos/transact/database/WorkflowDAO.java rename to transact/src/main/java/dev/dbos/transact/database/dao/WorkflowDAO.java index 4119571a2..08408d814 100644 --- a/transact/src/main/java/dev/dbos/transact/database/WorkflowDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/dao/WorkflowDAO.java @@ -1,6 +1,11 @@ -package dev.dbos.transact.database; +package dev.dbos.transact.database.dao; import dev.dbos.transact.Constants; +import dev.dbos.transact.database.DbContext; +import dev.dbos.transact.database.MetricData; +import dev.dbos.transact.database.Result; +import dev.dbos.transact.database.SystemDatabase; +import dev.dbos.transact.database.WorkflowInitResult; import dev.dbos.transact.exceptions.DBOSAwaitedWorkflowCancelledException; import dev.dbos.transact.exceptions.DBOSConflictingWorkflowException; import dev.dbos.transact.exceptions.DBOSMaxRecoveryAttemptsExceededException; @@ -46,12 +51,10 @@ import java.util.UUID; import java.util.stream.Stream; -import javax.sql.DataSource; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; -class WorkflowDAO { +public class WorkflowDAO { private static final Logger logger = LoggerFactory.getLogger(WorkflowDAO.class); @@ -71,10 +74,8 @@ class WorkflowDAO { private WorkflowDAO() {} - static WorkflowInitResult initWorkflowStatus( - DataSource dataSource, - String schema, - DBOSSerializer serializer, + public static WorkflowInitResult initWorkflowStatus( + DbContext ctx, WorkflowStatusInternal initStatus, Integer maxRetries, boolean isRecoveryRequest, @@ -84,17 +85,17 @@ static WorkflowInitResult initWorkflowStatus( logger.debug("initWorkflowStatus workflowId {}", initStatus.workflowId()); - try (Connection connection = dataSource.getConnection()) { + try (var conn = ctx.getConnection()) { boolean shouldCommit = false; try { - connection.setAutoCommit(false); - connection.setTransactionIsolation(Connection.TRANSACTION_READ_COMMITTED); + conn.setAutoCommit(false); + conn.setTransactionIsolation(Connection.TRANSACTION_READ_COMMITTED); InsertWorkflowResult resRow = insertWorkflowStatus( - connection, schema, initStatus, ownerXid, isRecoveryRequest || isDequeuedRequest); + conn, ctx.schema(), initStatus, ownerXid, isRecoveryRequest || isDequeuedRequest); if (!Objects.equals(resRow.workflowName(), initStatus.workflowName())) { String msg = @@ -142,9 +143,9 @@ static WorkflowInitResult initWorkflowStatus( SET status = ?, deduplication_id = NULL, started_at_epoch_ms = NULL, queue_name = NULL WHERE workflow_uuid = ? AND status = ? """ - .formatted(schema); + .formatted(ctx.schema()); - try (PreparedStatement stmt = connection.prepareStatement(sql)) { + try (PreparedStatement stmt = conn.prepareStatement(sql)) { stmt.setString(1, WorkflowState.MAX_RECOVERY_ATTEMPTS_EXCEEDED.name()); stmt.setString(2, initStatus.workflowId()); stmt.setString(3, WorkflowState.PENDING.name()); @@ -160,9 +161,9 @@ static WorkflowInitResult initWorkflowStatus( } finally { if (shouldCommit) { - connection.commit(); + conn.commit(); } else { - connection.rollback(); + conn.rollback(); } DebugTriggers.debugTriggerPoint(DebugTriggers.DEBUG_TRIGGER_INITWF_COMMIT); } @@ -188,7 +189,7 @@ static record InsertWorkflowResult( * @throws SQLException */ static InsertWorkflowResult insertWorkflowStatus( - Connection connection, + Connection conn, String schema, WorkflowStatusInternal status, String ownerXid, @@ -239,7 +240,7 @@ ON CONFLICT (workflow_uuid) status.authenticatedRoles() != null ? JsonUtility.toJson(status.authenticatedRoles()) : null; - try (PreparedStatement stmt = connection.prepareStatement(insertSQL)) { + try (var stmt = conn.prepareStatement(insertSQL)) { var now = System.currentTimeMillis(); stmt.setString(1, status.workflowId()); @@ -310,7 +311,7 @@ ON CONFLICT (workflow_uuid) } static void updateWorkflowOutcome( - Connection connection, + Connection conn, String schema, String workflowId, WorkflowState status, @@ -329,7 +330,7 @@ static void updateWorkflowOutcome( """ .formatted(schema); - try (PreparedStatement stmt = connection.prepareStatement(sql)) { + try (var stmt = conn.prepareStatement(sql)) { stmt.setString(1, status.name()); stmt.setString(2, output); stmt.setString(3, error); @@ -350,11 +351,11 @@ static void updateWorkflowOutcome( * @param workflowId id of the workflow * @param result output serialized as json */ - static void recordWorkflowOutput( - DataSource dataSource, String schema, String workflowId, String result) throws SQLException { + public static void recordWorkflowOutput(DbContext ctx, String workflowId, String result) + throws SQLException { - try (Connection connection = dataSource.getConnection()) { - updateWorkflowOutcome(connection, schema, workflowId, WorkflowState.SUCCESS, result, null); + try (var conn = ctx.getConnection()) { + updateWorkflowOutcome(conn, ctx.schema(), workflowId, WorkflowState.SUCCESS, result, null); } } @@ -364,20 +365,20 @@ static void recordWorkflowOutput( * @param workflowId id of the workflow * @param error output serialized as json */ - static void recordWorkflowError( - DataSource dataSource, String schema, String workflowId, String error) throws SQLException { + public static void recordWorkflowError(DbContext ctx, String workflowId, String error) + throws SQLException { - try (Connection connection = dataSource.getConnection()) { - updateWorkflowOutcome(connection, schema, workflowId, WorkflowState.ERROR, null, error); + try (var conn = ctx.getConnection()) { + updateWorkflowOutcome(conn, ctx.schema(), workflowId, WorkflowState.ERROR, null, error); } } - static String getWorkflowSerialization(DataSource dataSource, String schema, String workflowId) + public static String getWorkflowSerialization(DbContext ctx, String workflowId) throws SQLException { var sql = "SELECT serialization FROM \"%s\".workflow_status WHERE workflow_uuid = ?" - .formatted(schema); - try (var conn = dataSource.getConnection(); + .formatted(ctx.schema()); + try (var conn = ctx.getConnection(); var stmt = conn.prepareStatement(sql)) { stmt.setString(1, workflowId); try (var rs = stmt.executeQuery()) { @@ -389,16 +390,15 @@ static String getWorkflowSerialization(DataSource dataSource, String schema, Str return null; } - static WorkflowStatus getWorkflowStatus( - DataSource dataSource, String schema, DBOSSerializer serializer, String workflowId) + public static WorkflowStatus getWorkflowStatus(DbContext ctx, String workflowId) throws SQLException { - try (var conn = dataSource.getConnection()) { - return getWorkflowStatus(conn, schema, serializer, workflowId); + try (var conn = ctx.getConnection()) { + return getWorkflowStatus(conn, ctx.schema(), ctx.serializer(), workflowId); } } - static WorkflowStatus getWorkflowStatus( + public static WorkflowStatus getWorkflowStatus( Connection conn, String schema, DBOSSerializer serializer, String workflowId) throws SQLException { if (Objects.requireNonNull(workflowId, "workflowId must not be null").isEmpty()) { @@ -421,8 +421,7 @@ static WorkflowStatus getWorkflowStatus( return null; } - static void setWorkflowDelay( - DataSource dataSource, String schema, String workflowId, WorkflowDelay delay) + public static void setWorkflowDelay(DbContext ctx, String workflowId, WorkflowDelay delay) throws SQLException { Objects.requireNonNull(workflowId, "workflowId must not be null"); Objects.requireNonNull(delay, "delay must not be null"); @@ -446,8 +445,8 @@ static void setWorkflowDelay( WHERE workflow_uuid = ? AND status = ? """ - .formatted(schema); - try (var conn = dataSource.getConnection(); + .formatted(ctx.schema()); + try (var conn = ctx.getConnection(); var stmt = conn.prepareStatement(sql)) { stmt.setLong(1, resolved.toEpochMilli()); stmt.setString(2, workflowId); @@ -457,7 +456,7 @@ static void setWorkflowDelay( } } - static void transitionDelayedWorkflows(DataSource dataSource, String schema) throws SQLException { + public static void transitionDelayedWorkflows(DbContext ctx) throws SQLException { var sql = """ UPDATE "%s".workflow_status @@ -465,9 +464,9 @@ static void transitionDelayedWorkflows(DataSource dataSource, String schema) thr WHERE status = ? AND delay_until_epoch_ms <= ? """ - .formatted(schema); + .formatted(ctx.schema()); - try (var conn = dataSource.getConnection(); + try (var conn = ctx.getConnection(); var stmt = conn.prepareStatement(sql)) { stmt.setString(1, WorkflowState.ENQUEUED.name()); stmt.setString(2, WorkflowState.DELAYED.name()); @@ -477,10 +476,10 @@ static void transitionDelayedWorkflows(DataSource dataSource, String schema) thr } } - static List listWorkflows( - DataSource dataSource, String schema, DBOSSerializer serializer, ListWorkflowsInput input) + public static List listWorkflows(DbContext ctx, ListWorkflowsInput input) throws SQLException { + DBOSSerializer serializer = ctx.serializer(); if (input == null) { input = new ListWorkflowsInput(); } @@ -504,7 +503,7 @@ static List listWorkflows( sqlBuilder.append(", serialization"); } - sqlBuilder.append(" FROM \"%s\".workflow_status ".formatted(schema)); + sqlBuilder.append(" FROM \"%s\".workflow_status ".formatted(ctx.schema())); // --- WHERE Clauses --- StringJoiner whereConditions = new StringJoiner(" AND "); @@ -616,7 +615,7 @@ static List listWorkflows( parameters.add(input.offset()); } - try (Connection connection = dataSource.getConnection(); + try (Connection connection = ctx.getConnection(); PreparedStatement pstmt = connection.prepareStatement(sqlBuilder.toString())) { List arrays = new ArrayList<>(); try { @@ -653,12 +652,9 @@ static List listWorkflows( return workflows; } - static List getWorkflowAggregates( - DataSource dataSource, - String schema, - DBOSSerializer serializer, - GetWorkflowAggregatesInput input) - throws SQLException { + public static List getWorkflowAggregates( + DbContext ctx, GetWorkflowAggregatesInput input) throws SQLException { + if (input == null) { input = new GetWorkflowAggregatesInput(); } @@ -684,7 +680,7 @@ record GroupDim(String name, String column) {} StringJoiner selectCols = new StringJoiner(", "); for (var dim : dims) selectCols.add(dim.column()); selectCols.add("COUNT(*) AS count"); - sqlBuilder.append(selectCols).append(" FROM \"%s\".workflow_status".formatted(schema)); + sqlBuilder.append(selectCols).append(" FROM \"%s\".workflow_status".formatted(ctx.schema())); // --- WHERE --- StringJoiner whereConditions = new StringJoiner(" AND "); @@ -738,7 +734,7 @@ record GroupDim(String name, String column) {} sqlBuilder.append(" ORDER BY ").append(groupByCols); List results = new ArrayList<>(); - try (Connection connection = dataSource.getConnection(); + try (Connection connection = ctx.getConnection(); PreparedStatement pstmt = connection.prepareStatement(sqlBuilder.toString())) { List arrays = new ArrayList<>(); try { @@ -826,40 +822,32 @@ private static WorkflowStatus resultsToWorkflowStatus( return info; } - static List getPendingWorkflows( - DataSource dataSource, - String schema, - DBOSSerializer serializer, - List executorIds, - String appVersion) - throws SQLException { + public static List getPendingWorkflows( + DbContext ctx, List executorIds, String appVersion) throws SQLException { var input = new ListWorkflowsInput() .withStatus(WorkflowState.PENDING) .withExecutorIds(executorIds) .withApplicationVersion(appVersion); - return listWorkflows(dataSource, schema, serializer, input); + return listWorkflows(ctx, input); } @SuppressWarnings("unchecked") - static Result awaitWorkflowResult( - DataSource dataSource, - String schema, - DBOSSerializer serializer, - Duration dbPollingInterval, - String workflowId) - throws SQLException { + public static Result awaitWorkflowResult( + DbContext ctx, Duration dbPollingInterval, String workflowId) throws SQLException { + DBOSSerializer serializer = ctx.serializer(); final String sql = """ SELECT status, output, error, serialization FROM "%s".workflow_status WHERE workflow_uuid = ? """ - .formatted(schema); + .formatted(ctx.schema()); while (true) { - try (Connection connection = dataSource.getConnection(); + ctx.checkClosed(); + try (Connection connection = ctx.getConnection(); PreparedStatement stmt = connection.prepareStatement(sql)) { stmt.setString(1, workflowId); @@ -901,9 +889,8 @@ static Result awaitWorkflowResult( } } - static void recordChildWorkflow( - DataSource dataSource, - String schema, + public static void recordChildWorkflow( + DbContext ctx, String parentId, String childId, // workflowId of the child int functionId, // func id in the parent @@ -914,22 +901,21 @@ static void recordChildWorkflow( var result = new StepResult(parentId, functionId, functionName, null, null, null, null) .withChildWorkflowId(childId); - try (Connection connection = dataSource.getConnection()) { - StepsDAO.recordStepResultTxn(result, null, null, connection, schema); + try (var conn = ctx.getConnection()) { + StepsDAO.recordStepResult(conn, ctx.schema(), result, null, null); } } - static Optional checkChildWorkflow( - DataSource dataSource, String schema, String workflowUuid, int functionId) - throws SQLException { + public static Optional checkChildWorkflow( + DbContext ctx, String workflowUuid, int functionId) throws SQLException { final String sql = """ SELECT child_workflow_id FROM "%s".operation_outputs WHERE workflow_uuid = ? AND function_id = ? """ - .formatted(schema); + .formatted(ctx.schema()); - try (Connection connection = dataSource.getConnection(); + try (Connection connection = ctx.getConnection(); PreparedStatement stmt = connection.prepareStatement(sql)) { stmt.setString(1, workflowUuid); @@ -952,8 +938,7 @@ private static List filterNullsAndBlanks(List workflowIds) { return workflowIds.stream().filter(id -> id != null && !id.isBlank()).toList(); } - static void cancelWorkflows(DataSource dataSource, String schema, List workflowIds) - throws SQLException { + public static void cancelWorkflows(DbContext ctx, List workflowIds) throws SQLException { List filtered = filterNullsAndBlanks(workflowIds); if (filtered.isEmpty()) { return; @@ -969,9 +954,9 @@ static void cancelWorkflows(DataSource dataSource, String schema, List w WHERE workflow_uuid = ANY(?) AND status NOT IN (?, ?) """ - .formatted(schema); + .formatted(ctx.schema()); - try (Connection conn = dataSource.getConnection(); + try (Connection conn = ctx.getConnection(); PreparedStatement stmt = conn.prepareStatement(sql)) { Array array = conn.createArrayOf("text", filtered.toArray(String[]::new)); try { @@ -986,8 +971,7 @@ AND status NOT IN (?, ?) } } - static void resumeWorkflows( - DataSource dataSource, String schema, List workflowIds, String queueName) + public static void resumeWorkflows(DbContext ctx, List workflowIds, String queueName) throws SQLException { List filtered = filterNullsAndBlanks(workflowIds); if (filtered.isEmpty()) { @@ -1007,9 +991,9 @@ static void resumeWorkflows( WHERE workflow_uuid = ANY(?) AND status NOT IN (?, ?) """ - .formatted(schema); + .formatted(ctx.schema()); - try (Connection conn = dataSource.getConnection(); + try (Connection conn = ctx.getConnection(); PreparedStatement stmt = conn.prepareStatement(sql)) { Array array = conn.createArrayOf("text", filtered.toArray(String[]::new)); try { @@ -1025,9 +1009,8 @@ AND status NOT IN (?, ?) } } - static void deleteWorkflows( - DataSource dataSource, String schema, List workflowIds, boolean deleteChildren) - throws SQLException { + public static void deleteWorkflows( + DbContext ctx, List workflowIds, boolean deleteChildren) throws SQLException { List filtered = filterNullsAndBlanks(workflowIds); if (filtered.isEmpty()) { return; @@ -1036,7 +1019,7 @@ static void deleteWorkflows( var wfIdSet = new HashSet(filtered); if (deleteChildren) { for (var wfid : filtered) { - var children = getWorkflowChildren(dataSource, schema, wfid); + var children = getWorkflowChildren(ctx, wfid); wfIdSet.addAll(children); } } @@ -1046,9 +1029,9 @@ static void deleteWorkflows( DELETE FROM "%s".workflow_status WHERE workflow_uuid = ANY(?); """ - .formatted(schema); + .formatted(ctx.schema()); - try (var conn = dataSource.getConnection(); + try (var conn = ctx.getConnection(); var stmt = conn.prepareStatement(sql)) { var array = conn.createArrayOf("text", wfIdSet.toArray(String[]::new)); try { @@ -1060,7 +1043,7 @@ static void deleteWorkflows( } } - static Set getWorkflowChildren(DataSource dataSource, String schema, String workflowId) + public static Set getWorkflowChildren(DbContext ctx, String workflowId) throws SQLException { var children = new HashSet(); var toProcess = new ArrayDeque(); @@ -1072,9 +1055,9 @@ static Set getWorkflowChildren(DataSource dataSource, String schema, Str FROM "%s".operation_outputs WHERE workflow_uuid = ? AND child_workflow_id IS NOT NULL """ - .formatted(schema); + .formatted(ctx.schema()); - try (var conn = dataSource.getConnection(); + try (var conn = ctx.getConnection(); var stmt = conn.prepareStatement(sql)) { while (!toProcess.isEmpty()) { var wfid = toProcess.poll(); @@ -1094,18 +1077,13 @@ static Set getWorkflowChildren(DataSource dataSource, String schema, Str return children; } - static String forkWorkflow( - DataSource dataSource, - String schema, - DBOSSerializer serializer, - String originalWorkflowId, - int startStep, - ForkOptions options) + public static String forkWorkflow( + DbContext ctx, String originalWorkflowId, int startStep, ForkOptions options) throws SQLException { options = Objects.requireNonNullElseGet(options, ForkOptions::new); - var status = getWorkflowStatus(dataSource, schema, serializer, originalWorkflowId); + var status = getWorkflowStatus(ctx, originalWorkflowId); if (status == null) { throw new DBOSNonExistentWorkflowException(originalWorkflowId); } @@ -1124,52 +1102,52 @@ static String forkWorkflow( timeoutMS = explicit.value().toMillis(); } - try (Connection connection = dataSource.getConnection()) { - connection.setAutoCommit(false); + try (var conn = ctx.getConnection()) { + conn.setAutoCommit(false); try { // Create entry for forked workflow insertForkedWorkflowStatus( - connection, + conn, + ctx.schema(), + ctx.serializer(), originalWorkflowId, forkedWorkflowId, status, options.applicationVersion(), timeoutMS, options.queueName(), - options.queuePartitionKey(), - schema, - serializer); + options.queuePartitionKey()); // Copy operation outputs if starting from step > 0 if (startStep > 0) { - copyOperationOutputs(connection, originalWorkflowId, forkedWorkflowId, startStep, schema); + copyOperationOutputs(conn, ctx.schema(), originalWorkflowId, forkedWorkflowId, startStep); } // Mark the original workflow as having been forked - markWasForkedFrom(connection, originalWorkflowId, schema); + markWasForkedFrom(conn, ctx.schema(), originalWorkflowId); - connection.commit(); + conn.commit(); return forkedWorkflowId; } catch (SQLException e) { - connection.rollback(); + conn.rollback(); throw e; } } } private static void insertForkedWorkflowStatus( - Connection connection, + Connection conn, + String schema, + DBOSSerializer serializer, String originalWorkflowId, String forkedWorkflowId, WorkflowStatus originalStatus, String applicationVersion, Long timeoutMS, String queueName, - String queuePartitionKey, - String schema, - DBOSSerializer serializer) + String queuePartitionKey) throws SQLException { Objects.requireNonNull(schema); @@ -1183,7 +1161,7 @@ private static void insertForkedWorkflowStatus( """ .formatted(schema); - try (PreparedStatement stmt = connection.prepareStatement(sql)) { + try (var stmt = conn.prepareStatement(sql)) { stmt.setString(1, forkedWorkflowId); stmt.setString(2, WorkflowState.ENQUEUED.name()); stmt.setString(3, originalStatus.workflowName()); @@ -1213,7 +1191,7 @@ private static void insertForkedWorkflowStatus( } } - private static void markWasForkedFrom(Connection connection, String workflowId, String schema) + private static void markWasForkedFrom(Connection conn, String schema, String workflowId) throws SQLException { String sql = """ @@ -1222,18 +1200,18 @@ private static void markWasForkedFrom(Connection connection, String workflowId, WHERE workflow_uuid = ? """ .formatted(schema); - try (PreparedStatement stmt = connection.prepareStatement(sql)) { + try (var stmt = conn.prepareStatement(sql)) { stmt.setString(1, workflowId); stmt.executeUpdate(); } } private static void copyOperationOutputs( - Connection connection, + Connection conn, + String schema, String originalWorkflowId, String forkedWorkflowId, - int startStep, - String schema) + int startStep) throws SQLException { String stepOutputsSql = @@ -1245,7 +1223,7 @@ private static void copyOperationOutputs( WHERE workflow_uuid = ? AND function_id < ? """ .formatted(schema); - try (PreparedStatement stmt = connection.prepareStatement(stepOutputsSql)) { + try (var stmt = conn.prepareStatement(stepOutputsSql)) { stmt.setString(1, forkedWorkflowId); stmt.setString(2, originalWorkflowId); stmt.setInt(3, startStep); @@ -1263,7 +1241,7 @@ private static void copyOperationOutputs( WHERE workflow_uuid = ? AND function_id < ? """ .formatted(schema); - try (PreparedStatement stmt = connection.prepareStatement(eventHistorySql)) { + try (var stmt = conn.prepareStatement(eventHistorySql)) { stmt.setString(1, forkedWorkflowId); stmt.setString(2, originalWorkflowId); stmt.setInt(3, startStep); @@ -1289,7 +1267,7 @@ SELECT MAX(weh2.function_id) """ .formatted(schema); - try (PreparedStatement stmt = connection.prepareStatement(eventSql)) { + try (var stmt = conn.prepareStatement(eventSql)) { stmt.setString(1, forkedWorkflowId); stmt.setString(2, originalWorkflowId); stmt.setString(3, originalWorkflowId); @@ -1308,7 +1286,7 @@ SELECT MAX(weh2.function_id) WHERE workflow_uuid = ? AND function_id < ? """ .formatted(schema); - try (PreparedStatement stmt = connection.prepareStatement(streamsSql)) { + try (var stmt = conn.prepareStatement(streamsSql)) { stmt.setString(1, forkedWorkflowId); stmt.setString(2, originalWorkflowId); stmt.setInt(3, startStep); @@ -1318,14 +1296,14 @@ SELECT MAX(weh2.function_id) } } - private static Instant getRowsCutoff(Connection connection, long rowsThreshold, String schema) + private static Instant getRowsCutoff(Connection conn, String schema, long rowsThreshold) throws SQLException { String sql = """ SELECT created_at FROM "%s".workflow_status ORDER BY created_at DESC OFFSET ? LIMIT 1 """ .formatted(schema); - try (PreparedStatement stmt = connection.prepareStatement(sql)) { + try (var stmt = conn.prepareStatement(sql)) { stmt.setLong(1, rowsThreshold - 1); try (ResultSet rs = stmt.executeQuery()) { if (rs.next()) { @@ -1337,13 +1315,12 @@ private static Instant getRowsCutoff(Connection connection, long rowsThreshold, return null; } - static void garbageCollect( - DataSource dataSource, String schema, Instant cutoff, Long rowsThreshold) + public static void garbageCollect(DbContext ctx, Instant cutoff, Long rowsThreshold) throws SQLException { - try (Connection connection = dataSource.getConnection()) { + try (var conn = ctx.getConnection()) { if (rowsThreshold != null) { - var rowsCutoff = getRowsCutoff(connection, rowsThreshold, schema); + var rowsCutoff = getRowsCutoff(conn, ctx.schema(), rowsThreshold); if (rowsCutoff != null) { if (cutoff == null || rowsCutoff.isAfter(cutoff)) { cutoff = rowsCutoff; @@ -1356,8 +1333,8 @@ static void garbageCollect( """ DELETE FROM "%s".workflow_status WHERE created_at < ? AND status NOT IN (?, ?, ?) """ - .formatted(schema); - try (PreparedStatement stmt = connection.prepareStatement(sql)) { + .formatted(ctx.schema()); + try (var stmt = conn.prepareStatement(sql)) { stmt.setLong(1, cutoff.toEpochMilli()); stmt.setString(2, WorkflowState.PENDING.name()); stmt.setString(3, WorkflowState.ENQUEUED.name()); @@ -1369,8 +1346,7 @@ static void garbageCollect( } } - static List getMetrics( - DataSource dataSource, String schema, Instant startTime, Instant endTime) + public static List getMetrics(DbContext ctx, Instant startTime, Instant endTime) throws SQLException { final var start = Objects.requireNonNull(startTime).toEpochMilli(); final var end = Objects.requireNonNull(endTime).toEpochMilli(); @@ -1383,7 +1359,7 @@ SELECT name, COUNT(workflow_uuid) as count WHERE created_at >= ? AND created_at < ? GROUP BY name """ - .formatted(schema); + .formatted(ctx.schema()); final var stepSQL = """ SELECT function_name, COUNT(*) as count @@ -1391,9 +1367,9 @@ SELECT function_name, COUNT(*) as count WHERE completed_at_epoch_ms >= ? AND completed_at_epoch_ms < ? GROUP BY function_name """ - .formatted(schema); + .formatted(ctx.schema()); - try (var conn = dataSource.getConnection(); + try (var conn = ctx.getConnection(); var ps1 = conn.prepareStatement(wfSQL); var ps2 = conn.prepareStatement(stepSQL)) { @@ -1501,41 +1477,36 @@ static List listWorkflowStreams(Connection conn, String schema, return streams; } - static List exportWorkflow( - DataSource dataSource, - String schema, - DBOSSerializer serializer, - String workflowId, - boolean exportChildren) - throws SQLException { + public static List exportWorkflow( + DbContext ctx, String workflowId, boolean exportChildren) throws SQLException { + var workflowIds = exportChildren ? Stream.concat( - getWorkflowChildren(dataSource, schema, workflowId).stream(), - List.of(workflowId).stream()) + getWorkflowChildren(ctx, workflowId).stream(), List.of(workflowId).stream()) .toList() : List.of(workflowId); var workflows = new ArrayList(); for (var wfid : workflowIds) { - try (var conn = dataSource.getConnection()) { - var status = getWorkflowStatus(conn, schema, serializer, wfid); - var steps = StepsDAO.listWorkflowSteps(conn, wfid, true, null, null, schema, serializer); - var events = listWorkflowEvents(conn, schema, wfid); - var eventHistory = listWorkflowEventHistory(conn, schema, wfid); - var streams = listWorkflowStreams(conn, schema, wfid); + try (var conn = ctx.getConnection()) { + var status = getWorkflowStatus(conn, ctx.schema(), ctx.serializer(), wfid); + var steps = + StepsDAO.listWorkflowSteps( + conn, ctx.schema(), ctx.serializer(), wfid, true, null, null); + var events = listWorkflowEvents(conn, ctx.schema(), wfid); + var eventHistory = listWorkflowEventHistory(conn, ctx.schema(), wfid); + var streams = listWorkflowStreams(conn, ctx.schema(), wfid); workflows.add(new ExportedWorkflow(status, steps, events, eventHistory, streams)); } } return workflows; } - static void importWorkflow( - DataSource dataSource, - String schema, - DBOSSerializer serializer, - List workflows) + public static void importWorkflow(DbContext ctx, List workflows) throws SQLException { + + DBOSSerializer serializer = ctx.serializer(); var wfSQL = """ INSERT INTO "%s".workflow_status ( @@ -1553,7 +1524,7 @@ static void importWorkflow( ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? ) """ - .formatted(schema); + .formatted(ctx.schema()); var stepSQL = """ @@ -1566,7 +1537,7 @@ static void importWorkflow( ?, ?, ?, ?, ?, ?, ?, ?, ? ) """ - .formatted(schema); + .formatted(ctx.schema()); var eventSQL = """ @@ -1576,7 +1547,7 @@ static void importWorkflow( ?, ?, ?, ? ) """ - .formatted(schema); + .formatted(ctx.schema()); var eventHistorySQL = """ @@ -1586,7 +1557,7 @@ static void importWorkflow( ?, ?, ?, ?, ? ) """ - .formatted(schema); + .formatted(ctx.schema()); var streamsSQL = """ @@ -1596,9 +1567,9 @@ static void importWorkflow( ?, ?, ?, ?, ?, ? ) """ - .formatted(schema); + .formatted(ctx.schema()); - try (var conn = dataSource.getConnection()) { + try (var conn = ctx.getConnection()) { conn.setAutoCommit(false); try (var wfStmt = conn.prepareStatement(wfSQL); @@ -1723,16 +1694,16 @@ static void importWorkflow( } } - static Map getAllEvents( - DataSource dataSource, String schema, DBOSSerializer serializer, String workflowId) + public static Map getAllEvents(DbContext ctx, String workflowId) throws SQLException { - try (var conn = dataSource.getConnection()) { - var events = listWorkflowEvents(conn, schema, workflowId); + try (var conn = ctx.getConnection()) { + var events = listWorkflowEvents(conn, ctx.schema(), workflowId); var result = new LinkedHashMap(); for (var event : events) { result.put( event.key(), - SerializationUtil.deserializeValue(event.value(), event.serialization(), serializer)); + SerializationUtil.deserializeValue( + event.value(), event.serialization(), ctx.serializer())); } return result; } diff --git a/transact/src/main/java/dev/dbos/transact/database/signal/SignalKey.java b/transact/src/main/java/dev/dbos/transact/database/signal/SignalKey.java new file mode 100644 index 000000000..0524483a6 --- /dev/null +++ b/transact/src/main/java/dev/dbos/transact/database/signal/SignalKey.java @@ -0,0 +1,39 @@ +package dev.dbos.transact.database.signal; + +public sealed interface SignalKey + permits SignalKey.Cancellation, SignalKey.Event, SignalKey.Message, SignalKey.Shutdown { + + public enum WakeReason { + MESSAGE, + EVENT, + CANCELLED, + SHUTDOWN, + TIMEOUT + } + + WakeReason wakeReason(); + + record Cancellation(String workflowId) implements SignalKey { + public WakeReason wakeReason() { + return WakeReason.CANCELLED; + } + } + + record Event(String workflowId, String key) implements SignalKey { + public WakeReason wakeReason() { + return WakeReason.EVENT; + } + } + + record Message(String workflowId, String topic) implements SignalKey { + public WakeReason wakeReason() { + return WakeReason.MESSAGE; + } + } + + record Shutdown() implements SignalKey { + public WakeReason wakeReason() { + return WakeReason.SHUTDOWN; + } + } +} diff --git a/transact/src/main/java/dev/dbos/transact/database/signal/SignalMap.java b/transact/src/main/java/dev/dbos/transact/database/signal/SignalMap.java new file mode 100644 index 000000000..4f3565a20 --- /dev/null +++ b/transact/src/main/java/dev/dbos/transact/database/signal/SignalMap.java @@ -0,0 +1,73 @@ +package dev.dbos.transact.database.signal; + +import dev.dbos.transact.database.signal.SignalKey.WakeReason; + +import java.time.Duration; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; + +public class SignalMap { + private static class Entry { + final CompletableFuture future = new CompletableFuture<>(); + final AtomicInteger refs = new AtomicInteger(1); + final WakeReason reason; + + public Entry(WakeReason reason) { + this.reason = Objects.requireNonNull(reason); + } + } + + private final ConcurrentHashMap map = new ConcurrentHashMap<>(); + + public Subscription subscribe(K key, WakeReason reason) { + var entry = + map.compute( + key, + (k, e) -> { + if (e != null) { + e.refs.incrementAndGet(); + return e; + } + return new Entry(reason); + }); + + var sub = + new Subscription( + () -> + map.compute(key, (k, e) -> e != null && e.refs.decrementAndGet() == 0 ? null : e)); + + entry.future.thenAccept( + r -> { + if (!sub.closed) { + sub.complete(r); + } + }); + return sub; + } + + public void signal(K key) { + var e = map.remove(key); + if (e != null) { + e.future.complete(e.reason); + } + } + + public static WakeReason awaitAny(Duration timeout, Subscription... subscriptions) { + try { + return (WakeReason) + CompletableFuture.anyOf(subscriptions).get(timeout.toMillis(), TimeUnit.MILLISECONDS); + } catch (TimeoutException ignored) { + return WakeReason.TIMEOUT; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } catch (ExecutionException e) { + throw new RuntimeException(e); + } + } +} diff --git a/transact/src/main/java/dev/dbos/transact/database/signal/Subscription.java b/transact/src/main/java/dev/dbos/transact/database/signal/Subscription.java new file mode 100644 index 000000000..8d7788fb9 --- /dev/null +++ b/transact/src/main/java/dev/dbos/transact/database/signal/Subscription.java @@ -0,0 +1,20 @@ +package dev.dbos.transact.database.signal; + +import dev.dbos.transact.database.signal.SignalKey.WakeReason; + +import java.util.concurrent.CompletableFuture; + +public class Subscription extends CompletableFuture implements AutoCloseable { + private final Runnable onClose; + volatile boolean closed = false; + + public Subscription(Runnable onClose) { + this.onClose = onClose; + } + + @Override + public void close() { + closed = true; + onClose.run(); + } +} diff --git a/transact/src/main/java/dev/dbos/transact/execution/DBOSExecutor.java b/transact/src/main/java/dev/dbos/transact/execution/DBOSExecutor.java index 51a421645..92149f724 100644 --- a/transact/src/main/java/dev/dbos/transact/execution/DBOSExecutor.java +++ b/transact/src/main/java/dev/dbos/transact/execution/DBOSExecutor.java @@ -11,7 +11,7 @@ import dev.dbos.transact.context.DBOSContextHolder; import dev.dbos.transact.context.WorkflowInfo; import dev.dbos.transact.database.ExternalState; -import dev.dbos.transact.database.GetWorkflowEventContext; +import dev.dbos.transact.database.GetEventCaller; import dev.dbos.transact.database.Result; import dev.dbos.transact.database.StreamIterator; import dev.dbos.transact.database.SystemDatabase; @@ -471,11 +471,10 @@ public Object getEvent(String workflowId, String key, Duration timeout) { DBOSContext ctx = DBOSContextHolder.get(); if (ctx.isInWorkflow() && !ctx.isInStep()) { - int stepFunctionId = ctx.getAndIncrementFunctionId(); - int timeoutFunctionId = ctx.getAndIncrementFunctionId(); - GetWorkflowEventContext callerCtx = - new GetWorkflowEventContext(ctx.getWorkflowId(), stepFunctionId, timeoutFunctionId); - return systemDatabase.getEvent(workflowId, key, timeout, callerCtx); + int stepId = ctx.getAndIncrementFunctionId(); + int timeoutStepId = ctx.getAndIncrementFunctionId(); + GetEventCaller caller = new GetEventCaller(ctx.getWorkflowId(), stepId, timeoutStepId); + return systemDatabase.getEvent(workflowId, key, timeout, caller); } return systemDatabase.getEvent(workflowId, key, timeout, null); @@ -1045,7 +1044,7 @@ private T runStepInternal( var stepId = ctx.getAndIncrementFunctionId(); logger.debug("executeStep #{} ({}) for workflow {}", stepId, options.name(), workflowId); - var prevResult = systemDatabase.checkStepExecutionTxn(workflowId, stepId, options.name()); + var prevResult = systemDatabase.checkStepResult(workflowId, stepId, options.name()); if (prevResult != null) { if (prevResult.error() != null) { var t = @@ -1122,7 +1121,7 @@ private T runStepInternal( serializedException.serializedValue(), childWorkflowId, serializedException.serialization()); - systemDatabase.recordStepResultTxn(stepResult, startTime); + systemDatabase.recordStepResult(stepResult, startTime); throw (E) exception; } else { logger.debug("executeStep #{} for workflow {} completed {}", stepId, workflowId, output); @@ -1137,7 +1136,7 @@ private T runStepInternal( null, childWorkflowId, serializedOutput.serialization()); - systemDatabase.recordStepResultTxn(stepResult, startTime); + systemDatabase.recordStepResult(stepResult, startTime); return output; } } diff --git a/transact/src/main/java/dev/dbos/transact/migrations/MigrationManager.java b/transact/src/main/java/dev/dbos/transact/migrations/MigrationManager.java index ce6597cd7..1fe7d686c 100644 --- a/transact/src/main/java/dev/dbos/transact/migrations/MigrationManager.java +++ b/transact/src/main/java/dev/dbos/transact/migrations/MigrationManager.java @@ -23,29 +23,30 @@ public static void runMigrations(DBOSConfig config) { Objects.requireNonNull(config, "DBOS Config must not be null"); if (config.dataSource() != null) { - runMigrations(config.dataSource(), config.databaseSchema()); + runMigrations(config.dataSource(), config.databaseSchema(), config.useListenNotify()); } else { createDatabaseIfNotExists(config.databaseUrl(), config.dbUser(), config.dbPassword()); try (var ds = SystemDatabase.createDataSource( config.databaseUrl(), config.dbUser(), config.dbPassword())) { - runMigrations(ds, config.databaseSchema()); + runMigrations(ds, config.databaseSchema(), config.useListenNotify()); } } } - public static void runMigrations(String url, String user, String password, String schema) { + public static void runMigrations( + String url, String user, String password, String schema, boolean useListenNotify) { Objects.requireNonNull(url, "database url must not be null"); Objects.requireNonNull(user, "database user must not be null"); Objects.requireNonNull(password, "database password must not be null"); createDatabaseIfNotExists(url, user, password); try (var ds = SystemDatabase.createDataSource(url, user, password)) { - runMigrations(ds, schema); + runMigrations(ds, schema, useListenNotify); } } - private static void runMigrations(DataSource ds, String schema) { + private static void runMigrations(DataSource ds, String schema, boolean useListenNotify) { Objects.requireNonNull(ds, "Data Source must not be null"); schema = SystemDatabase.sanitizeSchema(schema); @@ -54,9 +55,15 @@ private static void runMigrations(DataSource ds, String schema) { } try (var conn = ds.getConnection()) { + + var isCockroach = SystemDatabase.isCockroach(conn); + if (isCockroach) { + useListenNotify = false; + } + ensureDbosSchema(conn, schema); ensureMigrationTable(conn, schema); - var migrations = getMigrations(schema); + var migrations = getMigrations(schema, useListenNotify); runDbosMigrations(conn, schema, migrations); } catch (SQLException e) { throw new RuntimeException("Failed to run migrations", e); @@ -173,6 +180,19 @@ public static int getCurrentSysDbVersion(Connection conn, String schema) { return 0; } + private static boolean notificationsPrimaryKeyExists(Connection conn, String schema) + throws SQLException { + var sql = + "SELECT 1 FROM information_schema.table_constraints" + + " WHERE table_schema = ? AND table_name = 'notifications' AND constraint_type = 'PRIMARY KEY'"; + try (var stmt = conn.prepareStatement(sql)) { + stmt.setString(1, schema); + try (var rs = stmt.executeQuery()) { + return rs.next(); + } + } + } + static void runDbosMigrations(Connection conn, String schema, List migrations) { Objects.requireNonNull(schema, "schema must not be null"); var lastApplied = getCurrentSysDbVersion(conn, schema); @@ -184,10 +204,27 @@ static void runDbosMigrations(Connection conn, String schema, List migra } logger.info("Applying DBOS system database schema migration {}", migrationIndex); - try (var stmt = conn.createStatement()) { - stmt.execute(migrations.get(i)); - } catch (SQLException e) { - throw new RuntimeException("Failed to run migration %d".formatted(migrationIndex), e); + + // Migration 10 adds a primary key to notifications. Skip the DDL if one already exists + // (guard for installs that were created before the primary key was added to migration 1). + boolean skipMigration = false; + if (migrationIndex == 10) { + try { + skipMigration = notificationsPrimaryKeyExists(conn, schema); + } catch (SQLException e) { + throw new RuntimeException("Failed to check notifications primary key", e); + } + if (skipMigration) { + logger.info("Migration 10 skipped, primary key already exists"); + } + } + + if (!skipMigration) { + try (var stmt = conn.createStatement()) { + stmt.execute(migrations.get(i)); + } catch (SQLException e) { + throw new RuntimeException("Failed to run migration %d".formatted(migrationIndex), e); + } } try { @@ -212,11 +249,12 @@ static void runDbosMigrations(Connection conn, String schema, List migra } } - public static List getMigrations(String schema) { + public static List getMigrations(String schema, boolean useListenNotify) { Objects.requireNonNull(schema); + var migrations = List.of( - MIGRATION_1, + migration1(useListenNotify), MIGRATION_2, MIGRATION_3, MIGRATION_4, @@ -238,8 +276,13 @@ public static List getMigrations(String schema) { return migrations.stream().map(m -> m.formatted(schema)).toList(); } + static String migration1(boolean useListenNotify) { + return useListenNotify ? MIGRATION_1 + MIGRATION_1_NOTIFY : MIGRATION_1; + } + static final String MIGRATION_1 = """ + -- Enable uuid extension for generating UUIDs CREATE EXTENSION IF NOT EXISTS "uuid-ossp"; CREATE TABLE "%1$s".workflow_status ( @@ -253,8 +296,8 @@ public static List getMigrations(String schema) { output TEXT, error TEXT, executor_id TEXT, - created_at BIGINT NOT NULL DEFAULT (EXTRACT(epoch FROM now()) * 1000::numeric)::bigint, - updated_at BIGINT NOT NULL DEFAULT (EXTRACT(epoch FROM now()) * 1000::numeric)::bigint, + created_at BIGINT NOT NULL DEFAULT (EXTRACT(epoch FROM now()) * 1000.0)::bigint, + updated_at BIGINT NOT NULL DEFAULT (EXTRACT(epoch FROM now()) * 1000.0)::bigint, application_version TEXT, application_id TEXT, class_name VARCHAR(255) DEFAULT NULL, @@ -294,12 +337,44 @@ message_uuid TEXT NOT NULL DEFAULT gen_random_uuid() PRIMARY KEY, -- Built-in fu destination_uuid TEXT NOT NULL, topic TEXT, message TEXT NOT NULL, - created_at_epoch_ms BIGINT NOT NULL DEFAULT (EXTRACT(epoch FROM now()) * 1000::numeric)::bigint, + created_at_epoch_ms BIGINT NOT NULL DEFAULT (EXTRACT(epoch FROM now()) * 1000.0)::bigint, FOREIGN KEY (destination_uuid) REFERENCES "%1$s".workflow_status(workflow_uuid) ON UPDATE CASCADE ON DELETE CASCADE ); CREATE INDEX idx_workflow_topic ON "%1$s".notifications (destination_uuid, topic); + CREATE TABLE "%1$s".workflow_events ( + workflow_uuid TEXT NOT NULL, + key TEXT NOT NULL, + value TEXT NOT NULL, + PRIMARY KEY (workflow_uuid, key), + FOREIGN KEY (workflow_uuid) REFERENCES "%1$s".workflow_status(workflow_uuid) + ON UPDATE CASCADE ON DELETE CASCADE + ); + + CREATE TABLE "%1$s".streams ( + workflow_uuid TEXT NOT NULL, + key TEXT NOT NULL, + value TEXT NOT NULL, + "offset" INT4 NOT NULL, + PRIMARY KEY (workflow_uuid, key, "offset"), + FOREIGN KEY (workflow_uuid) REFERENCES "%1$s".workflow_status(workflow_uuid) + ON UPDATE CASCADE ON DELETE CASCADE + ); + + CREATE TABLE "%1$s".event_dispatch_kv ( + service_name TEXT NOT NULL, + workflow_fn_name TEXT NOT NULL, + key TEXT NOT NULL, + value TEXT, + update_seq NUMERIC(38,0), + update_time NUMERIC(38,15), + PRIMARY KEY (service_name, workflow_fn_name, key) + ); + """; + + static final String MIGRATION_1_NOTIFY = + """ -- Create notification function CREATE OR REPLACE FUNCTION "%1$s".notifications_function() RETURNS TRIGGER AS $$ DECLARE @@ -315,15 +390,6 @@ FOREIGN KEY (destination_uuid) REFERENCES "%1$s".workflow_status(workflow_uuid) AFTER INSERT ON "%1$s".notifications FOR EACH ROW EXECUTE FUNCTION "%1$s".notifications_function(); - CREATE TABLE "%1$s".workflow_events ( - workflow_uuid TEXT NOT NULL, - key TEXT NOT NULL, - value TEXT NOT NULL, - PRIMARY KEY (workflow_uuid, key), - FOREIGN KEY (workflow_uuid) REFERENCES "%1$s".workflow_status(workflow_uuid) - ON UPDATE CASCADE ON DELETE CASCADE - ); - -- Create events function CREATE OR REPLACE FUNCTION "%1$s".workflow_events_function() RETURNS TRIGGER AS $$ DECLARE @@ -338,26 +404,6 @@ FOREIGN KEY (workflow_uuid) REFERENCES "%1$s".workflow_status(workflow_uuid) CREATE TRIGGER dbos_workflow_events_trigger AFTER INSERT ON "%1$s".workflow_events FOR EACH ROW EXECUTE FUNCTION "%1$s".workflow_events_function(); - - CREATE TABLE "%1$s".streams ( - workflow_uuid TEXT NOT NULL, - key TEXT NOT NULL, - value TEXT NOT NULL, - "offset" INT4 NOT NULL, - PRIMARY KEY (workflow_uuid, key, "offset"), - FOREIGN KEY (workflow_uuid) REFERENCES "%1$s".workflow_status(workflow_uuid) - ON UPDATE CASCADE ON DELETE CASCADE - ); - - CREATE TABLE "%1$s".event_dispatch_kv ( - service_name TEXT NOT NULL, - workflow_fn_name TEXT NOT NULL, - key TEXT NOT NULL, - value TEXT, - update_seq NUMERIC(38,0), - update_time NUMERIC(38,15), - PRIMARY KEY (service_name, workflow_fn_name, key) - ); """; static final String MIGRATION_2 = @@ -397,7 +443,7 @@ FOREIGN KEY (workflow_uuid) REFERENCES "%1$s".workflow_status(workflow_uuid) static final String MIGRATION_7 = """ - ALTER TABLE "%1$s"."workflow_status" ADD COLUMN "owner_xid" VARCHAR(40) DEFAULT NULL + ALTER TABLE "%1$s"."workflow_status" ADD COLUMN "owner_xid" TEXT DEFAULT NULL """; static final String MIGRATION_8 = @@ -421,17 +467,7 @@ FOREIGN KEY (workflow_uuid) REFERENCES "%1$s".workflow_status(workflow_uuid) static final String MIGRATION_10 = """ - DO $$ - BEGIN - IF NOT EXISTS ( - SELECT 1 FROM information_schema.table_constraints - WHERE table_schema = '%1$s' - AND table_name = 'notifications' - AND constraint_type = 'PRIMARY KEY' - ) THEN - ALTER TABLE "%1$s".notifications ADD PRIMARY KEY (message_uuid); - END IF; - END $$; + ALTER TABLE "%1$s".notifications ADD PRIMARY KEY (message_uuid); """; static final String MIGRATION_11 = diff --git a/transact/src/main/java/dev/dbos/transact/workflow/internal/StepResult.java b/transact/src/main/java/dev/dbos/transact/workflow/internal/StepResult.java index c706e9121..0db440dde 100644 --- a/transact/src/main/java/dev/dbos/transact/workflow/internal/StepResult.java +++ b/transact/src/main/java/dev/dbos/transact/workflow/internal/StepResult.java @@ -2,6 +2,7 @@ import dev.dbos.transact.json.DBOSSerializer; import dev.dbos.transact.json.SerializationUtil; +import dev.dbos.transact.json.SerializationUtil.SerializedResult; public record StepResult( String workflowId, @@ -32,6 +33,18 @@ public StepResult withSerialization(String v) { return new StepResult(workflowId, stepId, stepName, output, error, childWorkflowId, v); } + public static StepResult ofOutput( + String workflowId, int stepId, String stepName, SerializedResult result) { + return new StepResult( + workflowId, stepId, stepName, result.serializedValue(), null, null, result.serialization()); + } + + public static StepResult ofError( + String workflowId, int stepId, String stepName, SerializedResult result) { + return new StepResult( + workflowId, stepId, stepName, null, result.serializedValue(), null, result.serialization()); + } + @SuppressWarnings("unchecked") public R toResult(DBOSSerializer serializer) throws E { if (error != null) { diff --git a/transact/src/test/java/dev/dbos/transact/client/PgSqlClientTest.java b/transact/src/test/java/dev/dbos/transact/client/PgSqlClientTest.java index a549972ec..4c4146231 100644 --- a/transact/src/test/java/dev/dbos/transact/client/PgSqlClientTest.java +++ b/transact/src/test/java/dev/dbos/transact/client/PgSqlClientTest.java @@ -198,28 +198,40 @@ String enqueueWorkflow( .toArray(String[]::new); var sql = """ - SELECT dbos.enqueue_workflow( - workflow_name => ?, - class_name => ?, - queue_name => ?, - positional_args => ?, - deduplication_id => ?, - timeout_ms => ?, - deadline_epoch_ms => ?) - """; + SELECT dbos.enqueue_workflow( + ?, -- workflow_name + ?, -- queue_name + ?, -- positional_args + ?, -- named_args + ?, -- class_name + ?, -- config_name + ?, -- workflow_id + ?, -- app_version + ?, -- timeout_ms + ?, -- deadline_epoch_ms + ?, -- deduplication_id + ?, -- priority + ? -- queue_partition_key + ) + """; try (var conn = dataSource.getConnection(); var stmt = conn.prepareCall(sql)) { Long timeoutMS = timeout == null ? null : timeout.toMillis(); Long deadlineMS = deadline == null ? null : deadline.toEpochMilli(); - stmt.setString(1, Objects.requireNonNull(workflowName)); - stmt.setString(2, "ClientServiceImpl"); - stmt.setString(3, "testQueue"); - stmt.setString(5, dedupId); - stmt.setObject(6, timeoutMS, Types.BIGINT); - stmt.setObject(7, deadlineMS, Types.BIGINT); - var argsArray = conn.createArrayOf("json", jsonArgs); - stmt.setObject(4, argsArray); + stmt.setString(1, Objects.requireNonNull(workflowName)); // workflow_name + stmt.setString(2, "testQueue"); // queue_name + stmt.setObject(3, argsArray); // positional_args + stmt.setObject(4, null); // named_args + stmt.setString(5, "ClientServiceImpl"); // class_name + stmt.setObject(6, null); // config_name + stmt.setObject(7, null); // workflow_id + stmt.setObject(8, null); // app_version + stmt.setObject(9, timeoutMS, Types.BIGINT); // timeout_ms + stmt.setObject(10, deadlineMS, Types.BIGINT); // deadline_epoch_ms + stmt.setString(11, dedupId); // deduplication_id + stmt.setObject(12, null); // priority + stmt.setObject(13, null); // queue_partition_key try (ResultSet rs = stmt.executeQuery()) { return rs.next() ? rs.getString(1) : null; } finally { @@ -232,7 +244,7 @@ void sendMessage(String destinationId, Object message, String topic, String idem throws SQLException, JsonProcessingException { String jsonMessage = MAPPER.writeValueAsString(Objects.requireNonNull(message)); - var sql = "SELECT dbos.send_message(?, ?::json, topic => ?, idempotency_key => ?)"; + var sql = "SELECT dbos.send_message(?, ?::json, ?, ?)"; try (var conn = dataSource.getConnection(); var stmt = conn.prepareCall(sql)) { stmt.setString(1, Objects.requireNonNull(destinationId)); diff --git a/transact/src/test/java/dev/dbos/transact/config/ConfigTest.java b/transact/src/test/java/dev/dbos/transact/config/ConfigTest.java index 68efe1cbb..25304c963 100644 --- a/transact/src/test/java/dev/dbos/transact/config/ConfigTest.java +++ b/transact/src/test/java/dev/dbos/transact/config/ConfigTest.java @@ -276,6 +276,10 @@ public void calcAppVersionNotMatch() throws Exception { @Test public void configPGSimpleDataSource() throws Exception { + // dbos.launch doesn't attempt to create the database when + // receiving data source in DBOSConfig + pgContainer.createDatabase(); + var jdbcUrl = pgContainer.jdbcUrl(); assertTrue(jdbcUrl.startsWith("jdbc:")); @@ -316,6 +320,10 @@ public void configPGSimpleDataSource() throws Exception { @Test public void configHikariDataSource() throws Exception { + // need to create database since we are connecting + // the data source prior to dbos launch + pgContainer.createDatabase(); + var poolName = "dbos-configDataSource"; HikariConfig hikariConfig = new HikariConfig(); diff --git a/transact/src/test/java/dev/dbos/transact/database/ImportExportTest.java b/transact/src/test/java/dev/dbos/transact/database/ImportExportTest.java index 8cc4a218e..4423e20ad 100644 --- a/transact/src/test/java/dev/dbos/transact/database/ImportExportTest.java +++ b/transact/src/test/java/dev/dbos/transact/database/ImportExportTest.java @@ -59,8 +59,8 @@ private void createWorkflow(String wfId) throws SQLException { sysdb.recordWorkflowOutput(wfId, null); long now = System.currentTimeMillis(); - sysdb.recordStepResultTxn(new StepResult(wfId, 0, "step0"), now - 2000); - sysdb.recordStepResultTxn(new StepResult(wfId, 1, "step1"), now - 1000); + sysdb.recordStepResult(new StepResult(wfId, 0, "step0"), now - 2000); + sysdb.recordStepResult(new StepResult(wfId, 1, "step1"), now - 1000); // asStep=false: writes to workflow_events + workflow_events_history, not operation_outputs sysdb.setEvent(wfId, 0, "event-key-1", "event-val-1", false, null); diff --git a/transact/src/test/java/dev/dbos/transact/database/SignalMapTest.java b/transact/src/test/java/dev/dbos/transact/database/SignalMapTest.java new file mode 100644 index 000000000..d051e0013 --- /dev/null +++ b/transact/src/test/java/dev/dbos/transact/database/SignalMapTest.java @@ -0,0 +1,478 @@ +package dev.dbos.transact.database; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import dev.dbos.transact.database.signal.SignalKey; +import dev.dbos.transact.database.signal.SignalKey.WakeReason; +import dev.dbos.transact.database.signal.SignalMap; +import dev.dbos.transact.database.signal.Subscription; + +import java.time.Duration; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +class SignalMapTest { + + SignalMap map; + + static final SignalKey KEY = new SignalKey.Cancellation("wf-1"); + static final SignalKey KEY_A = new SignalKey.Cancellation("wf-a"); + static final SignalKey KEY_B = new SignalKey.Cancellation("wf-b"); + static final SignalKey FOO = new SignalKey.Cancellation("foo"); + + @BeforeEach + void setup() { + map = new SignalMap<>(); + } + + private static Subscription never() { + return new Subscription(() -> {}); + } + + // --- SignalKey structural equality --- + + @Test + void testSignalKeyStructuralEquality() { + assertEquals(new SignalKey.Cancellation("wf-1"), new SignalKey.Cancellation("wf-1")); + assertEquals(new SignalKey.Event("wf-1", "t"), new SignalKey.Event("wf-1", "t")); + assertNotEquals(new SignalKey.Cancellation("wf-1"), new SignalKey.Cancellation("wf-2")); + assertNotEquals( + (SignalKey) new SignalKey.Cancellation("wf-1"), + (SignalKey) new SignalKey.Event("wf-1", "wf-1")); + } + + // --- Core subscribe / signal behaviour --- + + @Test + void testBasicSubscribeAndSignal() throws Exception { + var f = map.subscribe(KEY, KEY.wakeReason()); + assertFalse(f.isDone()); + map.signal(KEY); + assertTrue(f.isDone()); + assertFalse(f.isCompletedExceptionally()); + assertEquals(WakeReason.CANCELLED, f.get()); + } + + @Test + void testMultipleListenersOnSameKey() throws Exception { + var f1 = map.subscribe(KEY, KEY.wakeReason()); + var f2 = map.subscribe(KEY, KEY.wakeReason()); + var f3 = map.subscribe(KEY, KEY.wakeReason()); + + assertFalse(f1.isDone()); + assertFalse(f2.isDone()); + assertFalse(f3.isDone()); + + map.signal(KEY); + + assertEquals(WakeReason.CANCELLED, f1.get()); + assertEquals(WakeReason.CANCELLED, f2.get()); + assertEquals(WakeReason.CANCELLED, f3.get()); + } + + @Test + void testMultipleSubscriptionsInAnyOf() throws Exception { + var f1 = map.subscribe(KEY_A, KEY_A.wakeReason()); + var f2 = map.subscribe(KEY_B, KEY_B.wakeReason()); + + var anyOf = CompletableFuture.anyOf(f1, f2); + assertFalse(anyOf.isDone()); + + map.signal(KEY_B); + + assertEquals(WakeReason.CANCELLED, (WakeReason) anyOf.get(1, TimeUnit.SECONDS)); + assertTrue(f2.isDone()); + assertFalse(f1.isDone()); + } + + @Test + void testSignalOnlyWakesMatchingKey() { + var f1 = map.subscribe(KEY_A, KEY_A.wakeReason()); + var f2 = map.subscribe(KEY_B, KEY_B.wakeReason()); + + map.signal(KEY_A); + + assertTrue(f1.isDone()); + assertFalse(f2.isDone()); + } + + @Test + void testDifferentKeyTypesWithSameFieldsDoNotCollide() { + var eventKey = new SignalKey.Event("wf-1", "wf-1"); + var cancellationKey = new SignalKey.Cancellation("wf-1"); + + var f1 = map.subscribe(eventKey, eventKey.wakeReason()); + var f2 = map.subscribe(cancellationKey, cancellationKey.wakeReason()); + + map.signal(cancellationKey); + + assertTrue(f2.isDone()); + assertFalse(f1.isDone()); + } + + @Test + void testSignalBeforeSubscribeDoesNotWake() { + map.signal(KEY); + + var f = map.subscribe(KEY, KEY.wakeReason()); + assertFalse(f.isDone()); + + map.signal(KEY); + assertTrue(f.isDone()); + } + + @Test + void testSignalIsOneShot() { + var f1 = map.subscribe(KEY, KEY.wakeReason()); + map.signal(KEY); + assertTrue(f1.isDone()); + + var f2 = map.subscribe(KEY, KEY.wakeReason()); + assertFalse(f2.isDone()); + } + + @Test + void testWakeReasonByKeyType() throws Exception { + var msgSub = map.subscribe(new SignalKey.Message("wf-1", "topic"), WakeReason.MESSAGE); + var eventSub = map.subscribe(new SignalKey.Event("wf-1", "topic"), WakeReason.EVENT); + var cancelSub = map.subscribe(new SignalKey.Cancellation("wf-1"), WakeReason.CANCELLED); + var shutdownSub = map.subscribe(new SignalKey.Shutdown(), WakeReason.SHUTDOWN); + + map.signal(new SignalKey.Message("wf-1", "topic")); + map.signal(new SignalKey.Event("wf-1", "topic")); + map.signal(new SignalKey.Cancellation("wf-1")); + map.signal(new SignalKey.Shutdown()); + + assertEquals(WakeReason.MESSAGE, msgSub.get()); + assertEquals(WakeReason.EVENT, eventSub.get()); + assertEquals(WakeReason.CANCELLED, cancelSub.get()); + assertEquals(WakeReason.SHUTDOWN, shutdownSub.get()); + } + + // --- Subscription / close --- + + @Test + void testCloseOneSubscriberDoesNotOrphanOthers() throws Exception { + var sub1 = map.subscribe(KEY, KEY.wakeReason()); + var sub2 = map.subscribe(KEY, KEY.wakeReason()); + + sub1.close(); + + map.signal(KEY); + assertEquals(WakeReason.CANCELLED, sub2.get()); + } + + @Test + void testClosePreventsFutureFromBeingSignalled() throws Exception { + var sub = map.subscribe(KEY, KEY.wakeReason()); + sub.close(); + map.signal(KEY); + + boolean completed = + sub.orTimeout(100, TimeUnit.MILLISECONDS).handle((v, ex) -> ex == null).get(); + assertFalse(completed); + } + + @Test + void testNeverFutureNeverCompletes() throws Exception { + var f = never(); + assertFalse(f.isDone()); + + boolean completed = f.orTimeout(100, TimeUnit.MILLISECONDS).handle((v, ex) -> ex == null).get(); + assertFalse(completed); + } + + // --- Threading --- + + @Test + void testSubscribeBeforeSignalFromAnotherThread() throws Exception { + var f = map.subscribe(FOO, FOO.wakeReason()); + + CompletableFuture.runAsync( + () -> { + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + map.signal(FOO); + }); + + assertTimeoutPreemptively(Duration.ofSeconds(1), (Executable) f::get); + } + + @Test + void testSignalFromMainAfterBackgroundSubscribes() throws Exception { + var backgroundDone = new CompletableFuture(); + + CompletableFuture.runAsync( + () -> { + var f = map.subscribe(FOO, FOO.wakeReason()); + try { + backgroundDone.complete(f.get(500, TimeUnit.MILLISECONDS)); + } catch (Exception e) { + backgroundDone.completeExceptionally(e); + } + }); + + Thread.sleep(100); + map.signal(FOO); + + assertTimeoutPreemptively(Duration.ofSeconds(1), (Executable) backgroundDone::get); + assertEquals(WakeReason.CANCELLED, backgroundDone.get()); + } + + @Test + void testMultipleSubscribersInSeparateThreads() throws Exception { + var done1 = new CompletableFuture(); + var done2 = new CompletableFuture(); + + CompletableFuture.runAsync( + () -> { + try { + done1.complete(map.subscribe(FOO, FOO.wakeReason()).get(500, TimeUnit.MILLISECONDS)); + } catch (Exception e) { + done1.completeExceptionally(e); + } + }); + CompletableFuture.runAsync( + () -> { + try { + done2.complete(map.subscribe(FOO, FOO.wakeReason()).get(500, TimeUnit.MILLISECONDS)); + } catch (Exception e) { + done2.completeExceptionally(e); + } + }); + + Thread.sleep(100); + map.signal(FOO); + + assertTimeoutPreemptively( + Duration.ofSeconds(1), + () -> { + assertEquals(WakeReason.CANCELLED, done1.get()); + assertEquals(WakeReason.CANCELLED, done2.get()); + }); + } + + @Test + void testConcurrentSignalAndSubscribe() throws Exception { + assertTimeoutPreemptively( + Duration.ofSeconds(5), + () -> { + for (int i = 0; i < 1000; i++) { + var m = new SignalMap(); + var sub = m.subscribe(KEY, KEY.wakeReason()); + CompletableFuture.runAsync(() -> m.signal(KEY)); + sub.get(1, TimeUnit.SECONDS); + } + }); + } + + // --- awaitAny --- + + @Test + void testAwaitAny_notifyFires() throws Exception { + var onNotify = map.subscribe(new SignalKey.Event("wf-1", "topic"), WakeReason.EVENT); + var onCancelled = map.subscribe(new SignalKey.Cancellation("wf-1"), WakeReason.CANCELLED); + var onShutdown = never(); + + map.signal(new SignalKey.Event("wf-1", "topic")); + + assertEquals( + WakeReason.EVENT, + SignalMap.awaitAny(Duration.ofSeconds(1), onNotify, onCancelled, onShutdown)); + } + + @Test + void testAwaitAny_cancelledFires() throws Exception { + var onNotify = map.subscribe(new SignalKey.Event("wf-1", "topic"), WakeReason.EVENT); + var onCancelled = map.subscribe(new SignalKey.Cancellation("wf-1"), WakeReason.CANCELLED); + var onShutdown = never(); + + map.signal(new SignalKey.Cancellation("wf-1")); + + assertEquals( + WakeReason.CANCELLED, + SignalMap.awaitAny(Duration.ofSeconds(1), onNotify, onCancelled, onShutdown)); + } + + @Test + void testAwaitAny_shutdownFires() throws Exception { + var onNotify = map.subscribe(new SignalKey.Event("wf-1", "topic"), WakeReason.EVENT); + var onCancelled = never(); + var onShutdown = map.subscribe(new SignalKey.Shutdown(), WakeReason.SHUTDOWN); + + map.signal(new SignalKey.Shutdown()); + + assertEquals( + WakeReason.SHUTDOWN, + SignalMap.awaitAny(Duration.ofSeconds(1), onNotify, onCancelled, onShutdown)); + } + + @Test + void testAwaitAny_timeout() throws Exception { + var onNotify = map.subscribe(new SignalKey.Event("wf-1", "topic"), WakeReason.EVENT); + var onCancelled = never(); + var onShutdown = never(); + + assertEquals( + WakeReason.TIMEOUT, + SignalMap.awaitAny(Duration.ofMillis(50), onNotify, onCancelled, onShutdown)); + } + + // --- anyOf determination via isDone (Option A) --- + + @Test + void testCheckIsDone_notifyFires() throws Exception { + var notifyKey = new SignalKey.Event("wf-1", "topic"); + var cancelKey = new SignalKey.Cancellation("wf-1"); + + var onNotify = map.subscribe(notifyKey, notifyKey.wakeReason()); + var onCancelled = map.subscribe(cancelKey, cancelKey.wakeReason()); + var onShutdown = never(); + + map.signal(notifyKey); + + try { + CompletableFuture.anyOf(onNotify, onCancelled, onShutdown).get(1, TimeUnit.SECONDS); + } catch (java.util.concurrent.TimeoutException ignored) { + } + + assertFalse(onCancelled.isDone() || onShutdown.isDone()); + assertTrue(onNotify.isDone()); + } + + @Test + void testCheckIsDone_cancelledFires() throws Exception { + var notifyKey = new SignalKey.Event("wf-1", "topic"); + var cancelKey = new SignalKey.Cancellation("wf-1"); + + var onNotify = map.subscribe(notifyKey, notifyKey.wakeReason()); + var onCancelled = map.subscribe(cancelKey, cancelKey.wakeReason()); + var onShutdown = never(); + + map.signal(cancelKey); + + try { + CompletableFuture.anyOf(onNotify, onCancelled, onShutdown).get(1, TimeUnit.SECONDS); + } catch (java.util.concurrent.TimeoutException ignored) { + } + + assertTrue(onCancelled.isDone() || onShutdown.isDone()); + assertFalse(onNotify.isDone()); + } + + @Test + void testCheckIsDone_shutdownFires() throws Exception { + var onNotify = map.subscribe(new SignalKey.Event("wf-1", "topic"), WakeReason.EVENT); + var onCancelled = never(); + var onShutdown = map.subscribe(new SignalKey.Shutdown(), WakeReason.SHUTDOWN); + + map.signal(new SignalKey.Shutdown()); + + try { + CompletableFuture.anyOf(onNotify, onCancelled, onShutdown).get(1, TimeUnit.SECONDS); + } catch (java.util.concurrent.TimeoutException ignored) { + } + + assertTrue(onCancelled.isDone() || onShutdown.isDone()); + assertFalse(onNotify.isDone()); + } + + @Test + void testCheckIsDone_timeout() throws Exception { + var onNotify = map.subscribe(new SignalKey.Event("wf-1", "topic"), WakeReason.EVENT); + var onCancelled = never(); + var onShutdown = never(); + + try { + CompletableFuture.anyOf(onNotify, onCancelled, onShutdown).get(50, TimeUnit.MILLISECONDS); + } catch (java.util.concurrent.TimeoutException ignored) { + } + + assertFalse(onCancelled.isDone() || onShutdown.isDone()); + assertFalse(onNotify.isDone()); + } + + // --- anyOf determination via tagged dispatch (Option B) --- + + @Test + void testTaggedDispatch_notifyFires() throws Exception { + var notifyKey = new SignalKey.Event("wf-1", "topic"); + var cancelKey = new SignalKey.Cancellation("wf-1"); + + var onNotify = map.subscribe(notifyKey, notifyKey.wakeReason()); + var onCancelled = map.subscribe(cancelKey, cancelKey.wakeReason()); + var onShutdown = never(); + + map.signal(notifyKey); + + var reason = + (WakeReason) + CompletableFuture.anyOf(onNotify, onCancelled, onShutdown).get(1, TimeUnit.SECONDS); + + assertEquals(WakeReason.EVENT, reason); + } + + @Test + void testTaggedDispatch_cancelledFires() throws Exception { + var notifyKey = new SignalKey.Event("wf-1", "topic"); + var cancelKey = new SignalKey.Cancellation("wf-1"); + + var onNotify = map.subscribe(notifyKey, notifyKey.wakeReason()); + var onCancelled = map.subscribe(cancelKey, cancelKey.wakeReason()); + var onShutdown = never(); + + map.signal(cancelKey); + + var reason = + (WakeReason) + CompletableFuture.anyOf(onNotify, onCancelled, onShutdown).get(1, TimeUnit.SECONDS); + + assertEquals(WakeReason.CANCELLED, reason); + } + + @Test + void testTaggedDispatch_shutdownFires() throws Exception { + var onNotify = map.subscribe(new SignalKey.Event("wf-1", "topic"), WakeReason.EVENT); + var onCancelled = never(); + var onShutdown = map.subscribe(new SignalKey.Shutdown(), WakeReason.SHUTDOWN); + + map.signal(new SignalKey.Shutdown()); + + var reason = + (WakeReason) + CompletableFuture.anyOf(onNotify, onCancelled, onShutdown).get(1, TimeUnit.SECONDS); + + assertEquals(WakeReason.SHUTDOWN, reason); + } + + @Test + void testTaggedDispatch_timeout() throws Exception { + var onNotify = map.subscribe(new SignalKey.Event("wf-1", "topic"), WakeReason.EVENT); + var onCancelled = never(); + var onShutdown = never(); + + WakeReason reason = null; + try { + reason = + (WakeReason) + CompletableFuture.anyOf(onNotify, onCancelled, onShutdown) + .get(50, TimeUnit.MILLISECONDS); + } catch (java.util.concurrent.TimeoutException ignored) { + } + + assertFalse(onNotify.isDone()); + assertNull(reason); + } +} diff --git a/transact/src/test/java/dev/dbos/transact/execution/DBOSExecutorTest.java b/transact/src/test/java/dev/dbos/transact/execution/DBOSExecutorTest.java index 36f8f6f88..f5b93f309 100644 --- a/transact/src/test/java/dev/dbos/transact/execution/DBOSExecutorTest.java +++ b/transact/src/test/java/dev/dbos/transact/execution/DBOSExecutorTest.java @@ -16,12 +16,12 @@ import java.util.List; import com.zaxxer.hikari.HikariDataSource; +import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.AutoClose; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledForJreRange; import org.junit.jupiter.api.condition.EnabledForJreRange; -import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.api.condition.JRE; import org.junitpioneer.jupiter.RetryingTest; import org.slf4j.Logger; @@ -69,8 +69,9 @@ public void virtualThreadPoolJava21() throws Exception { } @Test - @EnabledIfEnvironmentVariable(named = "JDKVERSION", matches = "21|25") - public void virtualThreadPoolJDK21And25() throws Exception { + public void virtualThreadPoolJDK21OrLater() throws Exception { + int jdk = Runtime.version().feature(); + Assumptions.assumeTrue(jdk >= 21, "Skipping: requires JDK 21 or later, got " + jdk); try (var dbos = new DBOS(dbosConfig)) { dbos.launch(); assertFalse(DBOSTestAccess.getDbosExecutor(dbos).usingThreadPoolExecutor()); @@ -87,8 +88,9 @@ public void threadPoolJava17() throws Exception { } @Test - @EnabledIfEnvironmentVariable(named = "JDKVERSION", matches = "17|17\\..*") - public void threadPoolJDK17() throws Exception { + public void threadPoolJDK20OrEarlier() throws Exception { + int jdk = Runtime.version().feature(); + Assumptions.assumeTrue(jdk < 21, "Skipping: requires JDK 20 or earlier, got " + jdk); try (var dbos = new DBOS(dbosConfig)) { dbos.launch(); assertTrue(DBOSTestAccess.getDbosExecutor(dbos).usingThreadPoolExecutor()); diff --git a/transact/src/test/java/dev/dbos/transact/invocation/CustomSchemaTest.java b/transact/src/test/java/dev/dbos/transact/invocation/CustomSchemaTest.java index 80f9d07a2..482bafc7b 100644 --- a/transact/src/test/java/dev/dbos/transact/invocation/CustomSchemaTest.java +++ b/transact/src/test/java/dev/dbos/transact/invocation/CustomSchemaTest.java @@ -21,7 +21,7 @@ import org.junit.jupiter.api.Test; public class CustomSchemaTest { - @AutoClose final PgContainer pgContainer = new PgContainer(); + @AutoClose final PgContainer pgContainer = PgContainer.createFresh(); private static final String schema = "F8nny_sCHem@-n@m3"; @AutoClose DBOS dbos; private HawkService proxy; diff --git a/transact/src/test/java/dev/dbos/transact/migrations/CockroachMigrationTest.java b/transact/src/test/java/dev/dbos/transact/migrations/CockroachMigrationTest.java new file mode 100644 index 000000000..9a56e04cd --- /dev/null +++ b/transact/src/test/java/dev/dbos/transact/migrations/CockroachMigrationTest.java @@ -0,0 +1,115 @@ +package dev.dbos.transact.migrations; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import dev.dbos.transact.config.DBOSConfig; +import dev.dbos.transact.database.SystemDatabase; +import dev.dbos.transact.utils.PgContainer; + +import java.sql.Connection; + +import org.junit.jupiter.api.Test; + +/** + * Tests that LISTEN/NOTIFY functions and triggers are created on PG and omitted on CRDB. Each test + * spins up its own container so these tests are always tied to the specified DB type regardless of + * how PgContainer is configured for the rest of the test suite. + */ +class CockroachMigrationTest { + + @Test + void testPg_isCockroachReturnsFalse() throws Exception { + try (var pg = PgContainer.getPG()) { + pg.start(); + try (var ds = + SystemDatabase.createDataSource(pg.getJdbcUrl(), pg.getUsername(), pg.getPassword())) { + assertFalse(SystemDatabase.isCockroach(ds)); + try (Connection conn = ds.getConnection()) { + assertFalse(SystemDatabase.isCockroach(conn)); + } + } + } + } + + @Test + void testCrdb_isCockroachReturnsTrue() throws Exception { + try (var crdb = PgContainer.getCRDB()) { + crdb.start(); + try (var ds = + SystemDatabase.createDataSource( + crdb.getJdbcUrl(), crdb.getUsername(), crdb.getPassword())) { + assertTrue(SystemDatabase.isCockroach(ds)); + try (Connection conn = ds.getConnection()) { + assertTrue(SystemDatabase.isCockroach(conn)); + } + } + } + } + + @Test + void testPg_notifyFunctionsAndTriggersPresent() throws Exception { + try (var pg = PgContainer.getPG()) { + pg.start(); + try (var ds = + SystemDatabase.createDataSource(pg.getJdbcUrl(), pg.getUsername(), pg.getPassword())) { + var config = DBOSConfig.defaults("migration-notify-test").withDataSource(ds); + MigrationManager.runMigrations(config); + + try (Connection conn = ds.getConnection()) { + var meta = conn.getMetaData(); + MigrationManagerTest.assertFunctionExists(meta, "notifications_function"); + MigrationManagerTest.assertFunctionExists(meta, "workflow_events_function"); + MigrationManagerTest.assertTriggerExists(conn, "dbos_notifications_trigger"); + MigrationManagerTest.assertTriggerExists(conn, "dbos_workflow_events_trigger"); + } + } + } + } + + @Test + void testPg_noListenNotify_notifyFunctionsAndTriggersAbsent() throws Exception { + try (var pg = PgContainer.getPG()) { + pg.start(); + try (var ds = + SystemDatabase.createDataSource(pg.getJdbcUrl(), pg.getUsername(), pg.getPassword())) { + var config = + DBOSConfig.defaults("migration-notify-test") + .withDataSource(ds) + .withUseListenNotify(false); + MigrationManager.runMigrations(config); + + try (Connection conn = ds.getConnection()) { + var meta = conn.getMetaData(); + MigrationManagerTest.assertFunctionAbsent(conn, "notifications_function"); + MigrationManagerTest.assertFunctionAbsent(conn, "workflow_events_function"); + MigrationManagerTest.assertTriggerAbsent(conn, "dbos_notifications_trigger"); + MigrationManagerTest.assertTriggerAbsent(conn, "dbos_workflow_events_trigger"); + MigrationManagerTest.assertFunctionExists(meta, "enqueue_workflow"); + MigrationManagerTest.assertFunctionExists(meta, "send_message"); + } + } + } + } + + @Test + void testCrdb_notifyFunctionsAndTriggersAbsent() throws Exception { + try (var crdb = PgContainer.getCRDB()) { + crdb.start(); + try (var ds = + SystemDatabase.createDataSource( + crdb.getJdbcUrl(), crdb.getUsername(), crdb.getPassword())) { + // useListenNotify=true intentionally — MigrationManager must override it to false for CRDB + var config = DBOSConfig.defaults("migration-notify-test").withDataSource(ds); + MigrationManager.runMigrations(config); + + try (Connection conn = ds.getConnection()) { + MigrationManagerTest.assertFunctionAbsent(conn, "notifications_function"); + MigrationManagerTest.assertFunctionAbsent(conn, "workflow_events_function"); + MigrationManagerTest.assertTriggerAbsent(conn, "dbos_notifications_trigger"); + MigrationManagerTest.assertTriggerAbsent(conn, "dbos_workflow_events_trigger"); + } + } + } + } +} diff --git a/transact/src/test/java/dev/dbos/transact/migrations/MigrationManagerTest.java b/transact/src/test/java/dev/dbos/transact/migrations/MigrationManagerTest.java index 1520ca39d..7f9bd9188 100644 --- a/transact/src/test/java/dev/dbos/transact/migrations/MigrationManagerTest.java +++ b/transact/src/test/java/dev/dbos/transact/migrations/MigrationManagerTest.java @@ -12,11 +12,13 @@ import java.sql.Connection; import java.sql.DatabaseMetaData; +import java.sql.DriverManager; import java.sql.ResultSet; import java.util.ArrayList; import com.zaxxer.hikari.HikariDataSource; import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.AutoClose; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -38,12 +40,20 @@ class MigrationManagerTest { "workflow_status" }; - // Expected functions after migrations - static final String[] EXPECTED_FUNCTIONS = { - "notifications_function", "workflow_events_function", "enqueue_workflow", "send_message" + // Expected functions after migrations (always present) + static final String[] EXPECTED_FUNCTIONS = {"enqueue_workflow", "send_message"}; + + // Expected LISTEN/NOTIFY functions after migrations (PG only, absent on CRDB) + static final String[] EXPECTED_NOTIFY_FUNCTIONS = { + "notifications_function", "workflow_events_function" + }; + + // Expected LISTEN/NOTIFY triggers after migrations (PG only, absent on CRDB) + static final String[] EXPECTED_NOTIFY_TRIGGERS = { + "dbos_notifications_trigger", "dbos_workflow_events_trigger" }; - @AutoClose final PgContainer pgContainer = new PgContainer(); + @AutoClose final PgContainer pgContainer = PgContainer.createFresh(); @AutoClose HikariDataSource dataSource; @BeforeEach @@ -69,13 +79,46 @@ void testRunMigrations_CreatesTables() throws Exception { for (String function : EXPECTED_FUNCTIONS) { assertFunctionExists(metaData, function); } + if (!PgContainer.USE_COCKROACH_DB) { + for (String function : EXPECTED_NOTIFY_FUNCTIONS) { + assertFunctionExists(metaData, function); + } + for (String trigger : EXPECTED_NOTIFY_TRIGGERS) { + assertTriggerExists(conn, trigger); + } + } - var migrations = new ArrayList<>(MigrationManager.getMigrations(Constants.DB_SCHEMA)); + var migrations = new ArrayList<>(MigrationManager.getMigrations(Constants.DB_SCHEMA, true)); var version = getVersion(conn); assertEquals(migrations.size(), version); } } + @Test + void testRunMigrations_NoNotify_OmitsTriggersAndFunctions() throws Exception { + var dbosConfig = pgContainer.dbosConfig().withUseListenNotify(false); + MigrationManager.runMigrations(dbosConfig); + + try (Connection conn = dataSource.getConnection()) { + DatabaseMetaData metaData = conn.getMetaData(); + + // All tables should still exist + for (String table : EXPECTED_TABLES) { + assertTableExists(metaData, table); + } + + // enqueue_workflow and send_message are unaffected + assertFunctionExists(metaData, "enqueue_workflow"); + assertFunctionExists(metaData, "send_message"); + + // LISTEN/NOTIFY functions and triggers must be absent + assertFunctionAbsent(conn, "notifications_function"); + assertFunctionAbsent(conn, "workflow_events_function"); + assertTriggerAbsent(conn, "dbos_notifications_trigger"); + assertTriggerAbsent(conn, "dbos_workflow_events_trigger"); + } + } + @ParameterizedTest @ValueSource(strings = {"invalid\"schema", "invalid'schema"}) void testRunMigrations_fails_invalid_schema(String invalidSchema) throws Exception { @@ -101,13 +144,45 @@ void testRunMigrations_customSchema(String schema) throws Exception { for (String function : EXPECTED_FUNCTIONS) { assertFunctionExists(metaData, function, schema); } + if (!PgContainer.USE_COCKROACH_DB) { + for (String function : EXPECTED_NOTIFY_FUNCTIONS) { + assertFunctionExists(metaData, function, schema); + } + for (String trigger : EXPECTED_NOTIFY_TRIGGERS) { + assertTriggerExists(conn, trigger, schema); + } + } - var migrations = new ArrayList<>(MigrationManager.getMigrations(schema)); + var migrations = new ArrayList<>(MigrationManager.getMigrations(schema, true)); var version = getVersion(conn, schema); assertEquals(migrations.size(), version); } } + @Test + void testRunMigrations_CreatesDatabaseIfNotExists() throws Exception { + var dbosConfig = pgContainer.dbosConfig(); + var pair = MigrationManager.extractDbAndPostgresUrl(pgContainer.jdbcUrl()); + + // Verify the database does not exist before running migrations + try (var conn = + DriverManager.getConnection(pair.url(), pgContainer.username(), pgContainer.password())) { + assertFalse( + databaseExists(conn, pair.database()), + "Database '%s' should not exist before runMigrations".formatted(pair.database())); + } + + MigrationManager.runMigrations(dbosConfig); + + // Verify the database now exists after running migrations + try (var conn = + DriverManager.getConnection(pair.url(), pgContainer.username(), pgContainer.password())) { + assertTrue( + databaseExists(conn, pair.database()), + "Database '%s' should exist after runMigrations".formatted(pair.database())); + } + } + @Test void testRunMigrations_IsIdempotent() throws Exception { @@ -126,7 +201,7 @@ void testRunMigrations_IsIdempotent() throws Exception { void testAddingNewMigration() throws Exception { testRunMigrations_CreatesTables(); - var migrations = new ArrayList<>(MigrationManager.getMigrations(Constants.DB_SCHEMA)); + var migrations = new ArrayList<>(MigrationManager.getMigrations(Constants.DB_SCHEMA, true)); migrations.add("CREATE TABLE dummy_table(id SERIAL PRIMARY KEY);"); try (var conn = dataSource.getConnection()) { @@ -168,6 +243,12 @@ public void extractDbAndPostgresUrl() { @Test void testOriginalMigration1ThenAllMigrations_NotificationsPrimaryKey() throws Exception { + Assumptions.assumeFalse(PgContainer.USE_COCKROACH_DB, "PG-only migration history test"); + + // need to create database since we are connecting + // the data source prior to dbos launch + pgContainer.createDatabase(); + try (Connection conn = dataSource.getConnection()) { // Ensure schema and migration table exist MigrationManager.ensureDbosSchema(conn, Constants.DB_SCHEMA); @@ -192,7 +273,7 @@ void testOriginalMigration1ThenAllMigrations_NotificationsPrimaryKey() throws Ex assertTableExists(metaData, "notifications"); // Now run all current migrations (including migration10 which ensures primary key) - var allMigrations = MigrationManager.getMigrations(Constants.DB_SCHEMA); + var allMigrations = MigrationManager.getMigrations(Constants.DB_SCHEMA, true); MigrationManager.runDbosMigrations(conn, Constants.DB_SCHEMA, allMigrations); // Verify that the notifications table has a primary key @@ -246,6 +327,63 @@ static int getVersion(Connection conn, String schema) throws Exception { } } + static void assertFunctionAbsent(Connection conn, String functionName) throws Exception { + String sql = + "SELECT 1 FROM pg_proc p JOIN pg_namespace n ON p.pronamespace = n.oid" + + " WHERE n.nspname = ? AND p.proname = ?"; + try (var ps = conn.prepareStatement(sql)) { + ps.setString(1, Constants.DB_SCHEMA); + ps.setString(2, functionName); + try (var rs = ps.executeQuery()) { + assertFalse(rs.next(), "Function %s should not exist".formatted(functionName)); + } + } + } + + static void assertTriggerExists(Connection conn, String triggerName) throws Exception { + assertTriggerExists(conn, triggerName, Constants.DB_SCHEMA); + } + + static void assertTriggerExists(Connection conn, String triggerName, String schema) + throws Exception { + schema = SystemDatabase.sanitizeSchema(schema); + String sql = + "SELECT 1 FROM pg_trigger t JOIN pg_class c ON t.tgrelid = c.oid" + + " JOIN pg_namespace n ON c.relnamespace = n.oid" + + " WHERE n.nspname = ? AND t.tgname = ?"; + try (var ps = conn.prepareStatement(sql)) { + ps.setString(1, schema); + ps.setString(2, triggerName); + try (var rs = ps.executeQuery()) { + assertTrue( + rs.next(), "Trigger %s should exist in schema %s".formatted(triggerName, schema)); + } + } + } + + static boolean databaseExists(Connection conn, String dbName) throws Exception { + try (ResultSet rs = conn.getMetaData().getCatalogs()) { + while (rs.next()) { + if (dbName.equals(rs.getString("TABLE_CAT"))) return true; + } + return false; + } + } + + static void assertTriggerAbsent(Connection conn, String triggerName) throws Exception { + String sql = + "SELECT 1 FROM pg_trigger t JOIN pg_class c ON t.tgrelid = c.oid" + + " JOIN pg_namespace n ON c.relnamespace = n.oid" + + " WHERE n.nspname = ? AND t.tgname = ?"; + try (var ps = conn.prepareStatement(sql)) { + ps.setString(1, Constants.DB_SCHEMA); + ps.setString(2, triggerName); + try (var rs = ps.executeQuery()) { + assertFalse(rs.next(), "Trigger %s should not exist".formatted(triggerName)); + } + } + } + static void assertNotificationTableHasPrimaryKey( DatabaseMetaData metaData, String tableName, String schemaName) throws Exception { try (ResultSet rs = metaData.getPrimaryKeys(null, schemaName, tableName)) { diff --git a/transact/src/test/java/dev/dbos/transact/notifications/EventsTest.java b/transact/src/test/java/dev/dbos/transact/notifications/EventsTest.java index b03ec1855..b3d25d5d0 100644 --- a/transact/src/test/java/dev/dbos/transact/notifications/EventsTest.java +++ b/transact/src/test/java/dev/dbos/transact/notifications/EventsTest.java @@ -51,6 +51,8 @@ interface EventsService { void setMultipleEventsWorkflow(); Map getAllEventsWorkflow(String workflowId); + + Object recvWorkflow(String topic, Duration timeout); } class EventsServiceImpl implements EventsService { @@ -165,6 +167,12 @@ public Map getAllEventsWorkflow(String workflowId) { return dbos.getAllEvents(workflowId); } + @Workflow + @Override + public Object recvWorkflow(String topic, Duration timeout) { + return dbos.recv(topic, timeout).orElse(null); + } + public void resetLatches() { advanceGetLatch1 = new CountDownLatch(1); advanceGetLatch2 = new CountDownLatch(1); @@ -426,6 +434,40 @@ public void getAllEventsAppearsInSteps() throws Exception { } } + @Test + public void testGetEventTimeoutReplayable() throws Exception { + var wfid = UUID.randomUUID().toString(); + + // Run workflow — getEvent times out and records a step result with null output + try (var ctx = new WorkflowOptions(wfid).setContext()) { + assertNull(proxy.getEventWorkflow("nonexistent-wfid", "somekey", Duration.ofMillis(50))); + } + + // Simulate crash: keep step result, reset workflow to PENDING + DBUtils.setWorkflowState(dataSource, wfid, WorkflowState.PENDING.name()); + + // Recover — workflow body re-executes, getEvent step replays via toResult() + var handle = DBOSTestAccess.getDbosExecutor(dbos).executeWorkflowById(wfid, true, false); + assertNull(handle.getResult()); + } + + @Test + public void testRecvTimeoutReplayable() throws Exception { + var wfid = UUID.randomUUID().toString(); + + // Run workflow — recv times out and records a step result with null output + try (var ctx = new WorkflowOptions(wfid).setContext()) { + assertNull(proxy.recvWorkflow("some-topic", Duration.ofMillis(50))); + } + + // Simulate crash: keep step result, reset workflow to PENDING + DBUtils.setWorkflowState(dataSource, wfid, WorkflowState.PENDING.name()); + + // Recover — workflow body re-executes, recv step replays via toResult() + var handle = DBOSTestAccess.getDbosExecutor(dbos).executeWorkflowById(wfid, true, false); + assertNull(handle.getResult()); + } + @Test public void concurrency() throws Exception { ExecutorService executor = Executors.newFixedThreadPool(2); diff --git a/transact/src/test/java/dev/dbos/transact/txstep/JdbcStepFactoryInitTest.java b/transact/src/test/java/dev/dbos/transact/txstep/JdbcStepFactoryInitTest.java index 7458e8df2..3a9dafbdb 100644 --- a/transact/src/test/java/dev/dbos/transact/txstep/JdbcStepFactoryInitTest.java +++ b/transact/src/test/java/dev/dbos/transact/txstep/JdbcStepFactoryInitTest.java @@ -17,10 +17,16 @@ import java.util.Objects; import org.junit.jupiter.api.AutoClose; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; public class JdbcStepFactoryInitTest { - @AutoClose final PgContainer pgContainer = new PgContainer(); + @AutoClose final PgContainer pgContainer = PgContainer.createFresh(); + + @BeforeEach + void beforeEach() throws SQLException { + pgContainer.createDatabase(); + } static boolean validateSchema(Connection conn, String schema) throws SQLException { Objects.requireNonNull(schema); diff --git a/transact/src/test/java/dev/dbos/transact/txstep/JdbcStepFactoryTest.java b/transact/src/test/java/dev/dbos/transact/txstep/JdbcStepFactoryTest.java index 61db3688e..e2cba8168 100644 --- a/transact/src/test/java/dev/dbos/transact/txstep/JdbcStepFactoryTest.java +++ b/transact/src/test/java/dev/dbos/transact/txstep/JdbcStepFactoryTest.java @@ -123,11 +123,15 @@ public class JdbcStepFactoryTest { @BeforeEach void beforeEach() throws SQLException { + pgContainer.createDatabase(); + dbosConfig = pgContainer.dbosConfig(); dataSource = pgContainer.dataSource(); try (var conn = dataSource.getConnection(); var stmt = conn.createStatement()) { + stmt.execute("DROP TABLE IF EXISTS greetings"); + stmt.execute("DROP TABLE IF EXISTS dbos.tx_step_outputs"); stmt.execute( "CREATE TABLE greetings(name text NOT NULL, greet_count integer DEFAULT 0, PRIMARY KEY(name))"); } diff --git a/transact/src/test/java/dev/dbos/transact/utils/CrdbParallelExecutionConfigurationStrategy.java b/transact/src/test/java/dev/dbos/transact/utils/CrdbParallelExecutionConfigurationStrategy.java new file mode 100644 index 000000000..8c45d98d1 --- /dev/null +++ b/transact/src/test/java/dev/dbos/transact/utils/CrdbParallelExecutionConfigurationStrategy.java @@ -0,0 +1,50 @@ +package dev.dbos.transact.utils; + +import org.junit.platform.engine.ConfigurationParameters; +import org.junit.platform.engine.support.hierarchical.DefaultParallelExecutionConfigurationStrategy; +import org.junit.platform.engine.support.hierarchical.ParallelExecutionConfiguration; +import org.junit.platform.engine.support.hierarchical.ParallelExecutionConfigurationStrategy; + +public class CrdbParallelExecutionConfigurationStrategy + implements ParallelExecutionConfigurationStrategy { + + @Override + public ParallelExecutionConfiguration createConfiguration( + ConfigurationParameters configurationParameters) { + if (PgContainer.USE_COCKROACH_DB) { + int parallelism = Runtime.getRuntime().availableProcessors() >= 8 ? 2 : 1; + return fixedConfig(parallelism); + } + return DefaultParallelExecutionConfigurationStrategy.DYNAMIC.createConfiguration( + configurationParameters); + } + + private static ParallelExecutionConfiguration fixedConfig(int parallelism) { + return new ParallelExecutionConfiguration() { + @Override + public int getParallelism() { + return parallelism; + } + + @Override + public int getMinimumRunnable() { + return parallelism; + } + + @Override + public int getMaxPoolSize() { + return parallelism + 256; + } + + @Override + public int getCorePoolSize() { + return parallelism; + } + + @Override + public int getKeepAliveSeconds() { + return 30; + } + }; + } +} diff --git a/transact/src/test/java/dev/dbos/transact/utils/PgContainer.java b/transact/src/test/java/dev/dbos/transact/utils/PgContainer.java index 2c1431dd5..6c0297919 100644 --- a/transact/src/test/java/dev/dbos/transact/utils/PgContainer.java +++ b/transact/src/test/java/dev/dbos/transact/utils/PgContainer.java @@ -3,85 +3,112 @@ import dev.dbos.transact.DBOSClient; import dev.dbos.transact.config.DBOSConfig; import dev.dbos.transact.database.SystemDatabase; +import dev.dbos.transact.migrations.MigrationManager; +import java.sql.Connection; import java.sql.DriverManager; import java.sql.SQLException; -import java.util.ArrayList; import java.util.Objects; -import java.util.UUID; -import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.Semaphore; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; import com.zaxxer.hikari.HikariDataSource; +import org.testcontainers.cockroachdb.CockroachContainer; +import org.testcontainers.containers.JdbcDatabaseContainer; import org.testcontainers.postgresql.PostgreSQLContainer; public class PgContainer implements AutoCloseable { - private static final int SIZE = Runtime.getRuntime().availableProcessors(); - private static final BlockingQueue POOL = new ArrayBlockingQueue<>(SIZE); - private static final Semaphore PERMITS = new Semaphore(SIZE); - - static { - Runtime.getRuntime() - .addShutdownHook( - new Thread( - () -> { - var containers = new ArrayList(); - POOL.drainTo(containers); - containers.forEach(PostgreSQLContainer::stop); - })); - } - - static PostgreSQLContainer acquire() { - try { - PERMITS.acquire(); - var container = POOL.poll(); - if (container == null) { - container = new PostgreSQLContainer("postgres:18"); - container.start(); + public static final boolean USE_COCKROACH_DB = + Boolean.parseBoolean(System.getenv("DBOS_TEST_USE_COCKROACH_DB")); + private static final String DB_NAME = "dbos_test_db"; + + private static final Queue> POOL = new ConcurrentLinkedQueue<>(); + + public static PostgreSQLContainer getPG() { + return new PostgreSQLContainer("postgres:18"); + } + + public static CockroachContainer getCRDB() { + return new CockroachContainer("cockroachdb/cockroach:latest-v26.2"); + } + + private static JdbcDatabaseContainer containerSupplier() { + var container = USE_COCKROACH_DB ? getCRDB() : getPG(); + container.start(); + return container; + } + + static JdbcDatabaseContainer acquire() { + var container = POOL.poll(); + if (container != null) { + var jdbcUrl = container.getJdbcUrl().replaceFirst("/[^/]+$", "/" + DB_NAME); + try (var conn = + DriverManager.getConnection(jdbcUrl, container.getUsername(), container.getPassword())) { + truncateDbosTables(conn); + } catch (SQLException e) { + throw new RuntimeException(e); } return container; - } catch (InterruptedException e) { - throw new RuntimeException(e); } + container = containerSupplier(); + var jdbcUrl = container.getJdbcUrl().replaceFirst("/[^/]+$", "/" + DB_NAME); + + MigrationManager.runMigrations( + jdbcUrl, container.getUsername(), container.getPassword(), "dbos", true); + return container; } - static void release(PostgreSQLContainer c) { + static void release(JdbcDatabaseContainer c) { POOL.offer(c); - PERMITS.release(); } - private final PostgreSQLContainer pgContainer; + public static void truncateDbosTables(Connection conn) throws SQLException { + // truncate the DBOS tables from the test DB before returning to the pool + var truncate = + """ + TRUNCATE TABLE + "dbos".workflow_status, + "dbos".operation_outputs, + "dbos".workflow_events, + "dbos".workflow_events_history, + "dbos".notifications, + "dbos".event_dispatch_kv, + "dbos".streams, + "dbos".application_versions, + "dbos".workflow_schedules + CASCADE + """; + try (var stmt = conn.createStatement()) { + stmt.execute(truncate); + } + } + + private final JdbcDatabaseContainer pgContainer; private final String jdbcUrl; - private final String dbName; + private final boolean pooled; public PgContainer() { - // take a container from the pool and create a new database for it - pgContainer = acquire(); - dbName = "test_" + UUID.randomUUID().toString().replace("-", ""); - jdbcUrl = pgContainer.getJdbcUrl().replaceFirst("/[^/]+$", "/" + dbName); - - try (var conn = - DriverManager.getConnection( - pgContainer.getJdbcUrl(), pgContainer.getUsername(), pgContainer.getPassword()); - var stmt = conn.createStatement()) { - stmt.execute("CREATE DATABASE " + dbName); - } catch (SQLException e) { - throw new RuntimeException(e); - } + this(false); + } + + private PgContainer(boolean requireFresh) { + pooled = !requireFresh; + pgContainer = pooled ? acquire() : containerSupplier(); + jdbcUrl = pgContainer.getJdbcUrl().replaceFirst("/[^/]+$", "/" + DB_NAME); + } + + public static PgContainer createFresh() { + return new PgContainer(true); } @Override public void close() throws Exception { - // drop the database we created and return the container too the pool - var _jdbcUrl = pgContainer.getJdbcUrl(); - try (var conn = DriverManager.getConnection(_jdbcUrl, username(), password()); - var stmt = conn.createStatement()) { - var sql = "DROP DATABASE IF EXISTS %s WITH (FORCE)".formatted(dbName); - stmt.execute(sql); + if (pooled) { + release(pgContainer); + } else { + pgContainer.close(); } - release(pgContainer); } public String jdbcUrl() { @@ -114,4 +141,8 @@ public HikariDataSource dataSource() { public DBOSClient dbosClient() { return new DBOSClient(jdbcUrl(), username(), password()); } + + public void createDatabase() { + MigrationManager.createDatabaseIfNotExists(jdbcUrl(), username(), password()); + } } diff --git a/transact/src/test/resources/junit-platform.properties b/transact/src/test/resources/junit-platform.properties index f67fe92f1..bbef1c697 100644 --- a/transact/src/test/resources/junit-platform.properties +++ b/transact/src/test/resources/junit-platform.properties @@ -1,6 +1,7 @@ junit.jupiter.execution.parallel.enabled = true junit.jupiter.execution.parallel.mode.default = concurrent junit.jupiter.execution.parallel.mode.classes.default = concurrent -junit.jupiter.execution.parallel.config.strategy = dynamic +junit.jupiter.execution.parallel.config.strategy = custom +junit.jupiter.execution.parallel.config.custom.class = dev.dbos.transact.utils.CrdbParallelExecutionConfigurationStrategy junit.jupiter.execution.parallel.config.dynamic.factor = 1.0 junit.jupiter.execution.timeout.default = 2 m